diff --git a/Dockerfile b/Dockerfile index 1b46ccdc4..c4c4df3e4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,6 +23,7 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ + gcc-8 g++-8 \ ca-certificates \ curl \ ffmpeg \ @@ -44,6 +45,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /usr/bin/g++ g++ /usr/bin/g++-8 RUN pip3 install --upgrade setuptools RUN pip3 install wheel RUN pip3 install future diff --git a/WORKSPACE b/WORKSPACE index 48632fe2f..e797410a7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -337,6 +337,8 @@ maven_install( "androidx.test.espresso:espresso-core:3.1.1", "com.github.bumptech.glide:glide:4.11.0", "com.google.android.material:material:aar:1.0.0-rc01", + "com.google.auto.value:auto-value:1.6.4", + "com.google.auto.value:auto-value-annotations:1.6.4", "com.google.code.findbugs:jsr305:3.0.2", "com.google.flogger:flogger-system-backend:0.3.1", "com.google.flogger:flogger:0.3.1", @@ -367,9 +369,9 @@ http_archive( ) # Tensorflow repo should always go after the other external dependencies. -# 2021-03-25 -_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887" -_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f" +# 2021-04-30 +_TENSORFLOW_GIT_COMMIT = "5bd3c57ef184543d22e34e36cff9d9bea608e06d" +_TENSORFLOW_SHA256= "9a45862834221aafacf6fb275f92b3876bc89443cbecc51be93f13839a6609f0" http_archive( name = "org_tensorflow", urls = [ diff --git a/build_desktop_examples.sh b/build_desktop_examples.sh index 5e493e79c..a35556cf0 100644 --- a/build_desktop_examples.sh +++ b/build_desktop_examples.sh @@ -17,15 +17,15 @@ # Script to build/run all MediaPipe desktop example apps (with webcam input). # # To build and run all apps and store them in out_dir: -# $ ./build_ios_examples.sh -d out_dir +# $ ./build_desktop_examples.sh -d out_dir # Omitting -d and the associated directory saves all generated apps in the # current directory. # To build all apps and store them in out_dir: -# $ ./build_ios_examples.sh -d out_dir -b +# $ ./build_desktop_examples.sh -d out_dir -b # Omitting -d and the associated directory saves all generated apps in the # current directory. # To run all apps already stored in out_dir: -# $ ./build_ios_examples.sh -d out_dir -r +# $ ./build_desktop_examples.sh -d out_dir -r # Omitting -d and the associated directory assumes all apps are in the current # directory. diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 3e1236aaa..98bf1def4 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -187,7 +187,7 @@ node { ``` In the calculator implementation, inputs and outputs are also identified by tag -name and index number. In the function below input are output are identified: +name and index number. In the function below input and output are identified: * By index number: The combined input stream is identified simply by index `0`. @@ -355,7 +355,6 @@ class PacketClonerCalculator : public CalculatorBase { current_[i].At(cc->InputTimestamp())); // Add a packet to output stream of index i a packet from inputstream i // with timestamp common to all present inputs - // } else { cc->Outputs().Index(i).SetNextTimestampBound( cc->InputTimestamp().NextAllowedInStream()); @@ -382,7 +381,7 @@ defined your calculator class, register it with a macro invocation REGISTER_CALCULATOR(calculator_class_name). Below is a trivial MediaPipe graph that has 3 input streams, 1 node -(PacketClonerCalculator) and 3 output streams. +(PacketClonerCalculator) and 2 output streams. ```proto input_stream: "room_mic_signal" diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index 2bbdd6856..d7d972be5 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -83,12 +83,12 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`. output_stream: "out3" node { - calculator: "PassThroughculator" + calculator: "PassThroughCalculator" input_stream: "out1" output_stream: "out2" } node { - calculator: "PassThroughculator" + calculator: "PassThroughCalculator" input_stream: "out2" output_stream: "out3" } diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index ee83116dd..71224a258 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -57,7 +57,7 @@ Please verify all the necessary packages are installed. * Android SDK Build-Tools 28 or 29 * Android SDK Platform-Tools 28 or 29 * Android SDK Tools 26.1.1 -* Android NDK 17c or above +* Android NDK 19c or above ### Option 1: Build with Bazel in Command Line @@ -111,7 +111,7 @@ app: * Verify that Android SDK Build-Tools 28 or 29 is installed. * Verify that Android SDK Platform-Tools 28 or 29 is installed. * Verify that Android SDK Tools 26.1.1 is installed. - * Verify that Android NDK 17c or above is installed. + * Verify that Android NDK 19c or above is installed. * Take note of the Android NDK Location, e.g., `/usr/local/home/Android/Sdk/ndk-bundle` or `/usr/local/home/Android/Sdk/ndk/20.0.5594570`. diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index 735bd7a39..2c2ca99f3 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -37,7 +37,7 @@ each project. load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar") mediapipe_aar( - name = "mp_face_detection_aar", + name = "mediapipe_face_detection", calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"], ) ``` @@ -45,26 +45,29 @@ each project. 2. Run the Bazel build command to generate the AAR. ```bash - bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ - --fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \ - //path/to/the/aar/build/file:aar_name + bazel build -c opt --strip=ALWAYS \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + //path/to/the/aar/build/file:aar_name.aar ``` - For the face detection AAR target we made in the step 1, run: + For the face detection AAR target we made in step 1, run: ```bash - bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \ - //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar + bazel build -c opt --strip=ALWAYS \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar # It should print: - # Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date: - # bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar + # Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar up-to-date: + # bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar ``` 3. (Optional) Save the AAR to your preferred location. ```bash - cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar + cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar /absolute/path/to/your/preferred/location ``` @@ -75,7 +78,7 @@ each project. 2. Copy the AAR into app/libs. ```bash - cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar + cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar /path/to/your/app/libs/ ``` @@ -92,29 +95,14 @@ each project. [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite). ```bash - bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph - cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/ + bazel build -c opt mediapipe/graphs/face_detection:face_detection_mobile_gpu_binary_graph + cp bazel-bin/mediapipe/graphs/face_detection/face_detection_mobile_gpu.binarypb /path/to/your/app/src/main/assets/ cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/ ``` ![Screenshot](../images/mobile/assets_location.png) -4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into - app/src/main/jniLibs. - - MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so - files into app/src/main/jniLibs. You can download the official OpenCV - Android SDK from - [here](https://github.com/opencv/opencv/releases/download/3.4.3/opencv-3.4.3-android-sdk.zip) - and run: - - ```bash - cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/ - ``` - - ![Screenshot](../images/mobile/android_studio_opencv_location.png) - -5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR. +4. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR. ``` dependencies { @@ -136,10 +124,14 @@ each project. implementation "androidx.camera:camera-core:$camerax_version" implementation "androidx.camera:camera-camera2:$camerax_version" implementation "androidx.camera:camera-lifecycle:$camerax_version" + // AutoValue + def auto_value_version = "1.6.4" + implementation "com.google.auto.value:auto-value-annotations:$auto_value_version" + annotationProcessor "com.google.auto.value:auto-value:$auto_value_version" } ``` -6. Follow our Android app examples to use MediaPipe in Android Studio for your +5. Follow our Android app examples to use MediaPipe in Android Studio for your use case. If you are looking for an example, a face detection example can be found [here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index c0a240ae8..95dce1d17 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -471,7 +471,7 @@ next section. 4. Install Visual C++ Build Tools 2019 and WinSDK Go to - [the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools), + [the VisualStudio website](https://visualstudio.microsoft.com/visual-cpp-build-tools), download build tools, and install Microsoft Visual C++ 2019 Redistributable and Microsoft Build Tools 2019. @@ -738,7 +738,7 @@ common build issues. root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh # Should print: - # Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r18b + # Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r19c # Set android_ndk_repository and android_sdk_repository in WORKSPACE # Done diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index 5d4bc2fb9..d59f35bbf 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -26,7 +26,7 @@ You can, for instance, activate a Python virtual environment: $ python3 -m venv mp_env && source mp_env/bin/activate ``` -Install MediaPipe Python package and start Python intepreter: +Install MediaPipe Python package and start Python interpreter: ```bash (mp_env)$ pip install mediapipe diff --git a/docs/getting_started/troubleshooting.md b/docs/getting_started/troubleshooting.md index 76b4de3c8..7cb87d524 100644 --- a/docs/getting_started/troubleshooting.md +++ b/docs/getting_started/troubleshooting.md @@ -97,6 +97,49 @@ linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666) may also help. +## Python pip install failure + +The error message: + +``` +ERROR: Could not find a version that satisfies the requirement mediapipe +ERROR: No matching distribution found for mediapipe +``` + +after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system. +Please note that MediaPipe Python PyPI officially supports the **64-bit** +version of Python 3.7 and above on the following OS: + +- x86_64 Linux +- x86_64 macOS 10.15+ +- amd64 Windows + +If the OS is currently supported and you still see this error, please make sure +that both the Python and pip binary are for Python 3.7 and above. Otherwise, +please consider building the MediaPipe Python package locally by following the +instructions [here](python.md#building-mediapipe-python-package). + +## Python DLL load failure on Windows + +The error message: + +``` +ImportError: DLL load failed: The specified module could not be found +``` + +usually indicates that the local Windows system is missing Visual C++ +redistributable packages and/or Visual C++ runtime DLLs. This can be solved by +either installing the official +[vc_redist.x64.exe](https://support.microsoft.com/en-us/topic/the-latest-supported-visual-c-downloads-2647da03-1eea-4433-9aff-95f26a218cc0) +or installing the "msvc-runtime" Python package by running + +```bash +$ python -m pip install msvc-runtime +``` + +Please note that the "msvc-runtime" Python package is not released or maintained +by Microsoft. + ## Native method not found The error message: diff --git a/docs/images/mobile/aar_location.png b/docs/images/mobile/aar_location.png index f85e8219e..3dde1fa18 100644 Binary files a/docs/images/mobile/aar_location.png and b/docs/images/mobile/aar_location.png differ diff --git a/docs/images/mobile/android_studio_opencv_location.png b/docs/images/mobile/android_studio_opencv_location.png deleted file mode 100644 index dbb26af1a..000000000 Binary files a/docs/images/mobile/android_studio_opencv_location.png and /dev/null differ diff --git a/docs/images/mobile/assets_location.png b/docs/images/mobile/assets_location.png index 9d8ab0469..d22dbfaa5 100644 Binary files a/docs/images/mobile/assets_location.png and b/docs/images/mobile/assets_location.png differ diff --git a/docs/images/mobile/pose_tracking_example.gif b/docs/images/mobile/pose_tracking_example.gif new file mode 100644 index 000000000..e88f12f11 Binary files /dev/null and b/docs/images/mobile/pose_tracking_example.gif differ diff --git a/docs/images/mobile/pose_tracking_upper_body_example.gif b/docs/images/mobile/pose_tracking_upper_body_example.gif deleted file mode 100644 index f9c0c5c6f..000000000 Binary files a/docs/images/mobile/pose_tracking_upper_body_example.gif and /dev/null differ diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 8bc4fdc13..8d5de36eb 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -77,7 +77,7 @@ Supported configuration options: ```python import cv2 import mediapipe as mp -mp_face_detction = mp.solutions.face_detection +mp_face_detection = mp.solutions.face_detection mp_drawing = mp.solutions.drawing_utils # For static images: diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 8ee0f8ff6..7c02c8d75 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -135,12 +135,11 @@ another detection until it loses track, on reducing computation and latency. If set to `true`, person detection runs every input image, ideal for processing a batch of static, possibly unrelated, images. Default to `false`. -#### upper_body_only +#### model_complexity -If set to `true`, the solution outputs only the 25 upper-body pose landmarks -(535 in total) instead of the full set of 33 pose landmarks (543 in total). Note -that upper-body-only prediction may be more accurate for use cases where the -lower-body parts are mostly out of view. Default to `false`. +Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as +well as inference latency generally go up with the model complexity. Default to +`1`. #### smooth_landmarks @@ -207,7 +206,7 @@ install MediaPipe Python package, then learn more in the companion Supported configuration options: * [static_image_mode](#static_image_mode) -* [upper_body_only](#upper_body_only) +* [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -219,7 +218,9 @@ mp_drawing = mp.solutions.drawing_utils mp_holistic = mp.solutions.holistic # For static images: -with mp_holistic.Holistic(static_image_mode=True) as holistic: +with mp_holistic.Holistic( + static_image_mode=True, + model_complexity=2) as holistic: for idx, file in enumerate(file_list): image = cv2.imread(file) image_height, image_width, _ = image.shape @@ -240,8 +241,6 @@ with mp_holistic.Holistic(static_image_mode=True) as holistic: annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) mp_drawing.draw_landmarks( annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - # Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when - # upper_body_only is set to True. mp_drawing.draw_landmarks( annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) @@ -291,7 +290,7 @@ and the following usage example. Supported configuration options: -* [upperBodyOnly](#upper_body_only) +* [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -348,7 +347,7 @@ const holistic = new Holistic({locateFile: (file) => { return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`; }}); holistic.setOptions({ - upperBodyOnly: false, + modelComplexity: 1, smoothLandmarks: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 diff --git a/docs/solutions/models.md b/docs/solutions/models.md index b0f1fad7a..e0ff4d14a 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -15,10 +15,10 @@ nav_order: 30 ### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) * Face detection model for front-facing/selfie camera: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite), [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite) * Face detection model for back-facing camera: - [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite) + [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_back.tflite) * [Model card](https://mediapipe.page.link/blazeface-mc) ### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) @@ -49,10 +49,10 @@ nav_order: 30 * Pose detection model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite) -* Full-body pose landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite) -* Upper-body pose landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite) +* Pose landmark model: + [TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite), + [TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full.tflite), + [TFLite model (heavy)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite) * [Model card](https://mediapipe.page.link/blazepose-mc) ### [Holistic](https://google.github.io/mediapipe/solutions/holistic) diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 064e2eb19..96e10c81e 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -30,8 +30,7 @@ overlay of digital content and information on top of the physical world in augmented reality. MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring -33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video -frames utilizing our +33 3D landmarks on the whole body from RGB video frames utilizing our [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) research that also powers the [ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection). @@ -40,9 +39,9 @@ environments for inference, whereas our method achieves real-time performance on most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in [python](#python-solution-api) and even on the [web](#javascript-solution-api). -![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) | -:--------------------------------------------------------------------------------------------: | -*Fig 1. Example of MediaPipe Pose for upper-body pose tracking.* | +![pose_tracking_example.gif](../images/mobile/pose_tracking_example.gif) | +:----------------------------------------------------------------------: | +*Fig 1. Example of MediaPipe Pose for pose tracking.* | ## ML Pipeline @@ -77,6 +76,23 @@ Note: To visualize a graph, copy the graph and paste it into to visualize its associated subgraphs, please see [visualizer documentation](../tools/visualizer.md). +## Pose Estimation Quality + +To evaluate the quality of our [models](./models.md#pose) against other +well-performing publicly available solutions, we use a validation dataset, +consisting of 1k images with diverse Yoga, HIIT, and Dance postures. Each image +contains only a single person located 2-4 meters from the camera. To be +consistent with other solutions, we perform evaluation only for 17 keypoints +from [COCO topology](https://cocodataset.org/#keypoints-2020). + +Method | [mAP](https://cocodataset.org/#keypoints-eval) | [PCK@0.2](https://github.com/cbsudux/Human-Pose-Estimation-101) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), MacBook Pro (15-inch, 2017) +----------------------------------------------------------------------------------------------------- | ---------------------------------------------: | --------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------: | ---------------------------------------------------------------------------: +BlazePose.Lite | 49.1 | 91.7 | 49 | 40 +BlazePose.Full | 64.5 | 95.8 | 40 | 37 +BlazePose.Heavy | 70.9 | 97.0 | 19 | 26 +[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 57.6 | 93.1 | N/A | N/A +[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 37.0 | 85.3 | N/A | N/A + ## Models ### Person/pose Detection Model (BlazePose Detector) @@ -97,11 +113,8 @@ hip midpoints. ### Pose Landmark Model (BlazePose GHUM 3D) -The landmark model in MediaPipe Pose comes in two versions: a full-body model -that predicts the location of 33 pose landmarks (see figure below), and an -upper-body version that only predicts the first 25. The latter may be more -accurate than the former in scenarios where the lower-body parts are mostly out -of view. +The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks +(see figure below). Please find more detail in the [BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), @@ -129,12 +142,11 @@ until it loses track, on reducing computation and latency. If set to `true`, person detection runs every input image, ideal for processing a batch of static, possibly unrelated, images. Default to `false`. -#### upper_body_only +#### model_complexity -If set to `true`, the solution outputs only the 25 upper-body pose landmarks. -Otherwise, it outputs the full set of 33 pose landmarks. Note that -upper-body-only prediction may be more accurate for use cases where the -lower-body parts are mostly out of view. Default to `false`. +Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as +well as inference latency generally go up with the model complexity. Default to +`1`. #### smooth_landmarks @@ -170,9 +182,6 @@ A list of pose landmarks. Each lanmark consists of the following: being the origin, and the smaller the value the closer the landmark is to the camera. The magnitude of `z` uses roughly the same scale as `x`. - Note: `z` is predicted only in full-body mode, and should be discarded when - [upper_body_only](#upper_body_only) is `true`. - * `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the landmark being visible (present and not occluded) in the image. @@ -185,7 +194,7 @@ install MediaPipe Python package, then learn more in the companion Supported configuration options: * [static_image_mode](#static_image_mode) -* [upper_body_only](#upper_body_only) +* [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -198,7 +207,9 @@ mp_pose = mp.solutions.pose # For static images: with mp_pose.Pose( - static_image_mode=True, min_detection_confidence=0.5) as pose: + static_image_mode=True, + model_complexity=2, + min_detection_confidence=0.5) as pose: for idx, file in enumerate(file_list): image = cv2.imread(file) image_height, image_width, _ = image.shape @@ -214,8 +225,6 @@ with mp_pose.Pose( ) # Draw pose landmarks on the image. annotated_image = image.copy() - # Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when - # upper_body_only is set to True. mp_drawing.draw_landmarks( annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) @@ -259,7 +268,7 @@ and the following usage example. Supported configuration options: -* [upperBodyOnly](#upper_body_only) +* [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -306,7 +315,7 @@ const pose = new Pose({locateFile: (file) => { return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`; }}); pose.setOptions({ - upperBodyOnly: false, + modelComplexity: 1, smoothLandmarks: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 @@ -347,16 +356,6 @@ to visualize its associated subgraphs, please see * iOS target: [`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD) -#### Upper-body Only - -* Graph: - [`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt) -* Android target: - [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1uKc6T7KSuA0Mlq2URi5YookHu0U3yoh_/view?usp=sharing) - [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu:upperbodyposetrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD) -* iOS target: - [`mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp`](http:/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD) - ### Desktop Please first see general instructions for [desktop](../getting_started/cpp.md) @@ -375,19 +374,6 @@ on how to build MediaPipe examples. * Target: [`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD) -#### Upper-body Only - -* Running on CPU - * Graph: - [`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt) - * Target: - [`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD) -* Running on GPU - * Graph: - [`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt) - * Target: - [`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD) - ## Resources * Google AI Blog: diff --git a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen index d3cd4971a..11daafdcb 100644 --- a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen +++ b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen @@ -16,7 +16,6 @@ "mediapipe/examples/ios/objectdetectiongpu/BUILD", "mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD", "mediapipe/examples/ios/posetrackinggpu/BUILD", - "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD", "mediapipe/framework/BUILD", "mediapipe/gpu/BUILD", "mediapipe/objc/BUILD", @@ -36,7 +35,6 @@ "//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp", "//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp", "//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp", - "//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp", "//mediapipe/objc:mediapipe_framework_ios" ], "optionSet" : { @@ -105,7 +103,6 @@ "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/posetrackinggpu", - "mediapipe/examples/ios/upperbodyposetrackinggpu", "mediapipe/framework", "mediapipe/framework/deps", "mediapipe/framework/formats", diff --git a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf index 7303828ad..33498e8c1 100644 --- a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf +++ b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf @@ -22,7 +22,6 @@ "mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/objectdetectiontrackinggpu", "mediapipe/examples/ios/posetrackinggpu", - "mediapipe/examples/ios/upperbodyposetrackinggpu", "mediapipe/objc" ], "projectName" : "Mediapipe", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index f319aef5b..425c349dc 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -451,8 +451,8 @@ cc_library( ) cc_library( - name = "nonzero_calculator", - srcs = ["nonzero_calculator.cc"], + name = "non_zero_calculator", + srcs = ["non_zero_calculator.cc"], visibility = [ "//visibility:public", ], @@ -464,6 +464,21 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "non_zero_calculator_test", + size = "small", + srcs = ["non_zero_calculator_test.cc"], + deps = [ + ":non_zero_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + ], +) + cc_test( name = "mux_calculator_test", srcs = ["mux_calculator_test.cc"], @@ -665,6 +680,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "default_side_packet_calculator", + srcs = ["default_side_packet_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], diff --git a/mediapipe/calculators/core/default_side_packet_calculator.cc b/mediapipe/calculators/core/default_side_packet_calculator.cc new file mode 100644 index 000000000..6485d9bff --- /dev/null +++ b/mediapipe/calculators/core/default_side_packet_calculator.cc @@ -0,0 +1,103 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +namespace { + +constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE"; +constexpr char kDefaultValueTag[] = "DEFAULT_VALUE"; +constexpr char kValueTag[] = "VALUE"; + +} // namespace + +// Outputs side packet default value if optional value is not provided. +// +// This calculator utilizes the fact that MediaPipe automatically removes +// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE). +// And if it happens - returns default value, otherwise - returns optional +// value. +// +// Input: +// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE) +// Optional side packet value that is outputted by the calculator as is if +// provided. +// +// DEFAULT_VALUE - AnyType +// Default side pack value that is outputted by the calculator if +// OPTIONAL_VALUE is not provided. +// +// Output: +// VALUE - AnyType (but same type as DEFAULT_VALUE) +// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise). +// +// Usage example: +// node { +// calculator: "DefaultSidePacketCalculator" +// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional" +// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default" +// output_side_packet: "VALUE:segmentation_mask_enabled" +// } +class DefaultSidePacketCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(DefaultSidePacketCalculator); + +absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) { + RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag)) + << "Default value must be provided"; + cc->InputSidePackets().Tag(kDefaultValueTag).SetAny(); + + // Optional input side packet can be unspecified. In this case MediaPipe will + // remove it from the calculator config. + if (cc->InputSidePackets().HasTag(kOptionalValueTag)) { + cc->InputSidePackets() + .Tag(kOptionalValueTag) + .SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag)); + } + + RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag)); + cc->OutputSidePackets().Tag(kValueTag).SetSameAs( + &cc->InputSidePackets().Tag(kDefaultValueTag)); + + return absl::OkStatus(); +} + +absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) { + // If optional value is provided it is returned as the calculator output. + if (cc->InputSidePackets().HasTag(kOptionalValueTag)) { + auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag); + cc->OutputSidePackets().Tag(kValueTag).Set(packet); + return absl::OkStatus(); + } + + // If no optional value + auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag); + cc->OutputSidePackets().Tag(kValueTag).Set(packet); + + return absl::OkStatus(); +} + +absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/nonzero_calculator.cc b/mediapipe/calculators/core/non_zero_calculator.cc similarity index 64% rename from mediapipe/calculators/core/nonzero_calculator.cc rename to mediapipe/calculators/core/non_zero_calculator.cc index 9a5928231..6555fbf2f 100644 --- a/mediapipe/calculators/core/nonzero_calculator.cc +++ b/mediapipe/calculators/core/non_zero_calculator.cc @@ -23,14 +23,26 @@ namespace api2 { class NonZeroCalculator : public Node { public: static constexpr Input::SideFallback kIn{"INPUT"}; - static constexpr Output kOut{"OUTPUT"}; + static constexpr Output::Optional kOut{"OUTPUT"}; + static constexpr Output::Optional kBooleanOut{"OUTPUT_BOOL"}; - MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut); + + absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected()) + << "At least one output stream is expected."; + return absl::OkStatus(); + } absl::Status Process(CalculatorContext* cc) final { if (!kIn(cc).IsEmpty()) { - auto output = std::make_unique((*kIn(cc) != 0) ? 1 : 0); - kOut(cc).Send(std::move(output)); + bool isNonZero = *kIn(cc) != 0; + if (kOut(cc).IsConnected()) { + kOut(cc).Send(std::make_unique(isNonZero ? 1 : 0)); + } + if (kBooleanOut(cc).IsConnected()) { + kBooleanOut(cc).Send(std::make_unique(isNonZero)); + } } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/non_zero_calculator_test.cc b/mediapipe/calculators/core/non_zero_calculator_test.cc new file mode 100644 index 000000000..72f5fa1d7 --- /dev/null +++ b/mediapipe/calculators/core/non_zero_calculator_test.cc @@ -0,0 +1,93 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +class NonZeroCalculatorTest : public ::testing::Test { + protected: + NonZeroCalculatorTest() + : runner_( + R"pb( + calculator: "NonZeroCalculator" + input_stream: "INPUT:input" + output_stream: "OUTPUT:output" + output_stream: "OUTPUT_BOOL:output_bool" + )pb") {} + + void SetInput(const std::vector& inputs) { + int timestamp = 0; + for (const auto input : inputs) { + runner_.MutableInputs() + ->Get("INPUT", 0) + .packets.push_back(MakePacket(input).At(Timestamp(timestamp++))); + } + } + + std::vector GetOutput() { + std::vector result; + for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) { + result.push_back(output.Get()); + } + return result; + } + + std::vector GetOutputBool() { + std::vector result; + for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) { + result.push_back(output.Get()); + } + return result; + } + + CalculatorRunner runner_; +}; + +TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) { + SetInput({0}); + + MP_ASSERT_OK(runner_.Run()); + + EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0)); + EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false)); +} + +TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) { + SetInput({1, 2, 3, -4, 5}); + + MP_ASSERT_OK(runner_.Run()); + + EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1)); + EXPECT_THAT(GetOutputBool(), + ::testing::ElementsAre(true, true, true, true, true)); +} + +TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) { + SetInput({1, 0, 3, 0, 5}); + MP_ASSERT_OK(runner_.Run()); + EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1)); + EXPECT_THAT(GetOutputBool(), + ::testing::ElementsAre(true, false, true, false, true)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index e4b0b7218..07f7d5f46 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -285,7 +285,7 @@ absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { // Run cropping shader on GPU. { - gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(dst_tex); glActiveTexture(GL_TEXTURE1); glBindTexture(src_tex.target(), src_tex.name()); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index bb98f14e0..60873ae9f 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -546,7 +546,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height, input.format()); - gpu_helper_.BindFramebuffer(dst); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(dst); glActiveTexture(GL_TEXTURE1); glBindTexture(src1.target(), src1.name()); diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index 6a12025f6..03d0c3c7a 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -209,6 +209,9 @@ absl::Status RecolorCalculator::Close(CalculatorContext* cc) { absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) { + cc->Outputs() + .Tag(kImageFrameTag) + .AddPacket(cc->Inputs().Tag(kImageFrameTag).Value()); return absl::OkStatus(); } // Get inputs and setup output. @@ -270,6 +273,9 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) { + cc->Outputs() + .Tag(kGpuBufferTag) + .AddPacket(cc->Inputs().Tag(kGpuBufferTag).Value()); return absl::OkStatus(); } #if !MEDIAPIPE_DISABLE_GPU @@ -287,7 +293,7 @@ absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { // Run recolor shader on GPU. { - gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(dst_tex); glActiveTexture(GL_TEXTURE1); glBindTexture(img_tex.target(), img_tex.name()); diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 08c150d21..87a661be6 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -323,7 +323,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTagGpu).Get(); auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask); - gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(output_texture); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, input_texture.name()); glActiveTexture(GL_TEXTURE2); @@ -335,7 +335,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { glBindTexture(GL_TEXTURE_2D, 0); alpha_texture.Release(); } else { - gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(output_texture); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, input_texture.name()); GlRender(cc); // use value from options diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index e27347a7e..59e4646ea 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -490,6 +490,7 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [":image_to_tensor_calculator_gpu_deps"], @@ -526,6 +527,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:gpu_origin_proto", ], ) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 91eba2de5..f681ab661 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -31,6 +31,7 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gpu_origin.pb.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" @@ -236,7 +237,7 @@ class ImageToTensorCalculator : public Node { } private: - bool DoesInputStartAtBottom() { + bool DoesGpuInputStartAtBottom() { return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; } @@ -290,11 +291,11 @@ class ImageToTensorCalculator : public Node { #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 ASSIGN_OR_RETURN(gpu_converter_, CreateImageToGlBufferTensorConverter( - cc, DoesInputStartAtBottom(), GetBorderMode())); + cc, DoesGpuInputStartAtBottom(), GetBorderMode())); #else ASSIGN_OR_RETURN(gpu_converter_, CreateImageToGlTextureTensorConverter( - cc, DoesInputStartAtBottom(), GetBorderMode())); + cc, DoesGpuInputStartAtBottom(), GetBorderMode())); #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU } diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto index 77fb1eb46..0451dc51f 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto @@ -17,20 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; - -message GpuOrigin { - enum Mode { - DEFAULT = 0; - - // OpenGL: bottom-left origin - // Metal : top-left origin - CONVENTIONAL = 1; - - // OpenGL: top-left origin - // Metal : top-left origin - TOP_LEFT = 2; - } -} +import "mediapipe/gpu/gpu_origin.proto"; message ImageToTensorCalculatorOptions { extend mediapipe.CalculatorOptions { diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index d2cabfcac..d7c0e6138 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -317,7 +317,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); - options.compile_options.precision_loss_allowed = 1; + options.compile_options.precision_loss_allowed = + allow_precision_loss_ ? 1 : 0; options.compile_options.preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST; options.compile_options.dynamic_batch_enabled = 0; diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a81d0d460..d86a45c07 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -97,6 +97,7 @@ class InferenceCalculatorMetalImpl Packet model_packet_; std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; + bool allow_precision_loss_ = false; #if MEDIAPIPE_TFLITE_METAL_INFERENCE MPPMetalHelper* gpu_helper_ = nullptr; @@ -122,6 +123,9 @@ absl::Status InferenceCalculatorMetalImpl::UpdateContract( } absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); + MP_RETURN_IF_ERROR(LoadModel(cc)); gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; @@ -222,7 +226,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { // Configure and create the delegate. TFLGpuDelegateOptions options; - options.allow_precision_loss = true; + options.allow_precision_loss = allow_precision_loss_; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); @@ -239,7 +243,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { tensor->dims->data + tensor->dims->size}; dims.back() = RoundUp(dims.back(), 4); gpu_buffers_in_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + allow_precision_loss_ ? Tensor::ElementType::kFloat16 + : Tensor::ElementType::kFloat32, + Tensor::Shape{dims})); auto buffer_view = gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( @@ -261,7 +267,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { output_shapes_[i] = {dims}; dims.back() = RoundUp(dims.back(), 4); gpu_buffers_out_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + allow_precision_loss_ ? Tensor::ElementType::kFloat16 + : Tensor::ElementType::kFloat32, + Tensor::Shape{dims})); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), output_indices[i], gpu_buffers_out_[i] @@ -271,17 +279,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { } // Create converter for GPU input. - converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device - isFloat16:true - convertToPBHWC4:true]; + converter_to_BPHWC4_ = + [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:allow_precision_loss_ + convertToPBHWC4:true]; if (converter_to_BPHWC4_ == nil) { return mediapipe::InternalError( "Error initializating input buffer converter"); } // Create converter for GPU output. - converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device - isFloat16:true - convertToPBHWC4:false]; + converter_from_BPHWC4_ = + [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:allow_precision_loss_ + convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { return absl::InternalError("Error initializating output buffer converter"); } diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index c3b91de71..87216f4d2 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -89,7 +89,8 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(options_.label_map_path())); std::string label_map_string; - MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + MP_RETURN_IF_ERROR( + mediapipe::GetResourceContents(string_path, &label_map_string)); std::istringstream stream(label_map_string); std::string line; @@ -98,6 +99,14 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { label_map_[i++] = line; } label_map_loaded_ = true; + } else if (options_.has_label_map()) { + for (int i = 0; i < options_.label_map().entries_size(); ++i) { + const auto& entry = options_.label_map().entries(i); + RET_CHECK(!label_map_.contains(entry.id())) + << "Duplicate id found: " << entry.id(); + label_map_[entry.id()] = entry.label(); + } + label_map_loaded_ = true; } return absl::OkStatus(); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index 51f7f3f90..3934a6101 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -25,6 +25,14 @@ message TensorsToClassificationCalculatorOptions { optional TensorsToClassificationCalculatorOptions ext = 335742638; } + message LabelMap { + message Entry { + optional int32 id = 1; + optional string label = 2; + } + repeated Entry entries = 1; + } + // Score threshold for perserving the class. optional float min_score_threshold = 1; // Number of highest scoring labels to output. If top_k is not positive then @@ -32,6 +40,10 @@ message TensorsToClassificationCalculatorOptions { optional int32 top_k = 2; // Path to a label map file for getting the actual name of class ids. optional string label_map_path = 3; + // Label map. (Can be used instead of label_map_path.) + // NOTE: "label_map_path", if specified, takes precedence over "label_map". + optional LabelMap label_map = 5; + // Whether the input is a single float for binary classification. // When true, only a single float is expected in the input tensor and the // label map, if provided, is expected to have exactly two labels. diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc index 03dd98ab3..92b20629d 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -115,6 +115,41 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMapPath) { } } +TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMap) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + label_map { + entries { id: 0, label: "ClassA" } + entries { id: 1, label: "ClassB" } + entries { id: 2, label: "ClassC" } + } + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + // Verify that the label field is set. + for (int i = 0; i < classification_list.classification_size(); ++i) { + EXPECT_EQ(i, classification_list.classification(i).index()); + EXPECT_EQ(i * 0.5, classification_list.classification(i).score()); + ASSERT_TRUE(classification_list.classification(i).has_label()); + } +} + TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMinScoreThreshold) { mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc index d72c75923..622e76850 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -34,15 +34,28 @@ constexpr char kTensor[] = "TENSOR"; } // namespace // Input: -// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The -// shape can be HxWx{3,1} or simply HxW. +// Tensor of type DT_FLOAT or DT_UINT8, with values between 0-255 +// (SRGB or GRAY8). The shape can be HxWx{3,1} or simply HxW. // -// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255. +// For DT_FLOAT tensors, optionally supports a scale factor that can scale 0-1 +// value ranges to 0-255. // // Output: // ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8) // // Possible extensions: support other input ranges, maybe 4D tensors. +// +// Example: +// node { +// calculator: "TensorToImageFrameCalculator" +// input_stream: "TENSOR:3d_float_tensor" +// output_stream: "IMAGE:image_frame" +// options { +// [mediapipe.TensorToImageFrameCalculatorOptions.ext] { +// scale_factor: 1.0 # set to 255.0 for [0,1] -> [0,255] scaling +// } +// } +// } class TensorToImageFrameCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc); @@ -57,8 +70,8 @@ class TensorToImageFrameCalculator : public CalculatorBase { REGISTER_CALCULATOR(TensorToImageFrameCalculator); absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) - << "Only one input stream is supported."; + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "One input stream must be provided."; RET_CHECK(cc->Inputs().HasTag(kTensor)) @@ -91,29 +104,44 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1."; } } - const int32 total_size = - input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth; - std::unique_ptr buffer(new uint8[total_size]); - auto data = input_tensor.flat().data(); - for (int i = 0; i < total_size; ++i) { - float d = scale_factor_ * data[i]; - if (d < 0) d = 0; - if (d > 255) d = 255; - buffer[i] = d; + int32 height = input_tensor.dim_size(0); + int32 width = input_tensor.dim_size(1); + auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8); + const int32 total_size = height * width * depth; + + ::std::unique_ptr output; + if (input_tensor.dtype() == tensorflow::DT_FLOAT) { + // Allocate buffer with alignments. + std::unique_ptr buffer( + new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]); + auto data = input_tensor.flat().data(); + for (int i = 0; i < total_size; ++i) { + float d = scale_factor_ * data[i]; + if (d < 0) d = 0; + if (d > 255) d = 255; + buffer[i] = d; + } + output = ::absl::make_unique(format, width, height, + width * depth, buffer.release()); + } else if (input_tensor.dtype() == tensorflow::DT_UINT8) { + if (scale_factor_ != 1.0) { + return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor"); + } + // tf::Tensor has internally ref-counted buffer. The following code make the + // ImageFrame own the copied Tensor through the deleter, which increases + // the refcount of the buffer and allow us to use the shared buffer as the + // image. This allows us to create an ImageFrame object without copying + // buffer. const ImageFrame prevents the buffer from being modified later. + auto copy = new tf::Tensor(input_tensor); + output = ::absl::make_unique( + format, width, height, width * depth, copy->flat().data(), + [copy](uint8*) { delete copy; }); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Expected float or uint8 tensor, received ", + DataTypeString(input_tensor.dtype()))); } - ::std::unique_ptr output; - if (depth == 3) { - output = ::absl::make_unique( - ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0), - input_tensor.dim_size(1) * 3, buffer.release()); - } else if (depth == 1) { - output = ::absl::make_unique( - ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0), - input_tensor.dim_size(1), buffer.release()); - } else { - return absl::InvalidArgumentError("Unrecognized image depth."); - } cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp()); return absl::OkStatus(); diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc index 54e989a20..88e268907 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc @@ -29,6 +29,7 @@ constexpr char kImage[] = "IMAGE"; } // namespace +template class TensorToImageFrameCalculatorTest : public ::testing::Test { protected: void SetUpRunner() { @@ -42,14 +43,20 @@ class TensorToImageFrameCalculatorTest : public ::testing::Test { std::unique_ptr runner_; }; -TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { - SetUpRunner(); +using TensorToImageFrameCalculatorTestTypes = ::testing::Types; +TYPED_TEST_CASE(TensorToImageFrameCalculatorTest, + TensorToImageFrameCalculatorTestTypes); + +TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { + // TYPED_TEST requires explicit "this->" + this->SetUpRunner(); + auto& runner = this->runner_; constexpr int kWidth = 16; constexpr int kHeight = 8; - const tf::TensorShape tensor_shape( - std::vector{kHeight, kWidth, 3}); - auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); - auto tensor_vec = tensor->flat().data(); + const tf::TensorShape tensor_shape{kHeight, kWidth, 3}; + auto tensor = absl::make_unique( + tf::DataTypeToEnum::v(), tensor_shape); + auto tensor_vec = tensor->template flat().data(); // Writing sequence of integers as floats which we want back (as they were // written). @@ -58,15 +65,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { } const int64 time = 1234; - runner_->MutableInputs()->Tag(kTensor).packets.push_back( + runner->MutableInputs()->Tag(kTensor).packets.push_back( Adopt(tensor.release()).At(Timestamp(time))); - EXPECT_TRUE(runner_->Run().ok()); + EXPECT_TRUE(runner->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag(kImage).packets; + runner->Outputs().Tag(kImage).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(ImageFormat::SRGB, output_image.Format()); EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kHeight, output_image.Height()); @@ -76,14 +84,15 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { } } -TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { - SetUpRunner(); +TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { + this->SetUpRunner(); + auto& runner = this->runner_; constexpr int kWidth = 16; constexpr int kHeight = 8; - const tf::TensorShape tensor_shape( - std::vector{kHeight, kWidth, 1}); - auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); - auto tensor_vec = tensor->flat().data(); + const tf::TensorShape tensor_shape{kHeight, kWidth, 1}; + auto tensor = absl::make_unique( + tf::DataTypeToEnum::v(), tensor_shape); + auto tensor_vec = tensor->template flat().data(); // Writing sequence of integers as floats which we want back (as they were // written). @@ -92,15 +101,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { } const int64 time = 1234; - runner_->MutableInputs()->Tag(kTensor).packets.push_back( + runner->MutableInputs()->Tag(kTensor).packets.push_back( Adopt(tensor.release()).At(Timestamp(time))); - EXPECT_TRUE(runner_->Run().ok()); + EXPECT_TRUE(runner->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag(kImage).packets; + runner->Outputs().Tag(kImage).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(ImageFormat::GRAY8, output_image.Format()); EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kHeight, output_image.Height()); @@ -110,13 +120,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { } } -TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) { - SetUpRunner(); +TYPED_TEST(TensorToImageFrameCalculatorTest, + Converts3DTensorToImageFrame2DGray) { + this->SetUpRunner(); + auto& runner = this->runner_; constexpr int kWidth = 16; constexpr int kHeight = 8; - const tf::TensorShape tensor_shape(std::vector{kHeight, kWidth}); - auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); - auto tensor_vec = tensor->flat().data(); + const tf::TensorShape tensor_shape{kHeight, kWidth}; + auto tensor = absl::make_unique( + tf::DataTypeToEnum::v(), tensor_shape); + auto tensor_vec = tensor->template flat().data(); // Writing sequence of integers as floats which we want back (as they were // written). @@ -125,15 +138,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) { } const int64 time = 1234; - runner_->MutableInputs()->Tag(kTensor).packets.push_back( + runner->MutableInputs()->Tag(kTensor).packets.push_back( Adopt(tensor.release()).At(Timestamp(time))); - EXPECT_TRUE(runner_->Run().ok()); + EXPECT_TRUE(runner->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag(kImage).packets; + runner->Outputs().Tag(kImage).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(ImageFormat::GRAY8, output_image.Format()); EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kHeight, output_image.Height()); diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index d52de7404..85955c43b 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -91,8 +91,6 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, // the input data when it arrives in Process(). In particular, if the header // states that we produce a 1xD column vector, the input tensor must also be 1xD // -// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and -// http://g/daredevil-project/VB26tcseUy8. // Example Config // node: { // calculator: "TensorToMatrixCalculator" @@ -158,22 +156,17 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { if (header_status.ok()) { if (cc->Options() .has_time_series_header_overrides()) { - // From design discussions with Daredevil, we only want to support single - // sample per packet for now, so we hardcode the sample_rate based on the - // packet_rate of the REFERENCE and fail noisily if we cannot. An - // alternative would be to calculate the sample_rate from the reference - // sample_rate and the change in num_samples between the reference and - // override headers: - // sample_rate_output = sample_rate_reference / - // (num_samples_override / num_samples_reference) + // This only supports a single sample per packet for now, so we hardcode + // the sample_rate based on the packet_rate of the REFERENCE and fail + // if we cannot. const TimeSeriesHeader& override_header = cc->Options() .time_series_header_overrides(); input_header->MergeFrom(override_header); - CHECK(input_header->has_packet_rate()) + RET_CHECK(input_header->has_packet_rate()) << "The TimeSeriesHeader.packet_rate must be set."; if (!override_header.has_sample_rate()) { - CHECK_EQ(input_header->num_samples(), 1) + RET_CHECK_EQ(input_header->num_samples(), 1) << "Currently the time series can only output single samples."; input_header->set_sample_rate(input_header->packet_rate()); } @@ -186,20 +179,16 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { } absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { - // Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK - // failures. These are absolute conditions of the graph for the graph to be - // valid, and if it is violated by any input anywhere, the graph will be - // invalid for all inputs. A hard CHECK will enable faster debugging by - // immediately exiting and more prominently displaying error messages. - // Do not replace with RET_CHECKs. - // Verify that each reference stream packet corresponds to a tensor packet // otherwise the header information is invalid. If we don't have a reference // stream, Process() is only called when we have an input tensor and this is // always True. - CHECK(cc->Inputs().HasTag(kTensor)) + RET_CHECK(cc->Inputs().HasTag(kTensor)) << "Tensor stream not available at same timestamp as the reference " "stream."; + RET_CHECK(!cc->Inputs().Tag(kTensor).IsEmpty()) << "Tensor stream is empty."; + RET_CHECK_OK(cc->Inputs().Tag(kTensor).Value().ValidateAsType()) + << "Tensor stream packet does not contain a Tensor."; const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) @@ -207,13 +196,12 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { const int32 length = input_tensor.dim_size(input_tensor.dims() - 1); const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0); if (header_.has_num_channels()) { - CHECK_EQ(length, header_.num_channels()) + RET_CHECK_EQ(length, header_.num_channels()) << "The number of channels at runtime does not match the header."; } if (header_.has_num_samples()) { - CHECK_EQ(width, header_.num_samples()) + RET_CHECK_EQ(width, header_.num_samples()) << "The number of samples at runtime does not match the header."; - ; } auto output = absl::make_unique(width, length); *output = diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index d78a53053..625612c17 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -98,388 +98,543 @@ class InferenceState { // This calculator performs inference on a trained TensorFlow model. // -// A mediapipe::TensorFlowSession with a model loaded and ready for use. -// For this calculator it must include a tag_to_tensor_map. -cc->InputSidePackets().Tag("SESSION").Set(); -if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { - cc->InputSidePackets() - .Tag("RECURRENT_INIT_TENSORS") - .Set>>(); -} -return absl::OkStatus(); -} +// TensorFlow Sessions can be created from checkpoint paths, frozen models, or +// the SavedModel system. See the TensorFlowSessionFrom* packet generators for +// details. Each of these methods defines a mapping between MediaPipe streams +// and TensorFlow tensors. All of this information is passed in as an +// input_side_packet. +// +// The input and output streams are TensorFlow tensors labeled by tags. The tags +// for the streams are matched to feeds and fetchs in a TensorFlow session using +// a named_signature.generic_signature in the ModelManifest. The +// generic_signature is used as key-value pairs between the MediaPipe tag and +// the TensorFlow tensor. The signature_name in the options proto determines +// which named_signature is used. The keys in the generic_signature must be +// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of +// the tensors corresponding to tags in the signature for input_streams are fed +// to the model and for output_streams the tensors are fetched from the model. +// +// Other calculators are used to convert data to and from tensors, this op only +// handles the TensorFlow session and batching. Batching occurs by concatenating +// input tensors along the 0th dimension across timestamps. If the 0th dimension +// is not a batch dimension, this calculator will add a 0th dimension by +// default. Setting add_batch_dim_to_tensors to false disables the dimension +// addition. Once batch_size inputs have been provided, the batch will be run +// and the output tensors sent out on the output streams with timestamps +// corresponding to the input stream packets. Setting the batch_size to 1 +// completely disables batching, but is indepdent of add_batch_dim_to_tensors. +// +// The TensorFlowInferenceCalculator also support feeding states recurrently for +// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the +// recurrent tensors. Initializing the recurrent state can be handled by the +// GraphTensorsPacketGenerator. +// +// The calculator updates two Counters to report timing information: +// ---TotalTimeUsecs = Total time spent running inference (in usecs), +// ---TotalProcessedTimestamps = # of instances processed +// (approximately batches processed * batch_size), +// where is replaced with CalculatorGraphConfig::Node::name() if it +// exists, or with TensorFlowInferenceCalculator if the name is not set. The +// name must be set for timing information to be instance-specific in graphs +// with multiple TensorFlowInferenceCalculators. +// +// Example config: +// packet_generator { +// packet_generator: "TensorFlowSessionFromSavedModelGenerator" +// output_side_packet: "tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: { +// saved_model_path: "/path/to/saved/model" +// signature_name: "mediapipe" +// } +// } +// } +// node { +// calculator: "TensorFlowInferenceCalculator" +// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag" +// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag" +// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag" +// input_side_packet: "SESSION:tensorflow_session" +// } +// +// Where the input and output streams are treated as Packet and +// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and +// "LABELS" and their respective tensors exported to /path/to/bundle. For an +// example of how this model was exported, see +// tensorflow_inference_test_graph_generator.py +// +// It is possible to use a GraphDef proto that was not exported by exporter (i.e +// without MetaGraph with bindings). Such GraphDef could contain all of its +// parameters in-lined (for example, it can be the output of freeze_graph.py). +// To instantiate a TensorFlow model from a GraphDef file, replace the +// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator: +// +// packet_generator { +// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" +// output_side_packet: "SESSION:tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { +// graph_proto_path: "[PATH]" +// tag_to_tensor_names { +// key: "JPG_STRING" +// value: "input:0" +// } +// tag_to_tensor_names { +// key: "SOFTMAX" +// value: "softmax:0" +// } +// } +// } +// } +// +// It is also possible to use a GraphDef proto and checkpoint file that have not +// been frozen. This can be used to load graphs directly as they have been +// written from training. However, it is more brittle and you are encouraged to +// use a one of the more perminent formats described above. To instantiate a +// TensorFlow model from a GraphDef file and checkpoint, replace the +// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator: +// +// packet_generator { +// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator" +// output_side_packet: "SESSION:tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: { +// graph_proto_path: "[PATH]" +// model_options { +// checkpoint_path: "[PATH2]" +// } +// tag_to_tensor_names { +// key: "JPG_STRING" +// value: "input:0" +// } +// tag_to_tensor_names { +// key: "SOFTMAX" +// value: "softmax:0" +// } +// } +// } +// } +class TensorFlowInferenceCalculator : public CalculatorBase { + public: + // Counters for recording timing information. The actual names have the value + // of CalculatorGraphConfig::Node::name() prepended. + static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs"; + static constexpr char kTotalProcessedTimestampsCounterSuffix[] = + "TotalProcessedTimestamps"; + static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] = + "TotalSessionRunsTimeUsecs"; + static constexpr char kTotalNumSessionRunsCounterSuffix[] = + "TotalNumSessionRuns"; -std::unique_ptr CreateInferenceState(CalculatorContext* cc) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - std::unique_ptr inference_state = - absl::make_unique(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && - !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { - std::map* init_tensor_map; - init_tensor_map = GetFromUniquePtr>( - cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); - for (const auto& p : *init_tensor_map) { - inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + TensorFlowInferenceCalculator() : session_(nullptr) { + clock_ = std::unique_ptr( + mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); + } + + static absl::Status GetContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + RET_CHECK(!cc->Inputs().GetTags().empty()); + for (const std::string& tag : cc->Inputs().GetTags()) { + // The tensorflow::Tensor with the tag equal to the graph node. May + // have a TimeSeriesHeader if all present TimeSeriesHeaders match. + if (!options.batched_input()) { + cc->Inputs().Tag(tag).Set(); + } else { + cc->Inputs().Tag(tag).Set>(); + } } - } - return inference_state; -} - -absl::Status Open(CalculatorContext* cc) override { - options_ = cc->Options(); - - RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); - session_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .session.get(); - tag_to_tensor_map_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .tag_to_tensor_map; - - // Validate and store the recurrent tags - RET_CHECK(options_.has_batch_size()); - RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty()) - << "To use recurrent_tag_pairs, batch_size must be 1."; - for (const auto& tag_pair : options_.recurrent_tag_pair()) { - const std::vector tags = absl::StrSplit(tag_pair, ':'); - RET_CHECK_EQ(tags.size(), 2) - << "recurrent_tag_pair must be a colon " - "separated std::string with two components: " - << tag_pair; - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) - << "Can't find tag '" << tags[0] << "' in signature " - << options_.signature_name(); - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) - << "Can't find tag '" << tags[1] << "' in signature " - << options_.signature_name(); - recurrent_feed_tags_.insert(tags[0]); - recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; - } - - // Check that all tags are present in this signature bound to tensors. - for (const std::string& tag : cc->Inputs().GetTags()) { - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - for (const std::string& tag : cc->Outputs().GetTags()) { - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - - { - absl::WriterMutexLock l(&mutex_); - inference_state_ = std::unique_ptr(); - } - - if (options_.batch_size() == 1 || options_.batched_input()) { - cc->SetOffset(0); - } - - return absl::OkStatus(); -} - -// Adds a batch dimension to the input tensor if specified in the calculator -// options. -absl::Status AddBatchDimension(tf::Tensor* input_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(input_tensor->shape()); - new_shape.InsertDim(0, 1); - RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) - << "Could not add 0th dimension to tensor without changing its shape." - << " Current shape: " << input_tensor->shape().DebugString(); - } - return absl::OkStatus(); -} - -absl::Status AggregateTensorPacket( - const std::string& tag_name, const Packet& packet, - std::map>* - input_tensors_by_tag_by_timestamp, - InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - tf::Tensor input_tensor(packet.Get()); - RET_CHECK_OK(AddBatchDimension(&input_tensor)); - if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { - // If we receive an input on a recurrent tag, override the state. - // It's OK to override the global state because there is just one - // input stream allowed for recurrent tensors. - inference_state_->input_tensor_batches_[tag_name].clear(); - } - (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( - std::make_pair(tag_name, input_tensor)); - return absl::OkStatus(); -} - -// Removes the batch dimension of the output tensor if specified in the -// calculator options. -absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(output_tensor->shape()); - new_shape.RemoveDim(0); - RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) - << "Could not remove 0th dimension from tensor without changing its " - << "shape. Current shape: " << output_tensor->shape().DebugString() - << " (The expected first dimension is 1 for a batch element.)"; - } - return absl::OkStatus(); -} - -absl::Status Process(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process; - { - absl::WriterMutexLock l(&mutex_); - if (inference_state_ == nullptr) { - inference_state_ = CreateInferenceState(cc); + RET_CHECK(!cc->Outputs().GetTags().empty()); + for (const std::string& tag : cc->Outputs().GetTags()) { + // The tensorflow::Tensor with tag equal to the graph node to + // output. Any TimeSeriesHeader from the inputs will be forwarded + // with channels set to 0. + cc->Outputs().Tag(tag).Set(); } - std::map> - input_tensors_by_tag_by_timestamp; - for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { - if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { - // Recurrent tensors can be empty. - if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { - if (options_.skip_on_missing_features()) { - return absl::OkStatus(); - } else { - return absl::InvalidArgumentError(absl::StrCat( - "Tag ", tag_as_node_name, - " not present at timestamp: ", cc->InputTimestamp().Value())); + // A mediapipe::TensorFlowSession with a model loaded and ready for use. + // For this calculator it must include a tag_to_tensor_map. + cc->InputSidePackets().Tag("SESSION").Set(); + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets() + .Tag("RECURRENT_INIT_TENSORS") + .Set>>(); + } + return absl::OkStatus(); + } + + std::unique_ptr CreateInferenceState(CalculatorContext* cc) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + std::unique_ptr inference_state = + absl::make_unique(); + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && + !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + std::map* init_tensor_map; + init_tensor_map = GetFromUniquePtr>( + cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + for (const auto& p : *init_tensor_map) { + inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + } + } + return inference_state; + } + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + session_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .session.get(); + tag_to_tensor_map_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .tag_to_tensor_map; + + // Validate and store the recurrent tags + RET_CHECK(options_.has_batch_size()); + RET_CHECK(options_.batch_size() == 1 || + options_.recurrent_tag_pair().empty()) + << "To use recurrent_tag_pairs, batch_size must be 1."; + for (const auto& tag_pair : options_.recurrent_tag_pair()) { + const std::vector tags = absl::StrSplit(tag_pair, ':'); + RET_CHECK_EQ(tags.size(), 2) + << "recurrent_tag_pair must be a colon " + "separated std::string with two components: " + << tag_pair; + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) + << "Can't find tag '" << tags[0] << "' in signature " + << options_.signature_name(); + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) + << "Can't find tag '" << tags[1] << "' in signature " + << options_.signature_name(); + recurrent_feed_tags_.insert(tags[0]); + recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; + } + + // Check that all tags are present in this signature bound to tensors. + for (const std::string& tag : cc->Inputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + for (const std::string& tag : cc->Outputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + + { + absl::WriterMutexLock l(&mutex_); + inference_state_ = std::unique_ptr(); + } + + if (options_.batch_size() == 1 || options_.batched_input()) { + cc->SetOffset(0); + } + + return absl::OkStatus(); + } + + // Adds a batch dimension to the input tensor if specified in the calculator + // options. + absl::Status AddBatchDimension(tf::Tensor* input_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(input_tensor->shape()); + new_shape.InsertDim(0, 1); + RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) + << "Could not add 0th dimension to tensor without changing its shape." + << " Current shape: " << input_tensor->shape().DebugString(); + } + return absl::OkStatus(); + } + + absl::Status AggregateTensorPacket( + const std::string& tag_name, const Packet& packet, + std::map>* + input_tensors_by_tag_by_timestamp, + InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + tf::Tensor input_tensor(packet.Get()); + RET_CHECK_OK(AddBatchDimension(&input_tensor)); + if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { + // If we receive an input on a recurrent tag, override the state. + // It's OK to override the global state because there is just one + // input stream allowed for recurrent tensors. + inference_state_->input_tensor_batches_[tag_name].clear(); + } + (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( + std::make_pair(tag_name, input_tensor)); + return absl::OkStatus(); + } + + // Removes the batch dimension of the output tensor if specified in the + // calculator options. + absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(output_tensor->shape()); + new_shape.RemoveDim(0); + RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) + << "Could not remove 0th dimension from tensor without changing its " + << "shape. Current shape: " << output_tensor->shape().DebugString() + << " (The expected first dimension is 1 for a batch element.)"; + } + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process; + { + absl::WriterMutexLock l(&mutex_); + if (inference_state_ == nullptr) { + inference_state_ = CreateInferenceState(cc); + } + std::map> + input_tensors_by_tag_by_timestamp; + for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { + if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { + // Recurrent tensors can be empty. + if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { + if (options_.skip_on_missing_features()) { + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Tag ", tag_as_node_name, + " not present at timestamp: ", cc->InputTimestamp().Value())); + } } + } else if (options_.batched_input()) { + const auto& tensor_packets = + cc->Inputs().Tag(tag_as_node_name).Get>(); + if (tensor_packets.size() > options_.batch_size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Batch for tag ", tag_as_node_name, + " has more packets than batch capacity. batch_size: ", + options_.batch_size(), " packets: ", tensor_packets.size())); + } + for (const auto& packet : tensor_packets) { + RET_CHECK_OK(AggregateTensorPacket( + tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp, + inference_state_.get())); + } + } else { + RET_CHECK_OK(AggregateTensorPacket( + tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), + &input_tensors_by_tag_by_timestamp, inference_state_.get())); } - } else if (options_.batched_input()) { - const auto& tensor_packets = - cc->Inputs().Tag(tag_as_node_name).Get>(); - if (tensor_packets.size() > options_.batch_size()) { - return absl::InvalidArgumentError(absl::StrCat( - "Batch for tag ", tag_as_node_name, - " has more packets than batch capacity. batch_size: ", - options_.batch_size(), " packets: ", tensor_packets.size())); + } + for (const auto& timestamp_and_input_tensors_by_tag : + input_tensors_by_tag_by_timestamp) { + inference_state_->batch_timestamps_.emplace_back( + timestamp_and_input_tensors_by_tag.first); + for (const auto& input_tensor_and_tag : + timestamp_and_input_tensors_by_tag.second) { + inference_state_->input_tensor_batches_[input_tensor_and_tag.first] + .emplace_back(input_tensor_and_tag.second); } - for (const auto& packet : tensor_packets) { - RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet, - &input_tensors_by_tag_by_timestamp, - inference_state_.get())); + } + if (inference_state_->batch_timestamps_.size() == options_.batch_size() || + options_.batched_input()) { + inference_state_to_process = std::move(inference_state_); + inference_state_ = std::unique_ptr(); + } + } + + if (inference_state_to_process) { + MP_RETURN_IF_ERROR( + OutputBatch(cc, std::move(inference_state_to_process))); + } + + return absl::OkStatus(); + } + + absl::Status Close(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process = nullptr; + { + absl::WriterMutexLock l(&mutex_); + if (cc->GraphStatus().ok() && inference_state_ != nullptr && + !inference_state_->batch_timestamps_.empty()) { + inference_state_to_process = std::move(inference_state_); + inference_state_ = std::unique_ptr(); + } + } + if (inference_state_to_process) { + MP_RETURN_IF_ERROR( + OutputBatch(cc, std::move(inference_state_to_process))); + } + return absl::OkStatus(); + } + + // When a batch of input tensors is ready to be run, runs TensorFlow and + // outputs the output tensors. The output tensors have timestamps matching + // the input tensor that formed that batch element. Any requested + // batch_dimension is added and removed. This code takes advantage of the fact + // that copying a tensor shares the same reference-counted, heap allocated + // memory buffer. Therefore, copies are cheap and should not cause the memory + // buffer to fall out of scope. In contrast, concat is only used where + // necessary. + absl::Status OutputBatch(CalculatorContext* cc, + std::unique_ptr inference_state) { + const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + std::vector> input_tensors; + + for (auto& keyed_tensors : inference_state->input_tensor_batches_) { + if (options_.batch_size() == 1) { + // Short circuit to avoid the cost of deep copying tensors in concat. + if (!keyed_tensors.second.empty()) { + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + keyed_tensors.second[0]); + } else { + // The input buffer can be empty for recurrent tensors. + RET_CHECK( + mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first)) + << "A non-recurrent tensor does not have an input: " + << keyed_tensors.first; } } else { - RET_CHECK_OK(AggregateTensorPacket( - tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), - &input_tensors_by_tag_by_timestamp, inference_state_.get())); - } - } - for (const auto& timestamp_and_input_tensors_by_tag : - input_tensors_by_tag_by_timestamp) { - inference_state_->batch_timestamps_.emplace_back( - timestamp_and_input_tensors_by_tag.first); - for (const auto& input_tensor_and_tag : - timestamp_and_input_tensors_by_tag.second) { - inference_state_->input_tensor_batches_[input_tensor_and_tag.first] - .emplace_back(input_tensor_and_tag.second); - } - } - if (inference_state_->batch_timestamps_.size() == options_.batch_size() || - options_.batched_input()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - - if (inference_state_to_process) { - MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); - } - - return absl::OkStatus(); -} - -absl::Status Close(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process = nullptr; - { - absl::WriterMutexLock l(&mutex_); - if (cc->GraphStatus().ok() && inference_state_ != nullptr && - !inference_state_->batch_timestamps_.empty()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - if (inference_state_to_process) { - MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); - } - return absl::OkStatus(); -} - -// When a batch of input tensors is ready to be run, runs TensorFlow and -// outputs the output tensors. The output tensors have timestamps matching -// the input tensor that formed that batch element. Any requested -// batch_dimension is added and removed. This code takes advantage of the fact -// that copying a tensor shares the same reference-counted, heap allocated -// memory buffer. Therefore, copies are cheap and should not cause the memory -// buffer to fall out of scope. In contrast, concat is only used where -// necessary. -absl::Status OutputBatch(CalculatorContext* cc, - std::unique_ptr inference_state) { - const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); - std::vector> input_tensors; - - for (auto& keyed_tensors : inference_state->input_tensor_batches_) { - if (options_.batch_size() == 1) { - // Short circuit to avoid the cost of deep copying tensors in concat. - if (!keyed_tensors.second.empty()) { + // Pad by replicating the first tens or, then ignore the values. + keyed_tensors.second.resize(options_.batch_size()); + std::fill(keyed_tensors.second.begin() + + inference_state->batch_timestamps_.size(), + keyed_tensors.second.end(), keyed_tensors.second[0]); + tf::Tensor concated; + const tf::Status concat_status = + tf::tensor::Concat(keyed_tensors.second, &concated); + CHECK(concat_status.ok()) << concat_status.ToString(); input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - keyed_tensors.second[0]); - } else { - // The input buffer can be empty for recurrent tensors. - RET_CHECK( - mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first)) - << "A non-recurrent tensor does not have an input: " - << keyed_tensors.first; + concated); } - } else { - // Pad by replicating the first tens or, then ignore the values. - keyed_tensors.second.resize(options_.batch_size()); - std::fill(keyed_tensors.second.begin() + - inference_state->batch_timestamps_.size(), - keyed_tensors.second.end(), keyed_tensors.second[0]); - tf::Tensor concated; - const tf::Status concat_status = - tf::tensor::Concat(keyed_tensors.second, &concated); - CHECK(concat_status.ok()) << concat_status.ToString(); - input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - concated); } - } - inference_state->input_tensor_batches_.clear(); - std::vector output_tensor_names; - std::vector output_name_in_signature; - for (const std::string& tag : cc->Outputs().GetTags()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); - output_name_in_signature.emplace_back(tag); - } - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - // Ensure that we always fetch the recurrent state tensors. - if (std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), - tag_pair.first) == output_name_in_signature.end()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); - output_name_in_signature.emplace_back(tag_pair.first); + inference_state->input_tensor_batches_.clear(); + std::vector output_tensor_names; + std::vector output_name_in_signature; + for (const std::string& tag : cc->Outputs().GetTags()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); + output_name_in_signature.emplace_back(tag); } - } - std::vector outputs; + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + // Ensure that we always fetch the recurrent state tensors. + if (std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), + tag_pair.first) == output_name_in_signature.end()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); + output_name_in_signature.emplace_back(tag_pair.first); + } + } + std::vector outputs; - SimpleSemaphore* session_run_throttle = nullptr; - if (options_.max_concurrent_session_runs() > 0) { - session_run_throttle = - get_session_run_throttle(options_.max_concurrent_session_runs()); - session_run_throttle->Acquire(1); - } - const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); - tf::Status tf_status; - { + SimpleSemaphore* session_run_throttle = nullptr; + if (options_.max_concurrent_session_runs() > 0) { + session_run_throttle = + get_session_run_throttle(options_.max_concurrent_session_runs()); + session_run_throttle->Acquire(1); + } + const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + tf::Status tf_status; + { #if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) - tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); + tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); #endif - tf_status = session_->Run(input_tensors, output_tensor_names, - {} /* target_node_names */, &outputs); - } + tf_status = session_->Run(input_tensors, output_tensor_names, + {} /* target_node_names */, &outputs); + } - if (session_run_throttle != nullptr) { - session_run_throttle->Release(1); - } + if (session_run_throttle != nullptr) { + session_run_throttle->Release(1); + } - // RET_CHECK on the tf::Status object itself in order to print an - // informative error message. - RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); + // RET_CHECK on the tf::Status object itself in order to print an + // informative error message. + RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); - const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) - ->IncrementBy(run_end_time - run_start_time); - cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); + const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) + ->IncrementBy(run_end_time - run_start_time); + cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); - // Feed back the recurrent state. - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - int pos = std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), tag_pair.first) - - output_name_in_signature.begin(); - inference_state->input_tensor_batches_[tag_pair.second].emplace_back( - outputs[pos]); - } + // Feed back the recurrent state. + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + int pos = std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), tag_pair.first) - + output_name_in_signature.begin(); + inference_state->input_tensor_batches_[tag_pair.second].emplace_back( + outputs[pos]); + } - absl::WriterMutexLock l(&mutex_); - // Set that we want to split on each index of the 0th dimension. - std::vector split_vector(options_.batch_size(), 1); - for (int i = 0; i < output_tensor_names.size(); ++i) { - if (options_.batch_size() == 1) { - if (cc->Outputs().HasTag(output_name_in_signature[i])) { - tf::Tensor output_tensor(outputs[i]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[0]); - } - } else { - std::vector split_tensors; - const tf::Status split_status = - tf::tensor::Split(outputs[i], split_vector, &split_tensors); - CHECK(split_status.ok()) << split_status.ToString(); - // Loop over timestamps so that we don't copy the padding. - for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { - tf::Tensor output_tensor(split_tensors[j]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[j]); + absl::WriterMutexLock l(&mutex_); + // Set that we want to split on each index of the 0th dimension. + std::vector split_vector(options_.batch_size(), 1); + for (int i = 0; i < output_tensor_names.size(); ++i) { + if (options_.batch_size() == 1) { + if (cc->Outputs().HasTag(output_name_in_signature[i])) { + tf::Tensor output_tensor(outputs[i]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[0]); + } + } else { + std::vector split_tensors; + const tf::Status split_status = + tf::tensor::Split(outputs[i], split_vector, &split_tensors); + CHECK(split_status.ok()) << split_status.ToString(); + // Loop over timestamps so that we don't copy the padding. + for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { + tf::Tensor output_tensor(split_tensors[j]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[j]); + } } } + + // Get end time and report. + const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalUsecsCounterSuffix) + ->IncrementBy(end_time - start_time); + cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) + ->IncrementBy(inference_state->batch_timestamps_.size()); + + // Make sure we hold on to the recursive state. + if (!options_.recurrent_tag_pair().empty()) { + inference_state_ = std::move(inference_state); + inference_state_->batch_timestamps_.clear(); + } + + return absl::OkStatus(); } - // Get end time and report. - const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time); - cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) - ->IncrementBy(inference_state->batch_timestamps_.size()); + private: + // The Session object is provided by a packet factory and is owned by the + // MediaPipe framework. Individual calls are thread-safe, but session state + // may be shared across threads. + tf::Session* session_; - // Make sure we hold on to the recursive state. - if (!options_.recurrent_tag_pair().empty()) { - inference_state_ = std::move(inference_state); - inference_state_->batch_timestamps_.clear(); + // A mapping between stream tags and the tensor names they are bound to. + std::map tag_to_tensor_map_; + + absl::Mutex mutex_; + std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); + + // The options for the calculator. + TensorFlowInferenceCalculatorOptions options_; + + // Store the feed and fetch tags for feed/fetch recurrent networks. + std::set recurrent_feed_tags_; + std::map recurrent_fetch_tags_to_feed_tags_; + + // Clock used to measure the computation time in OutputBatch(). + std::unique_ptr clock_; + + // The static singleton semaphore to throttle concurrent session runs. + static SimpleSemaphore* get_session_run_throttle( + int32 max_concurrent_session_runs) { + static SimpleSemaphore* session_run_throttle = + new SimpleSemaphore(max_concurrent_session_runs); + return session_run_throttle; } - - return absl::OkStatus(); -} - -private: -// The Session object is provided by a packet factory and is owned by the -// MediaPipe framework. Individual calls are thread-safe, but session state may -// be shared across threads. -tf::Session* session_; - -// A mapping between stream tags and the tensor names they are bound to. -std::map tag_to_tensor_map_; - -absl::Mutex mutex_; -std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); - -// The options for the calculator. -TensorFlowInferenceCalculatorOptions options_; - -// Store the feed and fetch tags for feed/fetch recurrent networks. -std::set recurrent_feed_tags_; -std::map recurrent_fetch_tags_to_feed_tags_; - -// Clock used to measure the computation time in OutputBatch(). -std::unique_ptr clock_; - -// The static singleton semaphore to throttle concurrent session runs. -static SimpleSemaphore* get_session_run_throttle( - int32 max_concurrent_session_runs) { - static SimpleSemaphore* session_run_throttle = - new SimpleSemaphore(max_concurrent_session_runs); - return session_run_throttle; -} -} -; +}; REGISTER_CALCULATOR(TensorFlowInferenceCalculator); constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[]; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 6aedb138f..c169c6b1e 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -80,6 +80,7 @@ const std::string MaybeConvertSignatureToTag( // which in turn contains a TensorFlow Session ready for execution and a map // between tags and tensor names. // +// // Example usage: // node { // calculator: "TensorFlowSessionFromSavedModelCalculator" diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 6f52f09ef..c4f7d40a9 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -217,38 +217,41 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { first_timestamp_seen_ = recent_timestamp; } } - if (recent_timestamp > last_timestamp_seen) { + if (recent_timestamp > last_timestamp_seen && + recent_timestamp < Timestamp::PostStream().Value()) { last_timestamp_key_ = map_kv.first; last_timestamp_seen = recent_timestamp; } } } if (!timestamps_.empty()) { - RET_CHECK(!last_timestamp_key_.empty()) - << "Something went wrong because the timestamp key is unset. " - "Example: " - << sequence_->DebugString(); - RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value()) - << "Something went wrong because the last timestamp is unset. " - "Example: " - << sequence_->DebugString(); - RET_CHECK_LT(first_timestamp_seen_, - Timestamp::OneOverPostStream().Value()) - << "Something went wrong because the first timestamp is unset. " - "Example: " - << sequence_->DebugString(); + for (const auto& kv : timestamps_) { + if (!kv.second.empty() && + kv.second[0] < Timestamp::PostStream().Value()) { + // These checks only make sense if any values are not PostStream, but + // only need to be made once. + RET_CHECK(!last_timestamp_key_.empty()) + << "Something went wrong because the timestamp key is unset. " + << "Example: " << sequence_->DebugString(); + RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value()) + << "Something went wrong because the last timestamp is unset. " + << "Example: " << sequence_->DebugString(); + RET_CHECK_LT(first_timestamp_seen_, + Timestamp::OneOverPostStream().Value()) + << "Something went wrong because the first timestamp is unset. " + << "Example: " << sequence_->DebugString(); + break; + } + } } current_timestamp_index_ = 0; + process_poststream_ = false; // Determine the data path and output it. const auto& options = cc->Options(); const auto& sequence = cc->InputSidePackets() .Tag(kSequenceExampleTag) .Get(); - if (cc->Outputs().HasTag(kKeypointsTag)) { - keypoint_names_ = absl::StrSplit(options.keypoint_names(), ','); - default_keypoint_location_ = options.default_keypoint_location(); - } if (cc->OutputSidePackets().HasTag(kDataPath)) { std::string root_directory = ""; if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) { @@ -349,19 +352,30 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // all packets on all streams that have a timestamp between the current // reference timestep and the previous reference timestep. This ensures that // we emit all timestamps in order, but also only emit a limited number in - // any particular call to Process(). - int64 start_timestamp = - timestamps_[last_timestamp_key_][current_timestamp_index_]; - if (current_timestamp_index_ == 0) { - start_timestamp = first_timestamp_seen_; + // any particular call to Process(). At the every end, we output the + // poststream packets. If we only have poststream packets, + // last_timestamp_key_ will be empty. + int64 start_timestamp = 0; + int64 end_timestamp = 0; + if (last_timestamp_key_.empty() || process_poststream_) { + process_poststream_ = true; + start_timestamp = Timestamp::PostStream().Value(); + end_timestamp = Timestamp::OneOverPostStream().Value(); + } else { + start_timestamp = + timestamps_[last_timestamp_key_][current_timestamp_index_]; + if (current_timestamp_index_ == 0) { + start_timestamp = first_timestamp_seen_; + } + + end_timestamp = start_timestamp + 1; // Base case at end of sequence. + if (current_timestamp_index_ < + timestamps_[last_timestamp_key_].size() - 1) { + end_timestamp = + timestamps_[last_timestamp_key_][current_timestamp_index_ + 1]; + } } - int64 end_timestamp = start_timestamp + 1; // Base case at end of sequence. - if (current_timestamp_index_ < - timestamps_[last_timestamp_key_].size() - 1) { - end_timestamp = - timestamps_[last_timestamp_key_][current_timestamp_index_ + 1]; - } for (const auto& map_kv : timestamps_) { for (int i = 0; i < map_kv.second.size(); ++i) { if (map_kv.second[i] >= start_timestamp && @@ -438,7 +452,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) { return absl::OkStatus(); } else { - return tool::StatusStop(); + if (process_poststream_) { + // Once we've processed the PostStream timestamp we can stop. + return tool::StatusStop(); + } else { + // Otherwise, we still need to do one more pass to process it. + process_poststream_ = true; + return absl::OkStatus(); + } } } @@ -462,6 +483,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { std::vector keypoint_names_; // Default keypoint location when missing. float default_keypoint_location_; + bool process_poststream_; }; REGISTER_CALCULATOR(UnpackMediaSequenceCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index dcbda224e..e8e40bad3 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -412,6 +412,72 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { ::testing::Eq(Timestamp::PostStream())); } +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) { + SetUpCalculator({"IMAGE:images"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + + int num_images = 1; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(i, input_sequence.get()); + mpms::AddImageEncoded(test_image_string, input_sequence.get()); + } + + mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get()); + mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), + input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("IMAGE").packets; + ASSERT_EQ(num_images, output_packets.size()); + + for (int i = 0; i < num_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) { + SetUpCalculator({"FLOAT_FEATURE_FDENSE_MAX:max"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + + int num_images = 1; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(i, input_sequence.get()); + mpms::AddImageEncoded(test_image_string, input_sequence.get()); + } + + mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get()); + mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), + input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& fdense_max_packets = + runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + ASSERT_EQ(fdense_max_packets.size(), 1); + const auto& fdense_max_vector = + fdense_max_packets[0].Get>(); + ASSERT_THAT(fdense_max_vector, ::testing::ElementsAreArray({3.0f, 4.0f})); + ASSERT_THAT(fdense_max_packets[0].Timestamp(), + ::testing::Eq(Timestamp::PostStream())); +} + TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index bc05f51b5..ef46460b1 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -904,7 +904,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); - options.compile_options.precision_loss_allowed = 1; + options.compile_options.precision_loss_allowed = + allow_precision_loss_ ? 1 : 0; options.compile_options.preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST; options.compile_options.dynamic_batch_enabled = 0; @@ -968,7 +969,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { const int kHalfSize = 2; // sizeof(half) // Configure and create the delegate. TFLGpuDelegateOptions options; - options.allow_precision_loss = true; + options.allow_precision_loss = allow_precision_loss_; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive; if (!delegate_) delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), @@ -1080,9 +1081,10 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { } // Create converter for GPU output. - converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device - isFloat16:true - convertToPBHWC4:false]; + converter_from_BPHWC4_ = + [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:allow_precision_loss_ + convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { return absl::InternalError( "Error initializating output buffer converter"); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index ec4945201..22a9a8d70 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -439,7 +439,7 @@ absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( // Run shader, upsample result. { - gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(output_texture); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); GlRender(); diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index bc24e5994..1ee0fb9cc 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -821,6 +821,25 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "landmark_projection_calculator_test", + srcs = ["landmark_projection_calculator_test.cc"], + deps = [ + ":landmark_projection_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_utils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], @@ -1252,3 +1271,45 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", ], ) + +mediapipe_proto_library( + name = "refine_landmarks_from_heatmap_calculator_proto", + srcs = ["refine_landmarks_from_heatmap_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "refine_landmarks_from_heatmap_calculator", + srcs = ["refine_landmarks_from_heatmap_calculator.cc"], + hdrs = ["refine_landmarks_from_heatmap_calculator.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":refine_landmarks_from_heatmap_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "refine_landmarks_from_heatmap_calculator_test", + srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"], + deps = [ + ":refine_landmarks_from_heatmap_calculator", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 7c5aadc55..2c0b25397 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -402,7 +402,7 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc, // Blend overlay image in GPU shader. { - gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + gpu_helper_.BindFramebuffer(output_texture); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, input_texture.name()); diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 0de1e53b2..07fe791f6 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -54,6 +54,7 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase { private: absl::node_hash_map label_map_; + ::mediapipe::DetectionLabelIdToTextCalculatorOptions options_; }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); @@ -68,13 +69,13 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract( absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - const auto& options = + options_ = cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); - if (options.has_label_map_path()) { + if (options_.has_label_map_path()) { std::string string_path; ASSIGN_OR_RETURN(string_path, - PathToResourceAsFile(options.label_map_path())); + PathToResourceAsFile(options_.label_map_path())); std::string label_map_string; MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); @@ -85,8 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { label_map_[i++] = line; } } else { - for (int i = 0; i < options.label_size(); ++i) { - label_map_[i] = options.label(i); + for (int i = 0; i < options_.label_size(); ++i) { + label_map_[i] = options_.label(i); } } return absl::OkStatus(); @@ -106,7 +107,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { } } // Remove label_id field if text labels exist. - if (has_text_label) { + if (has_text_label && !options_.keep_label_id()) { output_detection.clear_label_id(); } } diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto index b722b41c2..198ca4d65 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto @@ -31,4 +31,9 @@ message DetectionLabelIdToTextCalculatorOptions { // label: "label for id 1" // ... repeated string label = 2; + + // By default, the `label_id` field from the input is stripped if a text label + // could be found. By setting this field to true, it is always copied to the + // output detections. + optional bool keep_label_id = 3; } diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index cf448cff1..099bdc7e6 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -120,7 +120,11 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { labels.resize(classifications.classification_size()); scores.resize(classifications.classification_size()); for (int i = 0; i < classifications.classification_size(); ++i) { - labels[i] = classifications.classification(i).label(); + if (options_.use_display_name()) { + labels[i] = classifications.classification(i).display_name(); + } else { + labels[i] = classifications.classification(i).label(); + } scores[i] = classifications.classification(i).score(); } } else { diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index cd98934a5..c5012ce85 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -59,4 +59,7 @@ message LabelsToRenderDataCalculatorOptions { BOTTOM_LEFT = 1; } optional Location location = 6 [default = TOP_LEFT]; + + // Uses Classification.display_name field instead of Classification.label. + optional bool use_display_name = 9 [default = false]; } diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index 59b7c020c..e27edea66 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "mediapipe/calculators/util/landmark_projection_calculator.pb.h" @@ -27,20 +28,32 @@ namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kRectTag[] = "NORM_RECT"; +constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX"; } // namespace -// Projects normalized landmarks in a rectangle to its original coordinates. The -// rectangle must also be in normalized coordinates. +// Projects normalized landmarks to its original coordinates. // Input: -// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks -// in a normalized rectangle. -// NORM_RECT: An NormalizedRect representing a normalized rectangle in image -// coordinates. +// NORM_LANDMARKS - NormalizedLandmarkList +// Represents landmarks in a normalized rectangle if NORM_RECT is specified +// or landmarks that should be projected using PROJECTION_MATRIX if +// specified. (Prefer using PROJECTION_MATRIX as it eliminates need of +// letterbox removal step.) +// NORM_RECT - NormalizedRect +// Represents a normalized rectangle in image coordinates and results in +// landmarks with their locations adjusted to the image. +// PROJECTION_MATRIX - std::array +// A 4x4 row-major-order matrix that maps landmarks' locations from one +// coordinate system to another. In this case from the coordinate system of +// the normalized region of interest to the coordinate system of the image. +// +// Note: either NORM_RECT or PROJECTION_MATRIX has to be specified. +// Note: landmark's Z is projected in a custom way - it's scaled by width of +// the normalized region of interest used during landmarks detection. // // Output: -// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks -// with their locations adjusted to the image. +// NORM_LANDMARKS - NormalizedLandmarkList +// Landmarks with their locations adjusted according to the inputs. // // Usage example: // node { @@ -58,12 +71,27 @@ constexpr char kRectTag[] = "NORM_RECT"; // output_stream: "NORM_LANDMARKS:0:projected_landmarks_0" // output_stream: "NORM_LANDMARKS:1:projected_landmarks_1" // } +// +// node { +// calculator: "LandmarkProjectionCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// input_stream: "PROECTION_MATRIX:matrix" +// output_stream: "NORM_LANDMARKS:projected_landmarks" +// } +// +// node { +// calculator: "LandmarkProjectionCalculator" +// input_stream: "NORM_LANDMARKS:0:landmarks_0" +// input_stream: "NORM_LANDMARKS:1:landmarks_1" +// input_stream: "PROECTION_MATRIX:matrix" +// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0" +// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1" +// } class LandmarkProjectionCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && - cc->Inputs().HasTag(kRectTag)) - << "Missing one or more input streams."; + RET_CHECK(cc->Inputs().HasTag(kLandmarksTag)) + << "Missing NORM_LANDMARKS input."; RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag), cc->Outputs().NumEntries(kLandmarksTag)) @@ -73,7 +101,14 @@ class LandmarkProjectionCalculator : public CalculatorBase { id != cc->Inputs().EndId(kLandmarksTag); ++id) { cc->Inputs().Get(id).Set(); } - cc->Inputs().Tag(kRectTag).Set(); + RET_CHECK(cc->Inputs().HasTag(kRectTag) ^ + cc->Inputs().HasTag(kProjectionMatrix)) + << "Either NORM_RECT or PROJECTION_MATRIX must be specified."; + if (cc->Inputs().HasTag(kRectTag)) { + cc->Inputs().Tag(kRectTag).Set(); + } else { + cc->Inputs().Tag(kProjectionMatrix).Set>(); + } for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag); id != cc->Outputs().EndId(kLandmarksTag); ++id) { @@ -89,31 +124,50 @@ class LandmarkProjectionCalculator : public CalculatorBase { return absl::OkStatus(); } + static void ProjectXY(const NormalizedLandmark& lm, + const std::array& matrix, + NormalizedLandmark* out) { + out->set_x(lm.x() * matrix[0] + lm.y() * matrix[1] + lm.z() * matrix[2] + + matrix[3]); + out->set_y(lm.x() * matrix[4] + lm.y() * matrix[5] + lm.z() * matrix[6] + + matrix[7]); + } + + /** + * Landmark's Z scale is equal to a relative (to image) width of region of + * interest used during detection. To calculate based on matrix: + * 1. Project (0,0) --- (1,0) segment using matrix. + * 2. Calculate length of the projected segment. + */ + static float CalculateZScale(const std::array& matrix) { + NormalizedLandmark a; + a.set_x(0.0f); + a.set_y(0.0f); + NormalizedLandmark b; + b.set_x(1.0f); + b.set_y(0.0f); + NormalizedLandmark a_projected; + ProjectXY(a, matrix, &a_projected); + NormalizedLandmark b_projected; + ProjectXY(b, matrix, &b_projected); + return std::sqrt(std::pow(b_projected.x() - a_projected.x(), 2) + + std::pow(b_projected.y() - a_projected.y(), 2)); + } + absl::Status Process(CalculatorContext* cc) override { - if (cc->Inputs().Tag(kRectTag).IsEmpty()) { - return absl::OkStatus(); - } - const auto& input_rect = cc->Inputs().Tag(kRectTag).Get(); - - const auto& options = - cc->Options<::mediapipe::LandmarkProjectionCalculatorOptions>(); - - CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag); - CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag); - // Number of inputs and outpus is the same according to the contract. - for (; input_id != cc->Inputs().EndId(kLandmarksTag); - ++input_id, ++output_id) { - const auto& input_packet = cc->Inputs().Get(input_id); - if (input_packet.IsEmpty()) { - continue; + std::function + project_fn; + if (cc->Inputs().HasTag(kRectTag)) { + if (cc->Inputs().Tag(kRectTag).IsEmpty()) { + return absl::OkStatus(); } - - const auto& input_landmarks = input_packet.Get(); - NormalizedLandmarkList output_landmarks; - for (int i = 0; i < input_landmarks.landmark_size(); ++i) { - const NormalizedLandmark& landmark = input_landmarks.landmark(i); - NormalizedLandmark* new_landmark = output_landmarks.add_landmark(); - + const auto& input_rect = cc->Inputs().Tag(kRectTag).Get(); + const auto& options = + cc->Options(); + project_fn = [&input_rect, &options](const NormalizedLandmark& landmark, + NormalizedLandmark* new_landmark) { + // TODO: fix projection or deprecate (current projection + // calculations are incorrect for general case). const float x = landmark.x() - 0.5f; const float y = landmark.y() - 0.5f; const float angle = @@ -130,10 +184,44 @@ class LandmarkProjectionCalculator : public CalculatorBase { new_landmark->set_x(new_x); new_landmark->set_y(new_y); new_landmark->set_z(new_z); + }; + } else if (cc->Inputs().HasTag(kProjectionMatrix)) { + if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) { + return absl::OkStatus(); + } + const auto& project_mat = + cc->Inputs().Tag(kProjectionMatrix).Get>(); + const float z_scale = CalculateZScale(project_mat); + project_fn = [&project_mat, z_scale](const NormalizedLandmark& lm, + NormalizedLandmark* new_landmark) { + *new_landmark = lm; + ProjectXY(lm, project_mat, new_landmark); + new_landmark->set_z(z_scale * lm.z()); + }; + } else { + return absl::InternalError("Either rect or matrix must be specified."); + } + + CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag); + CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag); + // Number of inputs and outpus is the same according to the contract. + for (; input_id != cc->Inputs().EndId(kLandmarksTag); + ++input_id, ++output_id) { + const auto& input_packet = cc->Inputs().Get(input_id); + if (input_packet.IsEmpty()) { + continue; + } + + const auto& input_landmarks = input_packet.Get(); + NormalizedLandmarkList output_landmarks; + for (int i = 0; i < input_landmarks.landmark_size(); ++i) { + const NormalizedLandmark& landmark = input_landmarks.landmark(i); + NormalizedLandmark* new_landmark = output_landmarks.add_landmark(); + project_fn(landmark, new_landmark); } cc->Outputs().Get(output_id).AddPacket( - MakePacket(output_landmarks) + MakePacket(std::move(output_landmarks)) .At(cc->InputTimestamp())); } return absl::OkStatus(); diff --git a/mediapipe/calculators/util/landmark_projection_calculator_test.cc b/mediapipe/calculators/util/landmark_projection_calculator_test.cc new file mode 100644 index 000000000..b15bb0f0c --- /dev/null +++ b/mediapipe/calculators/util/landmark_projection_calculator_test.cc @@ -0,0 +1,240 @@ +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +absl::StatusOr RunCalculator( + mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) { + mediapipe::CalculatorRunner runner( + ParseTextProtoOrDie(R"pb( + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + input_stream: "NORM_RECT:rect" + output_stream: "NORM_LANDMARKS:projected_landmarks" + )pb")); + runner.MutableInputs() + ->Tag("NORM_LANDMARKS") + .packets.push_back( + MakePacket(std::move(input)) + .At(Timestamp(1))); + runner.MutableInputs() + ->Tag("NORM_RECT") + .packets.push_back(MakePacket(std::move(rect)) + .At(Timestamp(1))); + + MP_RETURN_IF_ERROR(runner.Run()); + const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + RET_CHECK_EQ(output_packets.size(), 1); + return output_packets[0].Get(); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithDefaultRect) { + mediapipe::NormalizedLandmarkList landmarks = + ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + )pb"); + mediapipe::NormalizedRect rect = + ParseTextProtoOrDie( + R"pb( + x_center: 0.5, + y_center: 0.5, + width: 1.0, + height: 1.0, + rotation: 0.0 + )pb"); + + auto status_or_result = RunCalculator(std::move(landmarks), std::move(rect)); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT( + status_or_result.value(), + EqualsProto(ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + )pb"))); +} + +mediapipe::NormalizedRect GetCroppedRect() { + return ParseTextProtoOrDie( + R"pb( + x_center: 0.5, y_center: 0.5, width: 0.5, height: 2, rotation: 0.0 + )pb"); +} + +mediapipe::NormalizedLandmarkList GetCroppedRectTestInput() { + return ParseTextProtoOrDie(R"pb( + landmark { x: 1.0, y: 1.0, z: -0.5 } + )pb"); +} + +mediapipe::NormalizedLandmarkList GetCroppedRectTestExpectedResult() { + return ParseTextProtoOrDie(R"pb( + landmark { x: 0.75, y: 1.5, z: -0.25 } + )pb"); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRect) { + auto status_or_result = + RunCalculator(GetCroppedRectTestInput(), GetCroppedRect()); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT(status_or_result.value(), + EqualsProto(GetCroppedRectTestExpectedResult())); +} + +absl::StatusOr RunCalculator( + mediapipe::NormalizedLandmarkList input, std::array matrix) { + mediapipe::CalculatorRunner runner( + ParseTextProtoOrDie(R"pb( + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + input_stream: "PROJECTION_MATRIX:matrix" + output_stream: "NORM_LANDMARKS:projected_landmarks" + )pb")); + runner.MutableInputs() + ->Tag("NORM_LANDMARKS") + .packets.push_back( + MakePacket(std::move(input)) + .At(Timestamp(1))); + runner.MutableInputs() + ->Tag("PROJECTION_MATRIX") + .packets.push_back(MakePacket>(std::move(matrix)) + .At(Timestamp(1))); + + MP_RETURN_IF_ERROR(runner.Run()); + const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + RET_CHECK_EQ(output_packets.size(), 1); + return output_packets[0].Get(); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithIdentityMatrix) { + mediapipe::NormalizedLandmarkList landmarks = + ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + )pb"); + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + + auto status_or_result = + RunCalculator(std::move(landmarks), std::move(matrix)); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT( + status_or_result.value(), + EqualsProto(ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + )pb"))); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRectMatrix) { + constexpr int kRectWidth = 1280; + constexpr int kRectHeight = 720; + auto roi = GetRoi(kRectWidth, kRectHeight, GetCroppedRect()); + std::array matrix; + GetRotatedSubRectToRectTransformMatrix(roi, kRectWidth, kRectHeight, + /*flip_horizontaly=*/false, &matrix); + auto status_or_result = RunCalculator(GetCroppedRectTestInput(), matrix); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT(status_or_result.value(), + EqualsProto(GetCroppedRectTestExpectedResult())); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithScaleMatrix) { + mediapipe::NormalizedLandmarkList landmarks = + ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + landmark { x: 5, y: 6, z: 7 } + )pb"); + // clang-format off + std::array matrix = { + 10.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 100.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + + auto status_or_result = + RunCalculator(std::move(landmarks), std::move(matrix)); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT( + status_or_result.value(), + EqualsProto(ParseTextProtoOrDie(R"pb( + landmark { x: 100, y: 2000, z: -5 } + landmark { x: 50, y: 600, z: 70 } + )pb"))); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithTranslateMatrix) { + mediapipe::NormalizedLandmarkList landmarks = + ParseTextProtoOrDie(R"pb( + landmark { x: 10, y: 20, z: -0.5 } + )pb"); + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 1.0f, 0.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + + auto status_or_result = + RunCalculator(std::move(landmarks), std::move(matrix)); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT( + status_or_result.value(), + EqualsProto(ParseTextProtoOrDie(R"pb( + landmark { x: 11, y: 22, z: -0.5 } + )pb"))); +} + +TEST(LandmarkProjectionCalculatorTest, ProjectingWithRotationMatrix) { + mediapipe::NormalizedLandmarkList landmarks = + ParseTextProtoOrDie(R"pb( + landmark { x: 4, y: 0, z: -0.5 } + )pb"); + // clang-format off + // 90 degrees rotation matrix + std::array matrix = { + 0.0f, -1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + + auto status_or_result = + RunCalculator(std::move(landmarks), std::move(matrix)); + MP_ASSERT_OK(status_or_result); + + EXPECT_THAT( + status_or_result.value(), + EqualsProto(ParseTextProtoOrDie(R"pb( + landmark { x: 0, y: 4, z: -0.5 } + )pb"))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index 79a740315..d94615228 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -1,3 +1,17 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/rect.pb.h" diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto index ef80cf3cf..dda6e2c9c 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto @@ -1,3 +1,17 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + syntax = "proto2"; package mediapipe; diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc new file mode 100644 index 000000000..08d9704e5 --- /dev/null +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc @@ -0,0 +1,166 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h" + +#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { + +namespace { + +inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); } + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims) { + if (dims.size() == 3) { + return std::make_tuple(dims[0], dims[1], dims[2]); + } else if (dims.size() == 4) { + // BHWC format check B == 1 + RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } else { + RET_CHECK(false) << "Invalid shape size for heatmap tensor" << dims.size(); + } +} + +} // namespace + +namespace api2 { + +// Refines landmarks using correspond heatmap area. +// +// Input: +// NORM_LANDMARKS - Required. Input normalized landmarks to update. +// TENSORS - Required. Vector of input tensors. 0th element should be heatmap. +// The rest is unused. +// Output: +// NORM_LANDMARKS - Required. Updated normalized landmarks. +class RefineLandmarksFromHeatmapCalculatorImpl + : public NodeImpl { + public: + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } + + absl::Status Process(CalculatorContext* cc) override { + // Make sure we bypass landmarks if there is no detection. + if (kInLandmarks(cc).IsEmpty()) { + return absl::OkStatus(); + } + // If for some reason heatmap is missing, just return original landmarks. + if (kInTensors(cc).IsEmpty()) { + kOutLandmarks(cc).Send(*kInLandmarks(cc)); + return absl::OkStatus(); + } + + // Check basic prerequisites. + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()) << "Empty input tensors list. First " + "element is expeced to be a heatmap"; + + const auto& hm_tensor = input_tensors[0]; + const auto& in_lms = *kInLandmarks(cc); + auto hm_view = hm_tensor.GetCpuReadView(); + auto hm_raw = hm_view.buffer(); + const auto& options = + cc->Options(); + + ASSIGN_OR_RETURN(auto out_lms, RefineLandmarksFromHeatMap( + in_lms, hm_raw, hm_tensor.shape().dims, + options.kernel_size(), + options.min_confidence_to_refine())); + + kOutLandmarks(cc).Send(std::move(out_lms)); + return absl::OkStatus(); + } +}; + +} // namespace api2 + +// Runs actual refinement +// High level algorithm: +// +// Heatmap is accepted as tensor in HWC layout where i-th channel is a heatmap +// for the i-th landmark. +// +// For each landmark we replace original value with a value calculated from the +// area in heatmap close to original landmark position (in particular are +// covered with kernel of size options.kernel_size). To calculate new coordinate +// from heatmap we calculate an weighted average inside the kernel. We update +// the landmark iff heatmap is confident in it's prediction i.e. max(heatmap) in +// kernel is at least options.min_confidence_to_refine big. +absl::StatusOr RefineLandmarksFromHeatMap( + const mediapipe::NormalizedLandmarkList& in_lms, + const float* heatmap_raw_data, const std::vector& heatmap_dims, + int kernel_size, float min_confidence_to_refine) { + ASSIGN_OR_RETURN(auto hm_dims, GetHwcFromDims(heatmap_dims)); + auto [hm_height, hm_width, hm_channels] = hm_dims; + + RET_CHECK_EQ(in_lms.landmark_size(), hm_channels) + << "Expected heatmap to have number of layers == to number of " + "landmarks"; + + int hm_row_size = hm_width * hm_channels; + int hm_pixel_size = hm_channels; + + mediapipe::NormalizedLandmarkList out_lms = in_lms; + for (int lm_index = 0; lm_index < out_lms.landmark_size(); ++lm_index) { + int center_col = out_lms.landmark(lm_index).x() * hm_width; + int center_row = out_lms.landmark(lm_index).y() * hm_height; + // Point is outside of the image let's keep it intact. + if (center_col < 0 || center_col >= hm_width || center_row < 0 || + center_col >= hm_height) { + continue; + } + + int offset = (kernel_size - 1) / 2; + // Calculate area to iterate over. Note that we decrease the kernel on + // the edges of the heatmap. Equivalent to zero border. + int begin_col = std::max(0, center_col - offset); + int end_col = std::min(hm_width, center_col + offset + 1); + int begin_row = std::max(0, center_row - offset); + int end_row = std::min(hm_height, center_row + offset + 1); + + float sum = 0; + float weighted_col = 0; + float weighted_row = 0; + float max_value = 0; + + // Main loop. Go over kernel and calculate weighted sum of coordinates, + // sum of weights and max weights. + for (int row = begin_row; row < end_row; ++row) { + for (int col = begin_col; col < end_col; ++col) { + // We expect memory to be in HWC layout without padding. + int idx = hm_row_size * row + hm_pixel_size * col + lm_index; + // Right now we hardcode sigmoid activation as it will be wasteful to + // calculate sigmoid for each value of heatmap in the model itself. If + // we ever have other activations it should be trivial to expand via + // options. + float confidence = Sigmoid(heatmap_raw_data[idx]); + sum += confidence; + max_value = std::max(max_value, confidence); + weighted_col += col * confidence; + weighted_row += row * confidence; + } + } + if (max_value >= min_confidence_to_refine && sum > 0) { + out_lms.mutable_landmark(lm_index)->set_x(weighted_col / hm_width / sum); + out_lms.mutable_landmark(lm_index)->set_y(weighted_row / hm_height / sum); + } + } + return out_lms; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h new file mode 100644 index 000000000..9656347e1 --- /dev/null +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h @@ -0,0 +1,50 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { +namespace api2 { + +class RefineLandmarksFromHeatmapCalculator : public NodeIntf { + public: + static constexpr Input kInLandmarks{ + "NORM_LANDMARKS"}; + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Output kOutLandmarks{ + "NORM_LANDMARKS"}; + + MEDIAPIPE_NODE_INTERFACE(RefineLandmarksFromHeatmapCalculator, kInLandmarks, + kInTensors, kOutLandmarks); +}; + +} // namespace api2 + +// Exposed for testing. +absl::StatusOr RefineLandmarksFromHeatMap( + const mediapipe::NormalizedLandmarkList& in_lms, + const float* heatmap_raw_data, const std::vector& heatmap_dims, + int kernel_size, float min_confidence_to_refine); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.proto b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.proto new file mode 100644 index 000000000..1f8ff04b6 --- /dev/null +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.proto @@ -0,0 +1,27 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message RefineLandmarksFromHeatmapCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional RefineLandmarksFromHeatmapCalculatorOptions ext = 362281653; + } + optional int32 kernel_size = 1 [default = 9]; + optional float min_confidence_to_refine = 2 [default = 0.5]; +} diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator_test.cc b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator_test.cc new file mode 100644 index 000000000..83afacbbc --- /dev/null +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator_test.cc @@ -0,0 +1,152 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +mediapipe::NormalizedLandmarkList vec_to_lms( + const std::vector>& inp) { + mediapipe::NormalizedLandmarkList ret; + for (const auto& it : inp) { + auto new_lm = ret.add_landmark(); + new_lm->set_x(it.first); + new_lm->set_y(it.second); + } + return ret; +} + +std::vector> lms_to_vec( + const mediapipe::NormalizedLandmarkList& lst) { + std::vector> ret; + for (const auto& lm : lst.landmark()) { + ret.push_back({lm.x(), lm.y()}); + } + return ret; +} + +std::vector CHW_to_HWC(std::vector inp, int height, int width, + int depth) { + std::vector ret(inp.size()); + const float* inp_ptr = inp.data(); + for (int c = 0; c < depth; ++c) { + for (int row = 0; row < height; ++row) { + for (int col = 0; col < width; ++col) { + int dest_idx = width * depth * row + depth * col + c; + ret[dest_idx] = *inp_ptr; + ++inp_ptr; + } + } + } + return ret; +} + +using testing::ElementsAre; +using testing::FloatEq; +using testing::Pair; + +TEST(RefineLandmarksFromHeatmapTest, Smoke) { + float z = -10000000000000000; + // clang-format off + std::vector hm = { + z, z, z, + 1, z, z, + z, z, z}; + // clang-format on + + auto ret_or_error = RefineLandmarksFromHeatMap(vec_to_lms({{0.5, 0.5}}), + hm.data(), {3, 3, 1}, 3, 0.1); + MP_EXPECT_OK(ret_or_error); + EXPECT_THAT(lms_to_vec(*ret_or_error), + ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)))); +} + +TEST(RefineLandmarksFromHeatmapTest, MultiLayer) { + float z = -10000000000000000; + // clang-format off + std::vector hm = CHW_to_HWC({ + z, z, z, + 1, z, z, + z, z, z, + z, z, z, + 1, z, z, + z, z, z, + z, z, z, + 1, z, z, + z, z, z}, 3, 3, 3); + // clang-format on + + auto ret_or_error = RefineLandmarksFromHeatMap( + vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3, + 0.1); + MP_EXPECT_OK(ret_or_error); + EXPECT_THAT(lms_to_vec(*ret_or_error), + ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)), + Pair(FloatEq(0), FloatEq(1 / 3.)), + Pair(FloatEq(0), FloatEq(1 / 3.)))); +} + +TEST(RefineLandmarksFromHeatmapTest, KeepIfNotSure) { + float z = -10000000000000000; + // clang-format off + std::vector hm = CHW_to_HWC({ + z, z, z, + 0, z, z, + z, z, z, + z, z, z, + 0, z, z, + z, z, z, + z, z, z, + 0, z, z, + z, z, z}, 3, 3, 3); + // clang-format on + + auto ret_or_error = RefineLandmarksFromHeatMap( + vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3, + 0.6); + MP_EXPECT_OK(ret_or_error); + EXPECT_THAT(lms_to_vec(*ret_or_error), + ElementsAre(Pair(FloatEq(0.5), FloatEq(0.5)), + Pair(FloatEq(0.5), FloatEq(0.5)), + Pair(FloatEq(0.5), FloatEq(0.5)))); +} + +TEST(RefineLandmarksFromHeatmapTest, Border) { + float z = -10000000000000000; + // clang-format off + std::vector hm = CHW_to_HWC({ + z, z, z, + 0, z, 0, + z, z, z, + + z, z, z, + 0, z, 0, + z, z, 0}, 3, 3, 2); + // clang-format on + + auto ret_or_error = RefineLandmarksFromHeatMap( + vec_to_lms({{0.0, 0.0}, {0.9, 0.9}}), hm.data(), {3, 3, 2}, 3, 0.1); + MP_EXPECT_OK(ret_or_error); + EXPECT_THAT(lms_to_vec(*ret_or_error), + ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)), + Pair(FloatEq(2 / 3.), FloatEq(1 / 6. + 2 / 6.)))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc index 65876c075..c86e6ca52 100644 --- a/mediapipe/calculators/util/thresholding_calculator.cc +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -101,7 +101,7 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) { } if (cc->InputSidePackets().HasTag("THRESHOLD")) { - threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); + threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); } return absl::OkStatus(); } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD index 44a6d6428..d9c660088 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD @@ -43,8 +43,7 @@ android_binary( "//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD index 5eff6a833..d1c45345f 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD @@ -37,7 +37,7 @@ android_binary( srcs = glob(["*.java"]), assets = [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", - "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite", ], assets_dir = "", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD deleted file mode 100644 index 50f9d643a..000000000 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2019 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -licenses(["notice"]) - -package(default_visibility = ["//visibility:private"]) - -cc_binary( - name = "libmediapipe_jni.so", - linkshared = 1, - linkstatic = 1, - deps = [ - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps", - "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", - ], -) - -cc_library( - name = "mediapipe_jni_lib", - srcs = [":libmediapipe_jni.so"], - alwayslink = 1, -) - -android_binary( - name = "upperbodyposetrackinggpu", - srcs = glob(["*.java"]), - assets = [ - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", - "//mediapipe/modules/pose_detection:pose_detection.tflite", - ], - assets_dir = "", - manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", - manifest_values = { - "applicationId": "com.google.mediapipe.apps.upperbodyposetrackinggpu", - "appName": "Upper Body Pose Tracking", - "mainActivity": ".MainActivity", - "cameraFacingFront": "False", - "binaryGraphName": "upper_body_pose_tracking_gpu.binarypb", - "inputVideoStreamName": "input_video", - "outputVideoStreamName": "output_video", - "flipFramesVertically": "True", - "converterNumBuffers": "2", - }, - multidex = "native", - deps = [ - ":mediapipe_jni_lib", - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", - "//mediapipe/framework/formats:landmark_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "@com_google_protobuf//:protobuf_javalite", - ], -) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/MainActivity.java deleted file mode 100644 index 99a3a81ed..000000000 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/MainActivity.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.apps.upperbodyposetrackinggpu; - -import android.os.Bundle; -import android.util.Log; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; -import com.google.mediapipe.framework.PacketGetter; -import com.google.protobuf.InvalidProtocolBufferException; - -/** Main activity of MediaPipe upper-body pose tracking app. */ -public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { - private static final String TAG = "MainActivity"; - - private static final String OUTPUT_LANDMARKS_STREAM_NAME = "pose_landmarks"; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - // To show verbose logging, run: - // adb shell setprop log.tag.MainActivity VERBOSE - if (Log.isLoggable(TAG, Log.VERBOSE)) { - processor.addPacketCallback( - OUTPUT_LANDMARKS_STREAM_NAME, - (packet) -> { - Log.v(TAG, "Received pose landmarks packet."); - try { - NormalizedLandmarkList poseLandmarks = - PacketGetter.getProto(packet, NormalizedLandmarkList.class); - Log.v( - TAG, - "[TS:" - + packet.getTimestamp() - + "] " - + getPoseLandmarksDebugString(poseLandmarks)); - } catch (InvalidProtocolBufferException exception) { - Log.e(TAG, "Failed to get proto.", exception); - } - }); - } - } - - private static String getPoseLandmarksDebugString(NormalizedLandmarkList poseLandmarks) { - String poseLandmarkStr = "Pose landmarks: " + poseLandmarks.getLandmarkCount() + "\n"; - int landmarkIndex = 0; - for (NormalizedLandmark landmark : poseLandmarks.getLandmarkList()) { - poseLandmarkStr += - "\tLandmark [" - + landmarkIndex - + "]: (" - + landmark.getX() - + ", " - + landmark.getY() - + ", " - + landmark.getZ() - + ")\n"; - ++landmarkIndex; - } - return poseLandmarkStr; - } -} diff --git a/mediapipe/examples/coral/graphs/face_detection_desktop_live.pbtxt b/mediapipe/examples/coral/graphs/face_detection_desktop_live.pbtxt index 553212868..fe72b14d6 100644 --- a/mediapipe/examples/coral/graphs/face_detection_desktop_live.pbtxt +++ b/mediapipe/examples/coral/graphs/face_detection_desktop_live.pbtxt @@ -142,26 +142,13 @@ node { } } -# Maps detection label IDs to the corresponding label text ("Face"). The label -# map is provided in the label_map_path option. -node { - calculator: "DetectionLabelIdToTextCalculator" - input_stream: "filtered_detections" - output_stream: "labeled_detections" - options: { - [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" - } - } -} - # Adjusts detection locations (already normalized to [0.f, 1.f]) on the # letterboxed image (after image transformation with the FIT scale mode) to the # corresponding locations on the same image with the letterbox removed (the # input image to the graph before image transformation). node { calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:labeled_detections" + input_stream: "DETECTIONS:filtered_detections" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "DETECTIONS:output_detections" } diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index 28c34b2b5..9194e3dde 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -33,6 +33,10 @@ constexpr char kDetections[] = "DETECTIONS"; constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kCropRect[] = "CROP_RECT"; constexpr char kFirstCropRect[] = "FIRST_CROP_RECT"; +// Can be used to control whether an animated zoom should actually performed +// (configured through option us_to_first_rect). If provided, a non-zero integer +// will allow the animated zoom to be used when the first detections arrive. +constexpr char kAnimateZoom[] = "ANIMATE_ZOOM"; // Field-of-view (degrees) of the camera's x-axis (width). // TODO: Parameterize FOV based on camera specs. constexpr float kFieldOfView = 60; @@ -76,10 +80,10 @@ class ContentZoomingCalculator : public CalculatorBase { absl::Status InitializeState(int frame_width, int frame_height); // Adjusts state to work with an updated frame size. absl::Status UpdateForResolutionChange(int frame_width, int frame_height); - // Returns true if we are zooming to the initial rect. - bool IsZoomingToInitialRect(const Timestamp& timestamp) const; - // Builds the output rectangle when zooming to the initial rect. - absl::StatusOr GetInitialZoomingRect( + // Returns true if we are animating to the first rect. + bool IsAnimatingToFirstRect(const Timestamp& timestamp) const; + // Builds the output rectangle when animating to the first rect. + absl::StatusOr GetAnimationRect( int frame_width, int frame_height, const Timestamp& timestamp) const; // Converts bounds to tilt offset, pan offset and height. absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, @@ -97,7 +101,10 @@ class ContentZoomingCalculator : public CalculatorBase { std::unique_ptr path_solver_tilt_; // Are parameters initialized. bool initialized_; - // Stores the time of the first crop rectangle. + // Stores the time of the first crop rectangle. This is used to control + // animating to it. Until a first crop rectangle was computed, it has + // the value Timestamp::Unset(). If animating is not requested, it receives + // the value Timestamp::Done() instead of the time. Timestamp first_rect_timestamp_; // Stores the first crop rectangle. mediapipe::NormalizedRect first_rect_; @@ -135,6 +142,9 @@ absl::Status ContentZoomingCalculator::GetContract( if (cc->Inputs().HasTag(kDetections)) { cc->Inputs().Tag(kDetections).Set>(); } + if (cc->Inputs().HasTag(kAnimateZoom)) { + cc->Inputs().Tag(kAnimateZoom).Set(); + } if (cc->Outputs().HasTag(kDetectedBorders)) { cc->Outputs().Tag(kDetectedBorders).Set(); } @@ -419,10 +429,11 @@ absl::Status ContentZoomingCalculator::UpdateForResolutionChange( return absl::OkStatus(); } -bool ContentZoomingCalculator::IsZoomingToInitialRect( +bool ContentZoomingCalculator::IsAnimatingToFirstRect( const Timestamp& timestamp) const { if (options_.us_to_first_rect() == 0 || - first_rect_timestamp_ == Timestamp::Unset()) { + first_rect_timestamp_ == Timestamp::Unset() || + first_rect_timestamp_ == Timestamp::Done()) { return false; } @@ -443,10 +454,10 @@ double easeInOutQuad(double t) { double lerp(double a, double b, double i) { return a * (1 - i) + b * i; } } // namespace -absl::StatusOr ContentZoomingCalculator::GetInitialZoomingRect( +absl::StatusOr ContentZoomingCalculator::GetAnimationRect( int frame_width, int frame_height, const Timestamp& timestamp) const { - RET_CHECK(IsZoomingToInitialRect(timestamp)) - << "Must only be called if zooming to initial rect."; + RET_CHECK(IsAnimatingToFirstRect(timestamp)) + << "Must only be called if animating to first rect."; const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); const int64 delay = options_.us_to_first_rect_delay(); @@ -538,15 +549,20 @@ absl::Status ContentZoomingCalculator::Process( } } - bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); + const bool may_start_animation = (options_.us_to_first_rect() != 0) && + (!cc->Inputs().HasTag(kAnimateZoom) || + cc->Inputs().Tag(kAnimateZoom).Get()); + bool is_animating = IsAnimatingToFirstRect(cc->InputTimestamp()); int offset_y, height, offset_x; - if (zooming_to_initial_rect) { - // If we are zooming to the first rect, ignore any new incoming detections. - height = last_measured_height_; - offset_x = last_measured_x_offset_; - offset_y = last_measured_y_offset_; - } else if (only_required_found) { + if (!is_animating && options_.start_zoomed_out() && !may_start_animation && + first_rect_timestamp_ == Timestamp::Unset()) { + // If we should start zoomed out and won't be doing an animation, + // initialize the path solvers using the full frame, ignoring detections. + height = max_frame_value_ * frame_height_; + offset_x = (target_aspect_ * height) / 2; + offset_y = frame_height_ / 2; + } else if (!is_animating && only_required_found) { // Convert bounds to tilt/zoom and in pixel coordinates. MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, &offset_x, &height)); @@ -555,9 +571,9 @@ absl::Status ContentZoomingCalculator::Process( last_measured_height_ = height; last_measured_x_offset_ = offset_x; last_measured_y_offset_ = offset_y; - } else if (cc->InputTimestamp().Microseconds() - - last_only_required_detection_ >= - options_.us_before_zoomout()) { + } else if (!is_animating && cc->InputTimestamp().Microseconds() - + last_only_required_detection_ >= + options_.us_before_zoomout()) { // No only_require detections found within salient regions packets // arriving since us_before_zoomout duration. height = max_frame_value_ * frame_height_ + @@ -566,7 +582,8 @@ absl::Status ContentZoomingCalculator::Process( offset_x = (target_aspect_ * height) / 2; offset_y = frame_height_ / 2; } else { - // No only detection found but using last detection due to + // Either animating to the first rectangle, or + // no only detection found but using last detection due to // duration_before_zoomout_us setting. height = last_measured_height_; offset_x = last_measured_x_offset_; @@ -642,24 +659,28 @@ absl::Status ContentZoomingCalculator::Process( .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); } - if (first_rect_timestamp_ == Timestamp::Unset() && - options_.us_to_first_rect() != 0) { - first_rect_timestamp_ = cc->InputTimestamp(); + // Record the first crop rectangle + if (first_rect_timestamp_ == Timestamp::Unset()) { first_rect_.set_x_center(path_offset_x / static_cast(frame_width_)); first_rect_.set_width(path_height * target_aspect_ / static_cast(frame_width_)); first_rect_.set_y_center(path_offset_y / static_cast(frame_height_)); first_rect_.set_height(path_height / static_cast(frame_height_)); - // After setting the first rectangle, check whether we should zoom to it. - zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); + + // Record the time to serve as departure point for the animation. + // If we are not allowed to start the animation, set Timestamp::Done. + first_rect_timestamp_ = + may_start_animation ? cc->InputTimestamp() : Timestamp::Done(); + // After setting the first rectangle, check whether we should animate to it. + is_animating = IsAnimatingToFirstRect(cc->InputTimestamp()); } // Transmit downstream to glcroppingcalculator. if (cc->Outputs().HasTag(kCropRect)) { std::unique_ptr gpu_rect; - if (zooming_to_initial_rect) { - auto rect = GetInitialZoomingRect(frame_width, frame_height, - cc->InputTimestamp()); + if (is_animating) { + auto rect = + GetAnimationRect(frame_width, frame_height, cc->InputTimestamp()); MP_RETURN_IF_ERROR(rect.status()); gpu_rect = absl::make_unique(*rect); } else { diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 4564b88be..6516ed21f 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 17 +// NextTag: 18 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -58,9 +58,15 @@ message ContentZoomingCalculatorOptions { // Whether to keep state between frames or to compute the final crop rect. optional bool is_stateless = 14 [default = false]; - // Duration (in MicroSeconds) for moving to the first crop rect. + // If true, on the first packet start with the camera zoomed out and then zoom + // in on the subject. If false, the camera will start zoomed in on the + // subject. + optional bool start_zoomed_out = 17 [default = false]; + + // Duration (in MicroSeconds) for animating to the first crop rect. + // Note that if set, takes precedence over start_zoomed_out. optional int64 us_to_first_rect = 15 [default = 0]; - // Duration (in MicroSeconds) to delay moving to the first crop rect. + // Duration (in MicroSeconds) to delay animating to the first crop rect. // Used only if us_to_first_rect is set and is interpreted as part of the // us_to_first_rect time budget. optional int64 us_to_first_rect_delay = 16 [default = 0]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 6859da11f..7be2c86e6 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -127,6 +127,29 @@ const char kConfigD[] = R"( } )"; +const char kConfigE[] = R"( + calculator: "ContentZoomingCalculator" + input_stream: "VIDEO_SIZE:size" + input_stream: "DETECTIONS:detections" + input_stream: "ANIMATE_ZOOM:animate_zoom" + output_stream: "CROP_RECT:rect" + output_stream: "FIRST_CROP_RECT:first_rect" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } + )"; + void CheckBorder(const StaticFeatures& static_features, int width, int height, int top_border, int bottom_border) { ASSERT_EQ(2, static_features.border().size()); @@ -145,9 +168,14 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height, EXPECT_EQ(Border::BOTTOM, part.relative_position()); } +struct AddDetectionFlags { + std::optional animated_zoom; +}; + void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, const int width, const int height, - CalculatorRunner* runner) { + CalculatorRunner* runner, + const AddDetectionFlags& flags = {}) { auto detections = std::make_unique>(); if (position.width > 0 && position.height > 0) { mediapipe::Detection detection; @@ -175,6 +203,14 @@ void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, runner->MutableInputs() ->Tag("VIDEO_SIZE") .packets.push_back(Adopt(input_size.release()).At(Timestamp(time))); + + if (flags.animated_zoom.has_value()) { + runner->MutableInputs() + ->Tag("ANIMATE_ZOOM") + .packets.push_back( + mediapipe::MakePacket(flags.animated_zoom.value()) + .At(Timestamp(time))); + } } void AddDetection(const cv::Rect_& position, const int64 time, @@ -703,7 +739,33 @@ TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) { CheckCropRect(500, 500, 1000, 1000, 2, runner->Outputs().Tag("CROP_RECT").packets); } + TEST(ContentZoomingCalculatorTest, StartZoomedOut) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_start_zoomed_out(true); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 400000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 800000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 880, 880, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 760, 760, 2, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 655, 655, 3, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, AnimateToFirstRect) { auto config = ParseTextProtoOrDie(kConfigD); auto* options = config.mutable_options()->MutableExtension( ContentZoomingCalculatorOptions::ext); @@ -733,6 +795,65 @@ TEST(ContentZoomingCalculatorTest, StartZoomedOut) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, CanControlAnimation) { + auto config = ParseTextProtoOrDie(kConfigE); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_start_zoomed_out(true); + options->set_us_to_first_rect(1000000); + options->set_us_to_first_rect_delay(500000); + auto runner = ::absl::make_unique(config); + // Request the animation for the first frame. + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get(), {.animated_zoom = true}); + // We now stop requesting animated zoom and expect the already started + // animation run to completion. This tests that the zoom in continues in the + // call when it was started in the Meet greenroom. + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 400000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 800000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1500000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 1000, 1000, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 470, 470, 2, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 222, 222, 3, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 222, 222, 4, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, DoesNotAnimateIfDisabledViaInput) { + auto config = ParseTextProtoOrDie(kConfigE); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_start_zoomed_out(true); + options->set_us_to_first_rect(1000000); + options->set_us_to_first_rect_delay(500000); + auto runner = ::absl::make_unique(config); + // Disable the animation already for the first frame. + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get(), {.animated_zoom = false}); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 400000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 800000, 1000, 1000, + runner.get(), {.animated_zoom = false}); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 880, 880, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 760, 760, 2, + runner->Outputs().Tag("CROP_RECT").packets); +} + TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) { auto config = ParseTextProtoOrDie(kConfigD); auto runner = ::absl::make_unique(config); diff --git a/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto index b00755d36..dd43de9da 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto @@ -47,4 +47,10 @@ message FaceBoxAdjusterCalculatorOptions { // and height respectively. optional float ipd_face_box_width_ratio = 6 [default = 0.5566]; optional float ipd_face_box_height_ratio = 7 [default = 0.3131]; + + // The max look up angle before considering the eye distance unstable. + optional float max_head_tilt_angle_deg = 8 [default = 12.0]; + // The max amount of time to use an old eye distance when the face look angle + // is unstable. + optional int32 max_facesize_history_us = 9 [default = 8000000]; } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 27867d31b..43f60470e 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -345,8 +345,7 @@ TEST(SceneCroppingCalculatorTest, ChecksPriorFrameBufferSize) { TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) { const CalculatorGraphConfig::Node config = ParseTextProtoOrDie(absl::Substitute( - kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight, - kTargetSizeType, 0, kPriorFrameBufferSize)); + kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight)); auto runner = absl::make_unique(config); const auto status = runner->Run(); EXPECT_FALSE(status.ok()); diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc index 899724921..c590f5a69 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc @@ -220,7 +220,7 @@ absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) { absl::Status KinematicPathSolver::UpdatePixelsPerDegree( const float pixels_per_degree) { - RET_CHECK_GT(pixels_per_degree_, 0) + RET_CHECK_GT(pixels_per_degree, 0) << "pixels_per_degree must be larger than 0."; pixels_per_degree_ = pixels_per_degree; return absl::OkStatus(); diff --git a/mediapipe/examples/desktop/autoflip/subgraph/face_detection_subgraph.pbtxt b/mediapipe/examples/desktop/autoflip/subgraph/face_detection_subgraph.pbtxt index 2dfb0c532..2a40f1d06 100644 --- a/mediapipe/examples/desktop/autoflip/subgraph/face_detection_subgraph.pbtxt +++ b/mediapipe/examples/desktop/autoflip/subgraph/face_detection_subgraph.pbtxt @@ -38,7 +38,7 @@ node { output_stream: "TENSORS:detection_tensors" options: { [mediapipe.TfLiteInferenceCalculatorOptions.ext] { - model_path: "mediapipe/models/face_detection_back.tflite" + model_path: "mediapipe/modules/face_detection/face_detection_back.tflite" } } } @@ -111,26 +111,13 @@ node { } } -# Maps detection label IDs to the corresponding label text ("Face"). The label -# map is provided in the label_map_path option. -node { - calculator: "DetectionLabelIdToTextCalculator" - input_stream: "filtered_detections" - output_stream: "labeled_detections" - options: { - [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "mediapipe/models/face_detection_back_labelmap.txt" - } - } -} - # Adjusts detection locations (already normalized to [0.f, 1.f]) on the # letterboxed image (after image transformation with the FIT scale mode) to the # corresponding locations on the same image with the letterbox removed (the # input image to the graph before image transformation). node { calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:labeled_detections" + input_stream: "DETECTIONS:filtered_detections" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "DETECTIONS:output_detections" } diff --git a/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD b/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD deleted file mode 100644 index 6240864a3..000000000 --- a/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2020 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -licenses(["notice"]) - -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) - -cc_binary( - name = "upper_body_pose_tracking_cpu", - deps = [ - "//mediapipe/examples/desktop:demo_run_graph_main", - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_cpu_deps", - ], -) - -# Linux only -cc_binary( - name = "upper_body_pose_tracking_gpu", - deps = [ - "//mediapipe/examples/desktop:demo_run_graph_main_gpu", - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps", - ], -) diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD index b8d6c00ab..b080564ce 100644 --- a/mediapipe/examples/ios/holistictrackinggpu/BUILD +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -61,8 +61,7 @@ objc_library( "//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index c78c6a674..01a82cb4b 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -63,7 +63,7 @@ objc_library( data = [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", "//mediapipe/modules/pose_detection:pose_detection.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD b/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD deleted file mode 100644 index 3455fbbf8..000000000 --- a/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load( - "@build_bazel_rules_apple//apple:ios.bzl", - "ios_application", -) -load( - "//mediapipe/examples/ios:bundle_id.bzl", - "BUNDLE_ID_PREFIX", - "example_provisioning", -) - -licenses(["notice"]) - -MIN_IOS_VERSION = "10.0" - -alias( - name = "upperbodyposetrackinggpu", - actual = "UpperBodyPoseTrackingGpuApp", -) - -ios_application( - name = "UpperBodyPoseTrackingGpuApp", - app_icons = ["//mediapipe/examples/ios/common:AppIcon"], - bundle_id = BUNDLE_ID_PREFIX + ".UpperBodyPoseTrackingGpu", - families = [ - "iphone", - "ipad", - ], - infoplists = [ - "//mediapipe/examples/ios/common:Info.plist", - "Info.plist", - ], - minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = example_provisioning(), - deps = [ - ":UpperBodyPoseTrackingGpuAppLibrary", - "@ios_opencv//:OpencvFramework", - ], -) - -objc_library( - name = "UpperBodyPoseTrackingGpuAppLibrary", - srcs = [ - "UpperBodyPoseTrackingViewController.mm", - ], - hdrs = [ - "UpperBodyPoseTrackingViewController.h", - ], - copts = ["-std=c++17"], - data = [ - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb", - "//mediapipe/modules/pose_detection:pose_detection.tflite", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", - ], - deps = [ - "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", - ] + select({ - "//mediapipe:ios_i386": [], - "//mediapipe:ios_x86_64": [], - "//conditions:default": [ - "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", - ], - }), -) diff --git a/mediapipe/examples/ios/upperbodyposetrackinggpu/Info.plist b/mediapipe/examples/ios/upperbodyposetrackinggpu/Info.plist deleted file mode 100644 index ec4b768d8..000000000 --- a/mediapipe/examples/ios/upperbodyposetrackinggpu/Info.plist +++ /dev/null @@ -1,16 +0,0 @@ - - - - - CameraPosition - back - MainViewController - UpperBodyPoseTrackingViewController - GraphOutputStream - output_video - GraphInputStream - input_video - GraphName - upper_body_pose_tracking_gpu - - diff --git a/mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.mm b/mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.mm deleted file mode 100644 index 00a14bfb7..000000000 --- a/mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.mm +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "UpperBodyPoseTrackingViewController.h" - -#include "mediapipe/framework/formats/landmark.pb.h" - -static const char* kLandmarksOutputStream = "pose_landmarks"; - -@implementation UpperBodyPoseTrackingViewController - -#pragma mark - UIViewController methods - -- (void)viewDidLoad { - [super viewDidLoad]; - - [self.mediapipeGraph addFrameOutputStream:kLandmarksOutputStream - outputPacketType:MPPPacketTypeRaw]; -} - -#pragma mark - MPPGraphDelegate methods - -// Receives a raw packet from the MediaPipe graph. Invoked on a MediaPipe worker thread. -- (void)mediapipeGraph:(MPPGraph*)graph - didOutputPacket:(const ::mediapipe::Packet&)packet - fromStream:(const std::string&)streamName { - if (streamName == kLandmarksOutputStream) { - if (packet.IsEmpty()) { - NSLog(@"[TS:%lld] No pose landmarks", packet.Timestamp().Value()); - return; - } - const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>(); - NSLog(@"[TS:%lld] Number of pose landmarks: %d", packet.Timestamp().Value(), - landmarks.landmark_size()); - for (int i = 0; i < landmarks.landmark_size(); ++i) { - NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(), - landmarks.landmark(i).y(), landmarks.landmark(i).z()); - } - } -} - -@end diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 2124ca580..747a4eda8 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -15,6 +15,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) @@ -27,10 +28,28 @@ package_group( ], ) -exports_files([ - "transitive_protos.bzl", - "encode_binary_proto.bzl", -]) +bzl_library( + name = "transitive_protos_bzl", + srcs = [ + "transitive_protos.bzl", + ], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +bzl_library( + name = "encode_binary_proto_bzl", + srcs = [ + "encode_binary_proto.bzl", + ], + visibility = ["//visibility:public"], +) + +alias( + name = "encode_binary_proto", + actual = ":encode_binary_proto_bzl", + deprecation = "Use encode_binary_proto_bzl", + visibility = ["//visibility:public"], +) mediapipe_proto_library( name = "calculator_proto", diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 7c9a45e36..1a4faca13 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -1,15 +1,8 @@ package( - default_visibility = [":preview_users"], + default_visibility = ["//visibility:public"], features = ["-use_header_modules"], ) -package_group( - name = "preview_users", - packages = [ - "//mediapipe/...", - ], -) - licenses(["notice"]) cc_library( diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 3c5ac66e5..b39b95aca 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -422,6 +422,9 @@ message CalculatorGraphConfig { // the graph config. string type = 20; - // Can be used for annotating a graph. + // The types and default values for graph options, in proto2 syntax. MediaPipeOptions options = 1001; + + // The types and default values for graph options, in proto3 syntax. + repeated google.protobuf.Any graph_options = 1002; } diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index ccbde4381..2c341128a 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -411,7 +411,8 @@ absl::Status CalculatorGraph::Initialize( absl::Status CalculatorGraph::ObserveOutputStream( const std::string& stream_name, - std::function packet_callback) { + std::function packet_callback, + bool observe_timestamp_bounds) { RET_CHECK(initialized_).SetNoLogging() << "CalculatorGraph is not initialized."; // TODO Allow output observers to be attached by graph level @@ -425,7 +426,7 @@ absl::Status CalculatorGraph::ObserveOutputStream( auto observer = absl::make_unique(); MP_RETURN_IF_ERROR(observer->Initialize( stream_name, &any_packet_type_, std::move(packet_callback), - &output_stream_managers_[output_stream_index])); + &output_stream_managers_[output_stream_index], observe_timestamp_bounds)); graph_output_streams_.push_back(std::move(observer)); return absl::OkStatus(); } diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 4c9079f0a..77a1ee551 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -157,7 +157,8 @@ class CalculatorGraph { // TODO: Rename to AddOutputStreamCallback. absl::Status ObserveOutputStream( const std::string& stream_name, - std::function packet_callback); + std::function packet_callback, + bool observe_timestamp_bounds = false); // Adds an OutputStreamPoller for a stream. This provides a synchronous, // polling API for accessing a stream's output. Should only be called before diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index 5e765c644..b55f9459d 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -1518,5 +1518,72 @@ TEST(CalculatorGraphBoundsTest, OffsetAndBound) { MP_ASSERT_OK(graph.WaitUntilDone()); } +// A Calculator that sends empty output stream packets. +class EmptyPacketCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return absl::OkStatus(); + } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) final { + if (cc->InputTimestamp().Value() % 2 == 0) { + cc->Outputs().Index(0).AddPacket(Packet().At(cc->InputTimestamp())); + } + return absl::OkStatus(); + } +}; +REGISTER_CALCULATOR(EmptyPacketCalculator); + +// This test shows that an output timestamp bound can be specified by outputing +// an empty packet with a settled timestamp. +TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { + // OffsetAndBoundCalculator runs on parallel threads and sends ts + // occasionally. + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "EmptyPacketCalculator" + input_stream: "input_0" + output_stream: "empty_0" + } + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "empty_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in packets. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + Packet p = MakePacket(i).At(Timestamp(ts)); + MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 empty packets are converted to bounds and then to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 062551342..4278671b0 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -16,11 +16,20 @@ # The dependencies of mediapipe. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +bzl_library( + name = "expand_template_bzl", + srcs = [ + "expand_template.bzl", + ], + visibility = ["//mediapipe/framework:__subpackages__"], +) + proto_library( name = "proto_descriptor_proto", srcs = ["proto_descriptor.proto"], diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 3067eb246..902524e10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -295,6 +295,7 @@ cc_library( "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", + "//mediapipe/framework:type_map", "//mediapipe/framework/port:logging", ] + select({ "//conditions:default": [ diff --git a/mediapipe/framework/formats/image.cc b/mediapipe/framework/formats/image.cc index b8944593c..0591c3c6c 100644 --- a/mediapipe/framework/formats/image.cc +++ b/mediapipe/framework/formats/image.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/type_map.h" + namespace mediapipe { // TODO Refactor common code from GpuBufferToImageFrameCalculator @@ -67,8 +69,7 @@ bool Image::ConvertToGpu() const { #else if (use_gpu_) return true; // Already on GPU. #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto packet = MakePacket(std::move(*image_frame_)); - image_frame_ = nullptr; + auto packet = PointToForeign(image_frame_.get()); CFHolder buffer; auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer); CHECK_OK(status); @@ -94,4 +95,7 @@ bool Image::ConvertToGpu() const { #endif // MEDIAPIPE_DISABLE_GPU } +MEDIAPIPE_REGISTER_TYPE(mediapipe::Image, "::mediapipe::Image", nullptr, + nullptr); + } // namespace mediapipe diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index 58184afca..bfddefacf 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -72,8 +72,8 @@ class Image { // Creates an Image representing the same image content as the ImageFrame // the input shared pointer points to, and retaining shared ownership. - explicit Image(ImageFrameSharedPtr frame_buffer) - : image_frame_(std::move(frame_buffer)) { + explicit Image(ImageFrameSharedPtr image_frame) + : image_frame_(std::move(image_frame)) { use_gpu_ = false; pixel_mutex_ = std::make_shared(); } diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index a5964567a..98f85f3c3 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -30,6 +30,9 @@ namespace mediapipe { +// Zero and negative values are not checked here. +bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; } + int BhwcBatchFromShape(const Tensor::Shape& shape) { LOG_IF(FATAL, shape.dims.empty()) << "Tensor::Shape must be non-empty to retrieve a named dimension"; @@ -237,6 +240,12 @@ void Tensor::AllocateOpenGlTexture2d() const { glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_, texture_height_); } else { + // GLES2.0 supports only clamp addressing mode for NPOT textures. + // If any of dimensions is NPOT then both addressing modes are clamp. + if (!IsPowerOfTwo(texture_width_) || !IsPowerOfTwo(texture_height_)) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + } // We assume all contexts will have the same extensions, so we only check // once for OES_texture_float extension, to save time. static bool has_oes_extension = diff --git a/mediapipe/framework/graph_output_stream.cc b/mediapipe/framework/graph_output_stream.cc index 0a2bc4c18..6639bb8bf 100644 --- a/mediapipe/framework/graph_output_stream.cc +++ b/mediapipe/framework/graph_output_stream.cc @@ -14,13 +14,16 @@ #include "mediapipe/framework/graph_output_stream.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port/status.h" + namespace mediapipe { namespace internal { absl::Status GraphOutputStream::Initialize( const std::string& stream_name, const PacketType* packet_type, - OutputStreamManager* output_stream_manager) { + OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) { RET_CHECK(output_stream_manager); // Initializes input_stream_handler_ with one input stream as the observer. @@ -31,6 +34,7 @@ absl::Status GraphOutputStream::Initialize( input_stream_handler_ = absl::make_unique( tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(), /*calculator_run_in_parallel=*/false); + input_stream_handler_->SetProcessTimestampBounds(observe_timestamp_bounds); const CollectionItemId& id = tag_map->BeginId(); input_stream_ = absl::make_unique(); MP_RETURN_IF_ERROR( @@ -52,20 +56,58 @@ void GraphOutputStream::PrepareForRun( absl::Status OutputStreamObserver::Initialize( const std::string& stream_name, const PacketType* packet_type, std::function packet_callback, - OutputStreamManager* output_stream_manager) { + OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) { RET_CHECK(output_stream_manager); packet_callback_ = std::move(packet_callback); + observe_timestamp_bounds_ = observe_timestamp_bounds; return GraphOutputStream::Initialize(stream_name, packet_type, - output_stream_manager); + output_stream_manager, + observe_timestamp_bounds); } absl::Status OutputStreamObserver::Notify() { + // Lets one thread perform packets notification as much as possible. + // Other threads should quit if a thread is already performing notification. + { + absl::MutexLock l(&mutex_); + + if (notifying_ == false) { + notifying_ = true; + } else { + return absl::OkStatus(); + } + } while (true) { bool empty; Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty); if (empty) { - break; + // Emits an empty packet at timestamp_bound.PreviousAllowedInStream(). + if (observe_timestamp_bounds_ && min_timestamp < Timestamp::Done()) { + Timestamp settled = (min_timestamp == Timestamp::PostStream() + ? Timestamp::PostStream() + : min_timestamp.PreviousAllowedInStream()); + if (last_processed_ts_ < settled) { + MP_RETURN_IF_ERROR(packet_callback_(Packet().At(settled))); + last_processed_ts_ = settled; + } + } + // Last check to make sure that the min timestamp or bound doesn't change. + // If so, flips notifying_ to false to allow any other threads to perform + // notification when new packets/timestamp bounds arrive. Otherwise, in + // case of the min timestamp or bound getting updated, jumps to the + // beginning of the notification loop for a new iteration. + { + absl::MutexLock l(&mutex_); + Timestamp new_min_timestamp = + input_stream_->MinTimestampOrBound(&empty); + if (new_min_timestamp == min_timestamp) { + notifying_ = false; + break; + } else { + continue; + } + } } int num_packets_dropped = 0; bool stream_is_done = false; @@ -75,6 +117,7 @@ absl::Status OutputStreamObserver::Notify() { << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", num_packets_dropped, input_stream_->Name()); MP_RETURN_IF_ERROR(packet_callback_(packet)); + last_processed_ts_ = min_timestamp; } return absl::OkStatus(); } diff --git a/mediapipe/framework/graph_output_stream.h b/mediapipe/framework/graph_output_stream.h index 9a60fa4bd..393407aa3 100644 --- a/mediapipe/framework/graph_output_stream.h +++ b/mediapipe/framework/graph_output_stream.h @@ -52,7 +52,8 @@ class GraphOutputStream { // is not transferred to the graph output stream object. absl::Status Initialize(const std::string& stream_name, const PacketType* packet_type, - OutputStreamManager* output_stream_manager); + OutputStreamManager* output_stream_manager, + bool observe_timestamp_bounds = false); // Installs callbacks into its GraphOutputStreamHandler. virtual void PrepareForRun(std::function notification_callback, @@ -99,6 +100,10 @@ class GraphOutputStream { } }; + bool observe_timestamp_bounds_; + absl::Mutex mutex_; + bool notifying_ ABSL_GUARDED_BY(mutex_) = false; + Timestamp last_processed_ts_ = Timestamp::Unstarted(); std::unique_ptr input_stream_handler_; std::unique_ptr input_stream_; }; @@ -112,7 +117,8 @@ class OutputStreamObserver : public GraphOutputStream { absl::Status Initialize( const std::string& stream_name, const PacketType* packet_type, std::function packet_callback, - OutputStreamManager* output_stream_manager); + OutputStreamManager* output_stream_manager, + bool observe_timestamp_bounds = false); // Notifies the observer of new packets emitted by the observed // output stream. @@ -128,6 +134,7 @@ class OutputStreamObserver : public GraphOutputStream { // OutputStreamPollerImpl that returns packets to the caller via // Next()/NextBatch(). +// TODO: Support observe_timestamp_bounds. class OutputStreamPollerImpl : public GraphOutputStream { public: virtual ~OutputStreamPollerImpl() {} diff --git a/mediapipe/framework/mediapipe_options.proto b/mediapipe/framework/mediapipe_options.proto index 5fa824d67..0ded9597a 100644 --- a/mediapipe/framework/mediapipe_options.proto +++ b/mediapipe/framework/mediapipe_options.proto @@ -20,6 +20,9 @@ syntax = "proto2"; package mediapipe; +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "MediaPipeOptionsProto"; + // Options used by a MediaPipe object. message MediaPipeOptions { extensions 20000 to max; diff --git a/mediapipe/framework/output_stream_shard.cc b/mediapipe/framework/output_stream_shard.cc index 704a18d8c..1b096efbb 100644 --- a/mediapipe/framework/output_stream_shard.cc +++ b/mediapipe/framework/output_stream_shard.cc @@ -101,8 +101,8 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) { } if (packet.IsEmpty()) { - return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "Empty packet sent to stream \"" << Name() << "\"."; + SetNextTimestampBound(packet.Timestamp().NextAllowedInStream()); + return absl::OkStatus(); } const Timestamp timestamp = packet.Timestamp(); diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index fdf35b591..7beabd152 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -20,6 +20,7 @@ load( "mediapipe_binary_graph", ) load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) @@ -29,6 +30,30 @@ exports_files([ "simple_subgraph_template.cc", ]) +bzl_library( + name = "mediapipe_graph_bzl", + srcs = [ + "mediapipe_graph.bzl", + ], + visibility = ["//visibility:public"], + deps = [ + ":build_defs_bzl", + "//mediapipe/framework:encode_binary_proto", + "//mediapipe/framework:transitive_protos_bzl", + "//mediapipe/framework/deps:expand_template_bzl", + ], +) + +bzl_library( + name = "build_defs_bzl", + srcs = [ + "build_defs.bzl", + ], + visibility = [ + "//mediapipe/framework:__subpackages__", + ], +) + cc_library( name = "text_to_binary_graph", srcs = ["text_to_binary_graph.cc"], @@ -744,5 +769,7 @@ cc_test( exports_files( ["build_defs.bzl"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [ + "//mediapipe/framework:__subpackages__", + ], ) diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index 2ac17f8e1..4a181b43f 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/tool/sink.h" #include +#include #include #include "absl/strings/str_cat.h" @@ -168,8 +169,19 @@ void AddMultiStreamCallback( std::function&)> callback, CalculatorGraphConfig* config, std::pair* side_packet) { + std::map side_packets; + AddMultiStreamCallback(streams, callback, config, &side_packets, + /*observe_timestamp_bounds=*/false); + *side_packet = *side_packets.begin(); +} + +void AddMultiStreamCallback( + const std::vector& streams, + std::function&)> callback, + CalculatorGraphConfig* config, std::map* side_packets, + bool observe_timestamp_bounds) { CHECK(config); - CHECK(side_packet); + CHECK(side_packets); CalculatorGraphConfig::Node* sink_node = config->add_node(); const std::string name = GetUnusedNodeName( *config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_"))); @@ -179,15 +191,23 @@ void AddMultiStreamCallback( sink_node->add_input_stream(stream_name); } + if (observe_timestamp_bounds) { + const std::string observe_ts_bounds_packet_name = GetUnusedSidePacketName( + *config, absl::StrCat(name, "_observe_ts_bounds")); + sink_node->add_input_side_packet(absl::StrCat( + "OBSERVE_TIMESTAMP_BOUNDS:", observe_ts_bounds_packet_name)); + InsertIfNotPresent(side_packets, observe_ts_bounds_packet_name, + MakePacket(true)); + } const std::string input_side_packet_name = GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback")); - side_packet->first = input_side_packet_name; sink_node->add_input_side_packet( absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name)); - side_packet->second = + InsertIfNotPresent( + side_packets, input_side_packet_name, MakePacket&)>>( - std::move(callback)); + std::move(callback))); } void AddCallbackWithHeaderCalculator(const std::string& stream_name, @@ -240,6 +260,10 @@ absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "InputSidePackets must use tags."; } + if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS")) { + cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Set(); + cc->SetProcessTimestampBounds(true); + } int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1; for (int i = 0; i < count; ++i) { @@ -266,6 +290,12 @@ absl::Status CallbackCalculator::Open(CalculatorContext* cc) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "missing callback."; } + if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS") && + !cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Get()) { + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "The value of the OBSERVE_TIMESTAMP_BOUNDS input side packet " + "must be set to true"; + } return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/sink.h b/mediapipe/framework/tool/sink.h index 8f09269fc..d659115ee 100644 --- a/mediapipe/framework/tool/sink.h +++ b/mediapipe/framework/tool/sink.h @@ -115,6 +115,12 @@ void AddMultiStreamCallback( std::function&)> callback, CalculatorGraphConfig* config, std::pair* side_packet); +void AddMultiStreamCallback( + const std::vector& streams, + std::function&)> callback, + CalculatorGraphConfig* config, std::map* side_packets, + bool observe_timestamp_bounds = false); + // Add a CallbackWithHeaderCalculator to intercept packets sent on // stream stream_name, and the header packet on stream stream_header. // The input side packet with the produced name callback_side_packet_name diff --git a/mediapipe/framework/tool/sink_test.cc b/mediapipe/framework/tool/sink_test.cc index 7a3236f49..2b5f94f9f 100644 --- a/mediapipe/framework/tool/sink_test.cc +++ b/mediapipe/framework/tool/sink_test.cc @@ -146,5 +146,63 @@ TEST(CallbackTest, TestAddMultiStreamCallback) { EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9)); } +class TimestampBoundTestCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + cc->Outputs().Index(1).Set(); + return absl::OkStatus(); + } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) final { + if (count_ % 5 == 0) { + cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_ + 1)); + cc->Outputs().Index(1).SetNextTimestampBound(Timestamp(count_ + 1)); + } + ++count_; + if (count_ == 13) { + return tool::StatusStop(); + } + return absl::OkStatus(); + } + + private: + int count_ = 1; +}; +REGISTER_CALCULATOR(TimestampBoundTestCalculator); + +TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) { + std::string config_str = R"( + node { + calculator: "TimestampBoundTestCalculator" + output_stream: "foo" + output_stream: "bar" + } + )"; + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(config_str); + + std::vector sums; + + std::map side_packets; + tool::AddMultiStreamCallback( + {"foo", "bar"}, + [&sums](const std::vector& packets) { + Packet foo_p = packets[0]; + Packet bar_p = packets[1]; + ASSERT_TRUE(foo_p.IsEmpty() && bar_p.IsEmpty()); + int foo = foo_p.Timestamp().Value(); + int bar = bar_p.Timestamp().Value(); + sums.push_back(foo + bar); + }, + &graph_config, &side_packets, true); + + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun(side_packets)); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_THAT(sums, testing::ElementsAre(10, 20)); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index e09f85407..1da8115a8 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -14,7 +14,7 @@ load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") licenses(["notice"]) @@ -240,6 +240,12 @@ cc_library( ], ) +mediapipe_proto_library( + name = "gpu_origin_proto", + srcs = ["gpu_origin.proto"], + visibility = ["//visibility:public"], +) + objc_library( name = "pixel_buffer_pool_util", srcs = ["pixel_buffer_pool_util.mm"], @@ -460,6 +466,8 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", + "//mediapipe/util:resource_cache", + "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ] + select({ @@ -760,8 +768,10 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 05a18bf4c..0c6865a86 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -563,8 +563,14 @@ class GlFenceSyncPoint : public GlSyncPoint { void WaitOnGpu() override { if (!sync_) return; - // TODO: do not wait if we are already on the same context? + // TODO: do not wait if we are already on the same context? + // WebGL2 specifies a waitSync call, but since cross-context + // synchronization is not supported, it's actually a no-op. Firefox prints + // a warning when it's called, so let's just skip the call. See + // b/184637485 for details. +#ifndef __EMSCRIPTEN__ glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED); +#endif } bool IsReady() override { diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index ca5ce0cee..31500ed9a 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -21,9 +22,11 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_surface_sink_calculator.pb.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" namespace mediapipe { +namespace api2 { enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; @@ -37,45 +40,52 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. -class GlSurfaceSinkCalculator : public CalculatorBase { +class GlSurfaceSinkCalculator : public Node { public: - GlSurfaceSinkCalculator() : initialized_(false) {} - ~GlSurfaceSinkCalculator() override; + static constexpr Input< + OneOf>::Optional kInVideo{ + "VIDEO"}; + static constexpr Input< + OneOf>::Optional kIn{""}; + static constexpr SideInput> + kSurface{"SURFACE"}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(GlSurfaceSinkCalculator, kInVideo, kIn, kSurface); - absl::Status Open(CalculatorContext* cc) override; - absl::Status Process(CalculatorContext* cc) override; + ~GlSurfaceSinkCalculator(); + + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; private: - GlCalculatorHelper helper_; - EglSurfaceHolder* surface_holder_; - bool initialized_; - std::unique_ptr renderer_; - FrameScaleMode scale_mode_ = FrameScaleMode::kFillAndCrop; + mediapipe::GlCalculatorHelper helper_; + mediapipe::EglSurfaceHolder* surface_holder_; + bool initialized_ = false; + std::unique_ptr renderer_; + mediapipe::FrameScaleMode scale_mode_ = + mediapipe::FrameScaleMode::kFillAndCrop; }; -REGISTER_CALCULATOR(GlSurfaceSinkCalculator); +MEDIAPIPE_REGISTER_NODE(GlSurfaceSinkCalculator); // static -absl::Status GlSurfaceSinkCalculator::GetContract(CalculatorContract* cc) { - TagOrIndex(&(cc->Inputs()), "VIDEO", 0).Set(); - cc->InputSidePackets() - .Tag("SURFACE") - .Set>(); +absl::Status GlSurfaceSinkCalculator::UpdateContract(CalculatorContract* cc) { + RET_CHECK(kInVideo(cc).IsConnected() ^ kIn(cc).IsConnected()) + << "Only one of VIDEO or index 0 input is expected."; + // Currently we pass GL context information and other stuff as external // inputs, which are handled by the helper. - return GlCalculatorHelper::UpdateContract(cc); + return mediapipe::GlCalculatorHelper::UpdateContract(cc); } absl::Status GlSurfaceSinkCalculator::Open(CalculatorContext* cc) { - surface_holder_ = cc->InputSidePackets() - .Tag("SURFACE") - .Get>() - .get(); + surface_holder_ = kSurface(cc).Get().get(); scale_mode_ = FrameScaleModeFromProto( - cc->Options().frame_scale_mode(), - FrameScaleMode::kFillAndCrop); + cc->Options() + .frame_scale_mode(), + mediapipe::FrameScaleMode::kFillAndCrop); // Let the helper access the GL context information. return helper_.Open(cc); @@ -90,9 +100,20 @@ absl::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } - const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get(); + mediapipe::Packet packet; + if (kInVideo(cc).IsConnected()) + packet = kInVideo(cc).packet(); + else + packet = kIn(cc).packet(); + + mediapipe::GpuBuffer input; + if (packet.ValidateAsType().ok()) + input = packet.Get(); + if (packet.ValidateAsType().ok()) + input = packet.Get().GetGpuBuffer(); + if (!initialized_) { - renderer_ = absl::make_unique(); + renderer_ = absl::make_unique(); MP_RETURN_IF_ERROR(renderer_->GlSetup()); initialized_ = true; } @@ -125,7 +146,7 @@ absl::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { MP_RETURN_IF_ERROR( renderer_->GlRender(src.width(), src.height(), dst_width, dst_height, - scale_mode_, FrameRotation::kNone, + scale_mode_, mediapipe::FrameRotation::kNone, /*flip_horizontal=*/false, /*flip_vertical=*/false, /*flip_texture=*/surface_holder_->flip_y)); @@ -145,7 +166,7 @@ absl::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { GlSurfaceSinkCalculator::~GlSurfaceSinkCalculator() { if (renderer_) { // TODO: use move capture when we have C++14 or better. - QuadRenderer* renderer = renderer_.release(); + mediapipe::QuadRenderer* renderer = renderer_.release(); helper_.RunInGlContext([renderer] { renderer->GlTeardown(); delete renderer; @@ -153,4 +174,5 @@ GlSurfaceSinkCalculator::~GlSurfaceSinkCalculator() { } } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index f6ee59d73..716a3b779 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -157,125 +157,19 @@ GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -void GpuBufferMultiPool::EntryList::Prepend(Entry* entry) { - if (head_ == nullptr) { - head_ = tail_ = entry; - } else { - entry->next = head_; - head_->prev = entry; - head_ = entry; - } - ++size_; -} - -void GpuBufferMultiPool::EntryList::Append(Entry* entry) { - if (tail_ == nullptr) { - head_ = tail_ = entry; - } else { - tail_->next = entry; - entry->prev = tail_; - tail_ = entry; - } - ++size_; -} - -void GpuBufferMultiPool::EntryList::Remove(Entry* entry) { - if (entry == head_) { - head_ = entry->next; - } else { - entry->prev->next = entry->next; - } - if (entry == tail_) { - tail_ = entry->prev; - } else { - entry->next->prev = entry->prev; - } - entry->prev = nullptr; - entry->next = nullptr; - --size_; -} - -void GpuBufferMultiPool::EntryList::InsertAfter(Entry* entry, Entry* after) { - if (after != nullptr) { - entry->next = after->next; - if (entry->next) entry->next->prev = entry; - entry->prev = after; - after->next = entry; - ++size_; - } else - Prepend(entry); -} - -void GpuBufferMultiPool::Evict(std::vector* evicted) { - // Remove excess entries. - while (entry_list_.size() > kMaxPoolCount) { - Entry* victim = entry_list_.tail(); - evicted->emplace_back(std::move(victim->pool)); - entry_list_.Remove(victim); - pools_.erase(victim->spec); - } - // Every kRequestCountScrubInterval requests, halve the request counts, and - // remove entries which have fallen to 0. - // This keeps sporadic requests from accumulating and eventually exceeding - // the minimum request threshold for allocating a pool. Also, it means that - // if the request regimen changes (e.g. a graph was always requesting a large - // size, but then switches to a small size to save memory or CPU), the pool - // can quickly adapt to it. - if (total_request_count_ >= kRequestCountScrubInterval) { - total_request_count_ = 0; - VLOG(2) << "begin pool scrub"; - for (Entry* entry = entry_list_.head(); entry != nullptr;) { - VLOG(2) << "entry for: " << entry->spec.width << "x" << entry->spec.height - << " request_count: " << entry->request_count - << " has pool: " << (entry->pool != nullptr); - entry->request_count /= 2; - Entry* next = entry->next; - if (entry->request_count == 0) { - evicted->emplace_back(std::move(entry->pool)); - entry_list_.Remove(entry); - pools_.erase(entry->spec); - } - entry = next; - } - } -} - GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( - const BufferSpec& key) { + const BufferSpec& spec) { SimplePool pool; std::vector evicted; { absl::MutexLock lock(&mutex_); - auto pool_it = pools_.find(key); - Entry* entry; - if (pool_it == pools_.end()) { - std::tie(pool_it, std::ignore) = - pools_.emplace(std::piecewise_construct, std::forward_as_tuple(key), - std::forward_as_tuple(key)); - entry = &pool_it->second; - CHECK_EQ(entry->request_count, 0); - entry->request_count = 1; - entry_list_.Append(entry); - if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); - } else { - entry = &pool_it->second; - ++entry->request_count; - Entry* larger = entry->prev; - while (larger != nullptr && - larger->request_count < entry->request_count) { - larger = larger->prev; - } - if (larger != entry->prev) { - entry_list_.Remove(entry); - entry_list_.InsertAfter(entry, larger); - } - } - if (!entry->pool && entry->request_count >= kMinRequestsBeforePool) { - entry->pool = MakeSimplePool(key); - } - pool = entry->pool; - ++total_request_count_; - Evict(&evicted); + pool = + cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { + return (request_count >= kMinRequestsBeforePool) + ? MakeSimplePool(spec) + : nullptr; + }); + evicted = cache_.Evict(kMaxPoolCount, kRequestCountScrubInterval); } // Evicted pools, and their buffers, will be released without holding the // lock. diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 0a34a3017..5ea6e314f 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -22,12 +22,10 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ -#include -#include -#include - +#include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/util/resource_cache.h" #ifdef __APPLE__ #include "mediapipe/gpu/pixel_buffer_pool_util.h" @@ -65,31 +63,24 @@ class GpuBufferMultiPool { void FlushTextureCaches(); #endif // defined(__APPLE__) - // This generates a "rol" instruction with both Clang and GCC. - inline static std::size_t RotateLeft(std::size_t x, int n) { - return (x << n) | (x >> (std::numeric_limits::digits - n)); - } - + // This class is not intended as part of the public api of this class. It is + // public only because it is used as a map key type, and the map + // implementation needs access to, e.g., the equality operator. struct BufferSpec { BufferSpec(int w, int h, mediapipe::GpuBufferFormat f) : width(w), height(h), format(f) {} + + template + friend H AbslHashValue(H h, const BufferSpec& spec) { + return H::combine(std::move(h), spec.width, spec.height, + static_cast(spec.format)); + } + int width; int height; mediapipe::GpuBufferFormat format; }; - struct BufferSpecHash { - std::size_t operator()(const BufferSpec& spec) const { - // Width and height are expected to be smaller than half the width of - // size_t. We can combine them into a single integer, and then use - // std::hash. - constexpr int kWidth = std::numeric_limits::digits; - return std::hash{}( - spec.width ^ RotateLeft(spec.height, kWidth / 2) ^ - RotateLeft(static_cast(spec.format), kWidth / 4)); - } - }; - private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER using SimplePool = std::shared_ptr; @@ -97,53 +88,18 @@ class GpuBufferMultiPool { using SimplePool = std::shared_ptr; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - struct Entry { - Entry(const BufferSpec& spec) : spec(spec) {} - Entry* prev = nullptr; - Entry* next = nullptr; - BufferSpec spec; - int request_count = 0; - SimplePool pool; - }; - - // Unlike std::list, this is an intrusive list, meaning that the prev and next - // pointers live inside the element. Apart from not requiring an extra - // allocation, this means that once we look up an entry by key in the pools_ - // map we do not need to look it up separately in the list. - // - class EntryList { - public: - void Prepend(Entry* entry); - void Append(Entry* entry); - void Remove(Entry* entry); - void InsertAfter(Entry* entry, Entry* after); - - Entry* head() { return head_; } - Entry* tail() { return tail_; } - size_t size() { return size_; } - - private: - Entry* head_ = nullptr; - Entry* tail_ = nullptr; - size_t size_ = 0; - }; - SimplePool MakeSimplePool(const BufferSpec& spec); // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a // pool, in which case the caller should invoke GetBufferWithoutPool instead // of GetBufferFromSimplePool. - SimplePool RequestPool(const BufferSpec& key); + SimplePool RequestPool(const BufferSpec& spec); GpuBuffer GetBufferFromSimplePool(BufferSpec spec, const SimplePool& pool); GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); - void Evict(std::vector* evicted) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); absl::Mutex mutex_; - std::unordered_map pools_ - ABSL_GUARDED_BY(mutex_); - EntryList entry_list_ ABSL_GUARDED_BY(mutex_); - int total_request_count_ = 0; + mediapipe::ResourceCache> + cache_ ABSL_GUARDED_BY(mutex_); #ifdef __APPLE__ // Texture caches used with this pool. diff --git a/mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.h b/mediapipe/gpu/gpu_origin.proto similarity index 64% rename from mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.h rename to mediapipe/gpu/gpu_origin.proto index e9dea0d1d..f4db83537 100644 --- a/mediapipe/examples/ios/upperbodyposetrackinggpu/UpperBodyPoseTrackingViewController.h +++ b/mediapipe/gpu/gpu_origin.proto @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import +syntax = "proto2"; -#import "mediapipe/examples/ios/common/CommonViewController.h" +package mediapipe; -@interface UpperBodyPoseTrackingViewController : CommonViewController +message GpuOrigin { + enum Mode { + DEFAULT = 0; -@end + // OpenGL: bottom-left origin + // Metal : top-left origin + CONVENTIONAL = 1; + + // OpenGL: top-left origin + // Metal : top-left origin + TOP_LEFT = 2; + } +} diff --git a/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt b/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt index 943bf1767..e3c572e28 100644 --- a/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt +++ b/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt @@ -14,7 +14,7 @@ node: { output_stream: "luma_video" } -# Applies the Sobel filter to luminance images sotred in RGB format. +# Applies the Sobel filter to luminance images stored in RGB format. node: { calculator: "SobelEdgesCalculator" input_stream: "luma_video" diff --git a/mediapipe/graphs/face_detection/face_detection_back_desktop_live.pbtxt b/mediapipe/graphs/face_detection/face_detection_back_desktop_live.pbtxt index daccc2782..a70e4c134 100644 --- a/mediapipe/graphs/face_detection/face_detection_back_desktop_live.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_back_desktop_live.pbtxt @@ -64,7 +64,7 @@ node { output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "mediapipe/models/face_detection_back.tflite" + model_path: "mediapipe/modules/face_detection/face_detection_back.tflite" } } } @@ -137,26 +137,13 @@ node { } } -# Maps detection label IDs to the corresponding label text ("Face"). The label -# map is provided in the label_map_path option. -node { - calculator: "DetectionLabelIdToTextCalculator" - input_stream: "filtered_detections" - output_stream: "labeled_detections" - node_options: { - [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "mediapipe/models/face_detection_back_labelmap.txt" - } - } -} - # Adjusts detection locations (already normalized to [0.f, 1.f]) on the # letterboxed image (after image transformation with the FIT scale mode) to the # corresponding locations on the same image with the letterbox removed (the # input image to the graph before image transformation). node { calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:labeled_detections" + input_stream: "DETECTIONS:filtered_detections" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "DETECTIONS:output_detections" } diff --git a/mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt b/mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt index 669b4b98b..893434190 100644 --- a/mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt @@ -65,7 +65,7 @@ node { output_stream: "TENSORS_GPU:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "mediapipe/models/face_detection_back.tflite" + model_path: "mediapipe/modules/face_detection/face_detection_back.tflite" } } } @@ -138,26 +138,13 @@ node { } } -# Maps detection label IDs to the corresponding label text ("Face"). The label -# map is provided in the label_map_path option. -node { - calculator: "DetectionLabelIdToTextCalculator" - input_stream: "filtered_detections" - output_stream: "labeled_detections" - node_options: { - [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "mediapipe/models/face_detection_back_labelmap.txt" - } - } -} - # Adjusts detection locations (already normalized to [0.f, 1.f]) on the # letterboxed image (after image transformation with the FIT scale mode) to the # corresponding locations on the same image with the letterbox removed (the # input image to the graph before image transformation). node { calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:labeled_detections" + input_stream: "DETECTIONS:filtered_detections" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "DETECTIONS:output_detections" } diff --git a/mediapipe/graphs/holistic_tracking/BUILD b/mediapipe/graphs/holistic_tracking/BUILD index 31dc72179..4d5a69439 100644 --- a/mediapipe/graphs/holistic_tracking/BUILD +++ b/mediapipe/graphs/holistic_tracking/BUILD @@ -36,7 +36,6 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/util:landmarks_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_scale_calculator", - "//mediapipe/framework/tool:switch_container", "//mediapipe/modules/holistic_landmark:hand_wrist_for_pose", ], ) diff --git a/mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt b/mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt index 65957ed61..088bf3e9c 100644 --- a/mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt +++ b/mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt @@ -38,11 +38,11 @@ node { node { calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:0:upper_body_only" + output_side_packet: "PACKET:0:model_complexity" output_side_packet: "PACKET:1:smooth_landmarks" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { - packet { bool_value: false } + packet { int_value: 1 } packet { bool_value: true } } } @@ -51,7 +51,7 @@ node { node { calculator: "HolisticLandmarkCpu" input_stream: "IMAGE:throttled_input_video" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" output_stream: "POSE_LANDMARKS:pose_landmarks" output_stream: "POSE_ROI:pose_roi" @@ -77,7 +77,6 @@ node { input_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" input_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" output_stream: "RENDER_DATA_VECTOR:render_data_vector" } diff --git a/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt b/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt index 13bf28e51..a4e2da01e 100644 --- a/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt +++ b/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt @@ -38,11 +38,11 @@ node { node { calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:0:upper_body_only" + output_side_packet: "PACKET:0:model_complexity" output_side_packet: "PACKET:1:smooth_landmarks" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { - packet { bool_value: false } + packet { int_value: 1 } packet { bool_value: true } } } @@ -51,7 +51,7 @@ node { node { calculator: "HolisticLandmarkGpu" input_stream: "IMAGE:throttled_input_video" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" output_stream: "POSE_LANDMARKS:pose_landmarks" output_stream: "POSE_ROI:pose_roi" @@ -77,7 +77,6 @@ node { input_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" input_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" output_stream: "RENDER_DATA_VECTOR:render_data_vector" } diff --git a/mediapipe/graphs/holistic_tracking/holistic_tracking_to_render_data.pbtxt b/mediapipe/graphs/holistic_tracking/holistic_tracking_to_render_data.pbtxt index 7a326b46f..4b05123b9 100644 --- a/mediapipe/graphs/holistic_tracking/holistic_tracking_to_render_data.pbtxt +++ b/mediapipe/graphs/holistic_tracking/holistic_tracking_to_render_data.pbtxt @@ -15,10 +15,6 @@ input_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" # Face landmarks. (NormalizedLandmarkList) input_stream: "FACE_LANDMARKS:face_landmarks" -# Whether to render the full set of pose landmarks, or only those on the -# upper body. If unspecified, functions as set to false. (bool) -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" - # Render data vector. (std::vector) output_stream: "RENDER_DATA_VECTOR:render_data_vector" @@ -81,28 +77,12 @@ node { # Gets pose landmarks after wrists. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "SplitNormalizedLandmarkListCalculator" input_stream: "landmarks" output_stream: "landmarks_after_wrist" node_options: { - [type.googleapis.com/mediapipe.SwitchContainerOptions] { - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 23 end: 33 } - } - } - } - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 23 end: 25 } - } - } - } + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 23 end: 33 } } } } @@ -156,80 +136,40 @@ node { # Takes left pose landmarks. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "SplitNormalizedLandmarkListCalculator" input_stream: "landmarks_merged" output_stream: "landmarks_left_side" node_options: { - [type.googleapis.com/mediapipe.SwitchContainerOptions] { - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 1 } - ranges: { begin: 2 end: 3 } - ranges: { begin: 4 end: 5 } - ranges: { begin: 6 end: 7 } - ranges: { begin: 8 end: 9 } - ranges: { begin: 10 end: 11 } - ranges: { begin: 12 end: 13 } - ranges: { begin: 14 end: 15 } - combine_outputs: true - } - } - } - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 1 } - ranges: { begin: 2 end: 3 } - ranges: { begin: 4 end: 5 } - ranges: { begin: 6 end: 7 } - combine_outputs: true - } - } - } + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 2 end: 3 } + ranges: { begin: 4 end: 5 } + ranges: { begin: 6 end: 7 } + ranges: { begin: 8 end: 9 } + ranges: { begin: 10 end: 11 } + ranges: { begin: 12 end: 13 } + ranges: { begin: 14 end: 15 } + combine_outputs: true } } } # Takes right pose landmarks. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "SplitNormalizedLandmarkListCalculator" input_stream: "landmarks_merged" output_stream: "landmarks_right_side" node_options: { - [type.googleapis.com/mediapipe.SwitchContainerOptions] { - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 2 } - ranges: { begin: 3 end: 4 } - ranges: { begin: 5 end: 6 } - ranges: { begin: 7 end: 8 } - ranges: { begin: 9 end: 10 } - ranges: { begin: 11 end: 12 } - ranges: { begin: 13 end: 14 } - ranges: { begin: 15 end: 16 } - combine_outputs: true - } - } - } - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 2 } - ranges: { begin: 3 end: 4 } - ranges: { begin: 5 end: 6 } - ranges: { begin: 7 end: 8 } - combine_outputs: true - } - } - } + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 1 end: 2 } + ranges: { begin: 3 end: 4 } + ranges: { begin: 5 end: 6 } + ranges: { begin: 7 end: 8 } + ranges: { begin: 9 end: 10 } + ranges: { begin: 11 end: 12 } + ranges: { begin: 13 end: 14 } + ranges: { begin: 15 end: 16 } + combine_outputs: true } } } @@ -240,92 +180,55 @@ node { # Converts pose connections to white lines. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "LandmarksToRenderDataCalculator" input_stream: "NORM_LANDMARKS:landmarks_merged" input_stream: "RENDER_SCALE:render_scale" output_stream: "RENDER_DATA:landmarks_render_data" node_options: { - [type.googleapis.com/mediapipe.SwitchContainerOptions] { - contained_node: { - calculator: "LandmarksToRenderDataCalculator" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 0 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 4 - landmark_connections: 1 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 5 - landmark_connections: 0 - landmark_connections: 6 - landmark_connections: 1 - landmark_connections: 7 - landmark_connections: 6 - landmark_connections: 7 - landmark_connections: 6 - landmark_connections: 8 - landmark_connections: 7 - landmark_connections: 9 - landmark_connections: 8 - landmark_connections: 10 - landmark_connections: 9 - landmark_connections: 11 - landmark_connections: 10 - landmark_connections: 12 - landmark_connections: 11 - landmark_connections: 13 - landmark_connections: 12 - landmark_connections: 14 - landmark_connections: 13 - landmark_connections: 15 - landmark_connections: 10 - landmark_connections: 14 - landmark_connections: 11 - landmark_connections: 15 + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_connections: 0 + landmark_connections: 1 + landmark_connections: 0 + landmark_connections: 2 + landmark_connections: 2 + landmark_connections: 4 + landmark_connections: 1 + landmark_connections: 3 + landmark_connections: 3 + landmark_connections: 5 + landmark_connections: 0 + landmark_connections: 6 + landmark_connections: 1 + landmark_connections: 7 + landmark_connections: 6 + landmark_connections: 7 + landmark_connections: 6 + landmark_connections: 8 + landmark_connections: 7 + landmark_connections: 9 + landmark_connections: 8 + landmark_connections: 10 + landmark_connections: 9 + landmark_connections: 11 + landmark_connections: 10 + landmark_connections: 12 + landmark_connections: 11 + landmark_connections: 13 + landmark_connections: 12 + landmark_connections: 14 + landmark_connections: 13 + landmark_connections: 15 + landmark_connections: 10 + landmark_connections: 14 + landmark_connections: 11 + landmark_connections: 15 - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.1 - } - } - } - contained_node: { - calculator: "LandmarksToRenderDataCalculator" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 0 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 4 - landmark_connections: 1 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 5 - landmark_connections: 0 - landmark_connections: 6 - landmark_connections: 1 - landmark_connections: 7 - landmark_connections: 6 - landmark_connections: 7 - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } - } + landmark_color { r: 255 g: 255 b: 255 } + connection_color { r: 255 g: 255 b: 255 } + thickness: 3.0 + visualize_landmark_depth: false + utilize_visibility: true + visibility_threshold: 0.1 } } } diff --git a/mediapipe/graphs/pose_tracking/BUILD b/mediapipe/graphs/pose_tracking/BUILD index 54af332ca..53d5ef5e2 100644 --- a/mediapipe/graphs/pose_tracking/BUILD +++ b/mediapipe/graphs/pose_tracking/BUILD @@ -24,6 +24,7 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "pose_tracking_gpu_deps", deps = [ + "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/util:landmarks_smoothing_calculator", @@ -42,6 +43,7 @@ mediapipe_binary_graph( cc_library( name = "pose_tracking_cpu_deps", deps = [ + "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/util:landmarks_smoothing_calculator", @@ -56,39 +58,3 @@ mediapipe_binary_graph( output_name = "pose_tracking_cpu.binarypb", deps = [":pose_tracking_cpu_deps"], ) - -cc_library( - name = "upper_body_pose_tracking_gpu_deps", - deps = [ - "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/util:landmarks_smoothing_calculator", - "//mediapipe/graphs/pose_tracking/subgraphs:upper_body_pose_renderer_gpu", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body_gpu", - ], -) - -mediapipe_binary_graph( - name = "upper_body_pose_tracking_gpu_binary_graph", - graph = "upper_body_pose_tracking_gpu.pbtxt", - output_name = "upper_body_pose_tracking_gpu.binarypb", - deps = [":upper_body_pose_tracking_gpu_deps"], -) - -cc_library( - name = "upper_body_pose_tracking_cpu_deps", - deps = [ - "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/util:landmarks_smoothing_calculator", - "//mediapipe/graphs/pose_tracking/subgraphs:upper_body_pose_renderer_cpu", - "//mediapipe/modules/pose_landmark:pose_landmark_upper_body_cpu", - ], -) - -mediapipe_binary_graph( - name = "upper_body_pose_tracking_cpu_binary_graph", - graph = "upper_body_pose_tracking_cpu.pbtxt", - output_name = "upper_body_pose_tracking_cpu.binarypb", - deps = [":upper_body_pose_tracking_cpu_deps"], -) diff --git a/mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt b/mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt index 441fc67a6..380e9e04c 100644 --- a/mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt +++ b/mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt @@ -29,9 +29,20 @@ node { output_stream: "throttled_input_video" } +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:model_complexity" + node_options: { + [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { + packet { int_value: 1 } + } + } +} + # Subgraph that detects poses and corresponding landmarks. node { calculator: "PoseLandmarkCpu" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_stream: "IMAGE:throttled_input_video" output_stream: "LANDMARKS:pose_landmarks" output_stream: "DETECTION:pose_detection" diff --git a/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt b/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt index d2712e16d..c47e76944 100644 --- a/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt +++ b/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt @@ -29,9 +29,20 @@ node { output_stream: "throttled_input_video" } +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:model_complexity" + node_options: { + [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { + packet { int_value: 1 } + } + } +} + # Subgraph that detects poses and corresponding landmarks. node { calculator: "PoseLandmarkGpu" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_stream: "IMAGE:throttled_input_video" output_stream: "LANDMARKS:pose_landmarks" output_stream: "DETECTION:pose_detection" diff --git a/mediapipe/graphs/pose_tracking/subgraphs/BUILD b/mediapipe/graphs/pose_tracking/subgraphs/BUILD index 3a1825704..bb089feb8 100644 --- a/mediapipe/graphs/pose_tracking/subgraphs/BUILD +++ b/mediapipe/graphs/pose_tracking/subgraphs/BUILD @@ -48,31 +48,3 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/util:rect_to_render_scale_calculator", ], ) - -mediapipe_simple_subgraph( - name = "upper_body_pose_renderer_gpu", - graph = "upper_body_pose_renderer_gpu.pbtxt", - register_as = "UpperBodyPoseRendererGpu", - deps = [ - "//mediapipe/calculators/core:split_normalized_landmark_list_calculator", - "//mediapipe/calculators/util:annotation_overlay_calculator", - "//mediapipe/calculators/util:detections_to_render_data_calculator", - "//mediapipe/calculators/util:landmarks_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_scale_calculator", - ], -) - -mediapipe_simple_subgraph( - name = "upper_body_pose_renderer_cpu", - graph = "upper_body_pose_renderer_cpu.pbtxt", - register_as = "UpperBodyPoseRendererCpu", - deps = [ - "//mediapipe/calculators/core:split_normalized_landmark_list_calculator", - "//mediapipe/calculators/util:annotation_overlay_calculator", - "//mediapipe/calculators/util:detections_to_render_data_calculator", - "//mediapipe/calculators/util:landmarks_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_scale_calculator", - ], -) diff --git a/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_cpu.pbtxt b/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_cpu.pbtxt deleted file mode 100644 index 6dcef9566..000000000 --- a/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_cpu.pbtxt +++ /dev/null @@ -1,254 +0,0 @@ -# MediaPipe pose landmarks rendering subgraph. - -type: "UpperBodyPoseRendererCpu" - -# CPU image. (ImageFrame) -input_stream: "IMAGE:input_image" -# Pose landmarks. (NormalizedLandmarkList) -input_stream: "LANDMARKS:pose_landmarks" -# Region of interest calculated based on landmarks. (NormalizedRect) -input_stream: "ROI:roi" -# Detected pose. (Detection) -input_stream: "DETECTION:detection" - -# CPU image with rendered data. (ImageFrame) -output_stream: "IMAGE:output_image" - -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE:input_image" - output_stream: "SIZE:image_size" -} - -# Calculates rendering scale based on the pose roi. -node { - calculator: "RectToRenderScaleCalculator" - input_stream: "NORM_RECT:roi" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "RENDER_SCALE:render_scale" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { - multiplier: 0.0012 - } - } -} - -# Converts detections to drawing primitives for annotation overlay. -node { - calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION:detection" - output_stream: "RENDER_DATA:detection_render_data" - node_options: { - [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 4.0 - color { r: 0 g: 255 b: 0 } - } - } -} - -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "visible_pose_landmarks" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 25 } - } - } -} - -# Converts landmarks to drawing primitives for annotation overlay. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 1 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 7 - landmark_connections: 0 - landmark_connections: 4 - landmark_connections: 4 - landmark_connections: 5 - landmark_connections: 5 - landmark_connections: 6 - landmark_connections: 6 - landmark_connections: 8 - landmark_connections: 9 - landmark_connections: 10 - landmark_connections: 11 - landmark_connections: 12 - landmark_connections: 11 - landmark_connections: 13 - landmark_connections: 13 - landmark_connections: 15 - landmark_connections: 15 - landmark_connections: 17 - landmark_connections: 15 - landmark_connections: 19 - landmark_connections: 15 - landmark_connections: 21 - landmark_connections: 17 - landmark_connections: 19 - landmark_connections: 12 - landmark_connections: 14 - landmark_connections: 14 - landmark_connections: 16 - landmark_connections: 16 - landmark_connections: 18 - landmark_connections: 16 - landmark_connections: 20 - landmark_connections: 16 - landmark_connections: 22 - landmark_connections: 18 - landmark_connections: 20 - landmark_connections: 11 - landmark_connections: 23 - landmark_connections: 12 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 24 - - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Take left pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_left_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 4 } - ranges: { begin: 7 end: 8 } - ranges: { begin: 9 end: 10 } - ranges: { begin: 11 end: 12 } - ranges: { begin: 13 end: 14 } - ranges: { begin: 15 end: 16 } - ranges: { begin: 17 end: 18 } - ranges: { begin: 19 end: 20 } - ranges: { begin: 21 end: 22 } - ranges: { begin: 23 end: 24 } - - combine_outputs: true - } - } -} - -# Take right pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_right_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 4 end: 7 } - ranges: { begin: 8 end: 9 } - ranges: { begin: 10 end: 11 } - ranges: { begin: 12 end: 13 } - ranges: { begin: 14 end: 15 } - ranges: { begin: 16 end: 17 } - ranges: { begin: 18 end: 19 } - ranges: { begin: 20 end: 21 } - ranges: { begin: 22 end: 23 } - ranges: { begin: 24 end: 25 } - - combine_outputs: true - } - } -} - -# Render pose joints as big white circles. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:visible_pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_background_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 5.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose left side joints as orange circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_left_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_left_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 138 b: 0 } - connection_color { r: 255 g: 138 b: 0 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose right side joints as cyan circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_right_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_right_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 0 g: 217 b: 231 } - connection_color { r: 0 g: 217 b: 231 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Converts normalized rects to drawing primitives for annotation overlay. -node { - calculator: "RectToRenderDataCalculator" - input_stream: "NORM_RECT:roi" - output_stream: "RENDER_DATA:roi_render_data" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { - filled: false - color { r: 255 g: 0 b: 0 } - thickness: 4.0 - } - } -} - -# Draws annotations and overlays them on top of the input images. -node { - calculator: "AnnotationOverlayCalculator" - input_stream: "IMAGE:input_image" - input_stream: "detection_render_data" - input_stream: "landmarks_render_data" - input_stream: "landmarks_background_joints_render_data" - input_stream: "landmarks_left_joints_render_data" - input_stream: "landmarks_right_joints_render_data" - input_stream: "roi_render_data" - output_stream: "IMAGE:output_image" -} diff --git a/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_gpu.pbtxt b/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_gpu.pbtxt deleted file mode 100644 index 567ad16ac..000000000 --- a/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_gpu.pbtxt +++ /dev/null @@ -1,254 +0,0 @@ -# MediaPipe pose landmarks rendering subgraph. - -type: "UpperBodyPoseRendererGpu" - -# GPU image. (GpuBuffer) -input_stream: "IMAGE:input_image" -# Pose landmarks. (NormalizedLandmarkList) -input_stream: "LANDMARKS:pose_landmarks" -# Region of interest calculated based on landmarks. (NormalizedRect) -input_stream: "ROI:roi" -# Detected pose. (Detection) -input_stream: "DETECTION:detection" - -# GPU image with rendered data. (GpuBuffer) -output_stream: "IMAGE:output_image" - -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE_GPU:input_image" - output_stream: "SIZE:image_size" -} - -# Calculates rendering scale based on the pose roi. -node { - calculator: "RectToRenderScaleCalculator" - input_stream: "NORM_RECT:roi" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "RENDER_SCALE:render_scale" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { - multiplier: 0.0012 - } - } -} - -# Converts detections to drawing primitives for annotation overlay. -node { - calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION:detection" - output_stream: "RENDER_DATA:detection_render_data" - node_options: { - [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 4.0 - color { r: 0 g: 255 b: 0 } - } - } -} - -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "visible_pose_landmarks" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 25 } - } - } -} - -# Converts landmarks to drawing primitives for annotation overlay. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 1 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 7 - landmark_connections: 0 - landmark_connections: 4 - landmark_connections: 4 - landmark_connections: 5 - landmark_connections: 5 - landmark_connections: 6 - landmark_connections: 6 - landmark_connections: 8 - landmark_connections: 9 - landmark_connections: 10 - landmark_connections: 11 - landmark_connections: 12 - landmark_connections: 11 - landmark_connections: 13 - landmark_connections: 13 - landmark_connections: 15 - landmark_connections: 15 - landmark_connections: 17 - landmark_connections: 15 - landmark_connections: 19 - landmark_connections: 15 - landmark_connections: 21 - landmark_connections: 17 - landmark_connections: 19 - landmark_connections: 12 - landmark_connections: 14 - landmark_connections: 14 - landmark_connections: 16 - landmark_connections: 16 - landmark_connections: 18 - landmark_connections: 16 - landmark_connections: 20 - landmark_connections: 16 - landmark_connections: 22 - landmark_connections: 18 - landmark_connections: 20 - landmark_connections: 11 - landmark_connections: 23 - landmark_connections: 12 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 24 - - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Take left pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_left_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 4 } - ranges: { begin: 7 end: 8 } - ranges: { begin: 9 end: 10 } - ranges: { begin: 11 end: 12 } - ranges: { begin: 13 end: 14 } - ranges: { begin: 15 end: 16 } - ranges: { begin: 17 end: 18 } - ranges: { begin: 19 end: 20 } - ranges: { begin: 21 end: 22 } - ranges: { begin: 23 end: 24 } - - combine_outputs: true - } - } -} - -# Take right pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_right_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 4 end: 7 } - ranges: { begin: 8 end: 9 } - ranges: { begin: 10 end: 11 } - ranges: { begin: 12 end: 13 } - ranges: { begin: 14 end: 15 } - ranges: { begin: 16 end: 17 } - ranges: { begin: 18 end: 19 } - ranges: { begin: 20 end: 21 } - ranges: { begin: 22 end: 23 } - ranges: { begin: 24 end: 25 } - - combine_outputs: true - } - } -} - -# Render pose joints as big white circles. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:visible_pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_background_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 5.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose left side joints as orange circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_left_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_left_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 138 b: 0 } - connection_color { r: 255 g: 138 b: 0 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose right side joints as cyan circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_right_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_right_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 0 g: 217 b: 231 } - connection_color { r: 0 g: 217 b: 231 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Converts normalized rects to drawing primitives for annotation overlay. -node { - calculator: "RectToRenderDataCalculator" - input_stream: "NORM_RECT:roi" - output_stream: "RENDER_DATA:roi_render_data" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { - filled: false - color { r: 255 g: 0 b: 0 } - thickness: 4.0 - } - } -} - -# Draws annotations and overlays them on top of the input images. -node { - calculator: "AnnotationOverlayCalculator" - input_stream: "IMAGE_GPU:input_image" - input_stream: "detection_render_data" - input_stream: "landmarks_render_data" - input_stream: "landmarks_background_joints_render_data" - input_stream: "landmarks_left_joints_render_data" - input_stream: "landmarks_right_joints_render_data" - input_stream: "roi_render_data" - output_stream: "IMAGE_GPU:output_image" -} diff --git a/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt b/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt deleted file mode 100644 index 4e4b5da38..000000000 --- a/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt +++ /dev/null @@ -1,72 +0,0 @@ -# MediaPipe graph that performs upper-body pose tracking with TensorFlow Lite on CPU. - -# CPU buffer. (ImageFrame) -input_stream: "input_video" - -# Output image with rendered results. (ImageFrame) -output_stream: "output_video" -# Pose landmarks. (NormalizedLandmarkList) -output_stream: "pose_landmarks" - -# Throttles the images flowing downstream for flow control. It passes through -# the very first incoming image unaltered, and waits for downstream nodes -# (calculators and subgraphs) in the graph to finish their tasks before it -# passes through another image. All images that come in while waiting are -# dropped, limiting the number of in-flight images in most part of the graph to -# 1. This prevents the downstream nodes from queuing up incoming images and data -# excessively, which leads to increased latency and memory usage, unwanted in -# real-time mobile applications. It also eliminates unnecessarily computation, -# e.g., the output produced by a node may get dropped downstream if the -# subsequent nodes are still busy processing previous inputs. -node { - calculator: "FlowLimiterCalculator" - input_stream: "input_video" - input_stream: "FINISHED:output_video" - input_stream_info: { - tag_index: "FINISHED" - back_edge: true - } - output_stream: "throttled_input_video" -} - -# Subgraph that detects poses and corresponding landmarks. -node { - calculator: "PoseLandmarkUpperBodyCpu" - input_stream: "IMAGE:throttled_input_video" - output_stream: "LANDMARKS:pose_landmarks" - output_stream: "DETECTION:pose_detection" - output_stream: "ROI_FROM_LANDMARKS:roi_from_landmarks" -} - -# Calculates size of the image. -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE:throttled_input_video" - output_stream: "SIZE:image_size" -} - -# Smoothes pose landmarks in order to reduce jitter. -node { - calculator: "LandmarksSmoothingCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_smoothed" - node_options: { - [type.googleapis.com/mediapipe.LandmarksSmoothingCalculatorOptions] { - velocity_filter: { - window_size: 5 - velocity_scale: 10.0 - } - } - } -} - -# Subgraph that renders pose-landmark annotation onto the input image. -node { - calculator: "UpperBodyPoseRendererCpu" - input_stream: "IMAGE:throttled_input_video" - input_stream: "LANDMARKS:pose_landmarks_smoothed" - input_stream: "ROI:roi_from_landmarks" - input_stream: "DETECTION:pose_detection" - output_stream: "IMAGE:output_video" -} diff --git a/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt b/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt deleted file mode 100644 index 5f6084690..000000000 --- a/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt +++ /dev/null @@ -1,72 +0,0 @@ -# MediaPipe graph that performs upper-body pose tracking with TensorFlow Lite on GPU. - -# GPU buffer. (GpuBuffer) -input_stream: "input_video" - -# Output image with rendered results. (GpuBuffer) -output_stream: "output_video" -# Pose landmarks. (NormalizedLandmarkList) -output_stream: "pose_landmarks" - -# Throttles the images flowing downstream for flow control. It passes through -# the very first incoming image unaltered, and waits for downstream nodes -# (calculators and subgraphs) in the graph to finish their tasks before it -# passes through another image. All images that come in while waiting are -# dropped, limiting the number of in-flight images in most part of the graph to -# 1. This prevents the downstream nodes from queuing up incoming images and data -# excessively, which leads to increased latency and memory usage, unwanted in -# real-time mobile applications. It also eliminates unnecessarily computation, -# e.g., the output produced by a node may get dropped downstream if the -# subsequent nodes are still busy processing previous inputs. -node { - calculator: "FlowLimiterCalculator" - input_stream: "input_video" - input_stream: "FINISHED:output_video" - input_stream_info: { - tag_index: "FINISHED" - back_edge: true - } - output_stream: "throttled_input_video" -} - -# Subgraph that detects poses and corresponding landmarks. -node { - calculator: "PoseLandmarkUpperBodyGpu" - input_stream: "IMAGE:throttled_input_video" - output_stream: "LANDMARKS:pose_landmarks" - output_stream: "DETECTION:pose_detection" - output_stream: "ROI_FROM_LANDMARKS:roi_from_landmarks" -} - -# Calculates size of the image. -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE_GPU:throttled_input_video" - output_stream: "SIZE:image_size" -} - -# Smoothes pose landmarks in order to reduce jitter. -node { - calculator: "LandmarksSmoothingCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_smoothed" - node_options: { - [type.googleapis.com/mediapipe.LandmarksSmoothingCalculatorOptions] { - velocity_filter: { - window_size: 5 - velocity_scale: 10.0 - } - } - } -} - -# Subgraph that renders pose-landmark annotation onto the input image. -node { - calculator: "UpperBodyPoseRendererGpu" - input_stream: "IMAGE:throttled_input_video" - input_stream: "LANDMARKS:pose_landmarks_smoothed" - input_stream: "ROI:roi_from_landmarks" - input_stream: "DETECTION:pose_detection" - output_stream: "IMAGE:output_video" -} diff --git a/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java b/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java index 1a3485591..ba37b511c 100644 --- a/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java +++ b/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java @@ -34,15 +34,22 @@ import android.view.Surface; import androidx.camera.core.Camera; import androidx.camera.core.CameraSelector; import androidx.camera.core.CameraX; +import androidx.camera.core.ImageCapture; +import androidx.camera.core.ImageCapture.OnImageSavedCallback; +import androidx.camera.core.ImageCapture.OutputFileOptions; import androidx.camera.core.Preview; import androidx.camera.lifecycle.ProcessCameraProvider; import androidx.core.content.ContextCompat; import com.google.common.util.concurrent.ListenableFuture; import com.google.mediapipe.glutil.EglManager; +import java.io.File; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.microedition.khronos.egl.EGLSurface; @@ -102,12 +109,17 @@ public class CameraXPreviewHelper extends CameraHelper { private ProcessCameraProvider cameraProvider; private Preview preview; + private ImageCapture imageCapture; + private ImageCapture.Builder imageCaptureBuilder; + private ExecutorService imageCaptureExecutorService; private Camera camera; // Size of the camera-preview frames from the camera. private Size frameSize; // Rotation of the camera-preview frames in degrees. private int frameRotation; + // Checks if the image capture use case is enabled. + private boolean isImageCaptureEnabled = false; @Nullable private CameraCharacteristics cameraCharacteristics = null; @@ -144,11 +156,29 @@ public class CameraXPreviewHelper extends CameraHelper { startCamera(activity, (LifecycleOwner) activity, cameraFacing, targetSize); } + /** + * Initializes the camera and sets it up for accessing frames. This constructor also enables the + * image capture use case from {@link CameraX}. + * + * @param imageCaptureBuilder Builder for an {@link ImageCapture}, this builder must contain the + * desired configuration options for the image capture being build (e.g. target resolution). + * @param targetSize the preview size to use. If set to {@code null}, the helper will default to + * 1280 * 720. + */ + public void startCamera( + Activity activity, + @Nonnull ImageCapture.Builder imageCaptureBuilder, + CameraFacing cameraFacing, + @Nullable Size targetSize) { + this.imageCaptureBuilder = imageCaptureBuilder; + startCamera(activity, (LifecycleOwner) activity, cameraFacing, targetSize); + } + /** * Initializes the camera and sets it up for accessing frames. * * @param targetSize the preview size to use. If set to {@code null}, the helper will default to - * 1280 * 720. + * 1280 * 720. */ public void startCamera( Context context, @@ -232,12 +262,38 @@ public class CameraXPreviewHelper extends CameraHelper { // the way the activity is currently structured. cameraProvider.unbindAll(); - // Bind preview use case to camera. - camera = cameraProvider.bindToLifecycle(lifecycleOwner, cameraSelector, preview); + // Bind use case(s) to camera. + if (imageCaptureBuilder != null) { + imageCapture = imageCaptureBuilder.build(); + camera = + cameraProvider.bindToLifecycle( + lifecycleOwner, cameraSelector, preview, imageCapture); + imageCaptureExecutorService = Executors.newSingleThreadExecutor(); + isImageCaptureEnabled = true; + } else { + camera = cameraProvider.bindToLifecycle(lifecycleOwner, cameraSelector, preview); + } }, mainThreadExecutor); } + /** + * Captures a new still image and saves to a file along with application specified metadata. This + * method works when {@link CameraXPreviewHelper#startCamera(Activity, ImageCapture.Builder, + * CameraFacing, Size)} has been called previously enabling image capture. The callback will be + * called only once for every invocation of this method. + * + * @param outputFile Save location for captured image. + * @param onImageSavedCallback Callback to be called for the newly captured image. + */ + public void takePicture(File outputFile, OnImageSavedCallback onImageSavedCallback) { + if (isImageCaptureEnabled) { + OutputFileOptions outputFileOptions = new OutputFileOptions.Builder(outputFile).build(); + imageCapture.takePicture( + outputFileOptions, imageCaptureExecutorService, onImageSavedCallback); + } + } + @Override public boolean isCameraRotated() { return frameRotation % 180 == 90; diff --git a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java index 26a28cc1c..4446633d3 100644 --- a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java +++ b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java @@ -72,6 +72,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor private int numAudioChannels = 1; // Sample rate of audio data sent to the MediaPipe graph. private double audioSampleRate; + // Use new Image container(true), or existing GpuBuffer(false). Configure via setUseImage(bool); + private boolean useImage = false; /** * Constructor for video input/output. @@ -118,6 +120,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor initializeGraphAndPacketCreator(graphConfig); } + /** + * Use Image container (if true), or existing GpuBuffer (if false, default). + * + *

Note: should be called before calling {@link onNewFrame(TextureFrame frame)}. + */ + public void setUseImage(boolean use) { + useImage = use; + } + /** * Initializes a graph for processing data in real time. * @@ -387,6 +398,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor * *

Normally the graph is initialized when the first frame arrives. You can optionally call this * method to initialize it ahead of time. + * * @throws MediaPipeException for any error status. */ public void preheat() { @@ -432,15 +444,18 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor addFrameListener.onWillAddFrame(timestamp); } - imagePacket = packetCreator.createGpuBuffer(frame); + if (useImage) { + imagePacket = packetCreator.createImage(frame); + } else { + imagePacket = packetCreator.createGpuBuffer(frame); + } // imagePacket takes ownership of frame and will release it. frame = null; try { // addConsumablePacketToInputStream allows the graph to take exclusive ownership of the // packet, which may allow for more memory optimizations. - mediapipeGraph.addConsumablePacketToInputStream( - videoInputStream, imagePacket, timestamp); + mediapipeGraph.addConsumablePacketToInputStream(videoInputStream, imagePacket, timestamp); // If addConsumablePacket succeeded, we don't need to release the packet ourselves. imagePacket = null; } catch (MediaPipeException e) { @@ -531,6 +546,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor /** * Starts running the MediaPipe graph. + * * @throws MediaPipeException for any error status. */ private void startGraph() { diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 5e1a7a135..69c0ebeb6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,6 +15,7 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import java.nio.ByteBuffer; // TODO: use Preconditions in this file. /** @@ -46,6 +47,14 @@ public class AndroidPacketCreator extends PacketCreator { return Packet.create(nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), bitmap)); } + /** Creates a 4 channel RGBA Image packet from a {@link Bitmap}. */ + public Packet createRgbaImage(Bitmap bitmap) { + if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { + throw new RuntimeException("bitmap must use ARGB_8888 config."); + } + return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); + } + /** * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on * failure. @@ -57,4 +66,10 @@ public class AndroidPacketCreator extends PacketCreator { * failure. */ private native long nativeCreateRgbaImageFrame(long context, Bitmap bitmap); + + /** + * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on + * failure. + */ + private native long nativeCreateRgbaImage(long context, Bitmap bitmap); } diff --git a/mediapipe/java/com/google/mediapipe/framework/Graph.java b/mediapipe/java/com/google/mediapipe/framework/Graph.java index b90e51d8a..276adc797 100644 --- a/mediapipe/java/com/google/mediapipe/framework/Graph.java +++ b/mediapipe/java/com/google/mediapipe/framework/Graph.java @@ -165,13 +165,29 @@ public class Graph { */ public synchronized void addMultiStreamCallback( List streamNames, PacketListCallback callback) { + addMultiStreamCallback(streamNames, callback, false); + } + + /** + * Adds a {@link PacketListCallback} to the context for callback during graph running. + * + * @param streamNames The output stream names in the graph for callback. + * @param callback The callback for handling the call when all output streams listed in + * streamNames get {@link Packet}. + * @param observeTimestampBounds Whether to output an empty packet when a timestamp bound change + * is observed with no output data. This can happen when an input packet is processed but no + * corresponding output packet is immediately generated. + * @throws MediaPipeException for any error status. + */ + public synchronized void addMultiStreamCallback( + List streamNames, PacketListCallback callback, boolean observeTimestampBounds) { Preconditions.checkState( nativeGraphHandle != 0, "Invalid context, tearDown() might have been called already."); Preconditions.checkNotNull(streamNames); Preconditions.checkNotNull(callback); Preconditions.checkState(!graphRunning && !startRunningGraphCalled); callbacks.add(callback); - nativeAddMultiStreamCallback(nativeGraphHandle, streamNames, callback); + nativeAddMultiStreamCallback(nativeGraphHandle, streamNames, callback, observeTimestampBounds); } /** @@ -180,7 +196,7 @@ public class Graph { *

Multiple outputs can be attached to the same stream. * * @param streamName The output stream name in the graph. - * @result a new SurfaceOutput. + * @return a new SurfaceOutput. */ public synchronized SurfaceOutput addSurfaceOutput(String streamName) { Preconditions.checkState( @@ -600,7 +616,10 @@ public class Graph { long context, String streamName, PacketCallback callback); private native void nativeAddMultiStreamCallback( - long context, List streamName, PacketListCallback callback); + long context, + List streamName, + PacketListCallback callback, + boolean observeTimestampBounds); private native long nativeAddSurfaceOutput(long context, String streamName); diff --git a/mediapipe/java/com/google/mediapipe/framework/Packet.java b/mediapipe/java/com/google/mediapipe/framework/Packet.java index f34573d0a..36096134f 100644 --- a/mediapipe/java/com/google/mediapipe/framework/Packet.java +++ b/mediapipe/java/com/google/mediapipe/framework/Packet.java @@ -31,26 +31,31 @@ public class Packet { /** * Creates a Java packet from a native mediapipe packet handle. * - * @return A Packet from a native internal::PacketWithContext handle. + * Returns a Packet from a native internal::PacketWithContext handle. */ public static Packet create(long nativeHandle) { return new Packet(nativeHandle); } /** - * @return The native handle of the packet. + * Returns the native handle of the packet. */ public long getNativeHandle() { return nativePacketHandle; } - /** @return The timestamp of the Packet. */ + /** Returns the timestamp of the Packet. */ public long getTimestamp() { return nativeGetTimestamp(nativePacketHandle); } + /** Returns true if the Packet is empty. */ + public boolean isEmpty() { + return nativeIsEmpty(nativePacketHandle); + } + /** - * @return a shared copy of the Packet. + * Returns a shared copy of the Packet. *

This is essentially increasing the reference count to the data encapsulated in the * native mediapipe packet. */ @@ -82,4 +87,6 @@ public class Packet { private native long nativeCopyPacket(long packetHandle); private native long nativeGetTimestamp(long packetHandle); + + private native boolean nativeIsEmpty(long packetHandle); } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index 47b1814be..f9ea6760c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -267,8 +267,7 @@ public class PacketCreator { /** Creates a {@link Packet} containing a protobuf MessageLite. */ public Packet createProto(MessageLite message) { SerializedMessage serialized = ProtoUtil.pack(message); - return Packet.create( - nativeCreateProto(mediapipeGraph.getNativeHandle(), serialized)); + return Packet.create(nativeCreateProto(mediapipeGraph.getNativeHandle(), serialized)); } /** Creates a {@link Packet} containing the given camera intrinsics. */ @@ -327,13 +326,59 @@ public class PacketCreator { frame)); } + /** + * Creates a mediapipe::Image with the provided {@link TextureFrame}. + * + *

Note: in order for MediaPipe to be able to access the texture, the application's GL context + * must be linked with MediaPipe's. This is ensured by calling {@link + * Graph#createGlRunner(String,long)} with the native handle to the application's GL context as + * the second argument. + */ + public Packet createImage(TextureFrame frame) { + return Packet.create( + nativeCreateGpuImage( + mediapipeGraph.getNativeHandle(), + frame.getTextureName(), + frame.getWidth(), + frame.getHeight(), + frame)); + } + + /** + * Creates a 1, 3, or 4 channel 8-bit Image packet from a U8, RGB, or RGBA byte buffer. + * + *

Use {@link ByteBuffer#allocateDirect} when allocating the buffer. + * + *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. + */ + public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + if (numChannels == 4) { + if (buffer.capacity() != width * height * 4) { + throw new RuntimeException("buffer doesn't have the correct size."); + } + } else if (numChannels == 3) { + int widthStep = (((width * 3) + 3) / 4) * 4; + if (widthStep * height != buffer.capacity()) { + throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + } + } else if (numChannels == 1) { + if (width * height != buffer.capacity()) { + throw new RuntimeException( + "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); + } + } else { + throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + return Packet.create( + nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + } + /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ private void releaseWithSyncToken(long nativeSyncToken, TextureReleaseCallback releaseCallback) { releaseCallback.release(new GraphGlSyncToken(nativeSyncToken)); } private native long nativeCreateReferencePacket(long context, long packet); - private native long nativeCreateRgbImage(long context, ByteBuffer buffer, int width, int height); private native long nativeCreateAudioPacket( long context, byte[] data, int offset, int numChannels, int numSamples); @@ -344,32 +389,55 @@ public class PacketCreator { private native long nativeCreateRgbImageFromRgba( long context, ByteBuffer buffer, int width, int height); + private native long nativeCreateRgbImage(long context, ByteBuffer buffer, int width, int height); + private native long nativeCreateGrayscaleImage( long context, ByteBuffer buffer, int width, int height); private native long nativeCreateRgbaImageFrame( long context, ByteBuffer buffer, int width, int height); + private native long nativeCreateFloatImageFrame( long context, FloatBuffer buffer, int width, int height); + private native long nativeCreateInt16(long context, short value); + private native long nativeCreateInt32(long context, int value); + private native long nativeCreateInt64(long context, long value); + private native long nativeCreateFloat32(long context, float value); + private native long nativeCreateFloat64(long context, double value); + private native long nativeCreateBool(long context, boolean value); + private native long nativeCreateString(long context, String value); + private native long nativeCreateVideoHeader(long context, int width, int height); + private native long nativeCreateTimeSeriesHeader( long context, int numChannels, double sampleRate); + private native long nativeCreateMatrix(long context, int rows, int cols, float[] data); + private native long nativeCreateGpuBuffer( long context, int name, int width, int height, TextureReleaseCallback releaseCallback); + + private native long nativeCreateGpuImage( + long context, int name, int width, int height, TextureReleaseCallback releaseCallback); + + private native long nativeCreateCpuImage( + long context, ByteBuffer buffer, int width, int height, int numChannels); + private native long nativeCreateInt32Array(long context, int[] data); + private native long nativeCreateFloat32Array(long context, float[] data); private native long nativeCreateFloat32Vector(long context, float[] data); private native long nativeCreateStringFromByteArray(long context, byte[] data); + private native long nativeCreateProto(long context, SerializedMessage data); private native long nativeCreateCalculatorOptions(long context, byte[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 849cf76db..3d59c3e0a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -63,7 +63,7 @@ public final class PacketGetter { * @param packet A MediaPipe packet that contains a pair of packets. */ public static PacketPair getPairOfPackets(final Packet packet) { - long[] handles = nativeGetPairPackets(packet.getNativeHandle()); + long[] handles = nativeGetPairPackets(packet.getNativeHandle()); return new PacketPair(Packet.create(handles[0]), Packet.create(handles[1])); } @@ -75,12 +75,12 @@ public final class PacketGetter { * @param packet A MediaPipe packet that contains a vector of packets. */ public static List getVectorOfPackets(final Packet packet) { - long[] handles = nativeGetVectorPackets(packet.getNativeHandle()); + long[] handles = nativeGetVectorPackets(packet.getNativeHandle()); List packets = new ArrayList<>(handles.length); - for (long handle : handles) { + for (long handle : handles) { packets.add(Packet.create(handle)); - } - return packets; + } + return packets; } public static short getInt16(final Packet packet) { @@ -292,33 +292,53 @@ public final class PacketGetter { } private static native long nativeGetPacketFromReference(long nativePacketHandle); + private static native long[] nativeGetPairPackets(long nativePacketHandle); + private static native long[] nativeGetVectorPackets(long nativePacketHandle); private static native short nativeGetInt16(long nativePacketHandle); + private static native int nativeGetInt32(long nativePacketHandle); + private static native long nativeGetInt64(long nativePacketHandle); + private static native float nativeGetFloat32(long nativePacketHandle); + private static native double nativeGetFloat64(long nativePacketHandle); + private static native boolean nativeGetBool(long nativePacketHandle); + private static native String nativeGetString(long nativePacketHandle); + private static native byte[] nativeGetBytes(long nativePacketHandle); + private static native byte[] nativeGetProtoBytes(long nativePacketHandle); + private static native void nativeGetProto(long nativePacketHandle, SerializedMessage result); + private static native short[] nativeGetInt16Vector(long nativePacketHandle); + private static native int[] nativeGetInt32Vector(long nativePacketHandle); + private static native long[] nativeGetInt64Vector(long nativePacketHandle); + private static native float[] nativeGetFloat32Vector(long nativePacketHandle); + private static native double[] nativeGetFloat64Vector(long nativePacketHandle); private static native byte[][] nativeGetProtoVector(long nativePacketHandle); private static native int nativeGetImageWidth(long nativePacketHandle); + private static native int nativeGetImageHeight(long nativePacketHandle); + private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer); + private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer); // Retrieves the values that are in the VideoHeader. private static native int nativeGetVideoHeaderWidth(long nativepackethandle); + private static native int nativeGetVideoHeaderHeight(long nativepackethandle); // Retrieves the values that are in the mediapipe::TimeSeriesHeader. private static native int nativeGetTimeSeriesHeaderNumChannels(long nativepackethandle); @@ -331,8 +351,11 @@ public final class PacketGetter { private static native float[] nativeGetMatrixData(long nativePacketHandle); private static native int nativeGetMatrixRows(long nativePacketHandle); + private static native int nativeGetMatrixCols(long nativePacketHandle); + private static native int nativeGetGpuBufferName(long nativePacketHandle); + private static native long nativeGetGpuBuffer(long nativePacketHandle); private PacketGetter() {} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index cd98e4595..e16f140b4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -95,6 +95,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", "//mediapipe/framework:camera_intrinsics", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:video_stream_header", @@ -119,6 +120,7 @@ cc_library( "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_surface_sink_calculator", + "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", ], diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc index 05ebe26f3..cda84ac16 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc @@ -20,6 +20,7 @@ #include #include "absl/memory/memory.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/logging.h" @@ -37,6 +38,58 @@ int64_t CreatePacketWithContext(jlong context, return mediapipe_graph->WrapPacketIntoContext(packet); } +// Create 3 or 4 channel 8-bit ImageFrame shared pointer from a Java Bitmap. +std::unique_ptr CreateImageFrameFromBitmap( + JNIEnv* env, jobject bitmap, int width, int height, int stride, + mediapipe::ImageFormat::Format format) { + auto image_frame = std::make_unique( + format, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + + void* pixel_addr = nullptr; + int result = AndroidBitmap_lockPixels(env, bitmap, &pixel_addr); + if (result != ANDROID_BITMAP_RESULT_SUCCESS) { + LOG(ERROR) << "AndroidBitmap_lockPixels() failed with result code " + << result; + return nullptr; + } + + if (format == mediapipe::ImageFormat::SRGBA) { + const int64_t buffer_size = stride * height; + if (buffer_size != image_frame->PixelDataSize()) { + LOG(ERROR) << "Bitmap stride: " << stride + << " times bitmap height: " << height + << " is not equal to the expected size: " + << image_frame->PixelDataSize(); + return nullptr; + } + std::memcpy(image_frame->MutablePixelData(), pixel_addr, + image_frame->PixelDataSize()); + } else if (format == mediapipe::ImageFormat::SRGB) { + if (stride != width * 4) { + LOG(ERROR) << "Bitmap stride: " << stride + << "is not equal to 4 times bitmap width: " << width; + return nullptr; + } + const uint8_t* rgba_data = static_cast(pixel_addr); + mediapipe::android::RgbaToRgb(rgba_data, stride, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + } else { + LOG(ERROR) << "unsupported image format: " << format; + return nullptr; + } + + result = AndroidBitmap_unlockPixels(env, bitmap); + if (result != ANDROID_BITMAP_RESULT_SUCCESS) { + LOG(ERROR) << "AndroidBitmap_unlockPixels() failed with result code " + << result; + return nullptr; + } + + return image_frame; +} + } // namespace JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( @@ -48,31 +101,12 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; return 0L; } - if (info.stride != info.width * 4) { - LOG(ERROR) << "Bitmap stride: " << info.stride - << "is not equal to 4 times bitmap width: " << info.width; - return 0L; - } - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, info.width, info.height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - void* pixel_addr = nullptr; - result = AndroidBitmap_lockPixels(env, bitmap, &pixel_addr); - if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_lockPixels() failed with result code " - << result; - return 0L; - } - const uint8_t* rgba_data = static_cast(pixel_addr); - mediapipe::android::RgbaToRgb(rgba_data, info.stride, info.width, info.height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - result = AndroidBitmap_unlockPixels(env, bitmap); - if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_unlockPixels() failed with result code " - << result; - return 0L; - } + + auto image_frame = + CreateImageFrameFromBitmap(env, bitmap, info.width, info.height, + info.stride, mediapipe::ImageFormat::SRGB); + if (nullptr == image_frame) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); return CreatePacketWithContext(context, packet); } @@ -86,32 +120,31 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; return 0L; } - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGBA, info.width, info.height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = info.stride * info.height; - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Bitmap stride: " << info.stride - << " times bitmap height: " << info.height - << " is not equal to the expected size: " - << image_frame->PixelDataSize(); - return 0L; - } - void* pixel_addr = nullptr; - result = AndroidBitmap_lockPixels(env, bitmap, &pixel_addr); - if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_lockPixels() failed with result code " - << result; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), pixel_addr, - image_frame->PixelDataSize()); - result = AndroidBitmap_unlockPixels(env, bitmap); - if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_unlockPixels() failed with result code " - << result; - return 0L; - } + + auto image_frame = + CreateImageFrameFromBitmap(env, bitmap, info.width, info.height, + info.stride, mediapipe::ImageFormat::SRGBA); + if (nullptr == image_frame) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); return CreatePacketWithContext(context, packet); } + +JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD(nativeCreateRgbaImage)( + JNIEnv* env, jobject thiz, jlong context, jobject bitmap) { + AndroidBitmapInfo info; + int result = AndroidBitmap_getInfo(env, bitmap, &info); + if (result != ANDROID_BITMAP_RESULT_SUCCESS) { + LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; + return 0L; + } + + auto image_frame = + CreateImageFrameFromBitmap(env, bitmap, info.width, info.height, + info.stride, mediapipe::ImageFormat::SRGBA); + if (nullptr == image_frame) return 0L; + + mediapipe::Packet packet = + mediapipe::MakePacket(std::move(image_frame)); + return CreatePacketWithContext(context, packet); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.h index a1fc587d9..457f26fed 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.h @@ -32,6 +32,9 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( nativeCreateRgbaImageFrame)(JNIEnv* env, jobject thiz, jlong context, jobject bitmap); +JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD(nativeCreateRgbaImage)( + JNIEnv* env, jobject thiz, jlong context, jobject bitmap); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h b/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h index 1cf0fb2ce..c6fc6217e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h @@ -43,6 +43,8 @@ class ClassRegistry { "com/google/mediapipe/framework/Compat"; static constexpr char const* kGraphClassName = "com/google/mediapipe/framework/Graph"; + static constexpr char const* kGraphProfilerClassName = + "com/google/mediapipe/framework/GraphProfiler"; static constexpr char const* kPacketClassName = "com/google/mediapipe/framework/Packet"; static constexpr char const* kMediaPipeExceptionClassName = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index e244c1186..e24df24f9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -202,17 +202,16 @@ absl::Status Graph::AddCallbackHandler(std::string output_stream_name, } absl::Status Graph::AddMultiStreamCallbackHandler( - std::vector output_stream_names, jobject java_callback) { + std::vector output_stream_names, jobject java_callback, + bool observe_timestamp_bounds) { if (!graph_config()) { return absl::InternalError("Graph is not loaded!"); } auto handler = absl::make_unique(this, java_callback); - std::pair side_packet_pair; - tool::AddMultiStreamCallback(output_stream_names, - handler->CreatePacketListCallback(), - graph_config(), &side_packet_pair); - side_packets_[side_packet_pair.first] = side_packet_pair.second; + tool::AddMultiStreamCallback( + output_stream_names, handler->CreatePacketListCallback(), graph_config(), + &side_packets_, observe_timestamp_bounds); EnsureMinimumExecutorStackSizeForJava(); callback_handlers_.emplace_back(std::move(handler)); return absl::OkStatus(); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 2a29c04cb..87ac516bd 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -53,7 +53,8 @@ class Graph { jobject java_callback); // Adds a callback for multiple output streams. absl::Status AddMultiStreamCallbackHandler( - std::vector output_stream_names, jobject java_callback); + std::vector output_stream_names, jobject java_callback, + bool observe_timestamp_bounds); // Loads a binary graph from a file. absl::Status LoadBinaryGraph(std::string path_to_graph); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index ec8cc3efd..2b761bf60 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -191,7 +191,7 @@ GRAPH_METHOD(nativeAddPacketCallback)(JNIEnv* env, jobject thiz, jlong context, JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( JNIEnv* env, jobject thiz, jlong context, jobject stream_names, - jobject callback) { + jobject callback, jboolean observe_timestamp_bounds) { mediapipe::android::Graph* mediapipe_graph = reinterpret_cast(context); std::vector output_stream_names = @@ -214,7 +214,8 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( return; } ThrowIfError(env, mediapipe_graph->AddMultiStreamCallbackHandler( - output_stream_names, global_callback_ref)); + output_stream_names, global_callback_ref, + observe_timestamp_bounds)); } JNIEXPORT jlong JNICALL GRAPH_METHOD(nativeAddSurfaceOutput)( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h index c7c321171..c416c1e4f 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h @@ -64,7 +64,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddPacketCallback)( JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( JNIEnv* env, jobject thiz, jlong context, jobject stream_names, - jobject callback); + jobject callback, jboolean observe_timestamp_bounds); JNIEXPORT jlong JNICALL GRAPH_METHOD(nativeAddSurfaceOutput)( JNIEnv* env, jobject thiz, jlong context, jstring stream_name); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc index 6f9df3bee..b5181a278 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc @@ -17,29 +17,32 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_profile.pb.h" -JNIEXPORT void JNICALL GRAPH_METHOD(nativeReset)(JNIEnv* env, jobject thiz, - jlong handle) { +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativeReset)(JNIEnv* env, + jobject thiz, + jlong handle) { mediapipe::ProfilingContext* profiling_context = reinterpret_cast(handle); profiling_context->Reset(); } -JNIEXPORT void JNICALL GRAPH_METHOD(nativePause)(JNIEnv* env, jobject thiz, - jlong handle) { +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativePause)(JNIEnv* env, + jobject thiz, + jlong handle) { mediapipe::ProfilingContext* profiling_context = reinterpret_cast(handle); profiling_context->Pause(); } -JNIEXPORT void JNICALL GRAPH_METHOD(nativeResume)(JNIEnv* env, jobject thiz, - jlong handle) { +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativeResume)(JNIEnv* env, + jobject thiz, + jlong handle) { mediapipe::ProfilingContext* profiling_context = reinterpret_cast(handle); profiling_context->Resume(); } -JNIEXPORT jobjectArray JNICALL GRAPH_METHOD(nativeGetCalculatorProfiles)( - JNIEnv* env, jobject thiz, jlong handle) { +JNIEXPORT jobjectArray JNICALL GRAPH_PROFILER_METHOD( + nativeGetCalculatorProfiles)(JNIEnv* env, jobject thiz, jlong handle) { mediapipe::ProfilingContext* profiling_context = reinterpret_cast(handle); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.h index 720814e31..13e600761 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.h @@ -21,21 +21,22 @@ extern "C" { #endif // __cplusplus -#define GRAPH_METHOD(METHOD_NAME) \ +#define GRAPH_PROFILER_METHOD(METHOD_NAME) \ Java_com_google_mediapipe_framework_GraphProfiler_##METHOD_NAME -JNIEXPORT void JNICALL GRAPH_METHOD(nativeReset)(JNIEnv* env, jobject thiz, - jlong profiling_context); - -JNIEXPORT void JNICALL GRAPH_METHOD(nativeResume)(JNIEnv* env, jobject thiz, - jlong profiling_context); - -JNIEXPORT void JNICALL GRAPH_METHOD(nativePause)(JNIEnv* env, jobject thiz, - jlong profiling_context); - -JNIEXPORT jobjectArray JNICALL GRAPH_METHOD(nativeGetCalculatorProfiles)( +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativeReset)( JNIEnv* env, jobject thiz, jlong profiling_context); +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativeResume)( + JNIEnv* env, jobject thiz, jlong profiling_context); + +JNIEXPORT void JNICALL GRAPH_PROFILER_METHOD(nativePause)( + JNIEnv* env, jobject thiz, jlong profiling_context); + +JNIEXPORT jobjectArray JNICALL GRAPH_PROFILER_METHOD( + nativeGetCalculatorProfiles)(JNIEnv* env, jobject thiz, + jlong profiling_context); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.cc index 12a2a92c1..450258aad 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.cc @@ -34,6 +34,12 @@ JNIEXPORT jlong JNICALL PACKET_METHOD(nativeGetTimestamp)(JNIEnv* env, .Value(); } +JNIEXPORT jboolean JNICALL PACKET_METHOD(nativeIsEmpty)(JNIEnv* env, + jobject thiz, + jlong packet) { + return mediapipe::android::Graph::GetPacketFromHandle(packet).IsEmpty(); +} + JNIEXPORT jlong JNICALL PACKET_METHOD(nativeCopyPacket)(JNIEnv* env, jobject thiz, jlong packet) { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.h index 44d8bb137..46d11dd2e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.h @@ -34,6 +34,11 @@ JNIEXPORT jlong JNICALL PACKET_METHOD(nativeGetTimestamp)(JNIEnv* env, jobject thiz, jlong packet); +// Returns true if the packet is empty. +JNIEXPORT jboolean JNICALL PACKET_METHOD(nativeIsEmpty)(JNIEnv* env, + jobject thiz, + jlong packet); + // Make a copy of a mediapipe packet, basically increase the reference count. JNIEXPORT jlong JNICALL PACKET_METHOD(nativeCopyPacket)(JNIEnv* env, jobject thiz, diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index c1bfd09ac..c9c8553fd 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" @@ -26,6 +27,7 @@ #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" @@ -53,6 +55,110 @@ int64_t CreatePacketWithContext(jlong context, reinterpret_cast(context); return mediapipe_graph->WrapPacketIntoContext(packet); } + +#if !MEDIAPIPE_DISABLE_GPU +mediapipe::GpuBuffer CreateGpuBuffer(JNIEnv* env, jobject thiz, jlong context, + jint name, jint width, jint height, + jobject texture_release_callback) { + mediapipe::android::Graph* mediapipe_graph = + reinterpret_cast(context); + auto* gpu_resources = mediapipe_graph->GetGpuResources(); + CHECK(gpu_resources) << "Cannot create a mediapipe::GpuBuffer packet on a " + "graph without GPU support"; + mediapipe::GlTextureBuffer::DeletionCallback cc_callback; + + if (texture_release_callback) { + // TODO: see if this can be cached. + // Note: we don't get this from the object because people may pass a + // subclass of PacketCreator, and the method is private. + jclass my_class = + env->FindClass("com/google/mediapipe/framework/PacketCreator"); + jmethodID release_method = + env->GetMethodID(my_class, "releaseWithSyncToken", + "(JL" + "com/google/mediapipe/framework/TextureReleaseCallback" + ";)V"); + CHECK(release_method); + env->DeleteLocalRef(my_class); + + jobject java_callback = env->NewGlobalRef(texture_release_callback); + jobject packet_creator = env->NewGlobalRef(thiz); + cc_callback = [packet_creator, release_method, + java_callback](mediapipe::GlSyncToken release_token) { + JNIEnv* env = mediapipe::java::GetJNIEnv(); + + jlong raw_token = reinterpret_cast( + new mediapipe::GlSyncToken(std::move(release_token))); + env->CallVoidMethod(packet_creator, release_method, raw_token, + java_callback); + + // Note that this callback is called only once, and is not saved + // anywhere else, so we can and should delete it here. + env->DeleteGlobalRef(java_callback); + env->DeleteGlobalRef(packet_creator); + }; + } + return mediapipe::GpuBuffer(mediapipe::GlTextureBuffer::Wrap( + GL_TEXTURE_2D, name, width, height, mediapipe::GpuBufferFormat::kBGRA32, + gpu_resources->gl_context(), cc_callback)); +} +#endif // !MEDIAPIPE_DISABLE_GPU + +// Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java +// ByteBuffer. +std::unique_ptr CreateImageFrameFromByteBuffer( + JNIEnv* env, jobject byte_buffer, jint width, jint height, + mediapipe::ImageFormat::Format format) { + switch (format) { + case mediapipe::ImageFormat::SRGBA: + case mediapipe::ImageFormat::SRGB: + case mediapipe::ImageFormat::GRAY8: + break; + default: + LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; + return nullptr; + } + + auto image_frame = std::make_unique( + format, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + + const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + const int num_channels = image_frame->NumberOfChannels(); + const int expected_buffer_size = + num_channels == 1 ? width * height : image_frame->PixelDataSize(); + + if (buffer_size != expected_buffer_size) { + if (num_channels != 1) + LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; + LOG(ERROR) << "Please check the input buffer size."; + LOG(ERROR) << "Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; + return nullptr; + } + + // Copy buffer data to image frame's pixel_data_. + if (num_channels == 1) { + const int width_step = image_frame->WidthStep(); + const char* src_row = + reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); + for (int i = height; i > 0; --i) { + std::memcpy(dst_row, src_row, width); + src_row += width; + dst_row += width_step; + } + } else { + // 3 and 4 channels. + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + std::memcpy(image_frame->MutablePixelData(), buffer_data, + image_frame->PixelDataSize()); + } + + return image_frame; +} + } // namespace JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( @@ -69,20 +175,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); + auto image_frame = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); + if (nullptr == image_frame) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); return CreatePacketWithContext(context, packet); } @@ -113,28 +209,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::GRAY8, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height - << ", Image height: " << height; - return 0L; - } + auto image_frame = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); + if (nullptr == image_frame) return 0L; - int width_step = image_frame->WidthStep(); - // Copy buffer data to image frame's pixel_data_. - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); return CreatePacketWithContext(context, packet); } @@ -163,20 +241,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* rgba_data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGBA, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), rgba_data, - image_frame->PixelDataSize()); + auto image_frame = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); + if (nullptr == image_frame) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); return CreatePacketWithContext(context, packet); } @@ -287,8 +355,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, " - "has to be rows * cols = " + LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " << rows * cols; return 0L; } @@ -300,54 +367,51 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( return CreatePacketWithContext(context, packet); } +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( + JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, + jint height, jint num_channels) { + mediapipe::ImageFormat::Format format; + switch (num_channels) { + case 4: + format = mediapipe::ImageFormat::SRGBA; + break; + case 3: + format = mediapipe::ImageFormat::SRGB; + break; + case 1: + format = mediapipe::ImageFormat::GRAY8; + break; + default: + LOG(ERROR) << "Channels must be either 1, 3, or 4."; + return 0L; + } + + auto image_frame = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); + if (nullptr == image_frame) return 0L; + + mediapipe::Packet packet = + mediapipe::MakePacket(std::move(image_frame)); + return CreatePacketWithContext(context, packet); +} + #if !MEDIAPIPE_DISABLE_GPU +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( + JNIEnv* env, jobject thiz, jlong context, jint name, jint width, + jint height, jobject texture_release_callback) { + mediapipe::Packet image_packet = + mediapipe::MakePacket(CreateGpuBuffer( + env, thiz, context, name, width, height, texture_release_callback)); + return CreatePacketWithContext(context, image_packet); +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuBuffer)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, jint height, jobject texture_release_callback) { - mediapipe::android::Graph* mediapipe_graph = - reinterpret_cast(context); - auto* gpu_resources = mediapipe_graph->GetGpuResources(); - CHECK(gpu_resources) << "Cannot create a mediapipe::GpuBuffer packet on a " - "graph without GPU support"; - mediapipe::GlTextureBuffer::DeletionCallback cc_callback; - - if (texture_release_callback) { - // TODO: see if this can be cached. - // Note: we don't get this from the object because people may pass a - // subclass of PacketCreator, and the method is private. - jclass my_class = - env->FindClass("com/google/mediapipe/framework/PacketCreator"); - jmethodID release_method = - env->GetMethodID(my_class, "releaseWithSyncToken", - "(JL" - "com/google/mediapipe/framework/TextureReleaseCallback" - ";)V"); - CHECK(release_method); - env->DeleteLocalRef(my_class); - - jobject java_callback = env->NewGlobalRef(texture_release_callback); - jobject packet_creator = env->NewGlobalRef(thiz); - cc_callback = [mediapipe_graph, packet_creator, release_method, - java_callback](mediapipe::GlSyncToken release_token) { - JNIEnv* env = mediapipe::java::GetJNIEnv(); - - jlong raw_token = reinterpret_cast( - new mediapipe::GlSyncToken(std::move(release_token))); - env->CallVoidMethod(packet_creator, release_method, raw_token, - java_callback); - - // Note that this callback is called only once, and is not saved - // anywhere else, so we can and should delete it here. - env->DeleteGlobalRef(java_callback); - env->DeleteGlobalRef(packet_creator); - }; - } - mediapipe::Packet packet = mediapipe::MakePacket( - mediapipe::GlTextureBuffer::Wrap(GL_TEXTURE_2D, name, width, height, - mediapipe::GpuBufferFormat::kBGRA32, - gpu_resources->gl_context(), - cc_callback)); + mediapipe::Packet packet = + mediapipe::MakePacket(CreateGpuBuffer( + env, thiz, context, name, width, height, texture_release_callback)); return CreatePacketWithContext(context, packet); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index 0b448ae79..d6f44b0a3 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -97,6 +97,14 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data); +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( + JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, + jint height, jint num_channels); + +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( + JNIEnv* env, jobject thiz, jlong context, jint name, jint width, + jint height, jobject texture_release_callback); + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuBuffer)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, jint height, jobject texture_release_callback); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 0b40dd642..30ec19a25 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -21,6 +21,7 @@ #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" @@ -255,22 +256,43 @@ JNIEXPORT jdoubleArray JNICALL PACKET_GETTER_METHOD(nativeGetFloat64Vector)( JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageWidth)(JNIEnv* env, jobject thiz, jlong packet) { + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + const bool is_image = + mediapipe_packet.ValidateAsType().ok(); const mediapipe::ImageFrame& image = - GetFromNativeHandle(packet); + is_image ? *GetFromNativeHandle(packet) + .GetImageFrameSharedPtr() + .get() + : GetFromNativeHandle(packet); return image.Width(); } JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)( JNIEnv* env, jobject thiz, jlong packet) { + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + const bool is_image = + mediapipe_packet.ValidateAsType().ok(); const mediapipe::ImageFrame& image = - GetFromNativeHandle(packet); + is_image ? *GetFromNativeHandle(packet) + .GetImageFrameSharedPtr() + .get() + : GetFromNativeHandle(packet); return image.Height(); } JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer) { + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + const bool is_image = + mediapipe_packet.ValidateAsType().ok(); const mediapipe::ImageFrame& image = - GetFromNativeHandle(packet); + is_image ? *GetFromNativeHandle(packet) + .GetImageFrameSharedPtr() + .get() + : GetFromNativeHandle(packet); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); @@ -418,10 +440,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetGpuBufferName)( JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)(JNIEnv* env, jobject thiz, jlong packet) { - const mediapipe::GpuBuffer& gpu_buffer = - GetFromNativeHandle(packet); - const mediapipe::GlTextureBufferSharedPtr& ptr = - gpu_buffer.GetGlTextureBufferSharedPtr(); + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + mediapipe::GlTextureBufferSharedPtr ptr; + if (mediapipe_packet.ValidateAsType().ok()) { + const mediapipe::Image& buffer = mediapipe_packet.Get(); + ptr = buffer.GetGlTextureBufferSharedPtr(); + } else { + const mediapipe::GpuBuffer& buffer = + mediapipe_packet.Get(); + ptr = buffer.GetGlTextureBufferSharedPtr(); + } ptr->WaitUntilComplete(); return reinterpret_cast( new mediapipe::GlTextureBufferSharedPtr(ptr)); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index d01fb5316..2d71dd58f 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -24,6 +24,7 @@ #endif #include "mediapipe/java/com/google/mediapipe/framework/jni/compat_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h" +#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_context_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" @@ -114,7 +115,7 @@ void RegisterGraphNatives(JNIEnv *env) { std::string packet_list_callback_name = class_registry.GetClassName( mediapipe::android::ClassRegistry::kPacketListCallbackClassName); std::string native_add_multi_stream_callback_signature = - absl::StrFormat("(JLjava/util/List;L%s;)V", packet_list_callback_name); + absl::StrFormat("(JLjava/util/List;L%s;Z)V", packet_list_callback_name); AddJNINativeMethod(&graph_methods, graph, "nativeAddMultiStreamCallback", native_add_multi_stream_callback_signature.c_str(), (void *)&GRAPH_METHOD(nativeAddMultiStreamCallback)); @@ -133,9 +134,25 @@ void RegisterGraphNatives(JNIEnv *env) { (void *)&GRAPH_METHOD(nativeWaitUntilGraphDone)); AddJNINativeMethod(&graph_methods, graph, "nativeReleaseGraph", "(J)V", (void *)&GRAPH_METHOD(nativeReleaseGraph)); + AddJNINativeMethod(&graph_methods, graph, "nativeGetProfiler", "(J)J", + (void *)&GRAPH_METHOD(nativeGetProfiler)); RegisterNativesVector(env, graph_class, graph_methods); } +void RegisterGraphProfilerNatives(JNIEnv *env) { + auto &class_registry = mediapipe::android::ClassRegistry::GetInstance(); + std::string graph_profiler( + mediapipe::android::ClassRegistry::kGraphProfilerClassName); + std::string graph_profiler_name = class_registry.GetClassName(graph_profiler); + jclass graph_profiler_class = env->FindClass(graph_profiler_name.c_str()); + + std::vector graph_profiler_methods; + AddJNINativeMethod( + &graph_profiler_methods, graph_profiler, "nativeGetCalculatorProfiles", + "(J)[[B", (void *)&GRAPH_PROFILER_METHOD(nativeGetCalculatorProfiles)); + RegisterNativesVector(env, graph_profiler_class, graph_profiler_methods); +} + void RegisterAndroidAssetUtilNatives(JNIEnv *env) { #if defined(__ANDROID__) auto &class_registry = mediapipe::android::ClassRegistry::GetInstance(); @@ -249,6 +266,8 @@ void RegisterPacketNatives(JNIEnv *env) { (void *)&PACKET_METHOD(nativeCopyPacket)); AddJNINativeMethod(&packet_methods, packet, "nativeGetTimestamp", "(J)J", (void *)&PACKET_METHOD(nativeGetTimestamp)); + AddJNINativeMethod(&packet_methods, packet, "nativeIsEmpty", "(J)Z", + (void *)&PACKET_METHOD(nativeIsEmpty)); RegisterNativesVector(env, packet_class, packet_methods); } @@ -271,6 +290,7 @@ void RegisterCompatNatives(JNIEnv *env) { void RegisterAllNatives(JNIEnv *env) { RegisterGraphNatives(env); + RegisterGraphProfilerNatives(env); RegisterAndroidAssetUtilNatives(env); RegisterAndroidPacketCreatorNatives(env); RegisterPacketCreatorNatives(env); diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 7f8ee7079..a0c5a503a 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -12,58 +12,115 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generate MediaPipe AAR including different variants of .so in jni folder. +"""Generates MediaPipe AAR including different variants of .so in jni folder. Usage: -Create a new mediapipe_aar() target in a BUILD file. For example, +Creates a new mediapipe_aar() target in a BUILD file. For example, putting the following code into mediapipe/examples/android/aar_demo/BUILD. ``` load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar") mediapipe_aar( - name = "my_aar", + name = "demo", calculators = ["//mediapipe/calculators/core:pass_through_calculator"], ) ``` -Then, run the following Bazel command to generate the AAR. +Then, runs the following Bazel command to generate the aar. ``` -$ bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a mediapipe/examples/android/aar_demo:my_aar +$ bazel build --strip=always -s -c opt \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + mediapipe/examples/android/aar_demo:demo.aar ``` -Finally, import the AAR into Android Studio. +Finally, imports the aar into Android Studio. """ load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") -def mediapipe_aar(name, calculators = [], assets = [], assets_dir = ""): - """Generate MediaPipe AAR. +def mediapipe_aar( + name, + srcs = [], + calculators = [], + assets = [], + assets_dir = ""): + """Generates MediaPipe android archive library. Args: - name: the name of the AAR. - calculators: the calculator libraries to be compiled into the .so. + name: the name of the aar. + srcs: the additional java source code to be added into the android library. + calculators: the calculator libraries to be compiled into the jni library. assets: additional assets to be included into the archive. assets_dir: path where the assets will the packaged. """ - native.cc_binary( - name = "libmediapipe_jni.so", - linkshared = 1, - linkstatic = 1, + _mediapipe_jni( + name = name + "_jni", + calculators = calculators, + ) + + _mediapipe_proto( + name = name + "_proto", + ) + + android_library( + name = name + "_android_lib", + srcs = srcs + [ + "//mediapipe/java/com/google/mediapipe/components:java_src", + "//mediapipe/java/com/google/mediapipe/framework:java_src", + "//mediapipe/java/com/google/mediapipe/glutil:java_src", + "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", + "com/google/mediapipe/formats/proto/ClassificationProto.java", + "com/google/mediapipe/formats/proto/DetectionProto.java", + "com/google/mediapipe/formats/proto/LandmarkProto.java", + "com/google/mediapipe/formats/proto/LocationDataProto.java", + "com/google/mediapipe/proto/CalculatorProto.java", + ], + manifest = "AndroidManifest.xml", + proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], deps = [ - "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", - ] + calculators, + ":" + name + "_jni_cc_lib", + ":" + name + "_jni_opencv_cc_lib", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:mediapipe_options_java_proto_lite", + "//mediapipe/framework:packet_factory_java_proto_lite", + "//mediapipe/framework:packet_generator_java_proto_lite", + "//mediapipe/framework:status_handler_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", + "//third_party:androidx_annotation", + "//third_party:androidx_appcompat", + "//third_party:androidx_core", + "//third_party:androidx_legacy_support_v4", + "//third_party:autovalue", + "//third_party:camerax_core", + "//third_party:camerax_camera2", + "//third_party:camerax_lifecycle", + "@com_google_protobuf//:protobuf_java", + "@maven//:com_google_code_findbugs_jsr305", + "@maven//:com_google_flogger_flogger", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_guava_guava", + "@maven//:androidx_lifecycle_lifecycle_common", + ], + assets = assets, + assets_dir = assets_dir, ) - native.cc_library( - name = name + "_mediapipe_jni_lib", - srcs = [":libmediapipe_jni.so"], - alwayslink = 1, - ) + _aar_with_jni(name, name + "_android_lib") +def _mediapipe_proto(name): + """Generates MediaPipe java proto libraries. + + Args: + name: the name of the target. + """ native.genrule( name = name + "_aar_manifest_generator", outs = ["AndroidManifest.xml"], @@ -121,51 +178,15 @@ cat > $(OUTS) <_dummy_app target below) native.genrule( name = name + "_binary_manifest_generator", @@ -200,7 +254,7 @@ EOF """, ) - # Generate dummy apk including .so files. + # Generates dummy apk including .so files. # We extract out .so files and throw away the apk. android_binary( name = name + "_dummy_app", diff --git a/mediapipe/models/face_detection_back_labelmap.txt b/mediapipe/models/face_detection_back_labelmap.txt deleted file mode 100644 index fd770b6d1..000000000 --- a/mediapipe/models/face_detection_back_labelmap.txt +++ /dev/null @@ -1 +0,0 @@ -Face diff --git a/mediapipe/models/face_detection_front.tflite b/mediapipe/models/face_detection_front.tflite deleted file mode 100755 index 659bce896..000000000 Binary files a/mediapipe/models/face_detection_front.tflite and /dev/null differ diff --git a/mediapipe/models/face_detection_front_labelmap.txt b/mediapipe/models/face_detection_front_labelmap.txt deleted file mode 100644 index fd770b6d1..000000000 --- a/mediapipe/models/face_detection_front_labelmap.txt +++ /dev/null @@ -1 +0,0 @@ -Face diff --git a/mediapipe/models/face_landmark.tflite b/mediapipe/models/face_landmark.tflite deleted file mode 100644 index 9058eaa33..000000000 Binary files a/mediapipe/models/face_landmark.tflite and /dev/null differ diff --git a/mediapipe/models/face_detection_back.tflite b/mediapipe/modules/face_detection/face_detection_back.tflite similarity index 100% rename from mediapipe/models/face_detection_back.tflite rename to mediapipe/modules/face_detection/face_detection_back.tflite diff --git a/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt b/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt index e850d7cc1..835951376 100644 --- a/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt +++ b/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt @@ -19,19 +19,19 @@ # - "pose_detection.tflite" is available at # "mediapipe/modules/pose_detection/pose_detection.tflite" # -# - "pose_landmark_full_body.tflite" or "pose_landmark_upper_body.tflite" is -# available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# - "pose_landmark_lite.tflite" or "pose_landmark_full.tflite" or +# "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "HolisticLandmarkCpu" # input_stream: "IMAGE:input_video" -# input_side_packet: UPPER_BODY_ONLY:upper_body_only +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_side_packet: SMOOTH_LANDMARKS:smooth_landmarks # output_stream: "POSE_LANDMARKS:pose_landmarks" # output_stream: "FACE_LANDMARKS:face_landmarks" @@ -50,17 +50,17 @@ type: "HolisticLandmarkCpu" # CPU image. (ImageFrame) input_stream: "IMAGE:image" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" + # Whether to filter landmarks across different input images to reduce jitter. # If unspecified, functions as set to true. (bool) input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # Pose landmarks. (NormalizedLandmarkList) -# We have 33 landmarks or 25 landmarks if UPPER_BODY_ONLY is set to true. +# 33 pose landmarks. output_stream: "POSE_LANDMARKS:pose_landmarks" # 21 left hand landmarks. (NormalizedLandmarkList) output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" @@ -77,7 +77,7 @@ output_stream: "POSE_DETECTION:pose_detection" node { calculator: "PoseLandmarkCpu" input_stream: "IMAGE:image" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" output_stream: "LANDMARKS:pose_landmarks" output_stream: "ROI_FROM_LANDMARKS:pose_landmarks_roi" diff --git a/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt b/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt index 990b21d1d..21cf8d881 100644 --- a/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt +++ b/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt @@ -19,19 +19,19 @@ # - "pose_detection.tflite" is available at # "mediapipe/modules/pose_detection/pose_detection.tflite" # -# - "pose_landmark_full_body.tflite" or "pose_landmark_upper_body.tflite" is -# available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# - "pose_landmark_lite.tflite" or "pose_landmark_full.tflite" or +# "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "HolisticLandmarkGpu" # input_stream: "IMAGE:input_video" -# input_side_packet: UPPER_BODY_ONLY:upper_body_only +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_side_packet: SMOOTH_LANDMARKS:smooth_landmarks # output_stream: "POSE_LANDMARKS:pose_landmarks" # output_stream: "FACE_LANDMARKS:face_landmarks" @@ -50,17 +50,17 @@ type: "HolisticLandmarkGpu" # GPU image. (GpuBuffer) input_stream: "IMAGE:image" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" + # Whether to filter landmarks across different input images to reduce jitter. # If unspecified, functions as set to true. (bool) input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # Pose landmarks. (NormalizedLandmarkList) -# We have 33 landmarks or 25 landmarks if UPPER_BODY_ONLY is set to true. +# 33 pose landmarks. output_stream: "POSE_LANDMARKS:pose_landmarks" # 21 left hand landmarks. (NormalizedLandmarkList) output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" @@ -77,7 +77,7 @@ output_stream: "POSE_DETECTION:pose_detection" node { calculator: "PoseLandmarkGpu" input_stream: "IMAGE:image" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" output_stream: "LANDMARKS:pose_landmarks" output_stream: "ROI_FROM_LANDMARKS:pose_landmarks_roi" diff --git a/mediapipe/modules/pose_detection/pose_detection.tflite b/mediapipe/modules/pose_detection/pose_detection.tflite index 55deb3c20..57f89be58 100755 Binary files a/mediapipe/modules/pose_detection/pose_detection.tflite and b/mediapipe/modules/pose_detection/pose_detection.tflite differ diff --git a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt index 7f10b59a1..d46d6a5c5 100644 --- a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt +++ b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt @@ -36,7 +36,7 @@ input_stream: "IMAGE:image" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Transforms the input image into a 128x128 while keeping the aspect ratio +# Transforms the input image into a 224x224 one while keeping the aspect ratio # (what is expected by the corresponding model), resulting in potential # letterboxing in the transformed image. node: { @@ -46,8 +46,8 @@ node: { output_stream: "LETTERBOX_PADDING:letterbox_padding" options: { [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 + output_tensor_width: 224 + output_tensor_height: 224 keep_aspect_ratio: true output_tensor_float_range { min: -1.0 @@ -74,6 +74,7 @@ node { model_path: "mediapipe/modules/pose_detection/pose_detection.tflite" delegate { xnnpack {} } } + # } } @@ -84,17 +85,18 @@ node { output_side_packet: "anchors" options: { [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 + num_layers: 5 min_scale: 0.1484375 max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 + input_size_height: 224 + input_size_width: 224 anchor_offset_x: 0.5 anchor_offset_y: 0.5 strides: 8 strides: 16 - strides: 16 - strides: 16 + strides: 32 + strides: 32 + strides: 32 aspect_ratios: 1.0 fixed_anchor_size: true } @@ -112,7 +114,7 @@ node { options: { [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { num_classes: 1 - num_boxes: 896 + num_boxes: 2254 num_coords: 12 box_coord_offset: 0 keypoint_coord_offset: 4 @@ -121,10 +123,10 @@ node { sigmoid_score: true score_clipping_thresh: 100.0 reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 + x_scale: 224.0 + y_scale: 224.0 + h_scale: 224.0 + w_scale: 224.0 min_score_thresh: 0.5 } } diff --git a/mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt b/mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt index b61f44477..98917d910 100644 --- a/mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt +++ b/mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt @@ -36,7 +36,7 @@ input_stream: "IMAGE:image" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" -# Transforms the input image into a 128x128 while keeping the aspect ratio +# Transforms the input image into a 224x224 one while keeping the aspect ratio # (what is expected by the corresponding model), resulting in potential # letterboxing in the transformed image. node: { @@ -46,8 +46,8 @@ node: { output_stream: "LETTERBOX_PADDING:letterbox_padding" options: { [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 128 - output_tensor_height: 128 + output_tensor_width: 224 + output_tensor_height: 224 keep_aspect_ratio: true output_tensor_float_range { min: -1.0 @@ -80,17 +80,18 @@ node { output_side_packet: "anchors" options: { [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 + num_layers: 5 min_scale: 0.1484375 max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 + input_size_height: 224 + input_size_width: 224 anchor_offset_x: 0.5 anchor_offset_y: 0.5 strides: 8 strides: 16 - strides: 16 - strides: 16 + strides: 32 + strides: 32 + strides: 32 aspect_ratios: 1.0 fixed_anchor_size: true } @@ -108,7 +109,7 @@ node { options: { [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { num_classes: 1 - num_boxes: 896 + num_boxes: 2254 num_coords: 12 box_coord_offset: 0 keypoint_coord_offset: 4 @@ -117,10 +118,10 @@ node { sigmoid_score: true score_clipping_thresh: 100.0 reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 + x_scale: 224.0 + y_scale: 224.0 + h_scale: 224.0 + w_scale: 224.0 min_score_thresh: 0.5 } } diff --git a/mediapipe/modules/pose_landmark/BUILD b/mediapipe/modules/pose_landmark/BUILD index a8d4008ff..90edbb8a0 100644 --- a/mediapipe/modules/pose_landmark/BUILD +++ b/mediapipe/modules/pose_landmark/BUILD @@ -48,8 +48,8 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator", "//mediapipe/calculators/util:thresholding_calculator", - "//mediapipe/framework/tool:switch_container", ], ) @@ -68,8 +68,8 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator", "//mediapipe/calculators/util:thresholding_calculator", - "//mediapipe/framework/tool:switch_container", ], ) @@ -124,30 +124,11 @@ mediapipe_simple_subgraph( ], ) -mediapipe_simple_subgraph( - name = "pose_landmark_upper_body_gpu", - graph = "pose_landmark_upper_body_gpu.pbtxt", - register_as = "PoseLandmarkUpperBodyGpu", - deps = [ - ":pose_landmark_gpu", - "//mediapipe/calculators/core:constant_side_packet_calculator", - ], -) - -mediapipe_simple_subgraph( - name = "pose_landmark_upper_body_cpu", - graph = "pose_landmark_upper_body_cpu.pbtxt", - register_as = "PoseLandmarkUpperBodyCpu", - deps = [ - ":pose_landmark_cpu", - "//mediapipe/calculators/core:constant_side_packet_calculator", - ], -) - exports_files( srcs = [ - "pose_landmark_full_body.tflite", - "pose_landmark_upper_body.tflite", + "pose_landmark_full.tflite", + "pose_landmark_heavy.tflite", + "pose_landmark_lite.tflite", ], ) @@ -158,7 +139,6 @@ mediapipe_simple_subgraph( deps = [ "//mediapipe/calculators/util:alignment_points_to_rects_calculator", "//mediapipe/calculators/util:rect_transformation_calculator", - "//mediapipe/framework/tool:switch_container", ], ) diff --git a/mediapipe/modules/pose_landmark/README.md b/mediapipe/modules/pose_landmark/README.md index 35f4062aa..57528382a 100644 --- a/mediapipe/modules/pose_landmark/README.md +++ b/mediapipe/modules/pose_landmark/README.md @@ -2,9 +2,7 @@ Subgraphs|Details :--- | :--- -[`PoseLandmarkByRoiCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt)| Detects landmarks of a single body pose, full-body by default but can be configured (via an input side packet) to cover upper-body only. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg). (CPU input, and inference is executed on CPU.) -[`PoseLandmarkByRoiGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt)| Detects landmarks of a single body pose, full-body by default but can be configured (via an input side packet) to cover upper-body only. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg). (GPU input, and inference is executed on GPU) -[`PoseLandmarkCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt)| Detects landmarks of a single body pose, full-body by default but can be configured (via an input side packet) to cover upper-body only. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg). (CPU input, and inference is executed on CPU) -[`PoseLandmarkGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt)| Detects landmarks of a single body pose, full-body by default but can be configured (via an input side packet) to cover upper-body only. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg). (GPU input, and inference is executed on GPU.) -[`PoseLandmarkUpperBodyCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_cpu.pbtxt)| Detects and tracks landmarks of a single upper-body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_topology.svg). (CPU input, and inference is executed on CPU) -[`PoseLandmarkUpperBodyGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_gpu.pbtxt)| Detects and tracks landmarks of a single upper-body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_topology.svg). (GPU input, and inference is executed on GPU.) +[`PoseLandmarkByRoiCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt)| Detects landmarks of a single body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_topology.svg). (CPU input, and inference is executed on CPU.) +[`PoseLandmarkByRoiGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt)| Detects landmarks of a single body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_topology.svg). (GPU input, and inference is executed on GPU) +[`PoseLandmarkCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt)| Detects landmarks of a single body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_topology.svg). (CPU input, and inference is executed on CPU) +[`PoseLandmarkGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt)| Detects landmarks of a single body pose. See landmarks (aka keypoints) [scheme](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_topology.svg). (GPU input, and inference is executed on GPU.) diff --git a/mediapipe/modules/pose_landmark/pose_detection_to_roi.pbtxt b/mediapipe/modules/pose_landmark/pose_detection_to_roi.pbtxt index b348837cb..47f82bbaf 100644 --- a/mediapipe/modules/pose_landmark/pose_detection_to_roi.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_detection_to_roi.pbtxt @@ -9,44 +9,22 @@ type: "PoseDetectionToRoi" input_stream: "DETECTION:detection" # Frame size (width and height). (std::pair) input_stream: "IMAGE_SIZE:image_size" -# Whether to detect/predict the full set of pose landmarks, or only those on the -# upper body. If unspecified, functions as set to false. (bool) -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" # ROI according to the first detection of input detections. (NormalizedRect) output_stream: "ROI:roi" # Converts pose detection into a rectangle based on center and scale alignment -# points. Pose detection contains four key points: first two for full-body pose -# and two more for upper-body pose. +# points. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "AlignmentPointsRectsCalculator" input_stream: "DETECTION:detection" input_stream: "IMAGE_SIZE:image_size" output_stream: "NORM_RECT:raw_roi" - options { - [mediapipe.SwitchContainerOptions.ext] { - contained_node: { - calculator: "AlignmentPointsRectsCalculator" - options: { - [mediapipe.DetectionsToRectsCalculatorOptions.ext] { - rotation_vector_start_keypoint_index: 0 - rotation_vector_end_keypoint_index: 1 - rotation_vector_target_angle_degrees: 90 - } - } - } - contained_node: { - calculator: "AlignmentPointsRectsCalculator" - options: { - [mediapipe.DetectionsToRectsCalculatorOptions.ext] { - rotation_vector_start_keypoint_index: 2 - rotation_vector_end_keypoint_index: 3 - rotation_vector_target_angle_degrees: 90 - } - } - } + options: { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 1 + rotation_vector_target_angle_degrees: 90 } } } @@ -59,8 +37,8 @@ node { output_stream: "roi" options: { [mediapipe.RectTransformationCalculatorOptions.ext] { - scale_x: 1.5 - scale_y: 1.5 + scale_x: 1.25 + scale_y: 1.25 square_long: true } } diff --git a/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt index 3ff3d9897..c4527f95c 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_by_roi_cpu.pbtxt @@ -1,18 +1,18 @@ -# MediaPipe graph to detect/predict upper-body pose landmarks. (CPU input, and -# inference is executed on CPU.) +# MediaPipe graph to detect/predict pose landmarks. (CPU input, and inference is +# executed on CPU.) # -# It is required that "pose_landmark_full_body.tflite" or -# "pose_landmark_upper_body.tflite" is available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# It is required that "pose_landmark_lite.tflite" or +# "pose_landmark_full.tflite" or "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "PoseLandmarkByRoiCpu" -# input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_stream: "IMAGE:image" # input_stream: "ROI:roi" # output_stream: "LANDMARKS:landmarks" @@ -26,16 +26,14 @@ input_stream: "IMAGE:image" # (NormalizedRect) input_stream: "ROI:roi" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" # Pose landmarks within the given ROI. (NormalizedLandmarkList) -# We have 33 landmarks (see pose_landmark_full_body_topology.svg) with the -# first 25 fall on the upper body (see pose_landmark_upper_body_topology.svg), -# and there are other auxiliary key points. +# We have 33 landmarks (see pose_landmark_topology.svg) and there are other +# auxiliary key points. # 0 - nose # 1 - left eye (inner) # 2 - left eye @@ -104,7 +102,7 @@ node: { # Loads the pose landmark TF Lite model. node { calculator: "PoseLandmarkModelLoader" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" output_side_packet: "MODEL:model" } @@ -128,10 +126,12 @@ node { input_stream: "output_tensors" output_stream: "landmark_tensors" output_stream: "pose_flag_tensor" + output_stream: "heatmap_tensor" options: { [mediapipe.SplitVectorCalculatorOptions.ext] { ranges: { begin: 0 end: 1 } ranges: { begin: 1 end: 2 } + ranges: { begin: 3 end: 4 } } } } @@ -168,36 +168,29 @@ node { # Decodes the landmark tensors into a vector of landmarks, where the landmark # coordinates are normalized by the size of the input image to the model. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "TensorsToLandmarksCalculator" input_stream: "TENSORS:ensured_landmark_tensors" output_stream: "NORM_LANDMARKS:raw_landmarks" options: { - [mediapipe.SwitchContainerOptions.ext] { - contained_node: { - calculator: "TensorsToLandmarksCalculator" - options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 35 - input_image_width: 256 - input_image_height: 256 - visibility_activation: SIGMOID - presence_activation: SIGMOID - } - } - } - contained_node: { - calculator: "TensorsToLandmarksCalculator" - options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 27 - input_image_width: 256 - input_image_height: 256 - visibility_activation: SIGMOID - presence_activation: SIGMOID - } - } - } + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 39 + input_image_width: 256 + input_image_height: 256 + visibility_activation: SIGMOID + presence_activation: SIGMOID + } + } +} + +# Refines landmarks with the heatmap tensor. +node { + calculator: "RefineLandmarksFromHeatmapCalculator" + input_stream: "NORM_LANDMARKS:raw_landmarks" + input_stream: "TENSORS:heatmap_tensor" + output_stream: "NORM_LANDMARKS:refined_landmarks" + options: { + [mediapipe.RefineLandmarksFromHeatmapCalculatorOptions.ext] { + kernel_size: 7 } } } @@ -208,7 +201,7 @@ node { # image before image transformation). node { calculator: "LandmarkLetterboxRemovalCalculator" - input_stream: "LANDMARKS:raw_landmarks" + input_stream: "LANDMARKS:refined_landmarks" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "LANDMARKS:adjusted_landmarks" } @@ -225,31 +218,14 @@ node { # Splits the landmarks into two sets: the actual pose landmarks and the # auxiliary landmarks. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "SplitNormalizedLandmarkListCalculator" input_stream: "all_landmarks" output_stream: "landmarks" output_stream: "auxiliary_landmarks" options: { - [mediapipe.SwitchContainerOptions.ext] { - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 33 } - ranges: { begin: 33 end: 35 } - } - } - } - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 25 } - ranges: { begin: 25 end: 27 } - } - } - } + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 33 } + ranges: { begin: 33 end: 35 } } } } diff --git a/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt index 75ede0774..0ffa50f77 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_by_roi_gpu.pbtxt @@ -1,18 +1,18 @@ -# MediaPipe graph to detect/predict upper-body pose landmarks. (GPU input, and -# inference is executed on GPU.) +# MediaPipe graph to detect/predict pose landmarks. (GPU input, and inference is +# executed on GPU.) # -# It is required that "pose_landmark_full_body.tflite" or -# "pose_landmark_upper_body.tflite" is available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# It is required that "pose_landmark_lite.tflite" or +# "pose_landmark_full.tflite" or "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "PoseLandmarkByRoiGpu" -# input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_stream: "IMAGE:image" # input_stream: "ROI:roi" # output_stream: "LANDMARKS:landmarks" @@ -26,16 +26,14 @@ input_stream: "IMAGE:image" # (NormalizedRect) input_stream: "ROI:roi" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" # Pose landmarks within the given ROI. (NormalizedLandmarkList) -# We have 33 landmarks (see pose_landmark_full_body_topology.svg) with the -# first 25 fall on the upper body (see pose_landmark_upper_body_topology.svg), -# and there are other auxiliary key points. +# We have 33 landmarks (see pose_landmark_topology.svg), and there are other +# auxiliary key points. # 0 - nose # 1 - left eye (inner) # 2 - left eye @@ -105,7 +103,7 @@ node: { # Loads the pose landmark TF Lite model. node { calculator: "PoseLandmarkModelLoader" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" output_side_packet: "MODEL:model" } @@ -115,6 +113,15 @@ node { input_side_packet: "MODEL:model" input_stream: "TENSORS:input_tensors" output_stream: "TENSORS:output_tensors" + options: { + [mediapipe.InferenceCalculatorOptions.ext] { + delegate { + gpu { + allow_precision_loss: false + } + } + } + } } # Splits a vector of TFLite tensors to multiple vectors according to the ranges @@ -124,10 +131,12 @@ node { input_stream: "output_tensors" output_stream: "landmark_tensors" output_stream: "pose_flag_tensor" + output_stream: "heatmap_tensor" options: { [mediapipe.SplitVectorCalculatorOptions.ext] { ranges: { begin: 0 end: 1 } ranges: { begin: 1 end: 2 } + ranges: { begin: 3 end: 4 } } } } @@ -164,36 +173,29 @@ node { # Decodes the landmark tensors into a vector of landmarks, where the landmark # coordinates are normalized by the size of the input image to the model. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "TensorsToLandmarksCalculator" input_stream: "TENSORS:ensured_landmark_tensors" output_stream: "NORM_LANDMARKS:raw_landmarks" options: { - [mediapipe.SwitchContainerOptions.ext] { - contained_node: { - calculator: "TensorsToLandmarksCalculator" - options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 35 - input_image_width: 256 - input_image_height: 256 - visibility_activation: SIGMOID - presence_activation: SIGMOID - } - } - } - contained_node: { - calculator: "TensorsToLandmarksCalculator" - options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 27 - input_image_width: 256 - input_image_height: 256 - visibility_activation: SIGMOID - presence_activation: SIGMOID - } - } - } + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 39 + input_image_width: 256 + input_image_height: 256 + visibility_activation: SIGMOID + presence_activation: SIGMOID + } + } +} + +# Refines landmarks with the heatmap tensor. +node { + calculator: "RefineLandmarksFromHeatmapCalculator" + input_stream: "NORM_LANDMARKS:raw_landmarks" + input_stream: "TENSORS:heatmap_tensor" + output_stream: "NORM_LANDMARKS:refined_landmarks" + options: { + [mediapipe.RefineLandmarksFromHeatmapCalculatorOptions.ext] { + kernel_size: 7 } } } @@ -204,7 +206,7 @@ node { # image before image transformation). node { calculator: "LandmarkLetterboxRemovalCalculator" - input_stream: "LANDMARKS:raw_landmarks" + input_stream: "LANDMARKS:refined_landmarks" input_stream: "LETTERBOX_PADDING:letterbox_padding" output_stream: "LANDMARKS:adjusted_landmarks" } @@ -221,31 +223,14 @@ node { # Splits the landmarks into two sets: the actual pose landmarks and the # auxiliary landmarks. node { - calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + calculator: "SplitNormalizedLandmarkListCalculator" input_stream: "all_landmarks" output_stream: "landmarks" output_stream: "auxiliary_landmarks" options: { - [mediapipe.SwitchContainerOptions.ext] { - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 33 } - ranges: { begin: 33 end: 35 } - } - } - } - contained_node: { - calculator: "SplitNormalizedLandmarkListCalculator" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 25 } - ranges: { begin: 25 end: 27 } - } - } - } + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 33 } + ranges: { begin: 33 end: 35 } } } } diff --git a/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt index a2da957c6..78513ca70 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt @@ -6,18 +6,18 @@ # "mediapipe/modules/pose_detection/pose_detection.tflite" # path during execution. # -# It is required that "pose_landmark_full_body.tflite" or -# "pose_landmark_upper_body.tflite" is available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# It is required that "pose_landmark_lite.tflite" or +# "pose_landmark_full.tflite" or "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "PoseLandmarkCpu" -# input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_stream: "IMAGE:image" # output_stream: "LANDMARKS:pose_landmarks" @@ -28,20 +28,18 @@ type: "PoseLandmarkCpu" # CPU image. (ImageFrame) input_stream: "IMAGE:image" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" - # Whether to filter landmarks across different input images to reduce jitter. # If unspecified, functions as set to false. (bool) input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" + # Pose landmarks within the given ROI. (NormalizedLandmarkList) -# We have 33 landmarks (see pose_landmark_full_body_topology.svg) with the -# first 25 fall on the upper body (see pose_landmark_upper_body_topology.svg), -# and there are other auxiliary key points. +# We have 33 landmarks (see pose_landmark_topology.svg), and there are other +# auxiliary key points. # 0 - nose # 1 - left eye (inner) # 2 - left eye @@ -164,7 +162,6 @@ node { # to detect landmarks. node { calculator: "PoseDetectionToRoi" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" input_stream: "DETECTION:pose_detection" input_stream: "IMAGE_SIZE:image_size_for_pose_detection" output_stream: "ROI:pose_rect_from_detection" @@ -183,7 +180,7 @@ node { # Detects pose landmarks within specified region of interest of the image. node { calculator: "PoseLandmarkByRoiCpu" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_stream: "IMAGE:image" input_stream: "ROI:pose_rect" output_stream: "LANDMARKS:unfiltered_pose_landmarks" diff --git a/mediapipe/modules/pose_landmark/pose_landmark_full.tflite b/mediapipe/modules/pose_landmark/pose_landmark_full.tflite new file mode 100755 index 000000000..922be2044 Binary files /dev/null and b/mediapipe/modules/pose_landmark/pose_landmark_full.tflite differ diff --git a/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite b/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite deleted file mode 100755 index 713130c2e..000000000 Binary files a/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite and /dev/null differ diff --git a/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt index 72ed00f5f..4acd5dc59 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt @@ -6,18 +6,18 @@ # "mediapipe/modules/pose_detection/pose_detection.tflite" # path during execution. # -# It is required that "pose_landmark_full_body.tflite" or -# "pose_landmark_upper_body.tflite" is available at -# "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" -# or -# "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" +# It is required that "pose_landmark_lite.tflite" or +# "pose_landmark_full.tflite" or "pose_landmark_heavy.tflite" is available at +# "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" or +# "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" # path respectively during execution, depending on the specification in the -# UPPER_BODY_ONLY input side packet. +# MODEL_COMPLEXITY input side packet. # # EXAMPLE: # node { # calculator: "PoseLandmarkGpu" -# input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# input_side_packet: "MODEL_COMPLEXITY:model_complexity" # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_stream: "IMAGE:image" # output_stream: "LANDMARKS:pose_landmarks" @@ -28,20 +28,18 @@ type: "PoseLandmarkGpu" # GPU image. (GpuBuffer) input_stream: "IMAGE:image" -# Whether to detect/predict the full set of pose landmarks (see below), or only -# those on the upper body. If unspecified, functions as set to false. (bool) -# Note that upper-body-only prediction may be more accurate for use cases where -# the lower-body parts are mostly out of view. -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" - # Whether to filter landmarks across different input images to reduce jitter. # If unspecified, functions as set to false. (bool) input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" + # Pose landmarks within the given ROI. (NormalizedLandmarkList) -# We have 33 landmarks (see pose_landmark_full_body_topology.svg) with the -# first 25 fall on the upper body (see pose_landmark_upper_body_topology.svg), -# and there are other auxiliary key points. +# We have 33 landmarks (see pose_landmark_topology.svg), and there are other +# auxiliary key points. # 0 - nose # 1 - left eye (inner) # 2 - left eye @@ -164,7 +162,6 @@ node { # to detect landmarks. node { calculator: "PoseDetectionToRoi" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" input_stream: "DETECTION:pose_detection" input_stream: "IMAGE_SIZE:image_size_for_pose_detection" output_stream: "ROI:pose_rect_from_detection" @@ -183,7 +180,7 @@ node { # Detects pose landmarks within specified region of interest of the image. node { calculator: "PoseLandmarkByRoiGpu" - input_side_packet: "UPPER_BODY_ONLY:upper_body_only" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_stream: "IMAGE:image" input_stream: "ROI:pose_rect" output_stream: "LANDMARKS:unfiltered_pose_landmarks" diff --git a/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite b/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite new file mode 100755 index 000000000..e72fc0369 Binary files /dev/null and b/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite differ diff --git a/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite b/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite new file mode 100755 index 000000000..54e217be7 Binary files /dev/null and b/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite differ diff --git a/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt index d40cf8416..d5a912b6d 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt @@ -2,18 +2,19 @@ type: "PoseLandmarkModelLoader" -# Whether to load the full-body landmark model or the upper-body on. (bool) -input_side_packet: "UPPER_BODY_ONLY:upper_body_only" +# Complexity of the pose landmark model: 0, 1 or 2. Landmark accuracy as well as +# inference latency generally go up with the model complexity. If unspecified, +# functions as set to 0. (int) +input_side_packet: "MODEL_COMPLEXITY:model_complexity" # TF Lite model represented as a FlatBuffer. # (std::unique_ptr>) output_side_packet: "MODEL:model" -# Determines path to the desired pose landmark model file based on specification -# in the input side packet. +# Determines path to the desired pose landmark model file. node { calculator: "SwitchContainer" - input_side_packet: "ENABLE:upper_body_only" + input_side_packet: "SELECT:model_complexity" output_side_packet: "PACKET:model_path" options: { [mediapipe.SwitchContainerOptions.ext] { @@ -22,7 +23,7 @@ node { options: { [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { packet { - string_value: "mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite" + string_value: "mediapipe/modules/pose_landmark/pose_landmark_lite.tflite" } } } @@ -32,7 +33,17 @@ node { options: { [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { packet { - string_value: "mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite" + string_value: "mediapipe/modules/pose_landmark/pose_landmark_full.tflite" + } + } + } + } + contained_node: { + calculator: "ConstantSidePacketCalculator" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { + string_value: "mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite" } } } diff --git a/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg b/mediapipe/modules/pose_landmark/pose_landmark_topology.svg similarity index 99% rename from mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg rename to mediapipe/modules/pose_landmark/pose_landmark_topology.svg index bc4afa734..a57269ded 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_full_body_topology.svg +++ b/mediapipe/modules/pose_landmark/pose_landmark_topology.svg @@ -11,7 +11,7 @@ width="1000" height="1400" viewBox="0 0 1000 1400" - sodipodi:docname="pose_landmark_full_body_topology.svg"> + sodipodi:docname="pose_landmark_topology.svg"> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 6 - 5 - 4 - 1 - 2 - 3 - 0 - 8 - 7 - 10 - 9 - 12 - 11 - 21 - 22 - 20 - 18 - 16 - 14 - 13 - 15 - 17 - 19 - 23 - 24 - - diff --git a/mediapipe/modules/pose_landmark/pose_landmarks_to_roi.pbtxt b/mediapipe/modules/pose_landmark/pose_landmarks_to_roi.pbtxt index 3d7fd28b2..b1fe0e3be 100644 --- a/mediapipe/modules/pose_landmark/pose_landmarks_to_roi.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmarks_to_roi.pbtxt @@ -43,8 +43,8 @@ node { output_stream: "roi" options: { [mediapipe.RectTransformationCalculatorOptions.ext] { - scale_x: 1.5 - scale_y: 1.5 + scale_x: 1.25 + scale_y: 1.25 square_long: true } } diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index a63f618c5..e622aba6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -72,6 +72,7 @@ objc_library( ":util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:mediapipe_profiling", + "//mediapipe/framework/formats:image", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:source_location", diff --git a/mediapipe/objc/MPPGraph.h b/mediapipe/objc/MPPGraph.h index 06bf1552b..c9c06cd36 100644 --- a/mediapipe/objc/MPPGraph.h +++ b/mediapipe/objc/MPPGraph.h @@ -59,16 +59,21 @@ typedef NS_ENUM(int, MPPPacketType) { /// Calls mediapipeGraph:didOutputPacket:fromStream: MPPPacketTypeRaw, - /// CFHolder. + /// GpuBuffer packet. /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: /// Use this packet type to pass GPU frames to calculators. MPPPacketTypePixelBuffer, - /// ImageFrame. + /// Image packet. + /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: + /// Use this packet type to pass GPU frames to calculators. + MPPPacketTypeImage, + + /// ImageFrame packet. /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: MPPPacketTypeImageFrame, - /// RGBA ImageFrame, but do not swap the channels if the input pixel buffer + /// RGBA ImageFrame packet, but do not swap the channels if the input pixel buffer /// is BGRA. This is useful when the graph needs RGBA ImageFrames, but the /// calculators do not care about the order of the channels, so BGRA data can /// be used as-is. @@ -164,6 +169,9 @@ typedef NS_ENUM(int, MPPPacketType) { - (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)pixelBuffer packetType:(MPPPacketType)packetType; +/// Creates a MediaPipe packet of type Image, wrapping the given CVPixelBufferRef. +- (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer; + /// Sends a pixel buffer into a graph input stream, using the specified packet /// type. The graph must have been started before calling this. Drops frames and /// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES, diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 04a96bb1b..bc9eff69f 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -21,6 +21,7 @@ #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" #include "mediapipe/gpu/MPPGraphGPUData.h" @@ -163,10 +164,15 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, _GTMDevLog(@"unsupported ImageFormat: %d", format); } #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - } else if (packetType == MPPPacketTypePixelBuffer) { + } else if (packetType == MPPPacketTypePixelBuffer || + packetType == MPPPacketTypeImage) { wrapper->_framesInFlight--; - CVPixelBufferRef pixelBuffer = packet.Get().GetCVPixelBufferRef(); - if ([wrapper.delegate + CVPixelBufferRef pixelBuffer; + if (packetType == MPPPacketTypePixelBuffer) + pixelBuffer = packet.Get().GetCVPixelBufferRef(); + else + pixelBuffer = packet.Get().GetCVPixelBufferRef(); +if ([wrapper.delegate respondsToSelector:@selector (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { [wrapper.delegate mediapipeGraph:wrapper @@ -315,6 +321,16 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } else if (packetType == MPPPacketTypePixelBuffer) { packet = mediapipe::MakePacket(imageBuffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } else if (packetType == MPPPacketTypeImage) { +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // GPU + packet = mediapipe::MakePacket(imageBuffer); +#else + // CPU + auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer, /* canOverwrite = */ false, + /* bgrAsRgb = */ false); + packet = mediapipe::MakePacket(std::move(frame)); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } else { _GTMDevLog(@"unsupported packet type: %d", packetType); @@ -322,6 +338,10 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, return packet; } +- (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer { + return [self packetWithPixelBuffer:(pixelBuffer) packetType:(MPPPacketTypeImage)]; +} + - (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer intoStream:(const std::string&)inputName packetType:(MPPPacketType)packetType diff --git a/mediapipe/objc/MPPGraphTests.mm b/mediapipe/objc/MPPGraphTests.mm index 7c1ea8e06..c3cf48047 100644 --- a/mediapipe/objc/MPPGraphTests.mm +++ b/mediapipe/objc/MPPGraphTests.mm @@ -16,6 +16,7 @@ #import #include "absl/memory/memory.h" +#include "mediapipe/framework/formats/image.h" #import "mediapipe/objc/MPPGraph.h" #import "mediapipe/objc/MPPGraphTestBase.h" #import "mediapipe/objc/NSError+util_status.h" @@ -333,4 +334,21 @@ REGISTER_CALCULATOR(ErrorCalculator); [self waitForExpectationsWithTimeout:3.0 handler:NULL]; } +- (void)testPixelBufferToImage { + CFHolder pixelBufferIn; + absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &pixelBufferIn); + XCTAssert(status.ok()); + + mediapipe::CalculatorGraphConfig config; + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + + mediapipe::Packet packet = [_graph imagePacketWithPixelBuffer:*pixelBufferIn]; + CVPixelBufferRef pixelBufferOut = packet.Get().GetCVPixelBufferRef(); + + XCTAssertTrue([self pixelBuffer:*pixelBufferIn + isCloseTo:pixelBufferOut + maxLocalDifference:0 + maxAverageDifference:0]); +} + @end diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 3bd188de4..08a299589 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -34,6 +34,7 @@ pybind_extension( deps = [ ":builtin_calculators", "//mediapipe/python/pybind:calculator_graph", + "//mediapipe/python/pybind:image", "//mediapipe/python/pybind:image_frame", "//mediapipe/python/pybind:matrix", "//mediapipe/python/pybind:packet", diff --git a/mediapipe/python/__init__.py b/mediapipe/python/__init__.py index 5cc9c8155..9f00ffbcd 100644 --- a/mediapipe/python/__init__.py +++ b/mediapipe/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 The MediaPipe Authors. +# Copyright 2020-2021 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ from mediapipe.python._framework_bindings import resource_util from mediapipe.python._framework_bindings.calculator_graph import CalculatorGraph from mediapipe.python._framework_bindings.calculator_graph import GraphInputStreamAddMode +from mediapipe.python._framework_bindings.image import Image from mediapipe.python._framework_bindings.image_frame import ImageFormat from mediapipe.python._framework_bindings.image_frame import ImageFrame from mediapipe.python._framework_bindings.matrix import Matrix diff --git a/mediapipe/python/framework_bindings.cc b/mediapipe/python/framework_bindings.cc index 0238fac47..d4022d9df 100644 --- a/mediapipe/python/framework_bindings.cc +++ b/mediapipe/python/framework_bindings.cc @@ -1,4 +1,4 @@ -// Copyright 2020 The MediaPipe Authors. +// Copyright 2020-2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/python/pybind/calculator_graph.h" +#include "mediapipe/python/pybind/image.h" #include "mediapipe/python/pybind/image_frame.h" #include "mediapipe/python/pybind/matrix.h" #include "mediapipe/python/pybind/packet.h" @@ -27,6 +28,7 @@ namespace python { PYBIND11_MODULE(_framework_bindings, m) { ResourceUtilSubmodule(&m); + ImageSubmodule(&m); ImageFrameSubmodule(&m); MatrixSubmodule(&m); TimestampSubmodule(&m); diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py new file mode 100644 index 000000000..d942b5a72 --- /dev/null +++ b/mediapipe/python/image_test.py @@ -0,0 +1,183 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for mediapipe.python._framework_bindings.image.""" + +import gc +import random +import sys +from absl.testing import absltest +import cv2 +import mediapipe as mp +import numpy as np +import PIL.Image + + +# TODO: Add unit tests specifically for memory management. +class ImageTest(absltest.TestCase): + + def test_create_image_from_gray_cv_mat(self): + w, h = random.randrange(3, 100), random.randrange(3, 100) + mat = cv2.cvtColor( + np.random.randint(2**8 - 1, size=(h, w, 3), dtype=np.uint8), + cv2.COLOR_RGB2GRAY) + mat[2, 2] = 42 + image = mp.Image(image_format=mp.ImageFormat.GRAY8, data=mat) + self.assertTrue(np.array_equal(mat, image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'index dimension mismatch'): + print(image[w, h, 1]) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[w, h]) + self.assertEqual(42, image[2, 2]) + + def test_create_image_from_rgb_cv_mat(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + mat = cv2.cvtColor( + np.random.randint(2**8 - 1, size=(h, w, channels), dtype=np.uint8), + cv2.COLOR_RGB2BGR) + mat[2, 2, 1] = 42 + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mat) + self.assertTrue(np.array_equal(mat, image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[w, h, channels]) + self.assertEqual(42, image[2, 2, 1]) + + def test_create_image_from_rgb48_cv_mat(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + mat = cv2.cvtColor( + np.random.randint(2**16 - 1, size=(h, w, channels), dtype=np.uint16), + cv2.COLOR_RGB2BGR) + mat[2, 2, 1] = 42 + image = mp.Image(image_format=mp.ImageFormat.SRGB48, data=mat) + self.assertTrue(np.array_equal(mat, image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[w, h, channels]) + self.assertEqual(42, image[2, 2, 1]) + + def test_create_image_from_gray_pil_image(self): + w, h = random.randrange(3, 100), random.randrange(3, 100) + img = PIL.Image.fromarray( + np.random.randint(2**8 - 1, size=(h, w), dtype=np.uint8), 'L') + image = mp.Image(image_format=mp.ImageFormat.GRAY8, data=np.asarray(img)) + self.assertTrue(np.array_equal(np.asarray(img), image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'index dimension mismatch'): + print(image[w, h, 1]) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[w, h]) + + def test_create_image_from_rgb_pil_image(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + img = PIL.Image.fromarray( + np.random.randint(2**8 - 1, size=(h, w, channels), dtype=np.uint8), + 'RGB') + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(img)) + self.assertTrue(np.array_equal(np.asarray(img), image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[w, h, channels]) + + def test_create_image_from_rgba64_pil_image(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 4 + img = PIL.Image.fromarray( + np.random.randint(2**16 - 1, size=(h, w, channels), dtype=np.uint16), + 'RGBA') + image = mp.Image( + image_format=mp.ImageFormat.SRGBA64, + data=np.asarray(img, dtype=np.uint16)) + self.assertTrue(np.array_equal(np.asarray(img), image.numpy_view())) + with self.assertRaisesRegex(IndexError, 'out of bounds'): + print(image[1000, 1000, 1000]) + + def test_image_numby_view(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + mat = cv2.cvtColor( + np.random.randint(2**8 - 1, size=(h, w, channels), dtype=np.uint8), + cv2.COLOR_RGB2BGR) + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mat) + output_ndarray = image.numpy_view() + self.assertTrue(np.array_equal(mat, image.numpy_view())) + # The output of numpy_view() is a reference to the internal data and it's + # unwritable after creation. + with self.assertRaisesRegex(ValueError, + 'assignment destination is read-only'): + output_ndarray[0, 0, 0] = 0 + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0, 0, 0] = 0 + + def test_cropped_gray8_image(self): + w, h = random.randrange(20, 100), random.randrange(20, 100) + channels, offset = 3, 10 + mat = cv2.cvtColor( + np.random.randint(2**8 - 1, size=(h, w, channels), dtype=np.uint8), + cv2.COLOR_RGB2GRAY) + image = mp.Image( + image_format=mp.ImageFormat.GRAY8, + data=np.ascontiguousarray(mat[offset:-offset, offset:-offset])) + self.assertTrue( + np.array_equal(mat[offset:-offset, offset:-offset], image.numpy_view())) + + def test_cropped_rgb_image(self): + w, h = random.randrange(20, 100), random.randrange(20, 100) + channels, offset = 3, 10 + mat = cv2.cvtColor( + np.random.randint(2**8 - 1, size=(h, w, channels), dtype=np.uint8), + cv2.COLOR_RGB2BGR) + image = mp.Image( + image_format=mp.ImageFormat.SRGB, + data=np.ascontiguousarray(mat[offset:-offset, offset:-offset, :])) + self.assertTrue( + np.array_equal(mat[offset:-offset, offset:-offset, :], + image.numpy_view())) + + # For image frames that store contiguous data, the output of numpy_view() + # points to the pixel data of the original image frame object. The life cycle + # of the data array should tie to the image frame object. + def test_image_numpy_view_with_contiguous_data(self): + w, h = 640, 480 + mat = np.random.randint(2**8 - 1, size=(h, w, 3), dtype=np.uint8) + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mat) + self.assertTrue(image.is_contiguous()) + initial_ref_count = sys.getrefcount(image) + self.assertTrue(np.array_equal(mat, image.numpy_view())) + # Get 2 data array objects and verify that the image frame's ref count is + # increased by 2. + np_view = image.numpy_view() + self.assertEqual(sys.getrefcount(image), initial_ref_count + 1) + np_view2 = image.numpy_view() + self.assertEqual(sys.getrefcount(image), initial_ref_count + 2) + del np_view + del np_view2 + gc.collect() + # After the two data array objects getting destroyed, the current ref count + # should euqal to the initial ref count. + self.assertEqual(sys.getrefcount(image), initial_ref_count) + + # For image frames that store non contiguous data, the output of numpy_view() + # stores a copy of the pixel data of the image frame object. The life cycle of + # the data array doesn't tie to the image frame object. + def test_image_numpy_view_with_non_contiguous_data(self): + w, h = 641, 481 + mat = np.random.randint(2**8 - 1, size=(h, w, 3), dtype=np.uint8) + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mat) + self.assertFalse(image.is_contiguous()) + initial_ref_count = sys.getrefcount(image) + self.assertTrue(np.array_equal(mat, image.numpy_view())) + np_view = image.numpy_view() + self.assertEqual(sys.getrefcount(image), initial_ref_count) + del np_view + gc.collect() + self.assertEqual(sys.getrefcount(image), initial_ref_count) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/python/packet_creator.py b/mediapipe/python/packet_creator.py index 112b3f336..6d388a341 100644 --- a/mediapipe/python/packet_creator.py +++ b/mediapipe/python/packet_creator.py @@ -21,6 +21,7 @@ import numpy as np from google.protobuf import message from mediapipe.python._framework_bindings import _packet_creator +from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import packet @@ -115,8 +116,11 @@ def create_image_frame(data: Union[image_frame.ImageFrame, np.ndarray], raise ValueError( 'The provided image_format doesn\'t match the one from the data arg.') if copy is not None and not copy: + # Taking a reference will make the created packet be mutable since the + # ImageFrame object can still be manipulated in Python, which voids packet + # immutability. raise ValueError( - 'Creating image frame packet by taking a reference of another image frame object is not supported yet.' + 'Creating ImageFrame packet by taking a reference of another ImageFrame object is not supported yet.' ) # pylint:disable=protected-access return _packet_creator._create_image_frame_from_image_frame(data) @@ -144,6 +148,104 @@ def create_image_frame(data: Union[image_frame.ImageFrame, np.ndarray], # pylint:enable=protected-access +def create_image(data: Union[image.Image, np.ndarray], + *, + image_format: image_frame.ImageFormat = None, + copy: bool = None) -> packet.Packet: + """Create a MediaPipe Image packet. + + A MediaPipe Image packet can be created from an existing MediaPipe + Image object and the data will be realigned and copied into a new + Image object inside of the packet. + + A MediaPipe Image packet can also be created from the raw pixel data + represented as a numpy array with one of the uint8, uint16, and float data + types. There are three data ownership modes depending on how the 'copy' arg + is set. + + i) Default mode + If copy is not set, mutable data is always copied while the immutable data + is by reference. + + ii) Copy mode (safe) + If copy is set to True, the data will be realigned and copied into an + Image object inside of the packet regardless the immutablity of the + original data. + + iii) Reference mode (dangerous) + If copy is set to False, the data will be forced to be shared. If the data is + mutable (data.flags.writeable is True), a warning will be raised. + + Args: + data: A MediaPipe Image object or the raw pixel data that is represnted as a + numpy ndarray. + image_format: One of the mp.ImageFormat enum types. + copy: Indicate if the packet should copy the data from the numpy nparray. + + Returns: + A MediaPipe Image Packet. + + Raises: + ValueError: + i) When "data" is a numpy ndarray, "image_format" is not provided or + the "data" array is not c_contiguous in the reference mode. + ii) When "data" is an Image object, the "image_format" arg doesn't + match the image format of the "data" Image object or "copy" is + explicitly set to False. + TypeError: If "image format" doesn't match "data" array's data type. + + Examples: + np_array = np.random.randint(255, size=(321, 123, 3), dtype=np.uint8) + # Copy mode by default if the data array is writable. + image_packet = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=np_array) + + # Make the array unwriteable to trigger the reference mode. + np_array.flags.writeable = False + image_packet = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=np_array) + + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_array) + image_packet = mp.packet_creator.create_image(image) + + """ + if isinstance(data, image.Image): + if image_format is not None and data.image_format != image_format: + raise ValueError( + 'The provided image_format doesn\'t match the one from the data arg.') + if copy is not None and not copy: + # Taking a reference will make the created packet be mutable since the + # Image object can still be manipulated in Python, which voids packet + # immutability. + raise ValueError( + 'Creating Image packet by taking a reference of another Image object is not supported yet.' + ) + # pylint:disable=protected-access + return _packet_creator._create_image_from_image(data) + # pylint:enable=protected-access + else: + if image_format is None: + raise ValueError('Please provide \'image_format\' with \'data\'.') + # If copy arg is not set, copying the data if it's immutable. Otherwise, + # take a reference of the immutable data to avoid data copy. + if copy is None: + copy = True if data.flags.writeable else False + if not copy: + # TODO: Investigate why the first 2 bytes of the data has data + # corruption when "data" is not c_contiguous. + if not data.flags.c_contiguous: + raise ValueError( + 'Reference mode is unavailable if \'data\' is not c_contiguous.') + if data.flags.writeable: + warnings.warn( + '\'data\' is still writeable. Taking a reference of the data to create Image packet is dangerous.', + RuntimeWarning, 2) + # pylint:disable=protected-access + return _packet_creator._create_image_from_pixel_data( + image_format, data, copy) + # pylint:enable=protected-access + + def create_proto(proto_message: message.Message) -> packet.Packet: """Create a MediaPipe protobuf message packet. diff --git a/mediapipe/python/packet_getter.py b/mediapipe/python/packet_getter.py index e8a6629df..af1ecece5 100644 --- a/mediapipe/python/packet_getter.py +++ b/mediapipe/python/packet_getter.py @@ -33,6 +33,7 @@ get_float_list = _packet_getter.get_float_list get_str_list = _packet_getter.get_str_list get_packet_list = _packet_getter.get_packet_list get_str_to_packet_dict = _packet_getter.get_str_to_packet_dict +get_image = _packet_getter.get_image get_image_frame = _packet_getter.get_image_frame get_matrix = _packet_getter.get_matrix diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index 766c5304d..eb7b3d4ea 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -232,34 +232,46 @@ class PacketTest(absltest.TestCase): self.assertEqual(mp.packet_getter.get_str(output_list['string']), '42') self.assertEqual(p.timestamp, 100) - def test_uint8_image_frame_packet(self): + def test_uint8_image_packet(self): uint8_img = np.random.randint( 2**8 - 1, size=(random.randrange(3, 100), random.randrange(3, 100), 3), dtype=np.uint8) - p = mp.packet_creator.create_image_frame( + image_frame_packet = mp.packet_creator.create_image_frame( mp.ImageFrame(image_format=mp.ImageFormat.SRGB, data=uint8_img)) - output_image_frame = mp.packet_getter.get_image_frame(p) + output_image_frame = mp.packet_getter.get_image_frame(image_frame_packet) self.assertTrue(np.array_equal(output_image_frame.numpy_view(), uint8_img)) + image_packet = mp.packet_creator.create_image( + mp.Image(image_format=mp.ImageFormat.SRGB, data=uint8_img)) + output_image = mp.packet_getter.get_image(image_packet) + self.assertTrue(np.array_equal(output_image.numpy_view(), uint8_img)) - def test_uint16_image_frame_packet(self): + def test_uint16_image_packet(self): uint16_img = np.random.randint( 2**16 - 1, size=(random.randrange(3, 100), random.randrange(3, 100), 4), dtype=np.uint16) - p = mp.packet_creator.create_image_frame( + image_frame_packet = mp.packet_creator.create_image_frame( mp.ImageFrame(image_format=mp.ImageFormat.SRGBA64, data=uint16_img)) - output_image_frame = mp.packet_getter.get_image_frame(p) + output_image_frame = mp.packet_getter.get_image_frame(image_frame_packet) self.assertTrue(np.array_equal(output_image_frame.numpy_view(), uint16_img)) + image_packet = mp.packet_creator.create_image( + mp.Image(image_format=mp.ImageFormat.SRGBA64, data=uint16_img)) + output_image = mp.packet_getter.get_image(image_packet) + self.assertTrue(np.array_equal(output_image.numpy_view(), uint16_img)) def test_float_image_frame_packet(self): float_img = np.float32( np.random.random_sample( (random.randrange(3, 100), random.randrange(3, 100), 2))) - p = mp.packet_creator.create_image_frame( + image_frame_packet = mp.packet_creator.create_image_frame( mp.ImageFrame(image_format=mp.ImageFormat.VEC32F2, data=float_img)) - output_image_frame = mp.packet_getter.get_image_frame(p) + output_image_frame = mp.packet_getter.get_image_frame(image_frame_packet) self.assertTrue(np.allclose(output_image_frame.numpy_view(), float_img)) + image_packet = mp.packet_creator.create_image( + mp.Image(image_format=mp.ImageFormat.VEC32F2, data=float_img)) + output_image = mp.packet_getter.get_image(image_packet) + self.assertTrue(np.array_equal(output_image.numpy_view(), float_img)) def test_image_frame_packet_creation_copy_mode(self): w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 @@ -362,6 +374,107 @@ class PacketTest(absltest.TestCase): # copy mode. self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + def test_image_packet_creation_copy_mode(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + rgb_data = np.random.randint(255, size=(h, w, channels), dtype=np.uint8) + # rgb_data is c_contiguous. + self.assertTrue(rgb_data.flags.c_contiguous) + initial_ref_count = sys.getrefcount(rgb_data) + p = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=rgb_data) + # copy mode doesn't increase the ref count of the data. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + + rgb_data = rgb_data[:, :, ::-1] + # rgb_data is now not c_contiguous. But, copy mode shouldn't be affected. + self.assertFalse(rgb_data.flags.c_contiguous) + initial_ref_count = sys.getrefcount(rgb_data) + p = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=rgb_data) + # copy mode doesn't increase the ref count of the data. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + + output_image = mp.packet_getter.get_image(p) + self.assertEqual(output_image.height, h) + self.assertEqual(output_image.width, w) + self.assertEqual(output_image.channels, channels) + self.assertTrue(np.array_equal(output_image.numpy_view(), rgb_data)) + + del p + del output_image + gc.collect() + # Destroying the packet also doesn't affect the ref count becuase of the + # copy mode. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + + def test_image_packet_creation_reference_mode(self): + w, h, channels = random.randrange(3, 100), random.randrange(3, 100), 3 + rgb_data = np.random.randint(255, size=(h, w, channels), dtype=np.uint8) + rgb_data.flags.writeable = False + initial_ref_count = sys.getrefcount(rgb_data) + image_packet = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=rgb_data) + # Reference mode increase the ref count of the rgb_data by 1. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count + 1) + del image_packet + gc.collect() + # Deleting image_packet should decrese the ref count of rgb_data by 1. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + rgb_data_copy = np.copy(rgb_data) + # rgb_data_copy is a copy of rgb_data and should not increase the ref count. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + text_config = """ + node { + calculator: 'PassThroughCalculator' + input_side_packet: "in" + output_side_packet: "out" + } + """ + graph = mp.CalculatorGraph(graph_config=text_config) + graph.start_run( + input_side_packets={ + 'in': + mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, data=rgb_data) + }) + # reference mode increase the ref count of the rgb_data by 1. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count + 1) + graph.wait_until_done() + output_packet = graph.get_output_side_packet('out') + del rgb_data + del graph + gc.collect() + # The pixel data of the output image frame packet should still be valid + # after the graph and the original rgb_data data are deleted. + self.assertTrue( + np.array_equal( + mp.packet_getter.get_image(output_packet).numpy_view(), + rgb_data_copy)) + + def test_image_packet_copy_creation_with_cropping(self): + w, h, channels = random.randrange(40, 100), random.randrange(40, 100), 3 + channels, offset = 3, 10 + rgb_data = np.random.randint(255, size=(h, w, channels), dtype=np.uint8) + initial_ref_count = sys.getrefcount(rgb_data) + p = mp.packet_creator.create_image( + image_format=mp.ImageFormat.SRGB, + data=rgb_data[offset:-offset, offset:-offset, :]) + # copy mode doesn't increase the ref count of the data. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + output_image = mp.packet_getter.get_image(p) + self.assertEqual(output_image.height, h - 2 * offset) + self.assertEqual(output_image.width, w - 2 * offset) + self.assertEqual(output_image.channels, channels) + self.assertTrue( + np.array_equal(rgb_data[offset:-offset, offset:-offset, :], + output_image.numpy_view())) + del p + del output_image + gc.collect() + # Destroying the packet also doesn't affect the ref count becuase of the + # copy mode. + self.assertEqual(sys.getrefcount(rgb_data), initial_ref_count) + def test_matrix_packet(self): np_matrix = np.array([[.1, .2, .3], [.4, .5, .6]]) initial_ref_count = sys.getrefcount(np_matrix) diff --git a/mediapipe/python/pybind/BUILD b/mediapipe/python/pybind/BUILD index 9a0f83141..d5183b35f 100644 --- a/mediapipe/python/pybind/BUILD +++ b/mediapipe/python/pybind/BUILD @@ -1,4 +1,4 @@ -# Copyright 2020 The MediaPipe Authors. +# Copyright 2020-2021 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +36,18 @@ pybind_library( ], ) +pybind_library( + name = "image", + srcs = ["image.cc"], + hdrs = ["image.h"], + deps = [ + ":image_frame_util", + ":util", + "//mediapipe/framework:type_map", + "//mediapipe/framework/formats:image", + ], +) + pybind_library( name = "image_frame", srcs = ["image_frame.cc"], @@ -51,6 +63,7 @@ pybind_library( name = "image_frame_util", hdrs = ["image_frame_util.h"], deps = [ + ":util", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", @@ -88,6 +101,7 @@ pybind_library( ":util", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", "@com_google_absl//absl/memory", @@ -104,6 +118,7 @@ pybind_library( ":util", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", ], diff --git a/mediapipe/python/pybind/calculator_graph.cc b/mediapipe/python/pybind/calculator_graph.cc index 9801b3be9..a13f03b01 100644 --- a/mediapipe/python/pybind/calculator_graph.cc +++ b/mediapipe/python/pybind/calculator_graph.cc @@ -394,13 +394,15 @@ void CalculatorGraphSubmodule(pybind11::module* module) { calculator_graph.def( "observe_output_stream", [](CalculatorGraph* self, const std::string& stream_name, - pybind11::function callback_fn) { + pybind11::function callback_fn, bool observe_timestamp_bounds) { RaisePyErrorIfNotOk(self->ObserveOutputStream( - stream_name, [callback_fn, stream_name](const Packet& packet) { + stream_name, + [callback_fn, stream_name](const Packet& packet) { absl::MutexLock lock(&callback_mutex); callback_fn(stream_name, packet); return absl::OkStatus(); - })); + }, + observe_timestamp_bounds)); }, R"doc(Observe the named output stream. @@ -411,6 +413,8 @@ void CalculatorGraphSubmodule(pybind11::module* module) { stream_name: The name of the output stream. callback_fn: The callback function to invoke on every packet emitted by the output stream. + observe_timestamp_bounds: If true, emits an empty packet at + timestamp_bound -1 when timestamp bound changes. Raises: RuntimeError: If the calculator graph isn't initialized or the stream @@ -422,7 +426,9 @@ void CalculatorGraphSubmodule(pybind11::module* module) { graph.observe_output_stream('out', lambda stream_name, packet: out.append(packet)) -)doc"); +)doc", + py::arg("stream_name"), py::arg("callback_fn"), + py::arg("observe_timestamp_bounds") = false); calculator_graph.def( "close", diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc new file mode 100644 index 000000000..d0c77b7df --- /dev/null +++ b/mediapipe/python/pybind/image.cc @@ -0,0 +1,234 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/image.h" + +#include "mediapipe/python/pybind/image_frame_util.h" +#include "mediapipe/python/pybind/util.h" +#include "pybind11/stl.h" + +namespace mediapipe { +namespace python { + +namespace py = pybind11; + +void ImageSubmodule(pybind11::module* module) { + py::module m = module->def_submodule("image", "MediaPipe image module"); + + py::options options; + options.disable_function_signatures(); + + // Image + py::class_ image( + m, "Image", + R"doc(A container for storing an image or a video frame, in one of several formats. + + Formats supported by Image are listed in the ImageFormat enum. + Pixels are encoded row-major in an interleaved fashion. Image supports + uint8, uint16, and float as its data types. + + Image can be created by copying the data from a numpy ndarray that stores + the pixel data continuously. An Image may realign the input data on its + default alignment boundary during creation. The data in an Image will + become immutable after creation. + + Creation examples: + import cv2 + cv_mat = cv2.imread(input_file)[:, :, ::-1] + rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.Image( + format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image = mp.Image( + format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + + The pixel data in an Image can be retrieved as a numpy ndarray by calling + `Image.numpy_view()`. The returned numpy ndarray is a reference to the + internal data and itself is unwritable. If the callers want to modify the + numpy ndarray, it's required to obtain a copy of it. + + Pixel data retrieval examples: + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) + + output_ndarray = image.numpy_view() + print(output_ndarray[0, 0, 0]) + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 + )doc", + py::dynamic_attr()); + + image + .def( + py::init([](mediapipe::ImageFormat::Format format, + const py::array_t& data) { + if (format != mediapipe::ImageFormat::GRAY8 && + format != mediapipe::ImageFormat::SRGB && + format != mediapipe::ImageFormat::SRGBA) { + throw RaisePyError(PyExc_RuntimeError, + "uint8 image data should be one of the GRAY8, " + "SRGB, and SRGBA MediaPipe image formats."); + } + return Image(std::make_shared( + std::move(*CreateImageFrame(format, data).release()))); + }), + R"doc(For uint8 data type, valid ImageFormat are GRAY8, SGRB, and SRGBA.)doc", + py::arg("image_format"), py::arg("data").noconvert()) + .def( + py::init([](mediapipe::ImageFormat::Format format, + const py::array_t& data) { + if (format != mediapipe::ImageFormat::GRAY16 && + format != mediapipe::ImageFormat::SRGB48 && + format != mediapipe::ImageFormat::SRGBA64) { + throw RaisePyError( + PyExc_RuntimeError, + "uint16 image data should be one of the GRAY16, " + "SRGB48, and SRGBA64 MediaPipe image formats."); + } + return Image(std::make_shared( + std::move(*CreateImageFrame(format, data).release()))); + }), + R"doc(For uint16 data type, valid ImageFormat are GRAY16, SRGB48, and SRGBA64.)doc", + py::arg("image_format"), py::arg("data").noconvert()) + .def( + py::init([](mediapipe::ImageFormat::Format format, + const py::array_t& data) { + if (format != mediapipe::ImageFormat::VEC32F1 && + format != mediapipe::ImageFormat::VEC32F2) { + throw RaisePyError( + PyExc_RuntimeError, + "float image data should be either VEC32F1 or VEC32F2 " + "MediaPipe image formats."); + } + return Image(std::make_shared( + std::move(*CreateImageFrame(format, data).release()))); + }), + R"doc(For float data type, valid ImageFormat are VEC32F1 and VEC32F2.)doc", + py::arg("image_format"), py::arg("data").noconvert()); + + image.def( + "numpy_view", + [](Image& self) { + py::object py_object = + py::cast(self, py::return_value_policy::reference); + // If the image data is contiguous, generates the data pyarray object + // on demand because 1) making a pyarray by referring to the existing + // image pixel data is relatively cheap and 2) caching the pyarray + // object in an attribute of the image is problematic: the image object + // and the data pyarray object refer to each other, which causes gc + // fails to free the pyarray after use. + // For the non-contiguous cases, gets a cached data pyarray object from + // the image pyobject attribute. This optimization is to avoid the + // expensive data realignment and copy operations happening more than + // once. + return self.GetImageFrameSharedPtr()->IsContiguous() + ? GenerateDataPyArrayOnDemand(*self.GetImageFrameSharedPtr(), + py_object) + : GetCachedContiguousDataAttr(*self.GetImageFrameSharedPtr(), + py_object); + }, + R"doc(Return the image pixel data as an unwritable numpy ndarray. + + Realign the pixel data to be stored contiguously and return a reference to the + unwritable numpy ndarray. If the callers want to modify the numpy array data, + it's required to obtain a copy of the ndarray. + + Returns: + An unwritable numpy ndarray. + + Examples: + output_ndarray = image.numpy_view() + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 +)doc"); + + image.def( + "__getitem__", + [](Image& self, const std::vector& pos) { + if (pos.size() != 3 && !(pos.size() == 2 && self.channels() == 1)) { + throw RaisePyError( + PyExc_IndexError, + absl::StrCat("Invalid index dimension: ", pos.size()).c_str()); + } + py::object py_object = + py::cast(self, py::return_value_policy::reference); + switch (self.GetImageFrameSharedPtr()->ByteDepth()) { + case 1: + return GetValue(*self.GetImageFrameSharedPtr(), pos, + py_object); + case 2: + return GetValue(*self.GetImageFrameSharedPtr(), pos, + py_object); + case 4: + return GetValue(*self.GetImageFrameSharedPtr(), pos, + py_object); + default: + return py::object(); + } + }, + R"doc(Use the indexer operators to access pixel data. + + Raises: + IndexError: If the index is invalid or out of bounds. + + Examples: + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) +)doc"); + + image + .def("uses_gpu", &Image::UsesGpu, + R"doc(Return True if data is currently on the GPU.)doc") + .def( + "is_contiguous", + [](Image& self) { + return self.GetImageFrameSharedPtr()->IsContiguous(); + }, + R"doc(Return True if the pixel data is stored contiguously (without any alignment padding areas).)doc") + .def( + "is_empty", + [](Image& self) { return self.GetImageFrameSharedPtr()->IsEmpty(); }, + R"doc(Return True if the pixel data is unallocated.)doc") + .def( + "is_aligned", + [](Image& self, uint32 alignment_boundary) { + return self.GetImageFrameSharedPtr()->IsAligned(alignment_boundary); + }, + R"doc(Return True if each row of the data is aligned to alignment boundary, which must be 1 or a power of 2. + + Args: + alignment_boundary: An integer. + + Returns: + A boolean. + + Examples: + image.is_aligned(16) +)doc"); + + image.def_property_readonly("width", &Image::width) + .def_property_readonly("height", &Image::height) + .def_property_readonly("channels", &Image::channels) + .def_property_readonly("step", &Image::step) + .def_property_readonly("image_format", &Image::image_format); +} + +} // namespace python +} // namespace mediapipe diff --git a/mediapipe/python/pybind/image.h b/mediapipe/python/pybind/image.h new file mode 100644 index 000000000..26c1f09aa --- /dev/null +++ b/mediapipe/python/pybind/image.h @@ -0,0 +1,28 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PYTHON_PYBIND_IMAGE_H_ +#define MEDIAPIPE_PYTHON_PYBIND_IMAGE_H_ + +#include "pybind11/pybind11.h" + +namespace mediapipe { +namespace python { + +void ImageSubmodule(pybind11::module* module); + +} // namespace python +} // namespace mediapipe + +#endif // MEDIAPIPE_PYTHON_PYBIND_IMAGE_H_ diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index 8bd79a8aa..a7fc6bfe4 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -1,4 +1,4 @@ -// Copyright 2020 The MediaPipe Authors. +// Copyright 2020-2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,118 +18,6 @@ namespace mediapipe { namespace python { -namespace { - -template -py::array GenerateContiguousDataArrayHelper(const ImageFrame& image_frame, - const py::object& py_object) { - std::vector shape{image_frame.Height(), image_frame.Width()}; - if (image_frame.NumberOfChannels() > 1) { - shape.push_back(image_frame.NumberOfChannels()); - } - py::array_t contiguous_data; - if (image_frame.IsContiguous()) { - contiguous_data = py::array_t( - shape, reinterpret_cast(image_frame.PixelData()), py_object); - } else { - auto contiguous_data_copy = - absl::make_unique(image_frame.Width() * image_frame.Height() * - image_frame.NumberOfChannels()); - image_frame.CopyToBuffer(contiguous_data_copy.get(), - image_frame.PixelDataSizeStoredContiguously()); - auto capsule = py::capsule(contiguous_data_copy.get(), [](void* data) { - if (data) { - delete[] reinterpret_cast(data); - } - }); - contiguous_data = py::array_t( - shape, contiguous_data_copy.release(), capsule); - } - - // In both cases, the underlying data is not writable in Python. - py::detail::array_proxy(contiguous_data.ptr())->flags &= - ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; - return contiguous_data; -} - -py::array GenerateContiguousDataArray(const ImageFrame& image_frame, - const py::object& py_object) { - switch (image_frame.ChannelSize()) { - case sizeof(uint8): - return GenerateContiguousDataArrayHelper(image_frame, py_object) - .cast(); - case sizeof(uint16): - return GenerateContiguousDataArrayHelper(image_frame, py_object) - .cast(); - case sizeof(float): - return GenerateContiguousDataArrayHelper(image_frame, py_object) - .cast(); - break; - default: - throw RaisePyError(PyExc_RuntimeError, - "Unsupported image frame channel size. Data is not " - "uint8, uint16, or float?"); - } -} - -// Generates a contiguous data pyarray object on demand. -// This function only accepts an image frame object that already stores -// contiguous data. The output py::array points to the raw pixel data array of -// the image frame object directly. -py::array GenerateDataPyArrayOnDemand(const ImageFrame& image_frame, - const py::object& py_object) { - if (!image_frame.IsContiguous()) { - throw RaisePyError(PyExc_RuntimeError, - "GenerateDataPyArrayOnDemand must take an ImageFrame " - "object that stores contiguous data."); - } - return GenerateContiguousDataArray(image_frame, py_object); -} - -// Gets the cached contiguous data array from the "__contiguous_data" attribute. -// If the attribute doesn't exist, the function calls -// GenerateContiguousDataArray() to generate the contiguous data pyarray object, -// which realigns and copies the data from the original image frame object. -// Then, the data array object is cached in the "__contiguous_data" attribute. -// This function only accepts an image frame object that stores non-contiguous -// data. -py::array GetCachedContiguousDataAttr(const ImageFrame& image_frame, - const py::object& py_object) { - if (image_frame.IsContiguous()) { - throw RaisePyError(PyExc_RuntimeError, - "GetCachedContiguousDataAttr must take an ImageFrame " - "object that stores non-contiguous data."); - } - py::object get_data_attr = - py::getattr(py_object, "__contiguous_data", py::none()); - if (image_frame.IsEmpty()) { - throw RaisePyError(PyExc_RuntimeError, "ImageFrame is unallocated."); - } - // If __contiguous_data attr doesn't store data yet, generates the contiguous - // data array object and caches the result. - if (get_data_attr.is_none()) { - py_object.attr("__contiguous_data") = - GenerateContiguousDataArray(image_frame, py_object); - } - return py_object.attr("__contiguous_data").cast(); -} - -template -py::object GetValue(const ImageFrame& image_frame, const std::vector& pos, - const py::object& py_object) { - py::array_t output_array = - image_frame.IsContiguous() - ? GenerateDataPyArrayOnDemand(image_frame, py_object) - : GetCachedContiguousDataAttr(image_frame, py_object); - if (pos.size() == 2) { - return py::cast(static_cast(output_array.at(pos[0], pos[1]))); - } else if (pos.size() == 3) { - return py::cast(static_cast(output_array.at(pos[0], pos[1], pos[2]))); - } - return py::none(); -} - -} // namespace namespace py = pybind11; diff --git a/mediapipe/python/pybind/image_frame_util.h b/mediapipe/python/pybind/image_frame_util.h index 9de9cb4ed..6acc30524 100644 --- a/mediapipe/python/pybind/image_frame_util.h +++ b/mediapipe/python/pybind/image_frame_util.h @@ -1,4 +1,4 @@ -// Copyright 2020 The MediaPipe Authors. +// Copyright 2020-2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/python/pybind/util.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -57,6 +58,115 @@ std::unique_ptr CreateImageFrame( return image_frame; } +template +py::array GenerateContiguousDataArrayHelper(const ImageFrame& image_frame, + const py::object& py_object) { + std::vector shape{image_frame.Height(), image_frame.Width()}; + if (image_frame.NumberOfChannels() > 1) { + shape.push_back(image_frame.NumberOfChannels()); + } + py::array_t contiguous_data; + if (image_frame.IsContiguous()) { + contiguous_data = py::array_t( + shape, reinterpret_cast(image_frame.PixelData()), py_object); + } else { + auto contiguous_data_copy = + absl::make_unique(image_frame.Width() * image_frame.Height() * + image_frame.NumberOfChannels()); + image_frame.CopyToBuffer(contiguous_data_copy.get(), + image_frame.PixelDataSizeStoredContiguously()); + auto capsule = py::capsule(contiguous_data_copy.get(), [](void* data) { + if (data) { + delete[] reinterpret_cast(data); + } + }); + contiguous_data = py::array_t( + shape, contiguous_data_copy.release(), capsule); + } + + // In both cases, the underlying data is not writable in Python. + py::detail::array_proxy(contiguous_data.ptr())->flags &= + ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; + return contiguous_data; +} + +inline py::array GenerateContiguousDataArray(const ImageFrame& image_frame, + const py::object& py_object) { + switch (image_frame.ChannelSize()) { + case sizeof(uint8): + return GenerateContiguousDataArrayHelper(image_frame, py_object) + .cast(); + case sizeof(uint16): + return GenerateContiguousDataArrayHelper(image_frame, py_object) + .cast(); + case sizeof(float): + return GenerateContiguousDataArrayHelper(image_frame, py_object) + .cast(); + break; + default: + throw RaisePyError(PyExc_RuntimeError, + "Unsupported image frame channel size. Data is not " + "uint8, uint16, or float?"); + } +} + +// Generates a contiguous data pyarray object on demand. +// This function only accepts an image frame object that already stores +// contiguous data. The output py::array points to the raw pixel data array of +// the image frame object directly. +inline py::array GenerateDataPyArrayOnDemand(const ImageFrame& image_frame, + const py::object& py_object) { + if (!image_frame.IsContiguous()) { + throw RaisePyError(PyExc_RuntimeError, + "GenerateDataPyArrayOnDemand must take an ImageFrame " + "object that stores contiguous data."); + } + return GenerateContiguousDataArray(image_frame, py_object); +} + +// Gets the cached contiguous data array from the "__contiguous_data" attribute. +// If the attribute doesn't exist, the function calls +// GenerateContiguousDataArray() to generate the contiguous data pyarray object, +// which realigns and copies the data from the original image frame object. +// Then, the data array object is cached in the "__contiguous_data" attribute. +// This function only accepts an image frame object that stores non-contiguous +// data. +inline py::array GetCachedContiguousDataAttr(const ImageFrame& image_frame, + const py::object& py_object) { + if (image_frame.IsContiguous()) { + throw RaisePyError(PyExc_RuntimeError, + "GetCachedContiguousDataAttr must take an ImageFrame " + "object that stores non-contiguous data."); + } + py::object get_data_attr = + py::getattr(py_object, "__contiguous_data", py::none()); + if (image_frame.IsEmpty()) { + throw RaisePyError(PyExc_RuntimeError, "ImageFrame is unallocated."); + } + // If __contiguous_data attr doesn't store data yet, generates the contiguous + // data array object and caches the result. + if (get_data_attr.is_none()) { + py_object.attr("__contiguous_data") = + GenerateContiguousDataArray(image_frame, py_object); + } + return py_object.attr("__contiguous_data").cast(); +} + +template +py::object GetValue(const ImageFrame& image_frame, const std::vector& pos, + const py::object& py_object) { + py::array_t output_array = + image_frame.IsContiguous() + ? GenerateDataPyArrayOnDemand(image_frame, py_object) + : GetCachedContiguousDataAttr(image_frame, py_object); + if (pos.size() == 2) { + return py::cast(static_cast(output_array.at(pos[0], pos[1]))); + } else if (pos.size() == 3) { + return py::cast(static_cast(output_array.at(pos[0], pos[1], pos[2]))); + } + return py::none(); +} + } // namespace python } // namespace mediapipe diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 2212732cd..5cc66a310 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -16,6 +16,7 @@ #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/integral_types.h" @@ -49,6 +50,28 @@ Packet CreateImageFramePacket(mediapipe::ImageFormat::Format format, return Packet(); } +Packet CreateImagePacket(mediapipe::ImageFormat::Format format, + const py::array& data, bool copy) { + if (format == mediapipe::ImageFormat::SRGB || + format == mediapipe::ImageFormat::SRGBA || + format == mediapipe::ImageFormat::GRAY8) { + return MakePacket(std::make_shared( + std::move(*CreateImageFrame(format, data, copy).release()))); + } else if (format == mediapipe::ImageFormat::GRAY16 || + format == mediapipe::ImageFormat::SRGB48 || + format == mediapipe::ImageFormat::SRGBA64) { + return MakePacket(std::make_shared( + std::move(*CreateImageFrame(format, data, copy).release()))); + } else if (format == mediapipe::ImageFormat::VEC32F1 || + format == mediapipe::ImageFormat::VEC32F2) { + return MakePacket(std::make_shared( + std::move(*CreateImageFrame(format, data, copy).release()))); + } + throw RaisePyError(PyExc_RuntimeError, + absl::StrCat("Unsupported ImageFormat: ", format).c_str()); + return Packet(); +} + } // namespace namespace py = pybind11; @@ -586,6 +609,10 @@ void InternalPacketCreators(pybind11::module* m) { py::arg("format"), py::arg("data").noconvert(), py::arg("copy"), py::return_value_policy::move); + m->def("_create_image_from_pixel_data", &CreateImagePacket, py::arg("format"), + py::arg("data").noconvert(), py::arg("copy"), + py::return_value_policy::move); + m->def( "_create_image_frame_from_image_frame", [](ImageFrame& image_frame) { @@ -598,6 +625,19 @@ void InternalPacketCreators(pybind11::module* m) { }, py::arg("image_frame").noconvert(), py::return_value_policy::move); + m->def( + "_create_image_from_image", + [](Image& image) { + auto image_frame_copy = absl::make_unique(); + // Set alignment_boundary to kGlDefaultAlignmentBoundary so that + // both GPU and CPU can process it. + image_frame_copy->CopyFrom(*image.GetImageFrameSharedPtr(), + ImageFrame::kGlDefaultAlignmentBoundary); + return MakePacket(std::make_shared( + std::move(*image_frame_copy.release()))); + }, + py::arg("image").noconvert(), py::return_value_policy::move); + m->def( "_create_proto", [](const std::string& type_name, const py::bytes& serialized_proto) { diff --git a/mediapipe/python/pybind/packet_getter.cc b/mediapipe/python/pybind/packet_getter.cc index f88e48b4c..271184409 100644 --- a/mediapipe/python/pybind/packet_getter.cc +++ b/mediapipe/python/pybind/packet_getter.cc @@ -14,6 +14,7 @@ #include "mediapipe/python/pybind/packet_getter.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/integral_types.h" @@ -322,6 +323,24 @@ void PublicPacketGetters(pybind11::module* m) { )doc", py::return_value_policy::reference_internal); + m->def("get_image", &GetContent, + R"doc(Get the content of a MediaPipe Image Packet as an Image object. + + Args: + packet: A MediaPipe Image Packet. + + Returns: + A MediaPipe Image object. + + Raises: + ValueError: If the Packet doesn't contain Image. + + Examples: + packet = packet_creator.create_image(frame) + data = packet_getter.get_image(packet) +)doc", + py::return_value_policy::reference_internal); + m->def( "get_matrix", [](const Packet& packet) { diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index 7b535c832..ece3a09fa 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -89,7 +89,8 @@ class _PacketDataType(enum.Enum): FLOAT = 'float' FLOAT_LIST = 'float_list' AUDIO = 'matrix' - IMAGE = 'image_frame' + IMAGE = 'image' + IMAGE_FRAME = 'image_frame' PROTO = 'proto' PROTO_LIST = 'proto_list' @@ -114,7 +115,7 @@ NAME_TO_TYPE: Mapping[str, '_PacketDataType'] = { '::mediapipe::Matrix': _PacketDataType.AUDIO, '::mediapipe::ImageFrame': - _PacketDataType.IMAGE, + _PacketDataType.IMAGE_FRAME, '::mediapipe::Classification': _PacketDataType.PROTO, '::mediapipe::ClassificationList': @@ -139,6 +140,8 @@ NAME_TO_TYPE: Mapping[str, '_PacketDataType'] = { _PacketDataType.PROTO, '::mediapipe::NormalizedLandmarkList': _PacketDataType.PROTO, + '::mediapipe::Image': + _PacketDataType.IMAGE, '::std::vector<::mediapipe::Classification>': _PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::ClassificationList>': @@ -242,7 +245,7 @@ class SolutionBase: self._graph_outputs[stream_name] = output_packet for stream_name in self._output_stream_type_info.keys(): - self._graph.observe_output_stream(stream_name, callback) + self._graph.observe_output_stream(stream_name, callback, True) input_side_packets = { name: self._make_packet(self._side_input_type_info[name], data) @@ -296,12 +299,14 @@ class SolutionBase: # input. self._simulated_timestamp += 33333 for stream_name, data in input_dict.items(): - if self._input_stream_type_info[stream_name] == _PacketDataType.IMAGE: + input_stream_type = self._input_stream_type_info[stream_name] + if (input_stream_type == _PacketDataType.IMAGE_FRAME or + input_stream_type == _PacketDataType.IMAGE): if data.shape[2] != RGB_CHANNELS: raise ValueError('Input image must contain three channel rgb data.') self._graph.add_packet_to_input_stream( stream=stream_name, - packet=self._make_packet(_PacketDataType.IMAGE, + packet=self._make_packet(input_stream_type, data).at(self._simulated_timestamp)) else: # TODO: Support audio data. @@ -476,18 +481,34 @@ class SolutionBase: def _make_packet(self, packet_data_type: _PacketDataType, data: Any) -> packet.Packet: - if packet_data_type == _PacketDataType.IMAGE: - return packet_creator.create_image_frame( + if (packet_data_type == _PacketDataType.IMAGE_FRAME or + packet_data_type == _PacketDataType.IMAGE): + return getattr(packet_creator, 'create_' + packet_data_type.value)( data, image_format=image_frame.ImageFormat.SRGB) else: return getattr(packet_creator, 'create_' + packet_data_type.value)(data) def _get_packet_content(self, packet_data_type: _PacketDataType, output_packet: packet.Packet) -> Any: + """Gets packet content from a packet by type. + + Args: + packet_data_type: The supported packet data type. + output_packet: The packet to get content from. + + Returns: + Packet content by packet data type. None to indicate "no output". + + """ + + if output_packet.is_empty(): + return None if packet_data_type == _PacketDataType.STRING: return packet_getter.get_str(output_packet) - elif packet_data_type == _PacketDataType.IMAGE: - return packet_getter.get_image_frame(output_packet).numpy_view() + elif (packet_data_type == _PacketDataType.IMAGE_FRAME or + packet_data_type == _PacketDataType.IMAGE): + return getattr(packet_getter, 'get_' + + packet_data_type.value)(output_packet).numpy_view() else: return getattr(packet_getter, 'get_' + packet_data_type.value)( output_packet) diff --git a/mediapipe/python/solutions/face_detection_test.py b/mediapipe/python/solutions/face_detection_test.py index 86eb794c5..25f5b33fd 100644 --- a/mediapipe/python/solutions/face_detection_test.py +++ b/mediapipe/python/solutions/face_detection_test.py @@ -14,6 +14,8 @@ """Tests for mediapipe.python.solutions.face_detection.""" import os +import tempfile # pylint: disable=unused-import +from typing import NamedTuple from absl.testing import absltest import cv2 @@ -21,16 +23,25 @@ import numpy as np import numpy.testing as npt # resources dependency +# undeclared dependency +from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import face_detection as mp_faces TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -EXPECTED_FACE_KEY_POINTS = [[182, 368], [186, 467], [236, 416], [284, 415], - [203, 310], [212, 521]] -DIFF_THRESHOLD = 10 # pixels +EXPECTED_FACE_KEY_POINTS = [[182, 363], [186, 460], [241, 420], [284, 417], + [199, 295], [198, 502]] +DIFF_THRESHOLD = 5 # pixels class FaceDetectionTest(absltest.TestCase): + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + for detection in results.detections: + mp_drawing.draw_detection(frame, detection) + path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + + '_frame_{}.png'.format(idx)) + cv2.imwrite(path, frame) + def test_invalid_image_shape(self): with mp_faces.FaceDetection() as faces: with self.assertRaisesRegex( @@ -46,11 +57,11 @@ class FaceDetectionTest(absltest.TestCase): def test_face(self): image_path = os.path.join(os.path.dirname(__file__), 'testdata/face.jpg') - image = cv2.flip(cv2.imread(image_path), 1) - + image = cv2.imread(image_path) with mp_faces.FaceDetection(min_detection_confidence=0.5) as faces: - for _ in range(5): + for idx in range(5): results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._annotate(image.copy(), results, idx) location_data = results.detections[0].location_data x = [keypoint.x for keypoint in location_data.relative_keypoints] y = [keypoint.y for keypoint in location_data.relative_keypoints] diff --git a/mediapipe/python/solutions/face_mesh_test.py b/mediapipe/python/solutions/face_mesh_test.py index e53479821..cf112044d 100644 --- a/mediapipe/python/solutions/face_mesh_test.py +++ b/mediapipe/python/solutions/face_mesh_test.py @@ -15,6 +15,8 @@ """Tests for mediapipe.python.solutions.face_mesh.""" import os +import tempfile # pylint: disable=unused-import +from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized @@ -23,48 +25,61 @@ import numpy as np import numpy.testing as npt # resources dependency +# undeclared dependency +from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import face_mesh as mp_faces TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLD = 20 # pixels +DIFF_THRESHOLD = 5 # pixels EYE_INDICES_TO_LANDMARKS = { - 33: [176, 350], - 7: [177, 353], - 163: [178, 357], - 144: [179, 362], - 145: [179, 369], - 153: [179, 376], - 154: [178, 382], - 155: [177, 386], - 133: [177, 388], - 246: [175, 352], - 161: [174, 355], - 160: [172, 360], - 159: [170, 367], - 158: [171, 374], - 157: [172, 381], - 173: [175, 386], - 263: [176, 475], - 249: [177, 471], - 390: [177, 467], - 373: [178, 462], - 374: [179, 454], - 380: [179, 448], - 381: [178, 441], - 382: [177, 437], - 362: [177, 435], - 466: [175, 473], - 388: [173, 469], - 387: [171, 464], - 386: [170, 457], - 385: [171, 450], - 384: [172, 443], - 398: [175, 438] + 33: [178, 345], + 7: [179, 348], + 163: [178, 352], + 144: [179, 357], + 145: [179, 365], + 153: [179, 371], + 154: [178, 378], + 155: [177, 381], + 133: [177, 383], + 246: [175, 347], + 161: [174, 350], + 160: [172, 355], + 159: [170, 362], + 158: [171, 368], + 157: [172, 375], + 173: [175, 380], + 263: [176, 467], + 249: [177, 464], + 390: [177, 460], + 373: [178, 455], + 374: [179, 448], + 380: [179, 441], + 381: [178, 435], + 382: [177, 432], + 362: [177, 430], + 466: [175, 465], + 388: [173, 462], + 387: [171, 457], + 386: [170, 450], + 385: [171, 444], + 384: [172, 437], + 398: [175, 432] } class FaceMeshTest(parameterized.TestCase): + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) + for face_landmarks in results.multi_face_landmarks: + mp_drawing.draw_landmarks( + image=frame, + landmark_list=face_landmarks, + landmark_drawing_spec=drawing_spec) + path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + + '_frame_{}.png'.format(idx)) + cv2.imwrite(path, frame) + def test_invalid_image_shape(self): with mp_faces.FaceMesh() as faces: with self.assertRaisesRegex( @@ -82,13 +97,13 @@ class FaceMeshTest(parameterized.TestCase): ('video_mode', False, 5)) def test_face(self, static_image_mode: bool, num_frames: int): image_path = os.path.join(os.path.dirname(__file__), 'testdata/face.jpg') - image = cv2.flip(cv2.imread(image_path), 1) - + image = cv2.imread(image_path) with mp_faces.FaceMesh( static_image_mode=static_image_mode, min_detection_confidence=0.5) as faces: - for _ in range(num_frames): + for idx in range(num_frames): results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._annotate(image.copy(), results, idx) multi_face_landmarks = [] for landmarks in results.multi_face_landmarks: self.assertLen(landmarks.landmark, 468) @@ -98,9 +113,9 @@ class FaceMeshTest(parameterized.TestCase): multi_face_landmarks.append(face_landmarks) self.assertLen(multi_face_landmarks, 1) # Verify the eye landmarks are correct as sanity check. - for idx, gt_lds in EYE_INDICES_TO_LANDMARKS.items(): + for eye_idx, gt_lds in EYE_INDICES_TO_LANDMARKS.items(): prediction_error = np.abs( - np.asarray(multi_face_landmarks[0][idx]) - np.asarray(gt_lds)) + np.asarray(multi_face_landmarks[0][eye_idx]) - np.asarray(gt_lds)) npt.assert_array_less(prediction_error, DIFF_THRESHOLD) diff --git a/mediapipe/python/solutions/hands.py b/mediapipe/python/solutions/hands.py index 15760ed75..a4bd035ab 100644 --- a/mediapipe/python/solutions/hands.py +++ b/mediapipe/python/solutions/hands.py @@ -44,7 +44,7 @@ class HandLandmark(enum.IntEnum): WRIST = 0 THUMB_CMC = 1 THUMB_MCP = 2 - THUMB_IP = 3 + THUMB_DIP = 3 THUMB_TIP = 4 INDEX_FINGER_MCP = 5 INDEX_FINGER_PIP = 6 @@ -68,8 +68,8 @@ BINARYPB_FILE_PATH = 'mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu HAND_CONNECTIONS = frozenset([ (HandLandmark.WRIST, HandLandmark.THUMB_CMC), (HandLandmark.THUMB_CMC, HandLandmark.THUMB_MCP), - (HandLandmark.THUMB_MCP, HandLandmark.THUMB_IP), - (HandLandmark.THUMB_IP, HandLandmark.THUMB_TIP), + (HandLandmark.THUMB_MCP, HandLandmark.THUMB_DIP), + (HandLandmark.THUMB_DIP, HandLandmark.THUMB_TIP), (HandLandmark.WRIST, HandLandmark.INDEX_FINGER_MCP), (HandLandmark.INDEX_FINGER_MCP, HandLandmark.INDEX_FINGER_PIP), (HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_DIP), diff --git a/mediapipe/python/solutions/hands_test.py b/mediapipe/python/solutions/hands_test.py index 1ea4dc563..7e262c1e4 100644 --- a/mediapipe/python/solutions/hands_test.py +++ b/mediapipe/python/solutions/hands_test.py @@ -15,6 +15,8 @@ """Tests for mediapipe.python.solutions.hands.""" import os +import tempfile # pylint: disable=unused-import +from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized @@ -23,28 +25,38 @@ import numpy as np import numpy.testing as npt # resources dependency +# undeclared dependency +from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import hands as mp_hands TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLD = 20 # pixels -EXPECTED_HAND_COORDINATES_PREDICTION = [[[332, 144], [323, 211], [286, 257], +DIFF_THRESHOLD = 15 # pixels +EXPECTED_HAND_COORDINATES_PREDICTION = [[[345, 144], [323, 211], [286, 257], [237, 289], [203, 322], [216, 219], [138, 238], [90, 249], [51, 253], [204, 177], [115, 184], [60, 187], [19, 185], [208, 138], [127, 131], [77, 124], [36, 117], [222, 106], [159, 92], [124, 79], [93, 68]], - [[43, 570], [56, 504], [94, 459], + [[40, 577], [56, 504], [94, 459], [146, 429], [182, 397], [167, 496], [245, 479], [292, 469], [330, 464], [177, 540], [265, 534], [319, 533], [360, 536], [172, 581], [252, 587], [304, 593], [346, 599], [157, 615], - [219, 628], [255, 638], [288, 648]]] + [223, 628], [258, 638], [288, 648]]] class HandsTest(parameterized.TestCase): + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + for hand_landmarks in results.multi_hand_landmarks: + mp_drawing.draw_landmarks(frame, hand_landmarks, + mp_hands.HAND_CONNECTIONS) + path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + + '_frame_{}.png'.format(idx)) + cv2.imwrite(path, frame) + def test_invalid_image_shape(self): with mp_hands.Hands() as hands: with self.assertRaisesRegex( @@ -63,14 +75,14 @@ class HandsTest(parameterized.TestCase): ('video_mode', False, 5)) def test_multi_hands(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg') - image = cv2.flip(cv2.imread(image_path), 1) - + image = cv2.imread(image_path) with mp_hands.Hands( static_image_mode=static_image_mode, max_num_hands=2, min_detection_confidence=0.5) as hands: - for _ in range(num_frames): + for idx in range(num_frames): results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._annotate(image.copy(), results, idx) handedness = [ handedness.classification[0].label for handedness in results.multi_handedness diff --git a/mediapipe/python/solutions/holistic.py b/mediapipe/python/solutions/holistic.py index 898bb1562..ac204e243 100644 --- a/mediapipe/python/solutions/holistic.py +++ b/mediapipe/python/solutions/holistic.py @@ -41,7 +41,6 @@ from mediapipe.python.solutions.hands import HAND_CONNECTIONS from mediapipe.python.solutions.hands import HandLandmark from mediapipe.python.solutions.pose import POSE_CONNECTIONS from mediapipe.python.solutions.pose import PoseLandmark -from mediapipe.python.solutions.pose import UPPER_BODY_POSE_CONNECTIONS # pylint: enable=unused-import BINARYPB_FILE_PATH = 'mediapipe/modules/holistic_landmark/holistic_landmark_cpu.binarypb' @@ -60,7 +59,7 @@ class Holistic(SolutionBase): def __init__(self, static_image_mode=False, - upper_body_only=False, + model_complexity=1, smooth_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5): @@ -70,9 +69,8 @@ class Holistic(SolutionBase): static_image_mode: Whether to treat the input images as a batch of static and possibly unrelated images, or a video stream. See details in https://solutions.mediapipe.dev/holistic#static_image_mode. - upper_body_only: Whether to track the full set of 33 pose landmarks or - only the 25 upper-body pose landmarks. See details in - https://solutions.mediapipe.dev/holistic#upper_body_only. + model_complexity: Complexity of the pose landmark model: 0, 1 or 2. See + details in https://solutions.mediapipe.dev/holistic#model_complexity. smooth_landmarks: Whether to filter landmarks across different input images to reduce jitter. See details in https://solutions.mediapipe.dev/holistic#smooth_landmarks. @@ -86,7 +84,7 @@ class Holistic(SolutionBase): super().__init__( binary_graph_path=BINARYPB_FILE_PATH, side_inputs={ - 'upper_body_only': upper_body_only, + 'model_complexity': model_complexity, 'smooth_landmarks': smooth_landmarks and not static_image_mode, }, calculator_params={ diff --git a/mediapipe/python/solutions/holistic_test.py b/mediapipe/python/solutions/holistic_test.py index cfd29e24c..62e96205a 100644 --- a/mediapipe/python/solutions/holistic_test.py +++ b/mediapipe/python/solutions/holistic_test.py @@ -14,6 +14,8 @@ """Tests for mediapipe.python.solutions.pose.""" import os +import tempfile # pylint: disable=unused-import +from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized @@ -22,45 +24,38 @@ import numpy as np import numpy.testing as npt # resources dependency +# undeclared dependency +from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import holistic as mp_holistic TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' POSE_DIFF_THRESHOLD = 30 # pixels HAND_DIFF_THRESHOLD = 30 # pixels -EXPECTED_UPPER_BODY_LANDMARKS = np.array([[457, 289], [465, 278], [467, 278], - [470, 277], [461, 279], [461, 279], - [461, 279], [485, 277], [474, 278], - [468, 296], [463, 297], [542, 324], - [449, 327], [614, 321], [376, 318], - [680, 322], [312, 310], [697, 320], - [293, 305], [699, 314], [289, 302], - [693, 316], [296, 305], [515, 451], - [467, 453]]) -EXPECTED_FULL_BODY_LANDMARKS = np.array([[460, 287], [469, 277], [472, 276], - [475, 276], [464, 277], [463, 277], - [463, 276], [492, 277], [472, 277], - [471, 295], [465, 295], [542, 323], - [448, 318], [619, 319], [372, 313], - [695, 316], [296, 308], [717, 313], - [273, 304], [718, 304], [280, 298], - [709, 307], [289, 303], [521, 470], - [459, 466], [626, 533], [364, 500], - [704, 616], [347, 614], [710, 631], - [357, 633], [737, 625], [306, 639]]) -EXPECTED_LEFT_HAND_LANDMARKS = np.array([[698, 314], [712, 314], [721, 314], - [727, 314], [732, 313], [728, 309], - [738, 309], [745, 308], [751, 307], - [724, 310], [735, 309], [742, 309], - [747, 307], [719, 312], [727, 313], - [729, 312], [731, 311], [713, 315], - [717, 315], [719, 314], [719, 313]]) -EXPECTED_RIGHT_HAND_LANDMARKS = np.array([[293, 307], [284, 306], [277, 304], - [271, 303], [266, 303], [271, 302], - [261, 302], [254, 301], [247, 299], - [272, 303], [261, 303], [253, 301], - [245, 299], [275, 304], [266, 303], - [258, 302], [252, 300], [279, 305], - [273, 305], [268, 304], [263, 303]]) +EXPECTED_POSE_LANDMARKS = np.array([[782, 243], [791, 232], [796, 233], + [801, 233], [773, 231], [766, 231], + [759, 232], [802, 242], [751, 239], + [791, 258], [766, 258], [830, 301], + [708, 298], [910, 248], [635, 234], + [954, 161], [593, 136], [961, 137], + [583, 110], [952, 132], [592, 106], + [950, 141], [596, 115], [793, 500], + [724, 502], [874, 626], [640, 629], + [965, 756], [542, 760], [962, 779], + [533, 781], [1025, 797], [487, 803]]) +EXPECTED_LEFT_HAND_LANDMARKS = np.array([[958, 167], [950, 161], [945, 151], + [945, 141], [947, 134], [945, 136], + [939, 122], [935, 113], [931, 106], + [951, 134], [946, 118], [942, 108], + [938, 100], [957, 135], [954, 120], + [951, 111], [948, 103], [964, 138], + [964, 128], [965, 122], [965, 117]]) +EXPECTED_RIGHT_HAND_LANDMARKS = np.array([[590, 135], [602, 125], [609, 114], + [613, 103], [617, 96], [596, 100], + [595, 84], [594, 74], [593, 68], + [588, 100], [586, 84], [585, 73], + [584, 65], [581, 103], [579, 89], + [579, 79], [579, 72], [575, 109], + [571, 99], [570, 93], [569, 87]]) class PoseTest(parameterized.TestCase): @@ -73,6 +68,22 @@ class PoseTest(parameterized.TestCase): def _assert_diff_less(self, array1, array2, threshold): npt.assert_array_less(np.abs(array1 - array2), threshold) + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) + mp_drawing.draw_landmarks( + image=frame, + landmark_list=results.face_landmarks, + landmark_drawing_spec=drawing_spec) + mp_drawing.draw_landmarks(frame, results.left_hand_landmarks, + mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(frame, results.right_hand_landmarks, + mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(frame, results.pose_landmarks, + mp_holistic.POSE_CONNECTIONS) + path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + + '_frame_{}.png'.format(idx)) + cv2.imwrite(path, frame) + def test_invalid_image_shape(self): with mp_holistic.Holistic() as holistic: with self.assertRaisesRegex( @@ -86,44 +97,24 @@ class PoseTest(parameterized.TestCase): results = holistic.process(image) self.assertIsNone(results.pose_landmarks) - @parameterized.named_parameters(('static_image_mode', True, 3), - ('video_mode', False, 3)) - def test_upper_body_model(self, static_image_mode, num_frames): - image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - with mp_holistic.Holistic( - static_image_mode=static_image_mode, upper_body_only=True) as holistic: - image = cv2.imread(image_path) - for _ in range(num_frames): - results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._assert_diff_less( - self._landmarks_list_to_array(results.pose_landmarks, image.shape), - EXPECTED_UPPER_BODY_LANDMARKS, - POSE_DIFF_THRESHOLD) - self._assert_diff_less( - self._landmarks_list_to_array(results.left_hand_landmarks, - image.shape), - EXPECTED_LEFT_HAND_LANDMARKS, - HAND_DIFF_THRESHOLD) - self._assert_diff_less( - self._landmarks_list_to_array(results.right_hand_landmarks, - image.shape), - EXPECTED_RIGHT_HAND_LANDMARKS, - HAND_DIFF_THRESHOLD) - # TODO: Verify the correctness of the face landmarks. - self.assertLen(results.face_landmarks.landmark, 468) - - @parameterized.named_parameters(('static_image_mode', True, 3), - ('video_mode', False, 3)) - def test_full_body_model(self, static_image_mode, num_frames): - image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') + @parameterized.named_parameters(('static_lite', True, 0, 3), + ('static_full', True, 1, 3), + ('static_heavy', True, 2, 3), + ('video_lite', False, 0, 3), + ('video_full', False, 1, 3), + ('video_heavy', False, 2, 3)) + def test_on_image(self, static_image_mode, model_complexity, num_frames): + image_path = os.path.join(os.path.dirname(__file__), + 'testdata/holistic.jpg') image = cv2.imread(image_path) - - with mp_holistic.Holistic(static_image_mode=static_image_mode) as holistic: - for _ in range(num_frames): + with mp_holistic.Holistic(static_image_mode=static_image_mode, + model_complexity=model_complexity) as holistic: + for idx in range(num_frames): results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._annotate(image.copy(), results, idx) self._assert_diff_less( self._landmarks_list_to_array(results.pose_landmarks, image.shape), - EXPECTED_FULL_BODY_LANDMARKS, + EXPECTED_POSE_LANDMARKS, POSE_DIFF_THRESHOLD) self._assert_diff_less( self._landmarks_list_to_array(results.left_hand_landmarks, diff --git a/mediapipe/python/solutions/pose.py b/mediapipe/python/solutions/pose.py index 74d8166af..47d2d87f6 100644 --- a/mediapipe/python/solutions/pose.py +++ b/mediapipe/python/solutions/pose.py @@ -42,7 +42,7 @@ from mediapipe.python.solution_base import SolutionBase class PoseLandmark(enum.IntEnum): - """The 25 (upper-body) pose landmarks.""" + """The 33 pose landmarks.""" NOSE = 0 LEFT_EYE_INNER = 1 LEFT_EYE = 2 @@ -78,7 +78,7 @@ class PoseLandmark(enum.IntEnum): RIGHT_FOOT_INDEX = 32 BINARYPB_FILE_PATH = 'mediapipe/modules/pose_landmark/pose_landmark_cpu.binarypb' -UPPER_BODY_POSE_CONNECTIONS = frozenset([ +POSE_CONNECTIONS = frozenset([ (PoseLandmark.NOSE, PoseLandmark.RIGHT_EYE_INNER), (PoseLandmark.RIGHT_EYE_INNER, PoseLandmark.RIGHT_EYE), (PoseLandmark.RIGHT_EYE, PoseLandmark.RIGHT_EYE_OUTER), @@ -104,21 +104,17 @@ UPPER_BODY_POSE_CONNECTIONS = frozenset([ (PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_HIP), (PoseLandmark.LEFT_SHOULDER, PoseLandmark.LEFT_HIP), (PoseLandmark.RIGHT_HIP, PoseLandmark.LEFT_HIP), + (PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE), + (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE), + (PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE), + (PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE), + (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_HEEL), + (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_HEEL), + (PoseLandmark.RIGHT_HEEL, PoseLandmark.RIGHT_FOOT_INDEX), + (PoseLandmark.LEFT_HEEL, PoseLandmark.LEFT_FOOT_INDEX), + (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_FOOT_INDEX), + (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_FOOT_INDEX), ]) -POSE_CONNECTIONS = frozenset.union( - UPPER_BODY_POSE_CONNECTIONS, - frozenset([ - (PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE), - (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE), - (PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE), - (PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE), - (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_HEEL), - (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_HEEL), - (PoseLandmark.RIGHT_HEEL, PoseLandmark.RIGHT_FOOT_INDEX), - (PoseLandmark.LEFT_HEEL, PoseLandmark.LEFT_FOOT_INDEX), - (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_FOOT_INDEX), - (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_FOOT_INDEX), - ])) class Pose(SolutionBase): @@ -133,7 +129,7 @@ class Pose(SolutionBase): def __init__(self, static_image_mode=False, - upper_body_only=False, + model_complexity=1, smooth_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5): @@ -143,9 +139,8 @@ class Pose(SolutionBase): static_image_mode: Whether to treat the input images as a batch of static and possibly unrelated images, or a video stream. See details in https://solutions.mediapipe.dev/pose#static_image_mode. - upper_body_only: Whether to track the full set of 33 pose landmarks or - only the 25 upper-body pose landmarks. See details in - https://solutions.mediapipe.dev/pose#upper_body_only. + model_complexity: Complexity of the pose landmark model: 0, 1 or 2. See + details in https://solutions.mediapipe.dev/pose#model_complexity. smooth_landmarks: Whether to filter landmarks across different input images to reduce jitter. See details in https://solutions.mediapipe.dev/pose#smooth_landmarks. @@ -159,7 +154,7 @@ class Pose(SolutionBase): super().__init__( binary_graph_path=BINARYPB_FILE_PATH, side_inputs={ - 'upper_body_only': upper_body_only, + 'model_complexity': model_complexity, 'smooth_landmarks': smooth_landmarks and not static_image_mode, }, calculator_params={ diff --git a/mediapipe/python/solutions/pose_test.py b/mediapipe/python/solutions/pose_test.py index b5d108460..5022514ea 100644 --- a/mediapipe/python/solutions/pose_test.py +++ b/mediapipe/python/solutions/pose_test.py @@ -16,6 +16,7 @@ import json import os import tempfile +from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized @@ -24,30 +25,23 @@ import numpy as np import numpy.testing as npt # resources dependency +# undeclared dependency +from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import pose as mp_pose TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' DIFF_THRESHOLD = 30 # pixels -EXPECTED_UPPER_BODY_LANDMARKS = np.array([[457, 289], [465, 278], [467, 278], - [470, 277], [461, 279], [461, 279], - [461, 279], [485, 277], [474, 278], - [468, 296], [463, 297], [542, 324], - [449, 327], [614, 321], [376, 318], - [680, 322], [312, 310], [697, 320], - [293, 305], [699, 314], [289, 302], - [693, 316], [296, 305], [515, 451], - [467, 453]]) -EXPECTED_FULL_BODY_LANDMARKS = np.array([[460, 287], [469, 277], [472, 276], - [475, 276], [464, 277], [463, 277], - [463, 276], [492, 277], [472, 277], - [471, 295], [465, 295], [542, 323], - [448, 318], [619, 319], [372, 313], - [695, 316], [296, 308], [717, 313], - [273, 304], [718, 304], [280, 298], - [709, 307], [289, 303], [521, 470], - [459, 466], [626, 533], [364, 500], - [704, 616], [347, 614], [710, 631], - [357, 633], [737, 625], [306, 639]]) +EXPECTED_POSE_LANDMARKS = np.array([[460, 287], [469, 277], [472, 276], + [475, 276], [464, 277], [463, 277], + [463, 276], [492, 277], [472, 277], + [471, 295], [465, 295], [542, 323], + [448, 318], [619, 319], [372, 313], + [695, 316], [296, 308], [717, 313], + [273, 304], [718, 304], [280, 298], + [709, 307], [289, 303], [521, 470], + [459, 466], [626, 533], [364, 500], + [704, 616], [347, 614], [710, 631], + [357, 633], [737, 625], [306, 639]]) class PoseTest(parameterized.TestCase): @@ -60,6 +54,13 @@ class PoseTest(parameterized.TestCase): def _assert_diff_less(self, array1, array2, threshold): npt.assert_array_less(np.abs(array1 - array2), threshold) + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + mp_drawing.draw_landmarks(frame, results.pose_landmarks, + mp_pose.POSE_CONNECTIONS) + path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + + '_frame_{}.png'.format(idx)) + cv2.imwrite(path, frame) + def test_invalid_image_shape(self): with mp_pose.Pose() as pose: with self.assertRaisesRegex( @@ -73,38 +74,28 @@ class PoseTest(parameterized.TestCase): results = pose.process(image) self.assertIsNone(results.pose_landmarks) - @parameterized.named_parameters(('static_image_mode', True, 3), - ('video_mode', False, 3)) - def test_upper_body_model(self, static_image_mode, num_frames): - image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - with mp_pose.Pose( - static_image_mode=static_image_mode, upper_body_only=True) as pose: - image = cv2.imread(image_path) - for _ in range(num_frames): - results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._assert_diff_less( - self._landmarks_list_to_array(results.pose_landmarks, - image.shape)[:, :2], - EXPECTED_UPPER_BODY_LANDMARKS, DIFF_THRESHOLD) - - @parameterized.named_parameters(('static_image_mode', True, 3), - ('video_mode', False, 3)) - def test_full_body_model(self, static_image_mode, num_frames): + @parameterized.named_parameters(('static_lite', True, 0, 3), + ('static_full', True, 1, 3), + ('static_heavy', True, 2, 3), + ('video_lite', False, 0, 3), + ('video_full', False, 1, 3), + ('video_heavy', False, 2, 3)) + def test_on_image(self, static_image_mode, model_complexity, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') image = cv2.imread(image_path) - - with mp_pose.Pose(static_image_mode=static_image_mode) as pose: - for _ in range(num_frames): + with mp_pose.Pose(static_image_mode=static_image_mode, + model_complexity=model_complexity) as pose: + for idx in range(num_frames): results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._annotate(image.copy(), results, idx) self._assert_diff_less( self._landmarks_list_to_array(results.pose_landmarks, image.shape)[:, :2], - EXPECTED_FULL_BODY_LANDMARKS, DIFF_THRESHOLD) + EXPECTED_POSE_LANDMARKS, DIFF_THRESHOLD) @parameterized.named_parameters( - ('full_body', False, 'pose_squats.full_body.npz'), - ('upper_body', True, 'pose_squats.upper_body.npz')) - def test_on_video(self, upper_body_only, expected_name): + ('full', 1, 'pose_squats.full.npz')) + def test_on_video(self, model_complexity, expected_name): """Tests pose models on a video.""" # If set to `True` will dump actual predictions to .npz and JSON files. dump_predictions = False @@ -120,8 +111,9 @@ class PoseTest(parameterized.TestCase): # Predict pose landmarks for each frame. video_cap = cv2.VideoCapture(video_path) actual_per_frame = [] - with mp_pose.Pose( - static_image_mode=False, upper_body_only=upper_body_only) as pose: + frame_idx = 0 + with mp_pose.Pose(static_image_mode=False, + model_complexity=model_complexity) as pose: while True: # Get next frame of the video. success, input_frame = video_cap.read() @@ -135,6 +127,10 @@ class PoseTest(parameterized.TestCase): input_frame.shape) actual_per_frame.append(pose_landmarks) + + input_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR) + self._annotate(input_frame, result, frame_idx) + frame_idx += 1 actual = np.asarray(actual_per_frame) if dump_predictions: diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 1fa71c6e7..da122442e 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -199,6 +199,25 @@ cc_library( }), ) +cc_library( + name = "resource_cache", + hdrs = ["resource_cache.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/functional:function_ref", + ], +) + +cc_test( + name = "resource_cache_test", + srcs = ["resource_cache_test.cc"], + deps = [ + ":resource_cache", + "//mediapipe/framework/port:gtest_main", + ], +) + cc_library( name = "tensor_to_detection", srcs = ["tensor_to_detection.cc"], diff --git a/mediapipe/util/resource_cache.h b/mediapipe/util/resource_cache.h new file mode 100644 index 000000000..4cd869f6a --- /dev/null +++ b/mediapipe/util/resource_cache.h @@ -0,0 +1,181 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_RESOURCE_CACHE_H_ +#define MEDIAPIPE_UTIL_RESOURCE_CACHE_H_ + +#include + +#include "absl/functional/function_ref.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +// Maintains a cache for resources of type `Value`, where the type of the +// resource (e.g., image dimension for an image pool) is described bye the `Key` +// type. The `Value` type must include an unset value, with implicit conversion +// to bool reflecting set/unset state. +template +class ResourceCache { + public: + Value Lookup( + const Key& key, + absl::FunctionRef create) { + auto map_it = map_.find(key); + Entry* entry; + if (map_it == map_.end()) { + std::tie(map_it, std::ignore) = + map_.emplace(std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple(key)); + entry = &map_it->second; + CHECK_EQ(entry->request_count, 0); + entry->request_count = 1; + entry_list_.Append(entry); + if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); + } else { + entry = &map_it->second; + ++entry->request_count; + Entry* larger = entry->prev; + while (larger != nullptr && + larger->request_count < entry->request_count) { + larger = larger->prev; + } + if (larger != entry->prev) { + entry_list_.Remove(entry); + entry_list_.InsertAfter(entry, larger); + } + } + if (!entry->value) { + entry->value = create(entry->key, entry->request_count); + } + ++total_request_count_; + return entry->value; + } + + std::vector Evict(int max_count, int request_count_scrub_interval) { + std::vector evicted; + + // Remove excess entries. + while (entry_list_.size() > max_count) { + Entry* victim = entry_list_.tail(); + evicted.emplace_back(std::move(victim->value)); + entry_list_.Remove(victim); + map_.erase(victim->key); + } + // Every request_count_scrub_interval, halve the request counts, and + // remove entries which have fallen to 0. + // This keeps sporadic requests from accumulating and eventually exceeding + // the minimum request threshold for allocating a pool. Also, it means that + // if the request regimen changes (e.g. a graph was always requesting a + // large size, but then switches to a small size to save memory or CPU), the + // pool can quickly adapt to it. + bool scrub = total_request_count_ >= request_count_scrub_interval; + if (scrub) { + total_request_count_ = 0; + for (Entry* entry = entry_list_.head(); entry != nullptr;) { + entry->request_count /= 2; + Entry* next = entry->next; + if (entry->request_count == 0) { + evicted.emplace_back(std::move(entry->value)); + entry_list_.Remove(entry); + map_.erase(entry->key); + } + entry = next; + } + } + return evicted; + } + + private: + struct Entry { + Entry(const Key& key) : key(key) {} + Entry* prev = nullptr; + Entry* next = nullptr; + int request_count = 0; + Key key; + Value value; + }; + + // Unlike std::list, this is an intrusive list, meaning that the prev and next + // pointers live inside the element. Apart from not requiring an extra + // allocation, this means that once we look up an entry by key in the pools_ + // map we do not need to look it up separately in the list. + // + class EntryList { + public: + void Prepend(Entry* entry) { + if (head_ == nullptr) { + head_ = tail_ = entry; + } else { + entry->next = head_; + head_->prev = entry; + head_ = entry; + } + ++size_; + } + void Append(Entry* entry) { + if (tail_ == nullptr) { + head_ = tail_ = entry; + } else { + tail_->next = entry; + entry->prev = tail_; + tail_ = entry; + } + ++size_; + } + void Remove(Entry* entry) { + if (entry == head_) { + head_ = entry->next; + } else { + entry->prev->next = entry->next; + } + if (entry == tail_) { + tail_ = entry->prev; + } else { + entry->next->prev = entry->prev; + } + entry->prev = nullptr; + entry->next = nullptr; + --size_; + } + void InsertAfter(Entry* entry, Entry* after) { + if (after != nullptr) { + entry->next = after->next; + if (entry->next) entry->next->prev = entry; + entry->prev = after; + after->next = entry; + ++size_; + } else { + Prepend(entry); + } + } + + Entry* head() { return head_; } + Entry* tail() { return tail_; } + size_t size() { return size_; } + + private: + Entry* head_ = nullptr; + Entry* tail_ = nullptr; + size_t size_ = 0; + }; + + std::unordered_map map_; + EntryList entry_list_; + int total_request_count_ = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_RESOURCE_CACHE_H_ diff --git a/mediapipe/util/resource_cache_test.cc b/mediapipe/util/resource_cache_test.cc new file mode 100644 index 000000000..ddba117e4 --- /dev/null +++ b/mediapipe/util/resource_cache_test.cc @@ -0,0 +1,150 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/resource_cache.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +#define EXPECT_BETWEEN(low, high, value) \ + do { \ + EXPECT_LE((low), (value)); \ + EXPECT_GE((high), (value)); \ + } while (0) + +namespace mediapipe { +namespace { + +using ::testing::_; +using ::testing::MockFunction; +using ::testing::Return; + +using IntCache = ResourceCache, std::hash>; +using MockCreate = + MockFunction(const int& key, int request_count)>; + +TEST(ResourceCacheTest, ReturnsNull) { + IntCache cache; + MockCreate create; + + EXPECT_CALL(create, Call(1, 1)).WillOnce(Return(nullptr)); + EXPECT_EQ(nullptr, cache.Lookup(1, create.AsStdFunction())); +} + +TEST(ResourceCacheTest, CountsRequests) { + IntCache cache; + MockCreate create11; + MockCreate create12; + MockCreate create21; + + EXPECT_CALL(create11, Call(1, 1)).WillOnce(Return(nullptr)); + EXPECT_CALL(create12, Call(1, 2)).WillOnce(Return(nullptr)); + EXPECT_CALL(create11, Call(2, 1)).WillOnce(Return(nullptr)); + + // Verify that request counts are updated, and separate by key. + EXPECT_EQ(nullptr, cache.Lookup(1, create11.AsStdFunction())); + EXPECT_EQ(nullptr, cache.Lookup(1, create12.AsStdFunction())); + EXPECT_EQ(nullptr, cache.Lookup(2, create11.AsStdFunction())); +} + +TEST(ResourceCacheTest, CachesValues) { + IntCache cache; + auto value1 = std::make_shared(1); + auto value2 = std::make_shared(2); + + MockCreate create1; + MockCreate create2; + MockCreate no_create; + + EXPECT_CALL(create1, Call(1, 1)).WillOnce(Return(value1)); + EXPECT_CALL(create2, Call(2, 1)).WillOnce(Return(value2)); + EXPECT_CALL(no_create, Call(_, _)).Times(0); + // Calls creating the values. + EXPECT_EQ(value1, cache.Lookup(1, create1.AsStdFunction())); + EXPECT_EQ(value2, cache.Lookup(2, create2.AsStdFunction())); + + // Calls returning existing values. + EXPECT_EQ(value1, cache.Lookup(1, no_create.AsStdFunction())); + EXPECT_EQ(value2, cache.Lookup(2, no_create.AsStdFunction())); +} + +TEST(ResourceCacheTest, EvictToMaxSize) { + IntCache cache; + MockCreate create; + + EXPECT_CALL(create, Call(_, 1)) + .WillRepeatedly([](int key, int request_count) { + return std::make_shared(key); + }); + + // Add three entries. + EXPECT_NE(nullptr, cache.Lookup(1, create.AsStdFunction())); + EXPECT_NE(nullptr, cache.Lookup(2, create.AsStdFunction())); + EXPECT_NE(nullptr, cache.Lookup(3, create.AsStdFunction())); + + // Keep only two. + auto evicted = cache.Evict(/*max_count=*/2, + /*request_count_scrub_interval=*/4); + ASSERT_EQ(1, evicted.size()); + int evicted_entry = *evicted[0]; + EXPECT_BETWEEN(1, 3, evicted_entry); + + MockCreate no_create; + EXPECT_CALL(no_create, Call(_, 1)).WillOnce(Return(nullptr)); + EXPECT_EQ(nullptr, cache.Lookup(evicted_entry, no_create.AsStdFunction())); + for (int key = 1; key <= 3; key++) { + if (key != evicted_entry) { + EXPECT_NE(nullptr, cache.Lookup(key, no_create.AsStdFunction())); + } + } +} + +TEST(ResourceCacheTest, EvictWithScrub) { + IntCache cache; + MockCreate create; + + EXPECT_CALL(create, Call(_, 1)) + .WillRepeatedly([](int key, int request_count) { + return std::make_shared(key); + }); + + EXPECT_NE(nullptr, cache.Lookup(1, create.AsStdFunction())); + EXPECT_NE(nullptr, cache.Lookup(2, create.AsStdFunction())); + EXPECT_NE(nullptr, cache.Lookup(3, create.AsStdFunction())); + + // 3 entries, total request count 4, so nothing evicted from this call. + EXPECT_TRUE( + cache.Evict(/*max_count=*/3, /*request_count_scrub_interval=*/4).empty()); + + // Increment request counts. + EXPECT_NE(nullptr, cache.Lookup(1, create.AsStdFunction())); + EXPECT_NE(nullptr, cache.Lookup(3, create.AsStdFunction())); + + // Expected to evict entry 2, and halve request counts for the other two + // entries. + auto evicted = + cache.Evict(/*max_count=*/3, /*request_count_scrub_interval=*/5); + ASSERT_EQ(1, evicted.size()); + EXPECT_EQ(2, *evicted[0]); + + // Increment request count. + EXPECT_NE(nullptr, cache.Lookup(3, create.AsStdFunction())); + // Expected to evict entry 1. + evicted = cache.Evict(/*max_count=*/3, /*request_count_scrub_interval=*/1); + ASSERT_EQ(1, evicted.size()); + EXPECT_EQ(1, *evicted[0]); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index 9685c5e72..cfd5c23c2 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -392,7 +392,7 @@ struct RegionFlowComputation::FrameTrackingData { void BuildPyramid(int levels, int window_size, bool with_derivative) { if (use_cv_tracking) { -#if CV_MAJOR_VERSION == 3 +#if CV_MAJOR_VERSION >= 3 // No-op if not called for opencv 3.0 (c interface computes // pyramids in place). // OpenCV changed how window size gets specified from our radius setting @@ -761,7 +761,7 @@ RegionFlowComputation::RegionFlowComputation( // Tracking algorithm dependent on cv support and flag. use_cv_tracking_ = options_.tracking_options().use_cv_tracking_algorithm(); -#if CV_MAJOR_VERSION != 3 +#if CV_MAJOR_VERSION < 3 if (use_cv_tracking_) { LOG(WARNING) << "Compiled without OpenCV 3.0 but cv_tracking_algorithm " << "was requested. Falling back to older algorithm"; @@ -2577,7 +2577,7 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, input_mean, gain_image_.get()); } -#if CV_MAJOR_VERSION == 3 +#if CV_MAJOR_VERSION >= 3 // OpenCV changed how window size gets specified from our radius setting // < 2.2 to diameter in 2.2+. const cv::Size cv_window_size(track_win_size * 2 + 1, track_win_size * 2 + 1); @@ -2599,7 +2599,7 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { -#if CV_MAJOR_VERSION == 3 +#if CV_MAJOR_VERSION >= 3 if (gain_correction) { if (!frame1_gain_reference) { input_frame1 = cv::_InputArray(*gain_image_); @@ -2788,7 +2788,7 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, feature_status_.resize(num_to_verify); if (use_cv_tracking_) { -#if CV_MAJOR_VERSION == 3 +#if CV_MAJOR_VERSION >= 3 cv::calcOpticalFlowPyrLK(input_frame2, input_frame1, verify_features, verify_features_tracked, feature_status_, verify_track_error, cv_window_size, diff --git a/requirements.txt b/requirements.txt index 709a31d3e..fcb9ad315 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ absl-py attrs>=19.1.0 -dataclasses numpy opencv-contrib-python protobuf>=3.11.4 diff --git a/setup.py b/setup.py index 8a4c71274..c19ecf992 100644 --- a/setup.py +++ b/setup.py @@ -85,8 +85,10 @@ def _check_bazel(): sys.exit(-1) try: bazel_version_info = subprocess.check_output(['bazel', '--version']) - except subprocess.CalledProcessError: - sys.stderr.write('fail to get bazel version by $ bazel --version.') + except subprocess.CalledProcessError as e: + sys.stderr.write('fail to get bazel version by $ bazel --version: ' + + str(e.output)) + sys.exit(-1) bazel_version_info = bazel_version_info.decode('UTF-8').strip() version = bazel_version_info.split('bazel ')[1].split('-')[0] version_segments = version.split('.') diff --git a/setup_android_sdk_and_ndk.sh b/setup_android_sdk_and_ndk.sh index 11f33555e..e574552f0 100644 --- a/setup_android_sdk_and_ndk.sh +++ b/setup_android_sdk_and_ndk.sh @@ -17,7 +17,7 @@ # Script to setup Android SDK and NDK. # usage: # $ cd -# $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r18b +# $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r19c set -e @@ -54,8 +54,8 @@ fi if [ -z $3 ] then - echo "Warning: ndk_version (argument 3) is not specified. Fallback to r18b." - ndk_version="r18b" + echo "Warning: ndk_version (argument 3) is not specified. Fallback to r19c." + ndk_version="r19c" fi if [ -d "$android_sdk_path" ] diff --git a/third_party/BUILD b/third_party/BUILD index 654f0cb72..5800098fb 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -286,3 +286,24 @@ android_library( "@maven//:androidx_camera_camera_lifecycle", ], ) + +java_plugin( + name = "autovalue_plugin", + processor_class = "com.google.auto.value.processor.AutoValueProcessor", + deps = [ + "@maven//:com_google_auto_value_auto_value", + "@maven//:com_google_auto_value_auto_value_annotations", + ], +) + +java_library( + name = "autovalue", + exported_plugins = [ + ":autovalue_plugin", + ], + neverlink = 1, + exports = [ + "@maven//:com_google_auto_value_auto_value", + "@maven//:com_google_auto_value_auto_value_annotations", + ], +) diff --git a/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff b/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff index 89e80a9c3..471cf2aa6 100644 --- a/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff +++ b/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff @@ -45,8 +45,8 @@ index 0b5e6ee..be5a506 100644 + const int level = AndroidLogLevel((int)data_->severity_); + const std::string text = std::string(data_->message_text_); + __android_log_write(level, "native", text.substr(0,data_->num_chars_to_log_).c_str()); -+#endif // !defined(__ANDROID__) ++#endif // defined(__ANDROID__) + if (append_newline) { // Fix the ostrstream back how it was before we screwed with it. - // It's 99.44% certain that we don't need to worry about doing this. \ No newline at end of file + // It's 99.44% certain that we don't need to worry about doing this. diff --git a/third_party/opencv_android.BUILD b/third_party/opencv_android.BUILD index 3bdfc88d3..6c00457e7 100644 --- a/third_party/opencv_android.BUILD +++ b/third_party/opencv_android.BUILD @@ -29,3 +29,14 @@ OPENCVANDROIDSDK_JNI_PATH = "sdk/native/jni/" "x86", "x86_64", ]] + +[alias( + name = "libopencv_java3_so_" + arch, + actual = OPENCVANDROIDSDK_NATIVELIBS_PATH + arch + "/" + OPENCV_LIBRARY_NAME, + visibility = ["//visibility:public"], +) for arch in [ + "arm64-v8a", + "armeabi-v7a", + "x86", + "x86_64", +]] diff --git a/third_party/opencv_windows.BUILD b/third_party/opencv_windows.BUILD index 23d3fa591..ecf788ee0 100644 --- a/third_party/opencv_windows.BUILD +++ b/third_party/opencv_windows.BUILD @@ -7,15 +7,31 @@ exports_files(["LICENSE"]) OPENCV_VERSION = "3410" # 3.4.10 +config_setting( + name = "opt_build", + values = {"compilation_mode": "opt"}, +) + +config_setting( + name = "dbg_build", + values = {"compilation_mode": "dbg"}, +) + # The following build rule assumes that the executable "opencv-3.4.10-vc14_vc15.exe" # is downloaded and the files are extracted to local. # If you install OpenCV separately, please modify the build rule accordingly. cc_library( name = "opencv", - srcs = [ - "x64/vc15/lib/opencv_world" + OPENCV_VERSION + ".lib", - "x64/vc15/bin/opencv_world" + OPENCV_VERSION + ".dll", - ], + srcs = select({ + ":opt_build": [ + "x64/vc15/lib/opencv_world" + OPENCV_VERSION + ".lib", + "x64/vc15/bin/opencv_world" + OPENCV_VERSION + ".dll", + ], + ":dbg_build": [ + "x64/vc15/lib/opencv_world" + OPENCV_VERSION + "d.lib", + "x64/vc15/bin/opencv_world" + OPENCV_VERSION + "d.dll", + ], + }), hdrs = glob(["include/opencv2/**/*.h*"]), includes = ["include/"], linkstatic = 1,