diff --git a/.bazelrc b/.bazelrc index 5a586f3ca..37a0bc114 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,25 +5,28 @@ common --experimental_repo_remote_exec # Basic build settings build --jobs 128 -build --define='absl=1' +build --define='absl=1' # for gtest build --enable_platform_specific_config +# Enable stack traces +test --test_env="GTEST_INSTALL_FAILURE_SIGNAL_HANDLER=1" + # Linux -build:linux --cxxopt=-std=c++14 -build:linux --host_cxxopt=-std=c++14 +build:linux --cxxopt=-std=c++17 +build:linux --host_cxxopt=-std=c++17 build:linux --copt=-w # windows -build:windows --cxxopt=/std:c++14 -build:windows --host_cxxopt=/std:c++14 +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 build:windows --copt=/w # For using M_* math constants on Windows with MSVC. build:windows --copt=/D_USE_MATH_DEFINES build:windows --host_copt=/D_USE_MATH_DEFINES # macOS -build:macos --cxxopt=-std=c++14 -build:macos --host_cxxopt=-std=c++14 +build:macos --cxxopt=-std=c++17 +build:macos --host_cxxopt=-std=c++17 build:macos --copt=-w # Sets the default Apple platform to macOS. @@ -83,3 +86,9 @@ build:ios_fat --watchos_cpus=armv7k build:darwin_x86_64 --apple_platform_type=macos build:darwin_x86_64 --macos_minimum_os=10.12 build:darwin_x86_64 --cpu=darwin_x86_64 + +# This bazelrc file is meant to be written by a setup script. +try-import %workspace%/.configure.bazelrc + +# This bazelrc file can be used for user-specific custom build settings. +try-import %workspace%/.user.bazelrc diff --git a/.gitignore b/.gitignore index aa1bde53e..b3a881711 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ bazel-* mediapipe/MediaPipe.xcodeproj mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user mediapipe/provisioning_profile.mobileprovision +.configure.bazelrc +.user.bazelrc diff --git a/Dockerfile b/Dockerfile index 6267a5f00..dc3b034a2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -54,7 +54,7 @@ RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=2.0.0 +ARG BAZEL_VERSION=3.4.1 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/MANIFEST.in b/MANIFEST.in index 33a48428c..1994721f3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,5 +7,6 @@ include MANIFEST.in include README.md include requirements.txt -recursive-include mediapipe/modules *.tflite *.txt -recursive-include mediapipe/graphs *.binarypb +recursive-include mediapipe/modules *.tflite *.txt *.binarypb +exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite +exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite diff --git a/README.md b/README.md index d7287dc35..06fa39b5e 100644 --- a/README.md +++ b/README.md @@ -8,47 +8,60 @@ nav_order: 1 -------------------------------------------------------------------------------- -## Cross-platform ML solutions made simple +## Live ML anywhere -[MediaPipe](https://google.github.io/mediapipe/) is the simplest way for researchers -and developers to build world-class ML solutions and applications for mobile, -desktop/cloud, web and IoT devices. +[MediaPipe](https://google.github.io/mediapipe/) offers cross-platform, customizable +ML solutions for live and streaming media. ![accelerated.png](docs/images/accelerated_small.png) | ![cross_platform.png](docs/images/cross_platform_small.png) :------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------: -***End-to-End acceleration***: *built-in fast ML inference and processing accelerated even on common hardware* | ***Build one, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* +***End-to-End acceleration***: *Built-in fast ML inference and processing accelerated even on common hardware* | ***Build once, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* ![ready_to_use.png](docs/images/ready_to_use_small.png) | ![open_source.png](docs/images/open_source_small.png) ***Ready-to-use solutions***: *Cutting-edge ML solutions demonstrating full power of the framework* | ***Free and open source***: *Framework and solutions both under Apache 2.0, fully extensible and customizable* ## ML solutions in MediaPipe -Face Detection | Face Mesh | Iris | Hands | Pose | Hair Segmentation -:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :---------------: -[![face_detection](docs/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](docs/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](docs/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](docs/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](docs/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](docs/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) +Face Detection | Face Mesh | Iris | Hands | Pose | Holistic +:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :------: +[![face_detection](docs/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](docs/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](docs/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](docs/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](docs/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](docs/images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) -Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT -:----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: -[![object_detection](docs/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](docs/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](docs/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](docs/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](docs/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) +Hair Segmentation | Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT +:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: +[![hair_segmentation](docs/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](docs/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](docs/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](docs/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](docs/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](docs/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) -[]() | Android | iOS | Desktop | Python | Web | Coral -:---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ -[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | -[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ | -[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | -[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | -[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ -[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | -[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | -[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | -[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | -[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | -[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | +[]() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) +:---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ +[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | +[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | +[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ +[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | +[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | +[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | +[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | +[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | + +See also +[MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) +for ML models released in MediaPipe. + +## MediaPipe in Python + +MediaPipe offers customizable Python solutions as a prebuilt Python package on +[PyPI](https://pypi.org/project/mediapipe/), which can be installed simply with +`pip install mediapipe`. It also provides tools for users to build their own +solutions. Please see +[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) +for more info. ## MediaPipe on the Web @@ -89,7 +102,13 @@ run code search using ## Publications -* [Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog) +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) + in Google AI Blog +* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) + in Google AI Blog +* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html) + in Google Developers Blog +* [Instant Motion Tracking With MediaPipe](https://developers.googleblog.com/2020/08/instant-motion-tracking-with-mediapipe.html) in Google Developers Blog * [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) in Google AI Blog diff --git a/WORKSPACE b/WORKSPACE index eb3efd275..32b466e6c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -10,14 +10,14 @@ http_archive( sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", ) load("@bazel_skylib//lib:versions.bzl", "versions") -versions.check(minimum_bazel_version = "2.0.0") +versions.check(minimum_bazel_version = "3.4.0") -# ABSL cpp library lts_2020_02_25 +# ABSL cpp library lts_2020_09_23 http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/20200923.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -26,8 +26,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20200225", - sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353" + strip_prefix = "abseil-cpp-20200923", + sha256 = "b3744a4f7a249d5eaf2309daad597631ce77ea62e0fc6abffbab4b4c3dc0fc08" ) http_archive( @@ -38,8 +38,8 @@ http_archive( http_archive( name = "rules_foreign_cc", - strip_prefix = "rules_foreign_cc-master", - url = "https://github.com/bazelbuild/rules_foreign_cc/archive/master.zip", + strip_prefix = "rules_foreign_cc-main", + url = "https://github.com/bazelbuild/rules_foreign_cc/archive/main.zip", ) load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") @@ -99,7 +99,7 @@ http_archive( "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", ], patches = [ - "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff" + "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff", ], patch_args = [ "-p1", @@ -170,15 +170,15 @@ http_archive( http_archive( name = "ceres_solver", - url = "https://github.com/ceres-solver/ceres-solver/archive/1.14.0.zip", + url = "https://github.com/ceres-solver/ceres-solver/archive/2.0.0.zip", patches = [ "@//third_party:ceres_solver_compatibility_fixes.diff" ], patch_args = [ "-p1", ], - strip_prefix = "ceres-solver-1.14.0", - sha256 = "5ba6d0db4e784621fda44a50c58bb23b0892684692f0c623e2063f9c19f192f1" + strip_prefix = "ceres-solver-2.0.0", + sha256 = "db12d37b4cebb26353ae5b7746c7985e00877baa8e7b12dc4d3a1512252fff3b" ) http_archive( @@ -324,8 +324,9 @@ maven_install( "androidx.lifecycle:lifecycle-common:2.2.0", "androidx.annotation:annotation:aar:1.1.0", "androidx.appcompat:appcompat:aar:1.1.0-rc01", - "androidx.camera:camera-core:aar:1.0.0-alpha06", - "androidx.camera:camera-camera2:aar:1.0.0-alpha06", + "androidx.camera:camera-core:1.0.0-beta10", + "androidx.camera:camera-camera2:1.0.0-beta10", + "androidx.camera:camera-lifecycle:1.0.0-beta10", "androidx.constraintlayout:constraintlayout:aar:1.1.3", "androidx.core:core:aar:1.1.0-rc03", "androidx.legacy:legacy-support-v4:aar:1.0.0", @@ -337,6 +338,7 @@ maven_install( "com.google.flogger:flogger-system-backend:0.3.1", "com.google.flogger:flogger:0.3.1", "com.google.guava:guava:27.0.1-android", + "com.google.guava:listenablefuture:1.0", "junit:junit:4.12", "org.hamcrest:hamcrest-library:1.3", ], @@ -362,9 +364,9 @@ http_archive( ) #Tensorflow repo should always go after the other external dependencies. -# 2020-08-30 -_TENSORFLOW_GIT_COMMIT = "57b009e31e59bd1a7ae85ef8c0232ed86c9b71db" -_TENSORFLOW_SHA256= "de7f5f06204e057383028c7e53f3b352cdf85b3a40981b1a770c9a415a792c0e" +# 2020-12-09 +_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a" +_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0" http_archive( name = "org_tensorflow", urls = [ @@ -372,6 +374,7 @@ http_archive( ], patches = [ "@//third_party:org_tensorflow_compatibility_fixes.diff", + "@//third_party:org_tensorflow_objc_cxx17.diff", ], patch_args = [ "-p1", diff --git a/build_android_examples.sh b/build_android_examples.sh index 58d6c681e..75ec54199 100644 --- a/build_android_examples.sh +++ b/build_android_examples.sh @@ -89,7 +89,6 @@ for app in ${apps}; do fi target="${app}:${target_name}" bin="${bin_dir}/${app}/${target_name}.apk" - apk="${out_dir}/${target_name}.apk" echo "=== Target: ${target}" @@ -99,32 +98,36 @@ for app in ${apps}; do if [[ $strip == true ]]; then bazel_flags+=(--linkopt=-s) fi - - if [[ ${app_name} == "templatematchingcpu" ]]; then - switch_to_opencv_4 - fi - bazel "${bazel_flags[@]}" - cp -f "${bin}" "${apk}" - if [[ ${app_name} == "templatematchingcpu" ]]; then - switch_to_opencv_3 - fi fi if [[ ${app_name} == "objectdetection3d" ]]; then - orig_apk=${apk} - apk="${out_dir}/${target_name}_shoes.apk" - cp -f "${orig_apk}" "${apk}" - apks+=(${apk}) - - apk="${out_dir}/${target_name}_chairs.apk" + categories=("shoe" "chair" "cup" "camera" "shoe_1stage" "chair_1stage") + for category in "${categories[@]}"; do + apk="${out_dir}/${target_name}_${category}.apk" + if [[ $install_only == false ]]; then + bazel_flags_extended=("${bazel_flags[@]}") + if [[ ${category} != "shoe" ]]; then + bazel_flags_extended+=(--define ${category}=true) + fi + bazel "${bazel_flags_extended[@]}" + cp -f "${bin}" "${apk}" + fi + apks+=(${apk}) + done + else + apk="${out_dir}/${target_name}.apk" if [[ $install_only == false ]]; then - bazel_flags+=(--define chair=true) + if [[ ${app_name} == "templatematchingcpu" ]]; then + switch_to_opencv_4 + fi bazel "${bazel_flags[@]}" cp -f "${bin}" "${apk}" + if [[ ${app_name} == "templatematchingcpu" ]]; then + switch_to_opencv_3 + fi fi + apks+=(${apk}) fi - - apks+=(${apk}) fi done diff --git a/build_desktop_examples.sh b/build_desktop_examples.sh new file mode 100644 index 000000000..5e493e79c --- /dev/null +++ b/build_desktop_examples.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +# +# Script to build/run all MediaPipe desktop example apps (with webcam input). +# +# To build and run all apps and store them in out_dir: +# $ ./build_ios_examples.sh -d out_dir +# Omitting -d and the associated directory saves all generated apps in the +# current directory. +# To build all apps and store them in out_dir: +# $ ./build_ios_examples.sh -d out_dir -b +# Omitting -d and the associated directory saves all generated apps in the +# current directory. +# To run all apps already stored in out_dir: +# $ ./build_ios_examples.sh -d out_dir -r +# Omitting -d and the associated directory assumes all apps are in the current +# directory. + +set -e + +out_dir="." +build_only=false +run_only=false +app_dir="mediapipe/examples/desktop" +bin_dir="bazel-bin" +declare -a default_bazel_flags=(build -c opt --define MEDIAPIPE_DISABLE_GPU=1) + +while [[ -n $1 ]]; do + case $1 in + -d) + shift + out_dir=$1 + ;; + -b) + build_only=true + ;; + -r) + run_only=true + ;; + *) + echo "Unsupported input argument $1." + exit 1 + ;; + esac + shift +done + +echo "app_dir: $app_dir" +echo "out_dir: $out_dir" + +declare -a bazel_flags + +apps="${app_dir}/*" +for app in ${apps}; do + if [[ -d "${app}" ]]; then + target_name=${app##*/} + if [[ "${target_name}" == "autoflip" || + "${target_name}" == "hello_world" || + "${target_name}" == "media_sequence" || + "${target_name}" == "object_detection_3d" || + "${target_name}" == "template_matching" || + "${target_name}" == "youtube8m" ]]; then + continue + fi + target="${app}:${target_name}_cpu" + + echo "=== Target: ${target}" + + if [[ $run_only == false ]]; then + bazel_flags=("${default_bazel_flags[@]}") + bazel_flags+=(${target}) + + bazel "${bazel_flags[@]}" + cp -f "${bin_dir}/${app}/"*"_cpu" "${out_dir}" + fi + if [[ $build_only == false ]]; then + if [[ ${target_name} == "object_tracking" ]]; then + graph_name="tracking/object_detection_tracking" + elif [[ ${target_name} == "upper_body_pose_tracking" ]]; then + graph_name="pose_tracking/upper_body_pose_tracking" + else + graph_name="${target_name}/${target_name}" + fi + if [[ ${target_name} == "holistic_tracking" || + ${target_name} == "iris_tracking" || + ${target_name} == "pose_tracking" || + ${target_name} == "upper_body_pose_tracking" ]]; then + graph_suffix="cpu" + else + graph_suffix="desktop_live" + fi + GLOG_logtostderr=1 "${out_dir}/${target_name}_cpu" \ + --calculator_graph_config_file=mediapipe/graphs/"${graph_name}_${graph_suffix}.pbtxt" + fi + fi +done diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 0ee3473e6..3e1236aaa 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -67,26 +67,26 @@ class CalculatorBase { // The subclasses of CalculatorBase must implement GetContract. // ... - static ::MediaPipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Open is called before any Process() calls, on a freshly constructed // calculator. Subclasses may override this method to perform necessary // setup, and possibly output Packets and/or set output streams' headers. // ... - virtual ::MediaPipe::Status Open(CalculatorContext* cc) { - return ::MediaPipe::OkStatus(); + virtual absl::Status Open(CalculatorContext* cc) { + return absl::OkStatus(); } // Processes the incoming inputs. May call the methods on cc to access // inputs and produce outputs. // ... - virtual ::MediaPipe::Status Process(CalculatorContext* cc) = 0; + virtual absl::Status Process(CalculatorContext* cc) = 0; // Is called if Open() was called and succeeded. Is called either // immediately after processing is complete or after a graph run has ended // (if an error occurred in the graph). ... - virtual ::MediaPipe::Status Close(CalculatorContext* cc) { - return ::MediaPipe::OkStatus(); + virtual absl::Status Close(CalculatorContext* cc) { + return absl::OkStatus(); } ... @@ -199,7 +199,7 @@ name and index number. In the function below input are output are identified: // c++ Code snippet describing the SomeAudioVideoCalculator GetContract() method class SomeAudioVideoCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); // SetAny() is used to specify that whatever the type of the // stream is, it's acceptable. This does not mean that any @@ -209,13 +209,13 @@ class SomeAudioVideoCalculator : public CalculatorBase { cc->Outputs().Tag("VIDEO").Set(); cc->Outputs().Get("AUDIO", 0).Set(); cc->Outputs().Get("AUDIO", 1).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` ## Processing -`Process()` called on a non-source node must return `::mediapipe::OkStatus()` to +`Process()` called on a non-source node must return `absl::OkStatus()` to indicate that all went well, or any other status code to signal an error If a non-source calculator returns `tool::StatusStop()`, then this signals the @@ -224,12 +224,12 @@ input streams will be closed (and remaining Packets will propagate through the graph). A source node in a graph will continue to have `Process()` called on it as long -as it returns `::mediapipe::OkStatus(`). To indicate that there is no more data -to be generated return `tool::StatusStop()`. Any other status indicates an error -has occurred. +as it returns `absl::OkStatus(`). To indicate that there is no more data to be +generated return `tool::StatusStop()`. Any other status indicates an error has +occurred. -`Close()` returns `::mediapipe::OkStatus()` to indicate success. Any other -status indicates a failure. +`Close()` returns `absl::OkStatus()` to indicate success. Any other status +indicates a failure. Here is the basic `Process()` function. It uses the `Input()` method (which can be used only if the calculator has a single input) to request its input data. It @@ -238,13 +238,13 @@ and does the calculations. When done it releases the pointer when adding it to the output stream. ```c++ -::util::Status MyCalculator::Process() { +absl::Status MyCalculator::Process() { const Matrix& input = Input()->Get(); std::unique_ptr output(new Matrix(input.rows(), input.cols())); // do your magic here.... // output->row(n) = ... Output()->Add(output.release(), InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` @@ -312,7 +312,7 @@ namespace mediapipe { // class PacketClonerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const int tick_signal_index = cc->Inputs().NumEntries() - 1; // cc->Inputs().NumEntries() returns the number of input streams // for the PacketClonerCalculator @@ -322,10 +322,10 @@ class PacketClonerCalculator : public CalculatorBase { cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } cc->Inputs().Index(tick_signal_index).SetAny(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { tick_signal_index_ = cc->Inputs().NumEntries() - 1; current_.resize(tick_signal_index_); // Pass along the header for each stream if present. @@ -336,10 +336,10 @@ class PacketClonerCalculator : public CalculatorBase { // the header for the input stream of index i } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Store input signals. for (int i = 0; i < tick_signal_index_; ++i) { if (!cc->Inputs().Index(i).Value().IsEmpty()) { @@ -364,7 +364,7 @@ class PacketClonerCalculator : public CalculatorBase { } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/docs/framework_concepts/gpu.md b/docs/framework_concepts/gpu.md index 77d566e8d..8f9df6067 100644 --- a/docs/framework_concepts/gpu.md +++ b/docs/framework_concepts/gpu.md @@ -66,10 +66,10 @@ calculator derived from base class GlSimpleCalculator. The GPU calculator // See GlSimpleCalculator for inputs, outputs and input side packets. class LuminanceCalculator : public GlSimpleCalculator { public: - ::mediapipe::Status GlSetup() override; - ::mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - ::mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, + const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -77,8 +77,8 @@ class LuminanceCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(LuminanceCalculator); -::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -128,7 +128,7 @@ REGISTER_CALCULATOR(LuminanceCalculator); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index 83f95e5bb..2bbdd6856 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -219,23 +219,23 @@ packet timestamps 0, 1, 2, 3, ... ```c++ class UnitDelayCalculator : public Calculator { public: -  static ::util::Status FillExpectations( +  static absl::Status FillExpectations(      const CalculatorOptions& extendable_options, PacketTypeSet* inputs,      PacketTypeSet* outputs, PacketTypeSet* input_side_packets) {    inputs->Index(0)->Set("An integer.");    outputs->Index(0)->Set("The input delayed by one time unit."); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } -  ::util::Status Open() final { +  absl::Status Open() final {    Output()->Add(new int(0), Timestamp(0)); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } -  ::util::Status Process() final { +  absl::Status Process() final {    const Packet& packet = Input()->Value();    Output()->AddPacket(packet.At(packet.Timestamp().NextAllowedInStream())); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } }; ``` diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md new file mode 100644 index 000000000..855f5fa29 --- /dev/null +++ b/docs/getting_started/android.md @@ -0,0 +1,191 @@ +--- +layout: default +title: MediaPipe on Android +parent: Getting Started +has_children: true +has_toc: false +nav_order: 1 +--- + +# MediaPipe on Android +{: .no_toc } + +1. TOC +{:toc} +--- + +Please follow instructions below to build Android example apps in the supported +MediaPipe [solutions](../solutions/solutions.md). To learn more about these +example apps, start from [Hello World! on Android](./hello_world_android.md). To +incorporate MediaPipe into an existing Android Studio project, see these +[instructions](./android_archive_library.md) that use Android Archive (AAR) and +Gradle. + +## Building Android example apps + +### Prerequisite + +* Install MediaPipe following these [instructions](./install.md). +* Setup Java Runtime. +* Setup Android SDK release 28.0.3 and above. +* Setup Android NDK r18b and above. + +MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see +below for Android Studio setup). However, if you prefer using MediaPipe without +Android Studio, please run +[`setup_android_sdk_and_ndk.sh`](https://github.com/google/mediapipe/blob/master/setup_android_sdk_and_ndk.sh) +to download and setup Android SDK and NDK before building any Android example +apps. + +If Android SDK and NDK are already installed (e.g., by Android Studio), set +$ANDROID_HOME and $ANDROID_NDK_HOME to point to the installed SDK and NDK. + +```bash +export ANDROID_HOME= +export ANDROID_NDK_HOME= +``` + +In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch +to a lower Android API level. You can achieve this by specifying `api_level = +$YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or +android_sdk_repository() in the +[`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) file. + +Please verify all the necessary packages are installed. + +* Android SDK Platform API Level 28 or 29 +* Android SDK Build-Tools 28 or 29 +* Android SDK Platform-Tools 28 or 29 +* Android SDK Tools 26.1.1 +* Android NDK 17c or above + +### Option 1: Build with Bazel in Command Line + +Tip: You can run this +[script](https://github.com/google/mediapipe/blob/master/build_android_examples.sh) +to build (and install) all MediaPipe Android example apps. + +1. To build an Android example app, build against the corresponding + `android_binary` build target. For instance, for + [MediaPipe Hands](../solutions/hands.md) the target is `handtrackinggpu` in + the + [BUILD](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD) + file: + + Note: To reduce the binary size, consider appending `--linkopt="-s"` to the + command below to strip symbols. + + ```bash + bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu + ``` + +2. Install it on a device with: + + ```bash + adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/handtrackinggpu.apk + ``` + +### Option 2: Build with Bazel in Android Studio + +The MediaPipe project can be imported into Android Studio using the Bazel +plugins. This allows the MediaPipe examples to be built and modified in Android +Studio. + +To incorporate MediaPipe into an existing Android Studio project, see these +[instructions](./android_archive_library.md) that use Android Archive (AAR) and +Gradle. + +The steps below use Android Studio 3.5 to build and install a MediaPipe example +app: + +1. Install and launch Android Studio 3.5. + +2. Select `Configure` -> `SDK Manager` -> `SDK Platforms`. + + * Verify that Android SDK Platform API Level 28 or 29 is installed. + * Take note of the Android SDK Location, e.g., + `/usr/local/home/Android/Sdk`. + +3. Select `Configure` -> `SDK Manager` -> `SDK Tools`. + + * Verify that Android SDK Build-Tools 28 or 29 is installed. + * Verify that Android SDK Platform-Tools 28 or 29 is installed. + * Verify that Android SDK Tools 26.1.1 is installed. + * Verify that Android NDK 17c or above is installed. + * Take note of the Android NDK Location, e.g., + `/usr/local/home/Android/Sdk/ndk-bundle` or + `/usr/local/home/Android/Sdk/ndk/20.0.5594570`. + +4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point + to the installed SDK and NDK. + + ```bash + export ANDROID_HOME=/usr/local/home/Android/Sdk + + # If the NDK libraries are installed by a previous version of Android Studio, do + export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk-bundle + # If the NDK libraries are installed by Android Studio 3.5, do + export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk/ + ``` + +5. Select `Configure` -> `Plugins` to install `Bazel`. + +6. On Linux, select `File` -> `Settings` -> `Bazel settings`. On macos, select + `Android Studio` -> `Preferences` -> `Bazel settings`. Then, modify `Bazel + binary location` to be the same as the output of `$ which bazel`. + +7. Select `Import Bazel Project`. + + * Select `Workspace`: `/path/to/mediapipe` and select `Next`. + * Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select + `Next`. + * Modify `Project View` to be the following and select `Finish`. + + ``` + directories: + # read project settings, e.g., .bazelrc + . + -mediapipe/objc + -mediapipe/examples/ios + + targets: + //mediapipe/examples/android/...:all + //mediapipe/java/...:all + + android_sdk_platform: android-29 + + sync_flags: + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain + ``` + +8. Select `Bazel` -> `Sync` -> `Sync project with Build files`. + + Note: Even after doing step 4, if you still see the error: `"no such package + '@androidsdk//': Either the path attribute of android_sdk_repository or the + ANDROID_HOME environment variable must be set."`, please modify the + [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) + file to point to your SDK and NDK library locations, as below: + + ``` + android_sdk_repository( + name = "androidsdk", + path = "/path/to/android/sdk" + ) + + android_ndk_repository( + name = "androidndk", + path = "/path/to/android/ndk" + ) + ``` + +9. Connect an Android device to the workstation. + +10. Select `Run...` -> `Edit Configurations...`. + + * Select `Templates` -> `Bazel Command`. + * Enter Target Expression: + `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu` + * Enter Bazel command: `mobile-install`. + * Enter Bazel flags: `-c opt --config=android_arm64`. + * Press the `[+]` button to add the new configuration. + * Select `Run` to run the example app on the connected Android device. diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index feeabfd39..735bd7a39 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -1,8 +1,9 @@ --- layout: default title: MediaPipe Android Archive -parent: Getting Started -nav_order: 7 +parent: MediaPipe on Android +grand_parent: Getting Started +nav_order: 2 --- # MediaPipe Android Archive @@ -44,7 +45,8 @@ each project. 2. Run the Bazel build command to generate the AAR. ```bash - bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \ //path/to/the/aar/build/file:aar_name ``` @@ -85,16 +87,14 @@ each project. Build the MediaPipe binary graph and copy the assets into app/src/main/assets, e.g., for the face detection graph, you need to build and copy - [the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41), - [the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), + [the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41) and - [the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt). + [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite). ```bash bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/ - cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/ - cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/ + cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/ ``` ![Screenshot](../images/mobile/assets_location.png) @@ -132,9 +132,10 @@ each project. implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library - def camerax_version = "1.0.0-alpha06" + def camerax_version = "1.0.0-beta10" implementation "androidx.camera:camera-core:$camerax_version" implementation "androidx.camera:camera-camera2:$camerax_version" + implementation "androidx.camera:camera-lifecycle:$camerax_version" } ``` diff --git a/docs/getting_started/building_examples.md b/docs/getting_started/building_examples.md index 842f1b155..2244b2736 100644 --- a/docs/getting_started/building_examples.md +++ b/docs/getting_started/building_examples.md @@ -2,7 +2,7 @@ layout: default title: Building MediaPipe Examples parent: Getting Started -nav_order: 2 +nav_exclude: true --- # Building MediaPipe Examples @@ -12,496 +12,22 @@ nav_order: 2 {:toc} --- -## Android +### Android -### Prerequisite +Please see these [instructions](./android.md). -* Java Runtime. -* Android SDK release 28.0.3 and above. -* Android NDK r18b and above. +### iOS -MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see -below for Android Studio setup). However, if you prefer using MediaPipe without -Android Studio, please run -[`setup_android_sdk_and_ndk.sh`](https://github.com/google/mediapipe/blob/master/setup_android_sdk_and_ndk.sh) -to download and setup Android SDK and NDK before building any Android example -apps. +Please see these [instructions](./ios.md). -If Android SDK and NDK are already installed (e.g., by Android Studio), set -$ANDROID_HOME and $ANDROID_NDK_HOME to point to the installed SDK and NDK. +### Python -```bash -export ANDROID_HOME= -export ANDROID_NDK_HOME= -``` +Please see these [instructions](./python.md). -In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch -to a lower Android API level. You can achieve this by specifying `api_level = -$YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or -android_sdk_repository() in the -[`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) file. +### JavaScript -Please verify all the necessary packages are installed. +Please see these [instructions](./javascript.md). -* Android SDK Platform API Level 28 or 29 -* Android SDK Build-Tools 28 or 29 -* Android SDK Platform-Tools 28 or 29 -* Android SDK Tools 26.1.1 -* Android NDK 17c or above +### C++ -### Option 1: Build with Bazel in Command Line - -Tip: You can run this -[script](https://github.com/google/mediapipe/blob/master/build_android_examples.sh) -to build (and install) all MediaPipe Android example apps. - -1. To build an Android example app, build against the corresponding - `android_binary` build target. For instance, for - [MediaPipe Hands](../solutions/hands.md) the target is `handtrackinggpu` in - the - [BUILD](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD) - file: - - Note: To reduce the binary size, consider appending `--linkopt="-s"` to the - command below to strip symbols. - - ```bash - bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu - ``` - -2. Install it on a device with: - - ```bash - adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/handtrackinggpu.apk - ``` - -### Option 2: Build with Bazel in Android Studio - -The MediaPipe project can be imported into Android Studio using the Bazel -plugins. This allows the MediaPipe examples to be built and modified in Android -Studio. - -To incorporate MediaPipe into an existing Android Studio project, see these -[instructions](./android_archive_library.md) that use Android Archive (AAR) and -Gradle. - -The steps below use Android Studio 3.5 to build and install a MediaPipe example -app: - -1. Install and launch Android Studio 3.5. - -2. Select `Configure` -> `SDK Manager` -> `SDK Platforms`. - - * Verify that Android SDK Platform API Level 28 or 29 is installed. - * Take note of the Android SDK Location, e.g., - `/usr/local/home/Android/Sdk`. - -3. Select `Configure` -> `SDK Manager` -> `SDK Tools`. - - * Verify that Android SDK Build-Tools 28 or 29 is installed. - * Verify that Android SDK Platform-Tools 28 or 29 is installed. - * Verify that Android SDK Tools 26.1.1 is installed. - * Verify that Android NDK 17c or above is installed. - * Take note of the Android NDK Location, e.g., - `/usr/local/home/Android/Sdk/ndk-bundle` or - `/usr/local/home/Android/Sdk/ndk/20.0.5594570`. - -4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point - to the installed SDK and NDK. - - ```bash - export ANDROID_HOME=/usr/local/home/Android/Sdk - - # If the NDK libraries are installed by a previous version of Android Studio, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk-bundle - # If the NDK libraries are installed by Android Studio 3.5, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk/ - ``` - -5. Select `Configure` -> `Plugins` to install `Bazel`. - -6. On Linux, select `File` -> `Settings` -> `Bazel settings`. On macos, select - `Android Studio` -> `Preferences` -> `Bazel settings`. Then, modify `Bazel - binary location` to be the same as the output of `$ which bazel`. - -7. Select `Import Bazel Project`. - - * Select `Workspace`: `/path/to/mediapipe` and select `Next`. - * Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select - `Next`. - * Modify `Project View` to be the following and select `Finish`. - - ``` - directories: - # read project settings, e.g., .bazelrc - . - -mediapipe/objc - -mediapipe/examples/ios - - targets: - //mediapipe/examples/android/...:all - //mediapipe/java/...:all - - android_sdk_platform: android-29 - - sync_flags: - --host_crosstool_top=@bazel_tools//tools/cpp:toolchain - ``` - -8. Select `Bazel` -> `Sync` -> `Sync project with Build files`. - - Note: Even after doing step 4, if you still see the error: `"no such package - '@androidsdk//': Either the path attribute of android_sdk_repository or the - ANDROID_HOME environment variable must be set."`, please modify the - [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) - file to point to your SDK and NDK library locations, as below: - - ``` - android_sdk_repository( - name = "androidsdk", - path = "/path/to/android/sdk" - ) - - android_ndk_repository( - name = "androidndk", - path = "/path/to/android/ndk" - ) - ``` - -9. Connect an Android device to the workstation. - -10. Select `Run...` -> `Edit Configurations...`. - - * Select `Templates` -> `Bazel Command`. - * Enter Target Expression: - `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu` - * Enter Bazel command: `mobile-install`. - * Enter Bazel flags: `-c opt --config=android_arm64`. - * Press the `[+]` button to add the new configuration. - * Select `Run` to run the example app on the connected Android device. - -## iOS - -### Prerequisite - -1. Install [Xcode](https://developer.apple.com/xcode/), then install the - Command Line Tools using: - - ```bash - xcode-select --install - ``` - -2. Install [Bazel](https://bazel.build/). - - We recommend using [Homebrew](https://brew.sh/) to get the latest version. - -3. Set Python 3.7 as the default Python version and install the Python "six" - library. This is needed for TensorFlow. - - ```bash - pip3 install --user six - ``` - -4. Clone the MediaPipe repository. - - ```bash - git clone https://github.com/google/mediapipe.git - ``` - -### Set up a bundle ID prefix - -All iOS apps must have a bundle ID, and you must have a provisioning profile -that lets you install an app with that ID onto your phone. To avoid clashes -between different MediaPipe users, you need to configure a unique prefix for the -bundle IDs of our iOS demo apps. - -If you have a custom provisioning profile, see -[Custom provisioning](#custom-provisioning) below. - -Otherwise, run this command to generate a unique prefix: - -```bash -python3 mediapipe/examples/ios/link_local_profiles.py -``` - -### Create an Xcode project - -This allows you to edit and debug one of the example apps in Xcode. It also -allows you to make use of automatic provisioning (see later section). - -1. We will use a tool called [Tulsi](https://tulsi.bazel.build/) for generating - Xcode projects from Bazel build configurations. - - ```bash - # cd out of the mediapipe directory, then: - git clone https://github.com/bazelbuild/tulsi.git - cd tulsi - # remove Xcode version from Tulsi's .bazelrc (see http://github.com/bazelbuild/tulsi#building-and-installing): - sed -i .orig '/xcode_version/d' .bazelrc - # build and run Tulsi: - sh build_and_run.sh - ``` - - This will install `Tulsi.app` inside the `Applications` directory in your - home directory. - -2. Open `mediapipe/Mediapipe.tulsiproj` using the Tulsi app. - - Tip: If Tulsi displays an error saying "Bazel could not be found", press the - "Bazel..." button in the Packages tab and select the `bazel` executable in - your homebrew `/bin/` directory. - -3. Select the MediaPipe config in the Configs tab, then press the Generate - button below. You will be asked for a location to save the Xcode project. - Once the project is generated, it will be opened in Xcode. - - If you get an error about bundle IDs, see the - [previous section](#set-up-a-bundle-id-prefix). - -### Set up provisioning - -To install applications on an iOS device, you need a provisioning profile. There -are two options: - -1. Automatic provisioning. This allows you to build and install an app to your - personal device. The provisining profile is managed by Xcode, and has to be - updated often (it is valid for about a week). - -2. Custom provisioning. This uses a provisioning profile associated with an - Apple developer account. These profiles have a longer validity period and - can target multiple devices, but you need a paid developer account with - Apple to obtain one. - -#### Automatic provisioning - -1. Create an Xcode project for MediaPipe, as discussed - [earlier](#create-an-xcode-project). - -2. In the project navigator in the left sidebar, select the "Mediapipe" - project. - -3. Select one of the application targets, e.g. HandTrackingGpuApp. - -4. Select the "Signing & Capabilities" tab. - -5. Check "Automatically manage signing", and confirm the dialog box. - -6. Select "_Your Name_ (Personal Team)" in the Team pop-up menu. - -7. This set-up needs to be done once for each application you want to install. - Repeat steps 3-6 as needed. - -This generates provisioning profiles for each app you have selected. Now we need -to tell Bazel to use them. We have provided a script to make this easier. - -1. In the terminal, to the `mediapipe` directory where you cloned the - repository. - -2. Run this command: - - ```bash - python3 mediapipe/examples/ios/link_local_profiles.py - ``` - -This will find and link the provisioning profile for all applications for which -you have enabled automatic provisioning in Xcode. - -Note: once a profile expires, Xcode will generate a new one; you must then run -this script again to link the updated profiles. - -#### Custom provisioning - -1. Obtain a provisioning profile from Apple. - -Tip: You can use this command to see the provisioning profiles you have -previously downloaded using Xcode: `open ~/Library/MobileDevice/"Provisioning -Profiles"`. If there are none, generate and download a profile on -[Apple's developer site](https://developer.apple.com/account/resources/). - -1. Symlink or copy your provisioning profile to - `mediapipe/mediapipe/provisioning_profile.mobileprovision`. - - ```bash - cd mediapipe - ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision - ``` - -Note: if you had previously set up automatic provisioning, you should remove the -`provisioning_profile.mobileprovision` symlink in each example's directory, -since it will take precedence over the common one. You can also overwrite it -with you own profile if you need a different profile for different apps. - -1. Open `mediapipe/examples/ios/bundle_id.bzl`, and change the - `BUNDLE_ID_PREFIX` to a prefix associated with your provisioning profile. - -### Build and run an app using Xcode - -1. Create the Xcode project, and make sure you have set up either automatic or - custom provisioning. - -2. You can now select any of the MediaPipe demos in the target menu, and build - and run them as normal. - -Note: When you ask Xcode to run an app, by default it will use the Debug -configuration. Some of our demos are computationally heavy; you may want to use -the Release configuration for better performance. - -Tip: To switch build configuration in Xcode, click on the target menu, choose -"Edit Scheme...", select the Run action, and switch the Build Configuration from -Debug to Release. Note that this is set independently for each target. - -Tip: On the device, in Settings > General > Device Management, make sure the -developer (yourself) is trusted. - -### Build an app using the command line - -1. Make sure you have set up either automatic or custom provisioning. - -2. Using [MediaPipe Hands](../solutions/hands.md) for example, run: - - ```bash - bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp - ``` - - You may see a permission request from `codesign` in order to sign the app. - - Tip: If you are using custom provisioning, you can run this - [script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh) - to build all MediaPipe iOS example apps. - -3. In Xcode, open the `Devices and Simulators` window (command-shift-2). - -4. Make sure your device is connected. You will see a list of installed apps. - Press the "+" button under the list, and select the `.ipa` file built by - Bazel. - -5. You can now run the app on your device. - -Tip: On the device, in Settings > General > Device Management, make sure the -developer (yourself) is trusted. - -## Desktop - -### Option 1: Running on CPU - -1. To build, for example, [MediaPipe Hands](../solutions/hands.md), run: - - ```bash - bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu - ``` - -2. To run the application: - - ```bash - GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt - ``` - - This will open up your webcam as long as it is connected and on. Any errors - is likely due to your webcam being not accessible. - -### Option 2: Running on GPU - -Note: This currently works only on Linux, and please first follow -[OpenGL ES Setup on Linux Desktop](./gpu_support.md#opengl-es-setup-on-linux-desktop). - -1. To build, for example, [MediaPipe Hands](../solutions/hands.md), run: - - ```bash - bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \ - mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu - ``` - -2. To run the application: - - ```bash - GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt - ``` - - This will open up your webcam as long as it is connected and on. Any errors - is likely due to your webcam being not accessible, or GPU drivers not setup - properly. - -## Python - -MediaPipe Python package is available on -[PyPI](https://pypi.org/project/mediapipe/), and can be installed simply by `pip -install mediapipe` on Linux and macOS, as described below in -[Run in python interpreter](#run-in-python-interpreter) and in this -[colab](https://mediapipe.page.link/mp-py-colab). - -### Run in Python interpreter - -Using [MediaPipe Pose](../solutions/pose.md) as an example: - -```bash -# Activate a Python virtual environment. -$ python3 -m venv mp_env && source mp_env/bin/activate - -# Install MediaPipe Python package -(mp_env)$ pip install mediapipe - -# Run in Python interpreter -(mp_env)$ python3 ->>> import mediapipe as mp ->>> pose_tracker = mp.examples.UpperBodyPoseTracker() - -# For image input ->>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file') ->>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file') - -# For live camera input -# (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.) ->>> pose_tracker.run_live() - -# Close the tracker. ->>> pose_tracker.close() -``` - -Tip: Use command `deactivate` to exit the Python virtual environment. - -### Building Python package from source - -Follow these steps only if you have local changes and need to build the Python -package from source. Otherwise, we strongly encourage our users to simply run -`pip install mediapipe`, more convenient and much faster. - -1. Make sure that Bazel and OpenCV are correctly installed and configured for - MediaPipe. Please see [Installation](./install.md) for how to setup Bazel - and OpenCV for MediaPipe on Linux and macOS. - -2. Install the following dependencies. - - ```bash - # Debian or Ubuntu - $ sudo apt install python3-dev - $ sudo apt install python3-venv - $ sudo apt install -y protobuf-compiler - ``` - - ```bash - # macOS - $ brew install protobuf - ``` - -3. Activate a Python virtual environment. - - ```bash - $ python3 -m venv mp_env && source mp_env/bin/activate - ``` - -4. In the virtual environment, go to the MediaPipe repo directory. - -5. Install the required Python packages. - - ```bash - (mp_env)mediapipe$ pip3 install -r requirements.txt - ``` - -6. Generate and install MediaPipe package. - - ```bash - (mp_env)mediapipe$ python3 setup.py gen_protos - (mp_env)mediapipe$ python3 setup.py install --link-opencv - ``` +Please see these [instructions](./cpp.md). diff --git a/docs/getting_started/cpp.md b/docs/getting_started/cpp.md new file mode 100644 index 000000000..8fc091fea --- /dev/null +++ b/docs/getting_started/cpp.md @@ -0,0 +1,62 @@ +--- +layout: default +title: MediaPipe in C++ +parent: Getting Started +has_children: true +has_toc: false +nav_order: 5 +--- + +# MediaPipe in C++ +{: .no_toc } + +1. TOC +{:toc} +--- + +Please follow instructions below to build C++ command-line example apps in the +supported MediaPipe [solutions](../solutions/solutions.md). To learn more about +these example apps, start from [Hello World! in C++](./hello_world_cpp.md). + +## Building C++ command-line example apps + +### Option 1: Running on CPU + +1. To build, for example, [MediaPipe Hands](../solutions/hands.md), run: + + ```bash + bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu + ``` + +2. To run the application: + + ```bash + GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt + ``` + + This will open up your webcam as long as it is connected and on. Any errors + is likely due to your webcam being not accessible. + +### Option 2: Running on GPU + +Note: This currently works only on Linux, and please first follow +[OpenGL ES Setup on Linux Desktop](./gpu_support.md#opengl-es-setup-on-linux-desktop). + +1. To build, for example, [MediaPipe Hands](../solutions/hands.md), run: + + ```bash + bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \ + mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu + ``` + +2. To run the application: + + ```bash + GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt + ``` + + This will open up your webcam as long as it is connected and on. Any errors + is likely due to your webcam being not accessible, or GPU drivers not setup + properly. diff --git a/docs/getting_started/gpu_support.md b/docs/getting_started/gpu_support.md index 2aae63a2e..38bab9be3 100644 --- a/docs/getting_started/gpu_support.md +++ b/docs/getting_started/gpu_support.md @@ -2,7 +2,7 @@ layout: default title: GPU Support parent: Getting Started -nav_order: 6 +nav_order: 7 --- # GPU Support diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index e4e8286f7..9f277f799 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -1,8 +1,9 @@ --- layout: default title: Hello World! on Android -parent: Getting Started -nav_order: 3 +parent: MediaPipe on Android +grand_parent: Getting Started +nav_order: 1 --- # Hello World! on Android @@ -58,7 +59,7 @@ node: { output_stream: "luma_video" } -# Applies the Sobel filter to luminance images sotred in RGB format. +# Applies the Sobel filter to luminance images stored in RGB format. node: { calculator: "SobelEdgesCalculator" input_stream: "luma_video" @@ -446,8 +447,8 @@ visible so that we can start seeing frames from the `previewFrameTexture`. However, before starting the camera, we need to decide which camera we want to use. [`CameraXPreviewHelper`] inherits from [`CameraHelper`] which provides two options, `FRONT` and `BACK`. We can pass in the decision from the `BUILD` file -as metadata such that no code change is required to build a another version of -the app using a different camera. +as metadata such that no code change is required to build another version of the +app using a different camera. Assuming we want to use `BACK` camera to perform edge detection on a live scene that we view from the camera, add the metadata into `AndroidManifest.xml`: @@ -496,7 +497,7 @@ CameraHelper.CameraFacing cameraFacing = applicationInfo.metaData.getBoolean("cameraFacingFront", false) ? CameraHelper.CameraFacing.FRONT : CameraHelper.CameraFacing.BACK; -cameraHelper.startCamera(this, cameraFacing, /*surfaceTexture=*/ null); +cameraHelper.startCamera(this, cameraFacing, /*unusedSurfaceTexture=*/ null); ``` At this point, the application should build successfully. However, when you run diff --git a/docs/getting_started/hello_world_desktop.md b/docs/getting_started/hello_world_cpp.md similarity index 96% rename from docs/getting_started/hello_world_desktop.md rename to docs/getting_started/hello_world_cpp.md index 61e9b6471..e3d34d9b4 100644 --- a/docs/getting_started/hello_world_desktop.md +++ b/docs/getting_started/hello_world_cpp.md @@ -1,11 +1,12 @@ --- layout: default -title: Hello World! on Desktop (C++) -parent: Getting Started -nav_order: 5 +title: Hello World! in C++ +parent: MediaPipe in C++ +grand_parent: Getting Started +nav_order: 1 --- -# Hello World! on Desktop (C++) +# Hello World! in C++ {: .no_toc } 1. TOC @@ -43,7 +44,7 @@ nav_order: 5 `PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto. ```C++ - ::mediapipe::Status PrintHelloWorld() { + absl::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie(R"( input_stream: "in" diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 19de67d01..06d79c67d 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -1,8 +1,9 @@ --- layout: default title: Hello World! on iOS -parent: Getting Started -nav_order: 4 +parent: MediaPipe on iOS +grand_parent: Getting Started +nav_order: 1 --- # Hello World! on iOS @@ -193,8 +194,7 @@ bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/helloworld:HelloWor Then, go back to XCode, open Window > Devices and Simulators, select your device, and add the `.ipa` file generated by the command above to your device. -Here is the document on [setting up and compiling](./building_examples.md#ios) -iOS MediaPipe apps. +Here is the document on [setting up and compiling](./ios.md) iOS MediaPipe apps. Open the application on your device. Since it is empty, it should display a blank white screen. @@ -492,6 +492,9 @@ in our app: if (![self.mediapipeGraph startWithError:&error]) { NSLog(@"Failed to start graph: %@", error); } + else if (![self.mediapipeGraph waitUntilIdleWithError:&error]) { + NSLog(@"Failed to complete graph initial run: %@", error); + } dispatch_async(_videoQueue, ^{ [_cameraSource start]; @@ -500,8 +503,9 @@ in our app: }]; ``` -Note: It is important to start the graph before starting the camera, so that the -graph is ready to process frames as soon as the camera starts sending them. +Note: It is important to start the graph before starting the camera and wait +until completion, so that the graph is ready to process frames as soon as the +camera starts sending them. Earlier, when we received frames from the camera in the `processVideoFrame` function, we displayed them in the `_liveView` using the `_renderer`. Now, we diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index b9be6e498..7a02def53 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -2,7 +2,7 @@ layout: default title: Installation parent: Getting Started -nav_order: 1 +nav_order: 6 --- # Installation @@ -12,7 +12,7 @@ nav_order: 1 {:toc} --- -Note: To interoperate with OpenCV, OpenCV 3.x and above are preferred. OpenCV +Note: To interoperate with OpenCV, OpenCV 3.x to 4.1 are preferred. OpenCV 2.x currently works but interoperability support may be deprecated in the future. @@ -23,39 +23,38 @@ Note: To make Mediapipe work with TensorFlow, please set Python 3.7 as the default Python version and install the Python "six" library by running `pip3 install --user six`. -Note: To build and run Android example apps, see these -[instructions](./building_examples.md#android). To build and run iOS example -apps, see these [instructions](./building_examples.md#ios). - ## Installing on Debian and Ubuntu -1. Checkout MediaPipe repository. +1. Install Bazel. + + Follow the official + [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) + to install Bazel 3.4 or higher. + + For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs + to be built from source: ```bash + # For Bazel 3.4.1 + mkdir $HOME/bazel-3.4.1 + cd $HOME/bazel-3.4.1 + wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip + sudo apt-get install build-essential openjdk-8-jdk python zip unzip + unzip bazel-3.4.1-dist.zip + env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh + sudo cp output/bazel /usr/local/bin/ + ``` + +2. Checkout MediaPipe repository. + + ```bash + $ cd $HOME $ git clone https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory $ cd mediapipe ``` -2. Install Bazel. - - Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) - to install Bazel 2.0 or higher. - - For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to - be built from source. - - ```bash - # For Bazel 3.0.0 - wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip - sudo apt-get install build-essential openjdk-8-jdk python zip unzip - unzip bazel-3.0.0-dist.zip - env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh - sudo cp output/bazel /usr/local/bin/ - ``` - 3. Install OpenCV and FFmpeg. Option 1. Use package manager tool to install the pre-compiled OpenCV @@ -174,7 +173,7 @@ apps, see these [instructions](./building_examples.md#ios). # when building GPU examples. ``` -5. Run the [Hello World desktop example](./hello_world_desktop.md). +5. Run the [Hello World! in C++ example](./hello_world_cpp.md). ```bash $ export GLOG_logtostderr=1 @@ -208,7 +207,13 @@ build issues. **Disclaimer**: Running MediaPipe on CentOS is experimental. -1. Checkout MediaPipe repository. +1. Install Bazel. + + Follow the official + [Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html) + to install Bazel 3.4 or higher. + +2. Checkout MediaPipe repository. ```bash $ git clone https://github.com/google/mediapipe.git @@ -217,12 +222,6 @@ build issues. $ cd mediapipe ``` -2. Install Bazel. - - Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html) - to install Bazel 2.0 or higher. - 3. Install OpenCV. Option 1. Use package manager tool to install the pre-compiled version. @@ -304,7 +303,7 @@ build issues. ) ``` -4. Run the [Hello World desktop example](./hello_world_desktop.md). +4. Run the [Hello World! in C++ example](./hello_world_cpp.md). ```bash $ export GLOG_logtostderr=1 @@ -337,7 +336,13 @@ build issues. * Install [Xcode](https://developer.apple.com/xcode/) and its Command Line Tools by `xcode-select --install`. -2. Checkout MediaPipe repository. +2. Install Bazel. + + Follow the official + [Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) + to install Bazel 3.4 or higher. + +3. Checkout MediaPipe repository. ```bash $ git clone https://github.com/google/mediapipe.git @@ -345,19 +350,6 @@ build issues. $ cd mediapipe ``` -3. Install Bazel. - - Option 1. Use package manager tool to install Bazel - - ```bash - $ brew install bazel - # Run 'bazel version' to check version of bazel - ``` - - Option 2. Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) - to install Bazel 2.0 or higher. - 4. Install OpenCV and FFmpeg. Option 1. Use HomeBrew package manager tool to install the pre-compiled @@ -427,7 +419,6 @@ build issues. linkstatic = 1, visibility = ["//visibility:public"], ) - ``` 5. Make sure that Python 3 and the Python "six" library are installed. @@ -440,7 +431,7 @@ build issues. $ pip3 install --user six ``` -6. Run the [Hello World desktop example](./hello_world_desktop.md). +6. Run the [Hello World! in C++ example](./hello_world_cpp.md). ```bash $ export GLOG_logtostderr=1 @@ -506,7 +497,7 @@ next section. Follow the official [Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) - to install Bazel 2.0 or higher. + to install Bazel 3.4 or higher. 6. Set Bazel variables. @@ -541,7 +532,7 @@ next section. ) ``` -9. Run the [Hello World desktop example](./hello_world_desktop.md). +9. Run the [Hello World! in C++ example](./hello_world_cpp.md). Note: For building MediaPipe on Windows, please add `--action_env PYTHON_BIN_PATH="C://path//to//python.exe"` to the build command. @@ -567,7 +558,6 @@ next section. # I20200514 20:43:12.279618 1200 hello_world.cc:56] Hello World! # I20200514 20:43:12.279618 1200 hello_world.cc:56] Hello World! # I20200514 20:43:12.280613 1200 hello_world.cc:56] Hello World! - ``` If you run into a build error, please read @@ -607,14 +597,14 @@ cameras. Alternatively, you use a video file as input. ```bash username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ - https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \ - sudo mkdir -p /usr/local/bazel/3.0.0 && \ - chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \ - sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \ - source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash + https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \ + sudo mkdir -p /usr/local/bazel/3.4.1 && \ + chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \ + sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \ + source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash - username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \ - alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel' + username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \ + alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel' ``` 6. Checkout MediaPipe repository. @@ -675,7 +665,7 @@ cameras. Alternatively, you use a video file as input. ) ``` -8. Run the [Hello World desktop example](./hello_world_desktop.md). +8. Run the [Hello World! in C++ example](./hello_world_cpp.md). ```bash username@DESKTOP-TMVLBJ1:~/mediapipe$ export GLOG_logtostderr=1 @@ -731,7 +721,7 @@ This will use a Docker image that will isolate mediapipe's installation from the # Successfully tagged mediapipe:latest ``` -3. Run the [Hello World desktop example](./hello_world_desktop.md). +3. Run the [Hello World! in C++ example](./hello_world_cpp.md). ```bash $ docker run -it --name mediapipe mediapipe:latest diff --git a/docs/getting_started/ios.md b/docs/getting_started/ios.md new file mode 100644 index 000000000..cd11828af --- /dev/null +++ b/docs/getting_started/ios.md @@ -0,0 +1,222 @@ +--- +layout: default +title: MediaPipe on iOS +parent: Getting Started +has_children: true +has_toc: false +nav_order: 2 +--- + +# MediaPipe on iOS +{: .no_toc } + +1. TOC +{:toc} +--- + +Please follow instructions below to build iOS example apps in the supported +MediaPipe [solutions](../solutions/solutions.md). To learn more about these +example apps, start from, start from +[Hello World! on iOS](./hello_world_ios.md). + +## Building iOS example apps + +### Prerequisite + +1. Install MediaPipe following these [instructions](./install.md). + +2. Install [Xcode](https://developer.apple.com/xcode/), then install the + Command Line Tools using: + + ```bash + xcode-select --install + ``` + +3. Install [Bazel](https://bazel.build/). + + We recommend using [Homebrew](https://brew.sh/) to get the latest version. + +4. Set Python 3.7 as the default Python version and install the Python "six" + library. This is needed for TensorFlow. + + ```bash + pip3 install --user six + ``` + +5. Clone the MediaPipe repository. + + ```bash + git clone https://github.com/google/mediapipe.git + ``` + +### Set up a bundle ID prefix + +All iOS apps must have a bundle ID, and you must have a provisioning profile +that lets you install an app with that ID onto your phone. To avoid clashes +between different MediaPipe users, you need to configure a unique prefix for the +bundle IDs of our iOS demo apps. + +If you have a custom provisioning profile, see +[Custom provisioning](#custom-provisioning) below. + +Otherwise, run this command to generate a unique prefix: + +```bash +python3 mediapipe/examples/ios/link_local_profiles.py +``` + +### Create an Xcode project + +This allows you to edit and debug one of the example apps in Xcode. It also +allows you to make use of automatic provisioning (see later section). + +1. We will use a tool called [Tulsi](https://tulsi.bazel.build/) for generating + Xcode projects from Bazel build configurations. + + ```bash + # cd out of the mediapipe directory, then: + git clone https://github.com/bazelbuild/tulsi.git + cd tulsi + # remove Xcode version from Tulsi's .bazelrc (see http://github.com/bazelbuild/tulsi#building-and-installing): + sed -i .orig '/xcode_version/d' .bazelrc + # build and run Tulsi: + sh build_and_run.sh + ``` + + This will install `Tulsi.app` inside the `Applications` directory in your + home directory. + +2. Open `mediapipe/Mediapipe.tulsiproj` using the Tulsi app. + + Tip: If Tulsi displays an error saying "Bazel could not be found", press the + "Bazel..." button in the Packages tab and select the `bazel` executable in + your homebrew `/bin/` directory. + +3. Select the MediaPipe config in the Configs tab, then press the Generate + button below. You will be asked for a location to save the Xcode project. + Once the project is generated, it will be opened in Xcode. + + If you get an error about bundle IDs, see the + [previous section](#set-up-a-bundle-id-prefix). + +### Set up provisioning + +To install applications on an iOS device, you need a provisioning profile. There +are two options: + +1. Automatic provisioning. This allows you to build and install an app to your + personal device. The provisining profile is managed by Xcode, and has to be + updated often (it is valid for about a week). + +2. Custom provisioning. This uses a provisioning profile associated with an + Apple developer account. These profiles have a longer validity period and + can target multiple devices, but you need a paid developer account with + Apple to obtain one. + +#### Automatic provisioning + +1. Create an Xcode project for MediaPipe, as discussed + [earlier](#create-an-xcode-project). + +2. In the project navigator in the left sidebar, select the "Mediapipe" + project. + +3. Select one of the application targets, e.g. HandTrackingGpuApp. + +4. Select the "Signing & Capabilities" tab. + +5. Check "Automatically manage signing", and confirm the dialog box. + +6. Select "_Your Name_ (Personal Team)" in the Team pop-up menu. + +7. This set-up needs to be done once for each application you want to install. + Repeat steps 3-6 as needed. + +This generates provisioning profiles for each app you have selected. Now we need +to tell Bazel to use them. We have provided a script to make this easier. + +1. In the terminal, to the `mediapipe` directory where you cloned the + repository. + +2. Run this command: + + ```bash + python3 mediapipe/examples/ios/link_local_profiles.py + ``` + +This will find and link the provisioning profile for all applications for which +you have enabled automatic provisioning in Xcode. + +Note: once a profile expires, Xcode will generate a new one; you must then run +this script again to link the updated profiles. + +#### Custom provisioning + +1. Obtain a provisioning profile from Apple. + +Tip: You can use this command to see the provisioning profiles you have +previously downloaded using Xcode: `open ~/Library/MobileDevice/"Provisioning +Profiles"`. If there are none, generate and download a profile on +[Apple's developer site](https://developer.apple.com/account/resources/). + +1. Symlink or copy your provisioning profile to + `mediapipe/mediapipe/provisioning_profile.mobileprovision`. + + ```bash + cd mediapipe + ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision + ``` + +Note: if you had previously set up automatic provisioning, you should remove the +`provisioning_profile.mobileprovision` symlink in each example's directory, +since it will take precedence over the common one. You can also overwrite it +with you own profile if you need a different profile for different apps. + +1. Open `mediapipe/examples/ios/bundle_id.bzl`, and change the + `BUNDLE_ID_PREFIX` to a prefix associated with your provisioning profile. + +### Build and run an app using Xcode + +1. Create the Xcode project, and make sure you have set up either automatic or + custom provisioning. + +2. You can now select any of the MediaPipe demos in the target menu, and build + and run them as normal. + +Note: When you ask Xcode to run an app, by default it will use the Debug +configuration. Some of our demos are computationally heavy; you may want to use +the Release configuration for better performance. + +Tip: To switch build configuration in Xcode, click on the target menu, choose +"Edit Scheme...", select the Run action, and switch the Build Configuration from +Debug to Release. Note that this is set independently for each target. + +Tip: On the device, in Settings > General > Device Management, make sure the +developer (yourself) is trusted. + +### Build an app using the command line + +1. Make sure you have set up either automatic or custom provisioning. + +2. Using [MediaPipe Hands](../solutions/hands.md) for example, run: + + ```bash + bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp + ``` + + You may see a permission request from `codesign` in order to sign the app. + + Tip: If you are using custom provisioning, you can run this + [script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh) + to build all MediaPipe iOS example apps. + +3. In Xcode, open the `Devices and Simulators` window (command-shift-2). + +4. Make sure your device is connected. You will see a list of installed apps. + Press the "+" button under the list, and select the `.ipa` file built by + Bazel. + +5. You can now run the app on your device. + +Tip: On the device, in Settings > General > Device Management, make sure the +developer (yourself) is trusted. diff --git a/docs/getting_started/javascript.md b/docs/getting_started/javascript.md new file mode 100644 index 000000000..c6df75bd8 --- /dev/null +++ b/docs/getting_started/javascript.md @@ -0,0 +1,94 @@ +--- +layout: default +title: MediaPipe in JavaScript +parent: Getting Started +nav_order: 4 +--- + +# MediaPipe in JavaScript +{: .no_toc } + +1. TOC +{:toc} +--- + +## Ready-to-use JavaScript Solutions + +MediaPipe currently offers the following solutions: + +Solution | NPM Package | Example +----------------- | ----------------------------- | ------- +[Face Mesh][F-pg] | [@mediapipe/face_mesh][F-npm] | [mediapipe.dev/demo/face_mesh][F-demo] +[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] +[Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] +[Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] +[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] + +Click on a solution link above for more information, including API and code +snippets. + +The quickest way to get acclimated is to look at the examples above. Each demo +has a link to a [CodePen][codepen] so that you can edit the code and try it +yourself. We have included a number of utility packages to help you get started: + +* [@mediapipe/drawing_utils][draw-npm] - Utilities to draw landmarks and + connectors. +* [@mediapipe/camera_utils][cam-npm] - Utilities to operate the camera. +* [@mediapipe/control_utils][ctrl-npm] - Utilities to show sliders and FPS + widgets. + +Note: See these demos and more at [MediaPipe on CodePen][codepen] + +All of these solutions are staged in [NPM][npm]. You can install any package +locally with `npm install`. Example: + +``` +npm install @mediapipe/holistic. +``` + +If you would rather not stage these locally, you can rely on a CDN (e.g., +[jsDelivr](https://www.jsdelivr.com/)). This will allow you to add scripts +directly to your HTML: + +``` + + + + +``` + +Note: You can specify version numbers to both NPM and jsdelivr. They are +structured as `..`. To prevent breaking changes from +affecting your work, restrict your request to a `` number. e.g., +`@mediapipe/holistic@0.1`. + +[Ho-pg]: ../solutions/holistic#javascript-solution-api +[F-pg]: ../solutions/face_mesh#javascript-solution-api +[Fd-pg]: ../solutions/face_detection#javascript-solution-api +[H-pg]: ../solutions/hands#javascript-solution-api +[P-pg]: ../solutions/pose#javascript-solution-api +[Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic +[F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh +[Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection +[H-npm]: https://www.npmjs.com/package/@mediapipe/hands +[P-npm]: https://www.npmjs.com/package/@mediapipe/pose +[draw-npm]: https://www.npmjs.com/package/@mediapipe/pose +[cam-npm]: https://www.npmjs.com/package/@mediapipe/pose +[ctrl-npm]: https://www.npmjs.com/package/@mediapipe/pose +[Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic +[F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh +[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection +[H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands +[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose +[Ho-pen]: https://code.mediapipe.dev/codepen/holistic +[F-pen]: https://code.mediapipe.dev/codepen/face_mesh +[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection +[H-pen]: https://code.mediapipe.dev/codepen/hands +[P-pen]: https://code.mediapipe.dev/codepen/pose +[Ho-demo]: https://mediapipe.dev/demo/holistic +[F-demo]: https://mediapipe.dev/demo/face_mesh +[Fd-demo]: https://mediapipe.dev/demo/face_detection +[H-demo]: https://mediapipe.dev/demo/hands +[P-demo]: https://mediapipe.dev/demo/pose +[npm]: https://www.npmjs.com/package/@mediapipe +[codepen]: https://code.mediapipe.dev/codepen diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md new file mode 100644 index 000000000..5d4bc2fb9 --- /dev/null +++ b/docs/getting_started/python.md @@ -0,0 +1,144 @@ +--- +layout: default +title: MediaPipe in Python +parent: Getting Started +has_children: true +has_toc: false +nav_order: 3 +--- + +# MediaPipe in Python +{: .no_toc } + +1. TOC +{:toc} +--- + +## Ready-to-use Python Solutions + +MediaPipe offers ready-to-use yet customizable Python solutions as a prebuilt +Python package. MediaPipe Python package is available on +[PyPI](https://pypi.org/project/mediapipe/) for Linux, macOS and Windows. + +You can, for instance, activate a Python virtual environment: + +```bash +$ python3 -m venv mp_env && source mp_env/bin/activate +``` + +Install MediaPipe Python package and start Python intepreter: + +```bash +(mp_env)$ pip install mediapipe +(mp_env)$ python3 +``` + +In Python interpreter, import the package and start using one of the solutions: + +```python +import mediapipe as mp +mp_face_mesh = mp.solutions.face_mesh +``` + +Tip: Use command `deactivate` to later exit the Python virtual environment. + +To learn more about configuration options and usage examples, please find +details in each solution via the links below: + +* [MediaPipe Face Detection](../solutions/face_detection#python-solution-api) +* [MediaPipe Face Mesh](../solutions/face_mesh#python-solution-api) +* [MediaPipe Hands](../solutions/hands#python-solution-api) +* [MediaPipe Holistic](../solutions/holistic#python-solution-api) +* [MediaPipe Objectron](../solutions/objectron#python-solution-api) +* [MediaPipe Pose](../solutions/pose#python-solution-api) + +## MediaPipe on Google Colab + +* [MediaPipe Face Detection Colab](https://mediapipe.page.link/face_detection_py_colab) +* [MediaPipe Face Mesh Colab](https://mediapipe.page.link/face_mesh_py_colab) +* [MediaPipe Hands Colab](https://mediapipe.page.link/hands_py_colab) +* [MediaPipe Holistic Colab](https://mediapipe.page.link/holistic_py_colab) +* [MediaPipe Objectron Colab](https://mediapipe.page.link/objectron_py_colab) +* [MediaPipe Pose Colab](https://mediapipe.page.link/pose_py_colab) +* [MediaPipe Pose Classification Colab (Basic)](https://mediapipe.page.link/pose_classification_basic) +* [MediaPipe Pose Classification Colab (Extended)](https://mediapipe.page.link/pose_classification_extended) + +## MediaPipe Python Framework + +The ready-to-use solutions are built upon the MediaPipe Python framework, which +can be used by advanced users to run their own MediaPipe graphs in Python. +Please see [here](./python_framework.md) for more info. + +## Building MediaPipe Python Package + +Follow the steps below only if you have local changes and need to build the +Python package from source. Otherwise, we strongly encourage our users to simply +run `pip install mediapipe` to use the ready-to-use solutions, more convenient +and much faster. + +MediaPipe PyPI currently doesn't provide aarch64 Python wheel +files. For building and using MediaPipe Python on aarch64 Linux systems such as +Nvidia Jetson and Raspberry Pi, please read +[here](https://github.com/jiuqiant/mediapipe-python-aarch64). + +1. Make sure that Bazel and OpenCV are correctly installed and configured for + MediaPipe. Please see [Installation](./install.md) for how to setup Bazel + and OpenCV for MediaPipe on Linux and macOS. + +2. Install the following dependencies. + + Debian or Ubuntu: + + ```bash + $ sudo apt install python3-dev + $ sudo apt install python3-venv + $ sudo apt install -y protobuf-compiler + + # If you need to build opencv from source. + $ sudo apt install cmake + ``` + + macOS: + + ```bash + $ brew install protobuf + + # If you need to build opencv from source. + $ brew install cmake + ``` + + Windows: + + Download the latest protoc win64 zip from + [the Protobuf GitHub repo](https://github.com/protocolbuffers/protobuf/releases), + unzip the file, and copy the protoc.exe executable to a preferred + location. Please ensure that location is added into the Path environment + variable. + +3. Activate a Python virtual environment. + + ```bash + $ python3 -m venv mp_env && source mp_env/bin/activate + ``` + +4. In the virtual environment, go to the MediaPipe repo directory. + +5. Install the required Python packages. + + ```bash + (mp_env)mediapipe$ pip3 install -r requirements.txt + ``` + +6. Generate and install MediaPipe package. + + ```bash + (mp_env)mediapipe$ python3 setup.py gen_protos + (mp_env)mediapipe$ python3 setup.py install --link-opencv + ``` + + or + + ```bash + (mp_env)mediapipe$ python3 setup.py gen_protos + (mp_env)mediapipe$ python3 setup.py bdist_wheel + ``` diff --git a/docs/getting_started/python_framework.md b/docs/getting_started/python_framework.md new file mode 100644 index 000000000..ece14bc91 --- /dev/null +++ b/docs/getting_started/python_framework.md @@ -0,0 +1,268 @@ +--- +layout: default +title: MediaPipe Python Framework +parent: MediaPipe in Python +grand_parent: Getting Started +nav_order: 1 +--- + +# MediaPipe Python Framework +{: .no_toc } + +1. TOC +{:toc} +--- + +The MediaPipe Python framework grants direct access to the core components of +the MediaPipe C++ framework such as Timestamp, Packet, and CalculatorGraph, +whereas the +[ready-to-use Python solutions](./python.md#ready-to-use-python-solutions) hide +the technical details of the framework and simply return the readable model +inference results back to the callers. + +MediaPipe framework sits on top of +[the pybind11 library](https://pybind11.readthedocs.io/en/stable/index.html). +The C++ core framework is exposed in Python via a C++/Python language binding. +The content below assumes that the reader already has a basic understanding of +the MediaPipe C++ framework. Otherwise, you can find useful information in +[Framework Concepts](../framework_concepts/framework_concepts.md). + +### Packet + +The packet is the basic data flow unit in MediaPipe. A packet consists of a +numeric timestamp and a shared pointer to an immutable payload. In Python, a +MediaPipe packet can be created by calling one of the packet creator methods in +the +[`mp.packet_creator`](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/packet_creator.cc) +module. Correspondingly, the packet payload can be retrieved by using one of the +packet getter methods in the +[`mp.packet_getter`](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/packet_getter.cc) +module. Note that the packet payload becomes **immutable** after packet +creation. Thus, the modification of the retrieved packet content doesn't affect +the actual payload in the packet. MediaPipe framework Python API supports the +most commonly used data types of MediaPipe (e.g., ImageFrame, Matrix, Protocol +Buffers, and the primitive data types) in the core binding. The comprehensive +table below shows the type mappings between the Python and the C++ data type +along with the packet creator and the content getter method for each data type +supported by the MediaPipe Python framework API. + +Python Data Type | C++ Data Type | Packet Creator | Content Getter +------------------------------------ | ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------- +bool | bool | create_bool(True) | get_bool(packet) +int or np.intc | int_t | create_int(1) | get_int(packet) +int or np.int8 | int8_t | create_int8(2**7-1) | get_int(packet) +int or np.int16 | int16_t | create_int16(2**15-1) | get_int(packet) +int or np.int32 | int32_t | create_int32(2**31-1) | get_int(packet) +int or np.int64 | int64_t | create_int64(2**63-1) | get_int(packet) +int or np.uint8 | uint8_t | create_uint8(2**8-1) | get_uint(packet) +int or np.uint16 | uint16_t | create_uint16(2**16-1) | get_uint(packet) +int or np.uint32 | uint32_t | create_uint32(2**32-1) | get_uint(packet) +int or np.uint64 | uint64_t | create_uint64(2**64-1) | get_uint(packet) +float or np.float32 | float | create_float(1.1) | get_float(packet) +float or np.double | double | create_double(1.1) | get_float(packet) +str (UTF-8) | std::string | create_string('abc') | get_str(packet) +bytes | std::string | create_string(b'\xd0\xd0\xd0') | get_bytes(packet) +mp.Packet | mp::Packet | create_packet(p) | get_packet(packet) +List\[bool\] | std::vector\ | create_bool_vector(\[True, False\]) | get_bool_list(packet) +List\[int\] or List\[np.intc\] | int\[\] | create_int_array(\[1, 2, 3\]) | get_int_list(packet, size=10) +List\[int\] or List\[np.intc\] | std::vector\ | create_int_vector(\[1, 2, 3\]) | get_int_list(packet) +List\[float\] or List\[np.float\] | float\[\] | create_float_arrary(\[0.1, 0.2\]) | get_float_list(packet, size=10) +List\[float\] or List\[np.float\] | std::vector\ | create_float_vector(\[0.1, 0.2\]) | get_float_list(packet, size=10) +List\[str\] | std::vector\ | create_string_vector(\['a'\]) | get_str_list(packet) +List\[mp.Packet\] | std::vector\ | create_packet_vector(
        \[packet1, packet2\]) | get_packet_list(p) +Mapping\[str, Packet\] | std::map | create_string_to_packet_map(
        {'a': packet1, 'b': packet2}) | get_str_to_packet_dict(packet) +np.ndarray
(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(
        format=ImageFormat.SRGB,
        data=mat) | get_image_frame(packet) +np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet) +Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet) +List\[Proto\] | std::vector\ | create_proto_vector(proto_list) | get_proto_list(packet) + +It's not uncommon that users create custom C++ classes and and send those into +the graphs and calculators. To allow the custom classes to be used in Python +with MediaPipe, you may extend the Packet API for a new data type in the +following steps: + +1. Write the pybind11 + [class binding code](https://pybind11.readthedocs.io/en/stable/advanced/classes.html) + or + [a custom type caster](https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html?highlight=custom%20type%20caster) + for the custom type in a cc file. + + ```c++ + #include "path/to/my_type/header/file.h" + #include "pybind11/pybind11.h" + + namespace py = pybind11; + + PYBIND11_MODULE(my_type_binding, m) { + // Write binding code or a custom type caster for MyType. + py::class_(m, "MyType") + .def(py::init<>()) + .def(...); + } + ``` + +2. Create a new packet creator and getter method of the custom type in a + separate cc file. + + ```c++ + #include "path/to/my_type/header/file.h" + #include "mediapipe/framework/packet.h" + #include "pybind11/pybind11.h" + + namespace mediapipe { + namespace py = pybind11; + + PYBIND11_MODULE(my_packet_methods, m) { + m.def( + "create_my_type", + [](const MyType& my_type) { return MakePacket(my_type); }); + + m.def( + "get_my_type", + [](const Packet& packet) { + if(!packet.ValidateAsType().ok()) { + PyErr_SetString(PyExc_ValueError, "Packet data type mismatch."); + return py::error_already_set(); + } + return packet.Get(); + }); + } // namespace mediapipe + ``` + +3. Add two bazel build rules for the custom type binding and the new packet + methods in the BUILD file. + + ``` + load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + + pybind_extension( + name = "my_type_binding", + srcs = ["my_type_binding.cc"], + deps = [":my_type"], + ) + + pybind_extension( + name = "my_packet_methods", + srcs = ["my_packet_methods.cc"], + deps = [ + ":my_type", + "//mediapipe/framework:packet" + ], + ) + ``` + +4. Build the pybind extension targets (with the suffix .so) by Bazel and move the generated dynamic libraries into one of the $LD_LIBRARY_PATH dirs. + +5. Use the binding modules in Python. + + ```python + import my_type_binding + import my_packet_methods + + packet = my_packet_methods.create_my_type(my_type_binding.MyType()) + my_type = my_packet_methods.get_my_type(packet) + ``` + +### Timestamp + +Each packet contains a timestamp that is in units of microseconds. In Python, +the Packet API provides a convenience method `packet.at()` to define the numeric +timestamp of a packet. More generally, `packet.timestamp` is the packet class +property for accessing the underlying timestamp. To convert an Unix epoch to a +MediaPipe timestamp, +[the Timestamp API](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/timestamp.cc) +offers a method `mp.Timestamp.from_seconds()` for this purpose. + +### ImageFrame + +ImageFrame is the container for storing an image or a video frame. Formats +supported by ImageFrame are listed in +[the ImageFormat enum](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/image_frame.cc#l=170). +Pixels are encoded row-major with interleaved color components, and ImageFrame +supports uint8, uint16, and float as its data types. MediaPipe provides +[an ImageFrame Python API](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/image_frame.cc) +to access the ImageFrame C++ class. In Python, the easiest way to retrieve the +pixel data is to call `image_frame.numpy_view()` to get a numpy ndarray. Note +that the returned numpy ndarray, a reference to the internal pixel data, is +unwritable. If the callers need to modify the numpy ndarray, it's required to +explicitly call a copy operation to obtain a copy. When MediaPipe takes a numpy +ndarray to make an ImageFrame, it assumes that the data is stored contiguously. +Correspondingly, the pixel data of an ImageFrame will be realigned to be +contiguous when it's returned to the Python side. + +### Graph + +In MediaPipe, all processing takes places within the context of a +CalculatorGraph. +[The CalculatorGraph Python API](https://github.com/google/mediapipe/tree/master/mediapipe/python/pybind/calculator_graph.cc) +is a direct binding to the C++ CalculatorGraph class. The major difference is +the CalculatorGraph Python API raises a Python error instead of returning a +non-OK Status when an error occurs. Therefore, as a Python user, you can handle +the exceptions as you normally do. The life cycle of a CalculatorGraph contains +three stages: initialization and setup, graph run, and graph shutdown. + +1. Initialize a CalculatorGraph with a CalculatorGraphConfig protobuf or binary + protobuf file, and provide callback method(s) to observe the output + stream(s). + + Option 1. Initialize a CalculatorGraph with a CalculatorGraphConfig protobuf + or its text representation, and observe the output stream(s): + + ```python + import mediapipe as mp + + config_text = """ + input_stream: 'in_stream' + output_stream: 'out_stream' + node { + calculator: 'PassThroughCalculator' + input_stream: 'in_stream' + output_stream: 'out_stream' + } + """ + graph = mp.CalculatorGraph(graph_config=config_text) + output_packets = [] + graph.observe_output_stream( + 'out_stream', + lambda stream_name, packet: + output_packets.append(mp.packet_getter.get_str(packet))) + ``` + + Option 2. Initialize a CalculatorGraph with with a binary protobuf file, and + observe the output stream(s). + + ```python + import mediapipe as mp + # resources dependency + + graph = mp.CalculatorGraph( + binary_graph=os.path.join( + resources.GetRunfilesDir(), 'path/to/your/graph.binarypb')) + graph.observe_output_stream( + 'out_stream', + lambda stream_name, packet: print(f'Get {packet} from {stream_name}')) + ``` + +2. Start the graph run and feed packets into the graph. + + ```python + graph.start_run() + + graph.add_packet_to_input_stream( + 'in_stream', mp.packet_creator.create_str('abc').at(0)) + + rgb_img = cv2.cvtColor(cv2.imread('/path/to/your/image.png'), cv2.COLOR_BGR2RGB) + graph.add_packet_to_input_stream( + 'in_stream', + mp.packet_creator.create_image_frame(format=mp.ImageFormat.SRGB, + data=rgb_img).at(1)) + ``` + +3. Close the graph after finish. You may restart the graph for another graph + run after the call to `close()`. + + ```python + graph.close() + ``` + +The Python script can be run by your local Python runtime. diff --git a/docs/images/box_coordinate.svg b/docs/images/box_coordinate.svg new file mode 100644 index 000000000..f436de896 --- /dev/null +++ b/docs/images/box_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
UP
UP
Front
Front
(0, 0, 0)
(0, 0, 0)
+Y
+Y
+X
+X
Viewer does not support full SVG 1.1
diff --git a/docs/images/camera_coordinate.svg b/docs/images/camera_coordinate.svg new file mode 100644 index 000000000..4cd3158ee --- /dev/null +++ b/docs/images/camera_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
+Y
+Y
+X
+X
-Z
-Z
(l, t, -n)
(l, t,...
(l, b, -n)
(l, b, -...
(r, t, n)
(r, t,...
(r, b, -n)
(r, b, -...
Viewer does not support full SVG 1.1
diff --git a/docs/images/face_geometry_metric_3d_space.gif b/docs/images/face_geometry_metric_3d_space.gif new file mode 100644 index 000000000..1ecd20921 Binary files /dev/null and b/docs/images/face_geometry_metric_3d_space.gif differ diff --git a/docs/images/face_geometry_renderer.gif b/docs/images/face_geometry_renderer.gif new file mode 100644 index 000000000..1f18f765f Binary files /dev/null and b/docs/images/face_geometry_renderer.gif differ diff --git a/docs/images/face_mesh_ar_effects.gif b/docs/images/face_mesh_ar_effects.gif index 868a40c4d..cf56ec719 100644 Binary files a/docs/images/face_mesh_ar_effects.gif and b/docs/images/face_mesh_ar_effects.gif differ diff --git a/docs/images/mobile/hand_landmarks.png b/docs/images/mobile/hand_landmarks.png new file mode 100644 index 000000000..f13746a86 Binary files /dev/null and b/docs/images/mobile/hand_landmarks.png differ diff --git a/docs/images/mobile/holistic_pipeline_example.jpg b/docs/images/mobile/holistic_pipeline_example.jpg new file mode 100644 index 000000000..a35b3784b Binary files /dev/null and b/docs/images/mobile/holistic_pipeline_example.jpg differ diff --git a/docs/images/mobile/holistic_sports_and_gestures_example.gif b/docs/images/mobile/holistic_sports_and_gestures_example.gif new file mode 100644 index 000000000..d579e77ab Binary files /dev/null and b/docs/images/mobile/holistic_sports_and_gestures_example.gif differ diff --git a/docs/images/mobile/holistic_tracking_android_gpu_small.gif b/docs/images/mobile/holistic_tracking_android_gpu_small.gif new file mode 100644 index 000000000..8cf0c226f Binary files /dev/null and b/docs/images/mobile/holistic_tracking_android_gpu_small.gif differ diff --git a/docs/images/mobile/objectron_camera_android_gpu.gif b/docs/images/mobile/objectron_camera_android_gpu.gif new file mode 100644 index 000000000..2ac32104d Binary files /dev/null and b/docs/images/mobile/objectron_camera_android_gpu.gif differ diff --git a/docs/images/mobile/objectron_chair_android_gpu.gif b/docs/images/mobile/objectron_chair_android_gpu.gif index abd1652ca..d2e0ef671 100644 Binary files a/docs/images/mobile/objectron_chair_android_gpu.gif and b/docs/images/mobile/objectron_chair_android_gpu.gif differ diff --git a/docs/images/mobile/objectron_chair_android_gpu_small.gif b/docs/images/mobile/objectron_chair_android_gpu_small.gif index bef4c5b18..919bc0335 100644 Binary files a/docs/images/mobile/objectron_chair_android_gpu_small.gif and b/docs/images/mobile/objectron_chair_android_gpu_small.gif differ diff --git a/docs/images/mobile/objectron_cup_android_gpu.gif b/docs/images/mobile/objectron_cup_android_gpu.gif new file mode 100644 index 000000000..6b49e8f17 Binary files /dev/null and b/docs/images/mobile/objectron_cup_android_gpu.gif differ diff --git a/docs/images/mobile/objectron_shoe_android_gpu.gif b/docs/images/mobile/objectron_shoe_android_gpu.gif index 117cdc5de..ad0ae3697 100644 Binary files a/docs/images/mobile/objectron_shoe_android_gpu.gif and b/docs/images/mobile/objectron_shoe_android_gpu.gif differ diff --git a/docs/images/mobile/pose_classification_pairwise_distances.png b/docs/images/mobile/pose_classification_pairwise_distances.png new file mode 100644 index 000000000..1aa2206df Binary files /dev/null and b/docs/images/mobile/pose_classification_pairwise_distances.png differ diff --git a/docs/images/mobile/pose_classification_pushups_and_squats.gif b/docs/images/mobile/pose_classification_pushups_and_squats.gif new file mode 100644 index 000000000..fe75f3bca Binary files /dev/null and b/docs/images/mobile/pose_classification_pushups_and_squats.gif differ diff --git a/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg b/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg new file mode 100644 index 000000000..269e1b86b Binary files /dev/null and b/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg differ diff --git a/docs/images/mobile/pose_tracking_full_body_landmarks.png b/docs/images/mobile/pose_tracking_full_body_landmarks.png new file mode 100644 index 000000000..89530d9e4 Binary files /dev/null and b/docs/images/mobile/pose_tracking_full_body_landmarks.png differ diff --git a/docs/images/mobile/pose_tracking_upper_body_landmarks.png b/docs/images/mobile/pose_tracking_upper_body_landmarks.png index cb18ad567..e2e964ec1 100644 Binary files a/docs/images/mobile/pose_tracking_upper_body_landmarks.png and b/docs/images/mobile/pose_tracking_upper_body_landmarks.png differ diff --git a/docs/images/ndc_coordinate.svg b/docs/images/ndc_coordinate.svg new file mode 100644 index 000000000..038660fd4 --- /dev/null +++ b/docs/images/ndc_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
(0, 0, 0)
(0, 0, 0)
+Y
+Y
+X
+X
(-1, 1, -1)
(-1, 1, -1)
(1, -1, -1)
(1, -1, -1)
(-1, -1, -1)
(-1, -1, -1)
(-1, -1, 1)
(-1, -1, 1)
(1, -1, 1)
(1, -1, 1)
(1, 1, 1)
(1, 1, 1)
Viewer does not support full SVG 1.1
diff --git a/docs/images/objectron_2stage_network_architecture.png b/docs/images/objectron_2stage_network_architecture.png new file mode 100644 index 000000000..591f31f64 Binary files /dev/null and b/docs/images/objectron_2stage_network_architecture.png differ diff --git a/docs/index.md b/docs/index.md index 3b67a53fa..d3db8892d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,47 +8,60 @@ nav_order: 1 -------------------------------------------------------------------------------- -## Cross-platform ML solutions made simple +## Live ML anywhere -[MediaPipe](https://google.github.io/mediapipe/) is the simplest way for researchers -and developers to build world-class ML solutions and applications for mobile, -desktop/cloud, web and IoT devices. +[MediaPipe](https://google.github.io/mediapipe/) offers cross-platform, customizable +ML solutions for live and streaming media. ![accelerated.png](images/accelerated_small.png) | ![cross_platform.png](images/cross_platform_small.png) :------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------: -***End-to-End acceleration***: *built-in fast ML inference and processing accelerated even on common hardware* | ***Build one, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* +***End-to-End acceleration***: *Built-in fast ML inference and processing accelerated even on common hardware* | ***Build once, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* ![ready_to_use.png](images/ready_to_use_small.png) | ![open_source.png](images/open_source_small.png) ***Ready-to-use solutions***: *Cutting-edge ML solutions demonstrating full power of the framework* | ***Free and open source***: *Framework and solutions both under Apache 2.0, fully extensible and customizable* ## ML solutions in MediaPipe -Face Detection | Face Mesh | Iris | Hands | Pose | Hair Segmentation -:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :---------------: -[![face_detection](images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) +Face Detection | Face Mesh | Iris | Hands | Pose | Holistic +:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :------: +[![face_detection](images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) -Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT -:----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: -[![object_detection](images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) +Hair Segmentation | Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT +:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: +[![hair_segmentation](images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) -[]() | Android | iOS | Desktop | Python | Web | Coral -:---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ -[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | -[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ | -[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | -[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | -[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ -[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | -[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | -[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | -[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | -[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | -[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | +[]() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) +:---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ +[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | +[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | +[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ +[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | +[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | +[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | +[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | +[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | + +See also +[MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) +for ML models released in MediaPipe. + +## MediaPipe in Python + +MediaPipe offers customizable Python solutions as a prebuilt Python package on +[PyPI](https://pypi.org/project/mediapipe/), which can be installed simply with +`pip install mediapipe`. It also provides tools for users to build their own +solutions. Please see +[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) +for more info. ## MediaPipe on the Web @@ -89,7 +102,13 @@ run code search using ## Publications -* [Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog) +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) + in Google AI Blog +* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) + in Google AI Blog +* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html) + in Google Developers Blog +* [Instant Motion Tracking With MediaPipe](https://developers.googleblog.com/2020/08/instant-motion-tracking-with-mediapipe.html) in Google Developers Blog * [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) in Google AI Blog diff --git a/docs/solutions/autoflip.md b/docs/solutions/autoflip.md index 3dec7719b..0e118cc55 100644 --- a/docs/solutions/autoflip.md +++ b/docs/solutions/autoflip.md @@ -2,14 +2,20 @@ layout: default title: AutoFlip (Saliency-aware Video Cropping) parent: Solutions -nav_order: 12 +nav_order: 13 --- # AutoFlip: Saliency-aware Video Cropping {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview diff --git a/docs/solutions/box_tracking.md b/docs/solutions/box_tracking.md index 34fed0277..0e7550e7f 100644 --- a/docs/solutions/box_tracking.md +++ b/docs/solutions/box_tracking.md @@ -2,14 +2,20 @@ layout: default title: Box Tracking parent: Solutions -nav_order: 8 +nav_order: 9 --- # MediaPipe Box Tracking {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -105,9 +111,8 @@ new detections to remove obsolete or duplicated boxes. ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), [iOS](../getting_started/building_examples.md#ios) -and [desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe -examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 036624332..f04af27d7 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -8,8 +8,14 @@ nav_order: 1 # MediaPipe Face Detection {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -33,12 +39,174 @@ section. ![face_detection_android_gpu.gif](../images/mobile/face_detection_android_gpu.gif) +## Solution APIs + +### Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the +detection to be considered successful. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### detections + +Collection of detected faces, where each face is represented as a detection +proto message that contains a bounding box and 6 key points (right eye, left +eye, nose tip, mouth center, right ear tragion, and left ear tragion). The +bounding box is composed of `xmin` and `width` (both normalized to `[0.0, 1.0]` +by the image width) and `ymin` and `height` (both normalized to `[0.0, 1.0]` by +the image height). Each key point is composed of `x` and `y`, which are +normalized to `[0.0, 1.0]` by the image width and height respectively. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [min_detection_confidence](#min_detection_confidence) + +```python +import cv2 +import mediapipe as mp +mp_face_detction = mp.solutions.face_detection + +# For static images: +with mp_face_detection.FaceDetection( + min_detection_confidence=0.5) as face_detection: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB and process it with MediaPipe Face Detection. + results = face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Draw face detections of each face. + if not results.detections: + continue + annotated_image = image.copy() + for detection in results.detections: + print('Nose tip:') + print(mp_face_detection.get_key_point( + detection, mp_face_detection.FaceKeyPoint.NOSE_TIP)) + mp_drawing.draw_detection(annotated_image, detection) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_face_detection.FaceDetection( + min_detection_confidence=0.5) as face_detection: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = face_detection.process(image) + + # Draw the face detection annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.detections: + for detection in results.detections: + mp_drawing.draw_detection(image, detection) + cv2.imshow('MediaPipe Face Detection', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [minDetectionConfidence](#min_detection_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), [iOS](../getting_started/building_examples.md#ios) -and [desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe -examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -102,9 +270,6 @@ to cross-compile and run MediaPipe examples on the [BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs](https://arxiv.org/abs/1907.05047) ([presentation](https://docs.google.com/presentation/d/1YCtASfnYyZtH-41QvnW5iZxELFnf0MF-pPWSLGj8yjQ/present?slide=id.g5bc8aeffdd_1_0)) ([poster](https://drive.google.com/file/d/1u6aB6wxDY7X2TmeUUKgFydulNtXkb3pu/view)) -* For front-facing/selfie camera: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), - [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite) -* For back-facing camera: - [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite) -* [Model card](https://mediapipe.page.link/blazeface-mc) +* [Models and model cards](./models.md#face_detection) +* [Web demo](https://code.mediapipe.dev/codepen/face_detection) +* [Python Colab](https://mediapipe.page.link/face_detection_py_colab) diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 712ea5b0b..0c620120c 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -8,8 +8,14 @@ nav_order: 2 # MediaPipe Face Mesh {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -19,13 +25,18 @@ landmarks in real-time even on mobile devices. It employs machine learning (ML) to infer the 3D surface geometry, requiring only a single camera input without the need for a dedicated depth sensor. Utilizing lightweight model architectures together with GPU acceleration throughout the pipeline, the solution delivers -real-time performance critical for live experiences. The core of the solution is -the same as what powers -[YouTube Stories](https://youtube-creators.googleblog.com/2018/11/introducing-more-ways-to-share-your.html)' -creator effects, the -[Augmented Faces API in ARCore](https://developers.google.com/ar/develop/java/augmented-faces/) -and the -[ML Kit Face Contour Detection API](https://firebase.google.com/docs/ml-kit/face-detection-concepts#contours). +real-time performance critical for live experiences. + +Additionally, the solution is bundled with the Face Geometry module that bridges +the gap between the face landmark estimation and useful real-time augmented +reality (AR) applications. It establishes a metric 3D space and uses the face +landmark screen positions to estimate face geometry within that space. The face +geometry data consists of common 3D geometry primitives, including a face pose +transformation matrix and a triangular face mesh. Under the hood, a lightweight +statistical analysis method called +[Procrustes Analysis](https://en.wikipedia.org/wiki/Procrustes_analysis) is +employed to drive a robust, performant and portable logic. The analysis runs on +CPU and has a minimal speed/memory footprint on top of the ML model inference. ![face_mesh_ar_effects.gif](../images/face_mesh_ar_effects.gif) | :-------------------------------------------------------------: | @@ -67,15 +78,15 @@ Note: To visualize a graph, copy the graph and paste it into to visualize its associated subgraphs, please see [visualizer documentation](../tools/visualizer.md). -## Models +### Models -### Face Detection Model +#### Face Detection Model The face detector is the same [BlazeFace](https://arxiv.org/abs/1907.05047) model used in [MediaPipe Face Detection](./face_detection.md). Please refer to [MediaPipe Face Detection](./face_detection.md) for details. -### Face Landmark Model +#### Face Landmark Model For 3D face landmarks we employed transfer learning and trained a network with several objectives: the network simultaneously predicts 3D landmark coordinates @@ -98,20 +109,336 @@ You can find more information about the face landmark model in this ![face_mesh_android_gpu.gif](../images/mobile/face_mesh_android_gpu.gif) | :------------------------------------------------------------------------: | -*Fig 2. Output of MediaPipe Face Mesh: the red box indicates the cropped area as input to the landmark model, the red dots represent the 468 landmarks in 3D, and the green lines connecting landmarks illustrate the contours around the eyes, eyebrows, lips and the entire face.* | +*Fig 2. Face landmarks: the red box indicates the cropped area as input to the landmark model, the red dots represent the 468 landmarks in 3D, and the green lines connecting landmarks illustrate the contours around the eyes, eyebrows, lips and the entire face.* | + +## Face Geometry Module + +The [Face Landmark Model](#face-landmark-model) performs a single-camera face landmark +detection in the screen coordinate space: the X- and Y- coordinates are +normalized screen coordinates, while the Z coordinate is relative and is scaled +as the X coodinate under the +[weak perspective projection camera model](https://en.wikipedia.org/wiki/3D_projection#Weak_perspective_projection). +This format is well-suited for some applications, however it does not directly +enable the full spectrum of augmented reality (AR) features like aligning a +virtual 3D object with a detected face. + +The +[Face Geometry module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry) +moves away from the screen coordinate space towards a metric 3D space and +provides necessary primitives to handle a detected face as a regular 3D object. +By design, you'll be able to use a perspective camera to project the final 3D +scene back into the screen coordinate space with a guarantee that the face +landmark positions are not changed. + +### Key Concepts + +#### Metric 3D Space + +The **Metric 3D space** established within the Face Geometry module is a +right-handed orthonormal metric 3D coordinate space. Within the space, there is +a **virtual perspective camera** located at the space origin and pointed in the +negative direction of the Z-axis. In the current pipeline, it is assumed that +the input camera frames are observed by exactly this virtual camera and +therefore its parameters are later used to convert the screen landmark +coordinates back into the Metric 3D space. The *virtual camera parameters* can +be set freely, however for better results it is advised to set them as close to +the *real physical camera parameters* as possible. + +![face_geometry_metric_3d_space.gif](../images/face_geometry_metric_3d_space.gif) | +:----------------------------------------------------------------------------: | +*Fig 3. A visualization of multiple key elements in the Metric 3D space.* | + +#### Canonical Face Model + +The **Canonical Face Model** is a static 3D model of a human face, which follows +the 468 3D face landmark topology of the +[Face Landmark Model](#face-landmark-model). The model bears two important +functions: + +- **Defines metric units**: the scale of the canonical face model defines the + metric units of the Metric 3D space. A metric unit used by the + [default canonical face model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model.fbx) + is a centimeter; +- **Bridges static and runtime spaces**: the face pose transformation matrix + is - in fact - a linear map from the canonical face model into the runtime + face landmark set estimated on each frame. This way, virtual 3D assets + modeled around the canonical face model can be aligned with a tracked face + by applying the face pose transformation matrix to them. + +### Components + +#### Geometry Pipeline + +The **Geometry Pipeline** is a key component, which is responsible for +estimating face geometry objects within the Metric 3D space. On each frame, the +following steps are executed in the given order: + +- Face landmark screen coordinates are converted into the Metric 3D space + coordinates; +- Face pose transformation matrix is estimated as a rigid linear mapping from + the canonical face metric landmark set into the runtime face metric landmark + set in a way that minimizes a difference between the two; +- A face mesh is created using the runtime face metric landmarks as the vertex + positions (XYZ), while both the vertex texture coordinates (UV) and the + triangular topology are inherited from the canonical face model. + +The geometry pipeline is implemented as a MediaPipe +[calculator](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc). +For your convenience, the face geometry pipeline calculator is bundled together +with corresponding metadata into a unified MediaPipe +[subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt). +The face geometry format is defined as a Protocol Buffer +[message](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/face_geometry.proto). + +#### Effect Renderer + +The **Effect Renderer** is a component, which serves as a working example of a +face effect renderer. It targets the *OpenGL ES 2.0* API to enable a real-time +performance on mobile devices and supports the following rendering modes: + +- **3D object rendering mode**: a virtual object is aligned with a detected + face to emulate an object attached to the face (example: glasses); +- **Face mesh rendering mode**: a texture is stretched on top of the face mesh + surface to emulate a face painting technique. + +In both rendering modes, the face mesh is first rendered as an occluder straight +into the depth buffer. This step helps to create a more believable effect via +hiding invisible elements behind the face surface. + +The effect renderer is implemented as a MediaPipe +[calculator](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/effect_renderer_calculator.cc). + +| ![face_geometry_renderer.gif](../images/face_geometry_renderer.gif) | +| :---------------------------------------------------------------------: | +| *Fig 4. An example of face effects rendered by the Face Geometry Effect Renderer.* | + +## Solution APIs + +### Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect faces in the first input images, and upon a successful +detection further localizes the face landmarks. In subsequent images, once all +[max_num_faces](#max_num_faces) faces are detected and the corresponding face +landmarks are localized, it simply tracks those landmarks without invoking +another detection until it loses track of any of the faces. This reduces latency +and is ideal for processing video frames. If set to `true`, face detection runs +on every input image, ideal for processing a batch of static, possibly +unrelated, images. Default to `false`. + +#### max_num_faces + +Maximum number of faces to detect. Default to `1`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +face landmarks to be considered tracked successfully, or otherwise face +detection will be invoked automatically on the next input image. Setting it to a +higher value can increase robustness of the solution, at the expense of a higher +latency. Ignored if [static_image_mode](#static_image_mode) is `true`, where +face detection simply runs on every image. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### multi_face_landmarks + +Collection of detected/tracked faces, where each face is represented as a list +of 468 face landmarks and each landmark is composed of `x`, `y` and `z`. `x` and +`y` are normalized to `[0.0, 1.0]` by the image width and height respectively. +`z` represents the landmark depth with the depth at center of the head being the +origin, and the smaller the value the closer the landmark is to the camera. The +magnitude of `z` uses roughly the same scale as `x`. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [max_num_faces](#max_num_faces) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_face_mesh = mp.solutions.face_mesh + +# For static images: +drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) +with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=1, + min_detection_confidence=0.5) as face_mesh: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB before processing. + results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Print and draw face mesh landmarks on the image. + if not results.multi_face_landmarks: + continue + annotated_image = image.copy() + for face_landmarks in results.multi_face_landmarks: + print('face_landmarks:', face_landmarks) + mp_drawing.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACE_CONNECTIONS, + landmark_drawing_spec=drawing_spec, + connection_drawing_spec=drawing_spec) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) +cap = cv2.VideoCapture(0) +with mp_face_mesh.FaceMesh( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as face_mesh: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = face_mesh.process(image) + + # Draw the face mesh annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.multi_face_landmarks: + for face_landmarks in results.multi_face_landmarks: + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACE_CONNECTIONS, + landmark_drawing_spec=drawing_spec, + connection_drawing_spec=drawing_spec) + cv2.imshow('MediaPipe FaceMesh', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [maxNumFaces](#max_num_faces) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), [iOS](../getting_started/building_examples.md#ios) and -[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see [visualizer documentation](../tools/visualizer.md). -### Mobile +### Face Landmark Example + +Face landmark example showcases real-time, cross-platform face landmark +detection. For visual reference, please refer to *Fig. 2*. + +#### Mobile * Graph: [`mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt) @@ -127,7 +454,7 @@ it, for Android modify `NUM_FACES` in and for iOS modify `kNumFaces` in [FaceMeshGpuViewController.mm](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/facemeshgpu/FaceMeshGpuViewController.mm). -### Desktop +#### Desktop * Running on CPU * Graph: @@ -143,18 +470,37 @@ and for iOS modify `kNumFaces` in Tip: Maximum number of faces to detect/process is set to 1 by default. To change it, in the graph file modify the option of `ConstantSidePacketCalculator`. +### Face Effect Example + +Face effect example showcases real-time mobile face effect application use case +for the Face Mesh solution. To enable a better user experience, this example +only works for a single face. For visual reference, please refer to *Fig. 4*. + +#### Mobile + +* Graph: + [`mediapipe/graphs/face_effect/face_effect_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_effect/face_effect_gpu.pbtxt) +* Android target: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1ccnaDnffEuIXriBZr2SK_Eu4FpO7K44s) + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD) +* iOS target: + [`mediapipe/examples/ios/faceeffect`](http:/mediapipe/examples/ios/faceeffect/BUILD) + ## Resources * Google AI Blog: [Real-Time AR Self-Expression with Machine Learning](https://ai.googleblog.com/2019/03/real-time-ar-self-expression-with.html) * TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html) +* Google Developers Blog: + [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html) * Paper: [Real-time Facial Surface Geometry from Monocular Video on Mobile GPUs](https://arxiv.org/abs/1907.06724) ([poster](https://docs.google.com/presentation/d/1-LWwOMO9TzEVdrZ1CS1ndJzciRHfYDJfbSxH_ke_JRg/present?slide=id.g5986dd4b4c_4_212)) -* Face detection model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite) -* Face landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark.tflite), - [TF.js model](https://tfhub.dev/mediapipe/facemesh/1) -* [Model card](https://mediapipe.page.link/facemesh-mc) +* Canonical face model: + [FBX](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model.fbx), + [OBJ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model.obj), + [UV visualization](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualization.png) +* [Models and model cards](./models.md#face_mesh) +* [Web demo](https://code.mediapipe.dev/codepen/face_mesh) +* [Python Colab](https://mediapipe.page.link/face_mesh_py_colab) diff --git a/docs/solutions/hair_segmentation.md b/docs/solutions/hair_segmentation.md index 0521ad60d..5e2e4a7c5 100644 --- a/docs/solutions/hair_segmentation.md +++ b/docs/solutions/hair_segmentation.md @@ -2,14 +2,20 @@ layout: default title: Hair Segmentation parent: Solutions -nav_order: 6 +nav_order: 7 --- # MediaPipe Hair Segmentation {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ![hair_segmentation_android_gpu_gif](../images/mobile/hair_segmentation_android_gpu.gif) @@ -17,9 +23,8 @@ nav_order: 6 ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), [iOS](../getting_started/building_examples.md#ios) -and [desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe -examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -54,5 +59,4 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web). [Real-time Hair segmentation and recoloring on Mobile GPUs](https://arxiv.org/abs/1907.06740) ([presentation](https://drive.google.com/file/d/1C8WYlWdDRNtU1_pYBvkkG5Z5wqYqf0yj/view)) ([supplementary video](https://drive.google.com/file/d/1LPtM99Ch2ogyXYbDNpEqnUfhFq0TfLuf/view)) -* [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hair_segmentation.tflite) -* [Model card](https://mediapipe.page.link/hairsegmentation-mc) +* [Models and model cards](./models.md#hair_segmentation) diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 8edfd5850..ac10124f2 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -8,8 +8,14 @@ nav_order: 4 # MediaPipe Hands {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -55,13 +61,21 @@ frame, and only when the landmark model could no longer identify hand presence is palm detection invoked to relocalize the hand. The pipeline is implemented as a MediaPipe -[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt), -which internally utilizes a -[palm/hand detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt), -a -[hand landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt) -and a -[renderer subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/renderer_gpu.pbtxt). +[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt) +that uses a +[hand landmark tracking subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt) +from the +[hand landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark), +and renders using a dedicated +[hand renderer subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_renderer_gpu.pbtxt). +The +[hand landmark tracking subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt) +internally uses a +[hand landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_gpu.pbtxt) +from the same module and a +[palm detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection_gpu.pbtxt) +from the +[palm detection module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection). Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -77,13 +91,14 @@ To detect initial hand locations, we designed a mobile real-time uses in a manner similar to the face detection model in [MediaPipe Face Mesh](./face_mesh.md). Detecting hands is a decidedly complex task: our -[model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite) has -to work across a variety of hand sizes with a large scale span (~20x) relative -to the image frame and be able to detect occluded and self-occluded hands. -Whereas faces have high contrast patterns, e.g., in the eye and mouth region, -the lack of such features in hands makes it comparatively difficult to detect -them reliably from their visual features alone. Instead, providing additional -context, like arm, body, or person features, aids accurate hand localization. +[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite) +has to work across a variety of hand sizes with a large scale span (~20x) +relative to the image frame and be able to detect occluded and self-occluded +hands. Whereas faces have high contrast patterns, e.g., in the eye and mouth +region, the lack of such features in hands makes it comparatively difficult to +detect them reliably from their visual features alone. Instead, providing +additional context, like arm, body, or person features, aids accurate hand +localization. Our method addresses the above challenges using different strategies. First, we train a palm detector instead of a hand detector, since estimating bounding @@ -105,7 +120,7 @@ just 86.22%. ### Hand Landmark Model After the palm detection over the whole image our subsequent hand landmark -[model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hand_landmark.tflite) +[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite) performs precise keypoint localization of 21 3D hand-knuckle coordinates inside the detected hand regions via regression, that is direct coordinate prediction. The model learns a consistent internal hand pose representation and is robust @@ -118,16 +133,236 @@ and provide additional supervision on the nature of hand geometry, we also render a high-quality synthetic hand model over various backgrounds and map it to the corresponding 3D coordinates. -| ![hand_crops.png](../images/mobile/hand_crops.png) | -| :-------------------------------------------------------------------------: | -| *Fig 2. Top: Aligned hand crops passed to the tracking network with ground truth annotation. Bottom: Rendered synthetic hand images with ground truth annotation.* | +![hand_landmarks.png](../images/mobile/hand_landmarks.png) | +:--------------------------------------------------------: | +*Fig 2. 21 hand landmarks.* | + +![hand_crops.png](../images/mobile/hand_crops.png) | +:-------------------------------------------------------------------------: | +*Fig 3. Top: Aligned hand crops passed to the tracking network with ground truth annotation. Bottom: Rendered synthetic hand images with ground truth annotation.* | + +## Solution APIs + +### Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect hands in the first input images, and upon a successful +detection further localizes the hand landmarks. In subsequent images, once all +[max_num_hands](#max_num_hands) hands are detected and the corresponding hand +landmarks are localized, it simply tracks those landmarks without invoking +another detection until it loses track of any of the hands. This reduces latency +and is ideal for processing video frames. If set to `true`, hand detection runs +on every input image, ideal for processing a batch of static, possibly +unrelated, images. Default to `false`. + +#### max_num_hands + +Maximum number of hands to detect. Default to `2`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence: + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +hand landmarks to be considered tracked successfully, or otherwise hand +detection will be invoked automatically on the next input image. Setting it to a +higher value can increase robustness of the solution, at the expense of a higher +latency. Ignored if [static_image_mode](#static_image_mode) is `true`, where +hand detection simply runs on every image. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### multi_hand_landmarks + +Collection of detected/tracked hands, where each hand is represented as a list +of 21 hand landmarks and each landmark is composed of `x`, `y` and `z`. `x` and +`y` are normalized to `[0.0, 1.0]` by the image width and height respectively. +`z` represents the landmark depth with the depth at the wrist being the origin, +and the smaller the value the closer the landmark is to the camera. The +magnitude of `z` uses roughly the same scale as `x`. + +#### multi_handedness + +Collection of handedness of the detected/tracked hands (i.e. is it a left or +right hand). Each hand is composed of `label` and `score`. `label` is a string +of value either `"Left"` or `"Right"`. `score` is the estimated probability of +the predicted handedness and is always greater than or equal to `0.5` (and the +opposite handedness has an estimated probability of `1 - score`). + +Note that handedness is determined assuming the input image is mirrored, i.e., +taken with a front-facing/selfie camera with images flipped horizontally. If it +is not the case, please swap the handedness output in the application. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [max_num_hands](#max_num_hands) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_hands = mp.solutions.hands + +# For static images: +with mp_hands.Hands( + static_image_mode=True, + max_num_hands=2, + min_detection_confidence=0.5) as hands: + for idx, file in enumerate(file_list): + # Read an image, flip it around y-axis for correct handedness output (see + # above). + image = cv2.flip(cv2.imread(file), 1) + # Convert the BGR image to RGB before processing. + results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Print handedness and draw hand landmarks on the image. + print('Handedness:', results.multi_handedness) + if not results.multi_hand_landmarks: + continue + image_height, image_width, _ = image.shape + annotated_image = image.copy() + for hand_landmarks in results.multi_hand_landmarks: + print('hand_landmarks:', hand_landmarks) + print( + f'Index finger tip coordinates: (', + f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].x * image_width}, ' + f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})' + ) + mp_drawing.draw_landmarks( + annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + cv2.imwrite( + '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_hands.Hands( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as hands: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = hands.process(image) + + # Draw the hand annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.multi_hand_landmarks: + for hand_landmarks in results.multi_hand_landmarks: + mp_drawing.draw_landmarks( + image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + cv2.imshow('MediaPipe Hands', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and a [fun application], and the following usage example. + +Supported configuration options: + +* [maxNumHands](#max_num_hands) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), [iOS](../getting_started/building_examples.md#ios) -and [desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe -examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -146,34 +381,11 @@ to visualize its associated subgraphs, please see * iOS target: [`mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handtrackinggpu/BUILD) -#### With Multi-hand Support - -* Graph: - [`mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt) -* Android target: - [(or download prebuilt ARM64 APK)](https://drive.google.com/open?id=1Wk6V9EVaz1ks_MInPqqVGvvJD01SGXDc) - [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu:multihandtrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/BUILD) -* iOS target: - [`mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/multihandtrackinggpu/BUILD) - -There are two key differences between this graph and that in the -[main example](#main-example) (which handles only one hand): - -1. There is a `NormalizedRectVectorHasMinSize` calculator, that checks if in - input vector of `NormalizedRect` objects has a minimum size equal to `N`. In - this graph, if the vector contains fewer than `N` objects, - `MultiHandDetection` subgraph runs. Otherwise, the `GateCalculator` doesn't - send any image packets to the `MultiHandDetection` subgraph. This way, the - main graph is efficient in that it avoids running the costly hand detection - step when there are already `N` hands in the frame. -2. The `MergeCalculator` has been replaced by the `AssociationNormRect` - calculator. This `AssociationNormRect` takes as input a vector of - `NormalizedRect` objects from the `MultiHandDetection` subgraph on the - current frame, and a vector of `NormalizedRect` objects from the - `MultiHandLandmark` subgraph from the previous frame, and performs an - association operation between these objects. This calculator ensures that - the output vector doesn't contain overlapping regions based on the specified - `min_similarity_threshold`. +Tip: Maximum number of hands to detect/process is set to 2 by default. To change +it, for Android modify `NUM_HANDS` in +[MainActivity.java](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java), +and for iOS modify `kNumHands` in +[HandTrackingViewController.mm](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handtrackinggpu/HandTrackingViewController.mm). #### Palm/Hand Detection Only (no landmarks) @@ -187,8 +399,6 @@ There are two key differences between this graph and that in the ### Desktop -#### Main Example - * Running on CPU * Graph: [`mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt) @@ -196,26 +406,12 @@ There are two key differences between this graph and that in the [`mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/hand_tracking/BUILD) * Running on GPU * Graph: - [`mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt) + [`mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_desktop_gpu.pbtxt) * Target: [`mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/hand_tracking/BUILD) -#### With Multi-hand Support - -* Running on CPU - * Graph: - [`mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live) - * Target: - [`mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/multi_hand_tracking/BUILD) -* Running on GPU - * Graph: - [`mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt) - * Target: - [`mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/multi_hand_tracking/BUILD) - -### Web - -Please refer to [these instructions](../index.md#mediapipe-on-the-web). +Tip: Maximum number of hands to detect/process is set to 2 by default. To change +it, in the graph file modify the option of `ConstantSidePacketCalculator`. ## Resources @@ -226,10 +422,7 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web). * Paper: [MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214) ([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk)) -* Palm detection model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite), - [TF.js model](https://tfhub.dev/mediapipe/handdetector/1) -* Hand landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hand_landmark.tflite), - [TF.js model](https://tfhub.dev/mediapipe/handskeleton/1) -* [Model card](https://mediapipe.page.link/handmc) +* [Models and model cards](./models.md#hands) +* [Web demo](https://code.mediapipe.dev/codepen/hands) +* [Fun application](https://code.mediapipe.dev/codepen/defrost) +* [Python Colab](https://mediapipe.page.link/hands_py_colab) diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md new file mode 100644 index 000000000..8ee0f8ff6 --- /dev/null +++ b/docs/solutions/holistic.md @@ -0,0 +1,412 @@ +--- +layout: default +title: Holistic +parent: Solutions +nav_order: 6 +--- + +# MediaPipe Holistic +{: .no_toc } + +
+ + Table of contents + + {: .text-delta } +1. TOC +{:toc} +
+--- + +## Overview + +Live perception of simultaneous [human pose](./pose.md), +[face landmarks](./face_mesh.md), and [hand tracking](./hands.md) in real-time +on mobile devices can enable various modern life applications: fitness and sport +analysis, gesture control and sign language recognition, augmented reality +try-on and effects. MediaPipe already offers fast and accurate, yet separate, +solutions for these tasks. Combining them all in real-time into a semantically +consistent end-to-end solution is a uniquely difficult problem requiring +simultaneous inference of multiple, dependent neural networks. + +![holistic_sports_and_gestures_example.gif](../images/mobile/holistic_sports_and_gestures_example.gif) | +:----------------------------------------------------------------------------------------------------: | +*Fig 1. Example of MediaPipe Holistic.* | + +## ML Pipeline + +The MediaPipe Holistic pipeline integrates separate models for +[pose](./pose.md), [face](./face_mesh.md) and [hand](./hands.md) components, +each of which are optimized for their particular domain. However, because of +their different specializations, the input to one component is not well-suited +for the others. The pose estimation model, for example, takes a lower, fixed +resolution video frame (256x256) as input. But if one were to crop the hand and +face regions from that image to pass to their respective models, the image +resolution would be too low for accurate articulation. Therefore, we designed +MediaPipe Holistic as a multi-stage pipeline, which treats the different regions +using a region appropriate image resolution. + +First, we estimate the human pose (top of Fig 2) with [BlazePose](./pose.md)’s +pose detector and subsequent landmark model. Then, using the inferred pose +landmarks we derive three regions of interest (ROI) crops for each hand (2x) and +the face, and employ a re-crop model to improve the ROI. We then crop the +full-resolution input frame to these ROIs and apply task-specific face and hand +models to estimate their corresponding landmarks. Finally, we merge all +landmarks with those of the pose model to yield the full 540+ landmarks. + +![holistic_pipeline_example.jpg](../images/mobile/holistic_pipeline_example.jpg) | +:------------------------------------------------------------------------------: | +*Fig 2. MediaPipe Holistic Pipeline Overview.* | + +To streamline the identification of ROIs for face and hands, we utilize a +tracking approach similar to the one we use for standalone +[face](./face_mesh.md) and [hand](./hands.md) pipelines. It assumes that the +object doesn't move significantly between frames and uses estimation from the +previous frame as a guide to the object region on the current one. However, +during fast movements, the tracker can lose the target, which requires the +detector to re-localize it in the image. MediaPipe Holistic uses +[pose](./pose.md) prediction (on every frame) as an additional ROI prior to +reduce the response time of the pipeline when reacting to fast movements. This +also enables the model to retain semantic consistency across the body and its +parts by preventing a mixup between left and right hands or body parts of one +person in the frame with another. + +In addition, the resolution of the input frame to the pose model is low enough +that the resulting ROIs for face and hands are still too inaccurate to guide the +re-cropping of those regions, which require a precise input crop to remain +lightweight. To close this accuracy gap we use lightweight face and hand re-crop +models that play the role of +[spatial transformers](https://arxiv.org/abs/1506.02025) and cost only ~10% of +corresponding model's inference time. + +The pipeline is implemented as a MediaPipe +[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt) +that uses a +[holistic landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt) +from the +[holistic landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark) +and renders using a dedicated +[holistic renderer subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/holistic_tracking/holistic_tracking_to_render_data.pbtxt). +The +[holistic landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt) +internally uses a +[pose landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark) +, +[hand landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark) +and +[face landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/). +Please check them for implementation details. + +Note: To visualize a graph, copy the graph and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how +to visualize its associated subgraphs, please see +[visualizer documentation](../tools/visualizer.md). + +## Models + +### Landmark Models + +MediaPipe Holistic utilizes the pose, face and hand landmark models in +[MediaPipe Pose](./pose.md), [MediaPipe Face Mesh](./face_mesh.md) and +[MediaPipe Hands](./hands.md) respectively to generate a total of 543 landmarks +(33 pose landmarks, 468 face landmarks, and 21 hand landmarks per hand). + +### Hand Recrop Model + +For cases when the accuracy of the pose model is low enough that the resulting +ROIs for hands are still too inaccurate we run the additional lightweight hand +re-crop model that play the role of +[spatial transformer](https://arxiv.org/abs/1506.02025) and cost only ~10% of +hand model inference time. + +## Solution APIs + +### Cross-platform Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect the most prominent person in the very first images, and upon +a successful detection further localizes the pose and other landmarks. In +subsequent images, it then simply tracks those landmarks without invoking +another detection until it loses track, on reducing computation and latency. If +set to `true`, person detection runs every input image, ideal for processing a +batch of static, possibly unrelated, images. Default to `false`. + +#### upper_body_only + +If set to `true`, the solution outputs only the 25 upper-body pose landmarks +(535 in total) instead of the full set of 33 pose landmarks (543 in total). Note +that upper-body-only prediction may be more accurate for use cases where the +lower-body parts are mostly out of view. Default to `false`. + +#### smooth_landmarks + +If set to `true`, the solution filters pose landmarks across different input +images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) +is also set to `true`. Default to `true`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +pose landmarks to be considered tracked successfully, or otherwise person +detection will be invoked automatically on the next input image. Setting it to a +higher value can increase robustness of the solution, at the expense of a higher +latency. Ignored if [static_image_mode](#static_image_mode) is `true`, where +person detection simply runs on every image. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### pose_landmarks + +A list of pose landmarks. Each landmark consists of the following: + +* `x` and `y`: Landmark coordinates normalized to `[0.0, 1.0]` by the image + width and height respectively. +* `z`: Should be discarded as currently the model is not fully trained to + predict depth, but this is something on the roadmap. +* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the + landmark being visible (present and not occluded) in the image. + +#### face_landmarks + +A list of 468 face landmarks. Each landmark consists of `x`, `y` and `z`. `x` +and `y` are normalized to `[0.0, 1.0]` by the image width and height +respectively. `z` represents the landmark depth with the depth at center of the +head being the origin, and the smaller the value the closer the landmark is to +the camera. The magnitude of `z` uses roughly the same scale as `x`. + +#### left_hand_landmarks + +A list of 21 hand landmarks on the left hand. Each landmark consists of `x`, `y` +and `z`. `x` and `y` are normalized to `[0.0, 1.0]` by the image width and +height respectively. `z` represents the landmark depth with the depth at the +wrist being the origin, and the smaller the value the closer the landmark is to +the camera. The magnitude of `z` uses roughly the same scale as `x`. + +#### right_hand_landmarks + +A list of 21 hand landmarks on the right hand, in the same representation as +[left_hand_landmarks](#left_hand_landmarks). + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [upper_body_only](#upper_body_only) +* [smooth_landmarks](#smooth_landmarks) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_holistic = mp.solutions.holistic + +# For static images: +with mp_holistic.Holistic(static_image_mode=True) as holistic: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + image_height, image_width, _ = image.shape + # Convert the BGR image to RGB before processing. + results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + if results.pose_landmarks: + print( + f'Nose coordinates: (' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + ) + # Draw pose, left and right hands, and face landmarks on the image. + annotated_image = image.copy() + mp_drawing.draw_landmarks( + annotated_image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + mp_drawing.draw_landmarks( + annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + # Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when + # upper_body_only is set to True. + mp_drawing.draw_landmarks( + annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_holistic.Holistic( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as holistic: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = holistic.process(image) + + # Draw landmark annotation on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + mp_drawing.draw_landmarks( + image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + cv2.imshow('MediaPipe Holistic', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [upperBodyOnly](#upper_body_only) +* [smoothLandmarks](#smooth_landmarks) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + +## Example Apps + +Please first see general instructions for +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. + +Note: To visualize a graph, copy the graph and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how +to visualize its associated subgraphs, please see +[visualizer documentation](../tools/visualizer.md). + +### Mobile + +* Graph: + [`mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt) +* Android target: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1o-Trp2GIRitA0OvmZWUQjVMa476xpfgK/view?usp=sharing) + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu:holistictrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD) +* iOS target: + [`mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp`](http:/mediapipe/examples/ios/holistictrackinggpu/BUILD) + +### Desktop + +Please first see general instructions for [desktop](../getting_started/cpp.md) +on how to build MediaPipe examples. + +* Running on CPU + * Graph: + [`mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/holistic_tracking/holistic_tracking_cpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/holistic_tracking:holistic_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/holistic_tracking/BUILD) +* Running on GPU + * Graph: + [`mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/holistic_tracking/holistic_tracking_gpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/holistic_tracking:holistic_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/holistic_tracking/BUILD) + +## Resources + +* Google AI Blog: + [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) +* [Models and model cards](./models.md#holistic) +* [Web demo](https://code.mediapipe.dev/codepen/holistic) +* [Python Colab](https://mediapipe.page.link/holistic_py_colab) diff --git a/docs/solutions/instant_motion_tracking.md b/docs/solutions/instant_motion_tracking.md index cf23a7b8c..36e5e83e0 100644 --- a/docs/solutions/instant_motion_tracking.md +++ b/docs/solutions/instant_motion_tracking.md @@ -2,14 +2,20 @@ layout: default title: Instant Motion Tracking parent: Solutions -nav_order: 9 +nav_order: 10 --- # MediaPipe Instant Motion Tracking {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -104,19 +110,37 @@ and connected camera. ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android) on how to build -MediaPipe examples. +[Android](../getting_started/android.md) on how to build MediaPipe examples. * Graph: [mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt) * Android target (or download prebuilt [ARM64 APK](https://drive.google.com/file/d/1KnaBBoKpCHR73nOBJ4fL_YdWVTAcwe6L/view?usp=sharing)): [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking:instantmotiontracking`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD) +* Assets rendered by the [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) must be preprocessed into an OpenGL-ready custom .uuu format. This can be done +for user assets as follows: +> First run +> +> ```shell +> ./mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh [INPUT_DIR] [INTERMEDIATE_OUTPUT_DIR] +> ``` +> and then run +> +> ```build +> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] +> ``` +> INPUT_DIR should be the folder with initial asset .obj files to be processed, +> and OUTPUT_DIR is the folder where the processed asset .uuu file will be placed. +> +> Note: ObjParser combines all .obj files found in the given directory into a +> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as +> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. + ## Resources -* Google Developers Blog: - [Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog) -* Google AI Blog: - [The Instant Motion Tracking Behind Motion Stills AR](https://ai.googleblog.com/2018/02/the-instant-motion-tracking-behind.html) -* Paper: - [Instant Motion Tracking and Its Applications to Augmented Reality](https://arxiv.org/abs/1907.06796) +* Google Developers Blog: + [Instant Motion Tracking With MediaPipe](https://developers.googleblog.com/2020/08/instant-motion-tracking-with-mediapipe.html) +* Google AI Blog: + [The Instant Motion Tracking Behind Motion Stills AR](https://ai.googleblog.com/2018/02/the-instant-motion-tracking-behind.html) +* Paper: + [Instant Motion Tracking and Its Applications to Augmented Reality](https://arxiv.org/abs/1907.06796) diff --git a/docs/solutions/iris.md b/docs/solutions/iris.md index 8bf207402..61ca8049c 100644 --- a/docs/solutions/iris.md +++ b/docs/solutions/iris.md @@ -8,8 +8,14 @@ nav_order: 3 # MediaPipe Iris {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -116,10 +122,8 @@ along with some simple geometric arguments. For more details please refer to our ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), -[iOS](../getting_started/building_examples.md#ios) and -[desktop](../getting_started/building_examples.md#desktop) on how to build -MediaPipe examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md) and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -140,9 +144,8 @@ to visualize its associated subgraphs, please see #### Live Camera Input -Please first see general instructions for -[desktop](../getting_started/building_examples.md#desktop) on how to build -MediaPipe examples. +Please first see general instructions for [desktop](../getting_started/cpp.md) +on how to build MediaPipe examples. * Running on CPU * Graph: @@ -199,11 +202,4 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web). * Paper: [Real-time Pupil Tracking from Monocular Video for Digital Puppetry](https://arxiv.org/abs/2006.11341) ([presentation](https://youtu.be/cIhXkiiapQI)) -* Face detection model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite) -* Face landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark.tflite), - [TF.js model](https://tfhub.dev/mediapipe/facemesh/1) -* Iris landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_landmark/iris_landmark.tflite) -* [Model card](https://mediapipe.page.link/iris-mc) +* [Models and model cards](./models.md#iris) diff --git a/docs/solutions/knift.md b/docs/solutions/knift.md index 8e4ed98b0..41691c418 100644 --- a/docs/solutions/knift.md +++ b/docs/solutions/knift.md @@ -2,14 +2,20 @@ layout: default title: KNIFT (Template-based Feature Matching) parent: Solutions -nav_order: 11 +nav_order: 12 --- # MediaPipe KNIFT {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview @@ -67,7 +73,7 @@ you'd like to use your own template images, see ![template_matching_mobile_template.jpg](../images/mobile/template_matching_mobile_template.jpg) Please first see general instructions for -[Android](../getting_started/building_examples.md#android) on how to build MediaPipe examples. +[Android](../getting_started/android.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -139,7 +145,4 @@ to run regular TFLite inference. * Google Developers Blog: [MediaPipe KNIFT: Template-based feature matching](https://developers.googleblog.com/2020/04/mediapipe-knift-template-based-feature-matching.html) -* [TFLite model for up to 200 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float.tflite) -* [TFLite model for up to 400 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float_400.tflite) -* [TFLite model for up to 1000 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float_1k.tflite) -* [Model card](https://mediapipe.page.link/knift-mc) +* [Models and model cards](./models.md#knift) diff --git a/docs/solutions/media_sequence.md b/docs/solutions/media_sequence.md index 16a2278cd..cd3b7ecef 100644 --- a/docs/solutions/media_sequence.md +++ b/docs/solutions/media_sequence.md @@ -2,14 +2,20 @@ layout: default title: Dataset Preparation with MediaSequence parent: Solutions -nav_order: 13 +nav_order: 14 --- # Dataset Preparation with MediaSequence {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview diff --git a/docs/solutions/models.md b/docs/solutions/models.md new file mode 100644 index 000000000..b0f1fad7a --- /dev/null +++ b/docs/solutions/models.md @@ -0,0 +1,90 @@ +--- +layout: default +title: Models and Model Cards +parent: Solutions +nav_order: 30 +--- + +# MediaPipe Models and Model Cards +{: .no_toc } + +1. TOC +{:toc} +--- + +### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) + +* Face detection model for front-facing/selfie camera: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), + [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite) +* Face detection model for back-facing camera: + [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite) +* [Model card](https://mediapipe.page.link/blazeface-mc) + +### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) + +* Face landmark model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark.tflite), + [TF.js model](https://tfhub.dev/mediapipe/facemesh/1) +* [Model card](https://mediapipe.page.link/facemesh-mc) + +### [Iris](https://google.github.io/mediapipe/solutions/iris) + +* Iris landmark model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_landmark/iris_landmark.tflite) +* [Model card](https://mediapipe.page.link/iris-mc) + +### [Hands](https://google.github.io/mediapipe/solutions/hands) + +* Palm detection model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite), + [TF.js model](https://tfhub.dev/mediapipe/handdetector/1) +* Hand landmark model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite), + [TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite), + [TF.js model](https://tfhub.dev/mediapipe/handskeleton/1) +* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse) + +### [Pose](https://google.github.io/mediapipe/solutions/pose) + +* Pose detection model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite) +* Full-body pose landmark model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite) +* Upper-body pose landmark model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite) +* [Model card](https://mediapipe.page.link/blazepose-mc) + +### [Holistic](https://google.github.io/mediapipe/solutions/holistic) + +* Hand recrop model: + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/hand_recrop.tflite) + +### [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) + +* [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hair_segmentation.tflite) +* [Model card](https://mediapipe.page.link/hairsegmentation-mc) + +### [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) + +* [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/ssdlite_object_detection.tflite) +* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) +* [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) +* [Model information](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md) + +### [Objectron](https://google.github.io/mediapipe/solutions/objectron) + +* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_sneakers.tflite) +* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_chair.tflite) +* [TFLite model for cameras](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_camera.tflite) +* [TFLite model for cups](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_cup.tflite) +* [Single-stage TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite) +* [Single-stage TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite) +* [Model card](https://mediapipe.page.link/objectron-mc) + +### [KNIFT](https://google.github.io/mediapipe/solutions/knift) + +* [TFLite model for up to 200 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float.tflite) +* [TFLite model for up to 400 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float_400.tflite) +* [TFLite model for up to 1000 keypoints](https://github.com/google/mediapipe/tree/master/mediapipe/models/knift_float_1k.tflite) +* [Model card](https://mediapipe.page.link/knift-mc) diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index 1cb353d0e..044748537 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -2,14 +2,20 @@ layout: default title: Object Detection parent: Solutions -nav_order: 7 +nav_order: 8 --- # MediaPipe Object Detection {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ![object_detection_android_gpu.gif](../images/mobile/object_detection_android_gpu.gif) @@ -24,8 +30,8 @@ to visualize its associated subgraphs, please see ### Mobile Please first see general instructions for -[Android](../getting_started/building_examples.md#android) and -[iOS](../getting_started/building_examples.md#ios) on how to build MediaPipe examples. +[Android](../getting_started/android.md) and [iOS](../getting_started/ios.md) on +how to build MediaPipe examples. #### GPU Pipeline @@ -56,8 +62,8 @@ same configuration as the GPU pipeline, runs entirely on CPU. #### Live Camera Input -Please first see general instructions for -[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples. +Please first see general instructions for [desktop](../getting_started/cpp.md) +on how to build MediaPipe examples. * Graph: [`mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt) @@ -144,7 +150,4 @@ to cross-compile and run MediaPipe examples on the ## Resources -* [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/ssdlite_object_detection.tflite) -* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) -* [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) -* [Model information](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md) +* [Models and model cards](./models.md#object_detection) diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 4c18f9f0f..c689f9c40 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -2,26 +2,31 @@ layout: default title: Objectron (3D Object Detection) parent: Solutions -nav_order: 10 +nav_order: 11 --- # MediaPipe Objectron {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview MediaPipe Objectron is a mobile real-time 3D object detection solution for -everyday objects. It detects objects in 2D images, and estimates their poses and -sizes through a machine learning (ML) model, trained on a newly created 3D -dataset. +everyday objects. It detects objects in 2D images, and estimates their poses +through a machine learning (ML) model, trained on the [Objectron dataset](https://github.com/google-research-datasets/Objectron). -![objectron_shoe_android_gpu.gif](../images/mobile/objectron_shoe_android_gpu.gif) | ![objectron_chair_android_gpu.gif](../images/mobile/objectron_chair_android_gpu.gif) -:--------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------: -*Fig 1(a). Objectron for Shoes.* | *Fig 1(b). Objectron for Chairs.* +![objectron_shoe_android_gpu.gif](../images/mobile/objectron_shoe_android_gpu.gif) | ![objectron_chair_android_gpu.gif](../images/mobile/objectron_chair_android_gpu.gif) | ![objectron_camera_android_gpu.gif](../images/mobile/objectron_camera_android_gpu.gif) | ![objectron_cup_android_gpu.gif](../images/mobile/objectron_cup_android_gpu.gif) +:--------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------: +*Fig 1a. Shoe Objectron* | *Fig 1b. Chair Objectron* | *Fig 1c. Camera Objectron* | *Fig 1d. Cup Objectron* Object detection is an extensively studied computer vision problem, but most of the research has focused on @@ -85,15 +90,42 @@ able to increase the accuracy by about 10%. :-------------------------------------------------------------------------------------------: | *Fig 4. An example of AR synthetic data generation. The virtual white-brown cereal box is rendered into the real scene, next to the real blue book.* | -## ML Model for 3D Object Detection +## ML Pipelines for 3D Object Detection + +We built two ML pipelines to predict the 3D bounding box of an object from a +single RGB image: one is a two-stage pipeline and the other is a single-stage +pipeline. The two-stage pipeline is 3x faster than the single-stage pipeline +with similar or better accuracy. The single stage pipeline is good at detecting +multiple objects, whereas the two stage pipeline is good for a single dominant +object. + +### Two-stage Pipeline + +Our two-stage pipeline is illustrated by the diagram in Fig 5. The first stage +uses an object detector to find the 2D crop of the object. The second stage +takes the image crop and estimates the 3D bounding box. At the same time, it +also computes the 2D crop of the object for the next frame, such that the object +detector does not need to run every frame. + +![objectron_network_architecture.png](../images/objectron_2stage_network_architecture.png) | +:----------------------------------------------------------------------------------------: | +*Fig 5. Network architecture and post-processing for two-stage 3D object detection.* | + +We can use any 2D object detector for the first stage. In this solution, we use +[TensorFlow Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection) trained +with the [Open Images dataset](https://storage.googleapis.com/openimages/web/index.html). +The second stage 3D bounding box predictor we released runs 83FPS on Adreno 650 +mobile GPU. + +### Single-stage Pipeline ![objectron_network_architecture.png](../images/objectron_network_architecture.png) | :---------------------------------------------------------------------------------: | -*Fig 5. Network architecture and post-processing for 3D object detection.* | +*Fig 6. Network architecture and post-processing for single-stage 3D object detection.* | -We [built a single-stage model](https://arxiv.org/abs/2003.03522) to predict the -pose and physical size of an object from a single RGB image. The model backbone -has an encoder-decoder architecture, built upon +Our [single-stage pipeline](https://arxiv.org/abs/2003.03522) is illustrated by +the diagram in Fig 6, the model backbone has an encoder-decoder architecture, +built upon [MobileNetv2](https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html). We employ a multi-task learning approach, jointly predicting an object's shape with detection and regression. The shape task predicts the object's shape @@ -114,9 +146,9 @@ size of the object. The model is light enough to run real-time on mobile devices ![objectron_sample_network_results.png](../images/objectron_sample_network_results.png) | :-------------------------------------------------------------------------------------: | -*Fig 6. Sample results of our network — (Left) original 2D image with estimated bounding boxes, (Middle) object detection by Gaussian distribution, (Right) predicted segmentation mask.* | +*Fig 7. Sample results of our network — (Left) original 2D image with estimated bounding boxes, (Middle) object detection by Gaussian distribution, (Right) predicted segmentation mask.* | -## Detection and Tracking Pipeline +#### Detection and Tracking When the model is applied to every frame captured by the mobile device, it can suffer from jitter due to the ambiguity of the 3D bounding box estimated in each @@ -130,11 +162,11 @@ temporally consistent, reducing the jitter. The Objectron 3D object detection and tracking pipeline is implemented as a MediaPipe -[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt), +[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt), which internally uses a -[detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/subgraphs/objectron_detection_gpu.pbtxt) +[detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/objectron_detection_1stage_gpu.pbtxt) and a -[tracking subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/subgraphs/objectron_tracking_gpu.pbtxt). +[tracking subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/objectron_tracking_1stage_gpu.pbtxt). The detection subgraph performs ML inference only once every few frames to reduce computation load, and decodes the output tensor to a FrameAnnotation that contains nine keypoints: the 3D bounding box's center and its eight vertices. @@ -147,43 +179,357 @@ new detection becomes available from the detection subgraph, the tracking subgraph is also responsible for consolidation between the detection and tracking results, based on the area of overlap. +## Objectron Dataset + +We also released our [Objectron dataset](http://objectron.dev), with which we +trained our 3D object detection models. The technical details of the Objectron +dataset, including usage and tutorials, are available on +the [dataset website](https://github.com/google-research-datasets/Objectron/). + +## Solution APIs + +### Cross-platform Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect objects in the very first images, and upon successful +detection further localizes the 3D bounding box landmarks. In subsequent images, +once all [max_num_objects](#max_num_objects) objects are detected and the +corresponding 3D bounding box landmarks are localized, it simply tracks those +landmarks without invoking another detection until it loses track of any of the +objects. This reduces latency and is ideal for processing video frames. If set +to `true`, object detection runs every input image, ideal for processing a batch +of static, possibly unrelated, images. Default to `false`. + +#### max_num_objects + +Maximum number of objects to detect. Default to `5`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the object-detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +3D bounding box landmarks to be considered tracked successfully, or otherwise +object detection will be invoked automatically on the next input image. Setting +it to a higher value can increase robustness of the solution, at the expense of +a higher latency. Ignored if [static_image_mode](#static_image_mode) is `true`, +where object detection simply runs on every image. Default to `0.99`. + +#### model_name + +Name of the model to use for predicting 3D bounding box landmarks. Currently supports +`{'Shoe', 'Chair', 'Cup', 'Camera'}`. + +#### focal_length + +Camera focal length `(fx, fy)`, by default is defined in +[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in +[pixel space](#pixel-space), users should provide `image_size` = `(image_width, +image_height)` to enable conversions inside the API. For further details about +NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). + +#### principal_point + +Camera principal point `(px, py)`, by default is defined in +[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in +[pixel space](#pixel-space), users should provide `image_size` = `(image_width, +image_height)` to enable conversions inside the API. For further details about +NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). + +#### image_size + +(**Optional**) size `(image_width, image_height)` of the input image, **ONLY** +needed when use `focal_length` and `principal_point` in pixel space. + +### Output + + + +#### detected_objects + +A list of detected 3D bounding box. Each 3D bounding box consists of the +following: + +* `landmarks_2d` : 2D landmarks of the object's 3D bounding box. The landmark + coordinates are normalized to `[0.0, 1.0]` by the image width and height + respectively. + +* `landmarks_3d` : 3D landmarks of the object's 3D bounding box. The landmark + coordinates are represented in [camera coordinate](#camera-coordinate) + frame. + +* `rotation` : rotation matrix from object coordinate frame to camera + coordinate frame. + +* `translation` : translation vector from object coordinate frame to camera + coordinate frame. + +* `scale` : relative scale of the object along `x`, `y` and `z` directions. + +## Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [max_num_objects](#max_num_objects) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) +* [model_name](#model_name) +* [focal_length](#focal_length) +* [principal_point](#principal_point) +* [image_size](#image_size) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_objectron = mp.solutions.objectron + +# For static images: +with mp_objectron.Objectron(static_image_mode=True, + max_num_objects=5, + min_detection_confidence=0.5, + model_name='Shoe') as objectron: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB and process it with MediaPipe Objectron. + results = objectron.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Draw box landmarks. + if not results.detected_objects: + print(f'No box landmarks detected on {file}') + continue + print(f'Box landmarks of {file}:') + annotated_image = image.copy() + for detected_object in results.detected_objects: + mp_drawing.draw_landmarks( + annotated_image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) + mp_drawing.draw_axis(annotated_image, detected_object.rotation, + detected_object.translation) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_objectron.Objectron(static_image_mode=False, + max_num_objects=5, + min_detection_confidence=0.5, + min_tracking_confidence=0.99, + model_name='Shoe') as objectron: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Convert the BGR image to RGB. + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = objectron.process(image) + + # Draw the box landmarks on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.detected_objects: + for detected_object in results.detected_objects: + mp_drawing.draw_landmarks( + image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) + mp_drawing.draw_axis(image, detected_object.rotation, + detected_object.translation) + cv2.imshow('MediaPipe Objectron', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android) and -[iOS](../getting_started/building_examples.md#ios) on how to build MediaPipe examples. +[Android](../getting_started/android.md) and [iOS](../getting_started/ios.md) on +how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see [visualizer documentation](../tools/visualizer.md). -### Objectron for Shoes +### Two-stage Objectron * Graph: - [`mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt) -* Android target: - [(or download prebuilt ARM64 APK)](https://drive.google.com/open?id=1S0K4hbWt3o31FfQ4QU3Rz7IHrvOUMx1d) - [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD) -* iOS target: Not available + [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) -### Objectron for Chairs - -* Graph: - [`mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/chair_classic_occlusion_tracking.pbtxt) * Android target: - [(or download prebuilt ARM64 APK)](https://drive.google.com/open?id=1MM8K-13bXLCVS1EHQ-KgkVyEahEPrKej) - [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD) - and add `--define chair=true` to the build command, i.e., + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD). + + Build for **shoes** (default) with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1ANW9WDOCb8QO1r8gDC03A4UgrPkICdPP/view?usp=sharing) + + ```bash + bazel build -c opt --config android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d + ``` + + Build for **chairs** with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1lcUv1TBnv_SxnKSQwdOqbdLa9mkaTJHy/view?usp=sharing) ```bash bazel build -c opt --config android_arm64 --define chair=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d ``` + Build for **cups** with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1bf77KDkowwrduleiC9B1M1XnEhjnOQbX/view?usp=sharing) + + ```bash + bazel build -c opt --config android_arm64 --define cup=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d + ``` + + Build for **cameras** with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1GM7lPO-s5URVxIzQur1bLsionEJs3yIl/view?usp=sharing) + + ```bash + bazel build -c opt --config android_arm64 --define camera=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d + ``` + * iOS target: Not available +### Single-stage Objectron + +* Graph: + [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) + +* Android target: + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD). + + Build with **single-stage** model for **shoes** with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1MvaEg4dkvKN8jAU1Z2GtudyXi1rQHYsE/view?usp=sharing) + + ```bash + bazel build -c opt --config android_arm64 --define shoe_1stage=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d + ``` + + Build with **single-stage** model for **chairs** with: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1GJL4z3jr-wD1jMHGd4NBfOG-Yoq5t167/view?usp=sharing) + + ```bash + bazel build -c opt --config android_arm64 --define chair_1stage=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d + ``` + +* iOS target: Not available + +### Assets + +Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file + format into a custom .uuu format. This can be done for user assets as follows: +> First run +> +> ```shell +> ./mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh [INPUT_DIR] [INTERMEDIATE_OUTPUT_DIR] +> ``` +> and then run +> +> ```build +> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] +> ``` +> INPUT_DIR should be the folder with initial asset .obj files to be processed, +> and OUTPUT_DIR is the folder where the processed asset .uuu file will be placed. +> +> Note: ObjParser combines all .obj files found in the given directory into a +> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as +> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. + +### Coordinate Systems + +#### Object Coordinate + +Each object has its object coordinate frame. We use the below object coordinate +definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front, +origin is at the center of the 3D bounding box. + +![box_coordinate.svg](../images/box_coordinate.svg) + +#### Camera Coordinate + +A 3D object is parameterized by its `scale` and `rotation`, `translation` with +regard to the camera coordinate frame. In this API we use the below camera +coordinate definition, with `+x` pointing right, `+y` pointing up and `-z` +pointing to the scene. + +![camera_coordinate.svg](../images/camera_coordinate.svg) + +To work with box landmarks, one can first derive landmark coordinates in object +frame by scaling a origin centered unit box with `scale`, then transform to +camera frame by applying `rotation` and `translation`: + +``` +landmarks_3d = rotation * scale * unit_box + translation +``` + +#### NDC Space + +In this API we use +[NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html) +as an intermediate space when projecting points from 3D to 2D. In NDC space, +`x`, `y` are confined to `[-1, 1]`. + +![ndc_coordinate.svg](../images/ndc_coordinate.svg) + +By default the camera parameters `(fx, fy)` and `(px, py)` are defined in NDC +space. Given `(X, Y, Z)` of 3D points in camera coordinate, one can project 3D +points to NDC space as follows: + +``` +x_ndc = -fx * X / Z + px +y_ndc = -fy * Y / Z + py +z_ndc = 1 / Z +``` + +#### Pixel Space + +In this API we set upper-left coner of an image as the origin of pixel +coordinate. One can convert from NDC to pixel space as follows: + +``` +x_pixel = (1 + x_ndc) / 2.0 * image_width +y_pixel = (1 - y_ndc) / 2.0 * image_height +``` + +Alternatively one can directly project from camera coordinate to pixel +coordinate with camera parameters `(fx_pixel, fy_pixel)` and `(px_pixel, +py_pixel)` defined in pixel space as follows: + +``` +x_pixel = -fx_pixel * X / Z + px_pixel +y_pixel = fy_pixel * Y / Z + py_pixel +``` + +Conversion of camera parameters from pixel space to NDC space: + +``` +fx = fx_pixel * 2.0 / image_width +fy = fy_pixel * 2.0 / image_height +``` + +``` +px = -px_pixel * 2.0 / image_width + 1.0 +py = -py_pixel * 2.0 / image_height + 1.0 +``` + ## Resources +* Google AI Blog: + [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) * Google AI Blog: [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak @@ -191,5 +537,5 @@ to visualize its associated subgraphs, please see * Paper: [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)) -* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite) -* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite) +* [Models and model cards](./models.md#objectron) +* [Python Colab](https://mediapipe.page.link/objectron_py_colab) diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 6b3fa3868..064e2eb19 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -2,34 +2,43 @@ layout: default title: Pose parent: Solutions +has_children: true +has_toc: false nav_order: 5 --- -# MediaPipe BlazePose +# MediaPipe Pose {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- ## Overview Human pose estimation from video plays a critical role in various applications -such as quantifying physical exercises, sign language recognition, and full-body -gesture control. For example, it can form the basis for yoga, dance, and fitness -applications. It can also enable the overlay of digital content and information -on top of the physical world in augmented reality. +such as [quantifying physical exercises](./pose_classification.md), sign +language recognition, and full-body gesture control. For example, it can form +the basis for yoga, dance, and fitness applications. It can also enable the +overlay of digital content and information on top of the physical world in +augmented reality. -MediaPipe Pose is a ML solution for high-fidelity upper-body pose tracking, -inferring 25 2D upper-body landmarks from RGB video frames utilizing our +MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring +33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video +frames utilizing our [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) -research. Current state-of-the-art approaches rely primarily on powerful desktop +research that also powers the +[ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection). +Current state-of-the-art approaches rely primarily on powerful desktop environments for inference, whereas our method achieves real-time performance on most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in -[python](#python) and even on the [web](#web). A variant of MediaPipe Pose that -performs full-body pose tracking on mobile phones will be included in an -upcoming release of -[ML Kit](https://developers.google.com/ml-kit/early-access/pose-detection). +[python](#python-solution-api) and even on the [web](#javascript-solution-api). ![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) | :--------------------------------------------------------------------------------------------: | @@ -40,23 +49,24 @@ upcoming release of The solution utilizes a two-step detector-tracker ML pipeline, proven to be effective in our [MediaPipe Hands](./hands.md) and [MediaPipe Face Mesh](./face_mesh.md) solutions. Using a detector, the pipeline -first locates the pose region-of-interest (ROI) within the frame. The tracker -subsequently predicts the pose landmarks within the ROI using the ROI-cropped -frame as input. Note that for video use cases the detector is invoked only as -needed, i.e., for the very first frame and when the tracker could no longer -identify body pose presence in the previous frame. For other frames the pipeline -simply derives the ROI from the previous frame’s pose landmarks. +first locates the person/pose region-of-interest (ROI) within the frame. The +tracker subsequently predicts the pose landmarks within the ROI using the +ROI-cropped frame as input. Note that for video use cases the detector is +invoked only as needed, i.e., for the very first frame and when the tracker +could no longer identify body pose presence in the previous frame. For other +frames the pipeline simply derives the ROI from the previous frame’s pose +landmarks. The pipeline is implemented as a MediaPipe -[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt) +[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt) that uses a -[pose landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_gpu.pbtxt) +[pose landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt) from the [pose landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark) and renders using a dedicated -[upper-body pose renderer subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/subgraphs/upper_body_pose_renderer_gpu.pbtxt). +[pose renderer subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_gpu.pbtxt). The -[pose landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body_gpu.pbtxt) +[pose landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt) internally uses a [pose detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt) from the @@ -69,7 +79,7 @@ to visualize its associated subgraphs, please see ## Models -### Pose Detection Model (BlazePose Detector) +### Person/pose Detection Model (BlazePose Detector) The detector is inspired by our own lightweight [BlazeFace](https://arxiv.org/abs/1907.05047) model, used in @@ -85,28 +95,240 @@ hip midpoints. :----------------------------------------------------------------------------------------------------: | *Fig 2. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* | -### Pose Landmark Model (BlazePose Tracker) +### Pose Landmark Model (BlazePose GHUM 3D) -The landmark model currently included in MediaPipe Pose predicts the location of -25 upper-body landmarks (see figure below), with three degrees of freedom each -(x, y location and visibility), plus two virtual alignment keypoints. It shares -the same architecture as the full-body version that predicts 33 landmarks, -described in more detail in the -[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) -and in this [paper](https://arxiv.org/abs/2006.10204). +The landmark model in MediaPipe Pose comes in two versions: a full-body model +that predicts the location of 33 pose landmarks (see figure below), and an +upper-body version that only predicts the first 25. The latter may be more +accurate than the former in scenarios where the lower-body parts are mostly out +of view. -![pose_tracking_upper_body_landmarks.png](../images/mobile/pose_tracking_upper_body_landmarks.png) | -:------------------------------------------------------------------------------------------------: | -*Fig 3. 25 upper-body pose landmarks.* | +Please find more detail in the +[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), +this [paper](https://arxiv.org/abs/2006.10204) and +[the model card](./models.md#pose), and the attributes in each landmark +[below](#pose_landmarks). + +![pose_tracking_full_body_landmarks.png](../images/mobile/pose_tracking_full_body_landmarks.png) | +:----------------------------------------------------------------------------------------------: | +*Fig 3. 33 pose landmarks.* | + +## Solution APIs + +### Cross-platform Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect the most prominent person in the very first images, and upon +a successful detection further localizes the pose landmarks. In subsequent +images, it then simply tracks those landmarks without invoking another detection +until it loses track, on reducing computation and latency. If set to `true`, +person detection runs every input image, ideal for processing a batch of static, +possibly unrelated, images. Default to `false`. + +#### upper_body_only + +If set to `true`, the solution outputs only the 25 upper-body pose landmarks. +Otherwise, it outputs the full set of 33 pose landmarks. Note that +upper-body-only prediction may be more accurate for use cases where the +lower-body parts are mostly out of view. Default to `false`. + +#### smooth_landmarks + +If set to `true`, the solution filters pose landmarks across different input +images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) +is also set to `true`. Default to `true`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +pose landmarks to be considered tracked successfully, or otherwise person +detection will be invoked automatically on the next input image. Setting it to a +higher value can increase robustness of the solution, at the expense of a higher +latency. Ignored if [static_image_mode](#static_image_mode) is `true`, where +person detection simply runs on every image. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### pose_landmarks + +A list of pose landmarks. Each lanmark consists of the following: + +* `x` and `y`: Landmark coordinates normalized to `[0.0, 1.0]` by the image + width and height respectively. +* `z`: Represents the landmark depth with the depth at the midpoint of hips + being the origin, and the smaller the value the closer the landmark is to + the camera. The magnitude of `z` uses roughly the same scale as `x`. + + Note: `z` is predicted only in full-body mode, and should be discarded when + [upper_body_only](#upper_body_only) is `true`. + +* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the + landmark being visible (present and not occluded) in the image. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [upper_body_only](#upper_body_only) +* [smooth_landmarks](#smooth_landmarks) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_pose = mp.solutions.pose + +# For static images: +with mp_pose.Pose( + static_image_mode=True, min_detection_confidence=0.5) as pose: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + image_height, image_width, _ = image.shape + # Convert the BGR image to RGB before processing. + results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + if not results.pose_landmarks: + continue + print( + f'Nose coordinates: (' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + ) + # Draw pose landmarks on the image. + annotated_image = image.copy() + # Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when + # upper_body_only is set to True. + mp_drawing.draw_landmarks( + annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_pose.Pose( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as pose: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = pose.process(image) + + # Draw the pose annotation on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + mp_drawing.draw_landmarks( + image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + cv2.imshow('MediaPipe Pose', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [upperBodyOnly](#upper_body_only) +* [smoothLandmarks](#smooth_landmarks) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` ## Example Apps Please first see general instructions for -[Android](../getting_started/building_examples.md#android), -[iOS](../getting_started/building_examples.md#ios), -[desktop](../getting_started/building_examples.md#desktop) and -[Python](../getting_started/building_examples.md#python) on how to build -MediaPipe examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how @@ -115,6 +337,18 @@ to visualize its associated subgraphs, please see ### Mobile +#### Main Example + +* Graph: + [`mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt) +* Android target: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/17GFIrqEJS6W8UHKXlYevTtSCLxN9pWlY/view?usp=sharing) + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu:posetrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD) +* iOS target: + [`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD) + +#### Upper-body Only + * Graph: [`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt) * Android target: @@ -125,9 +359,23 @@ to visualize its associated subgraphs, please see ### Desktop -Please first see general instructions for -[desktop](../getting_started/building_examples.md#desktop) on how to build -MediaPipe examples. +Please first see general instructions for [desktop](../getting_started/cpp.md) +on how to build MediaPipe examples. + +#### Main Example + +* Running on CPU + * Graph: + [`mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/pose_tracking_cpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/pose_tracking:pose_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD) +* Running on GPU + * Graph: + [`mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/pose_tracking_gpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD) + +#### Upper-body Only * Running on CPU * Graph: @@ -140,48 +388,6 @@ MediaPipe examples. * Target: [`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD) -### Python - -MediaPipe Python package is available on -[PyPI](https://pypi.org/project/mediapipe/), and can be installed simply by `pip -install mediapipe` on Linux and macOS, as described below and in this -[colab](https://mediapipe.page.link/mp-py-colab). If you do need to build the -Python package from source, see -[additional instructions](../getting_started/building_examples.md#python). - -```bash -# Activate a Python virtual environment. -$ python3 -m venv mp_env && source mp_env/bin/activate - -# Install MediaPipe Python package -(mp_env)$ pip install mediapipe - -# Run in Python interpreter -(mp_env)$ python3 ->>> import mediapipe as mp ->>> pose_tracker = mp.examples.UpperBodyPoseTracker() - -# For image input ->>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file') ->>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file') -# To print out the pose landmarks, you can simply do "print(pose_landmarks)". -# However, the data points can be more accessible with the following approach. ->>> [print('x is', data_point.x, 'y is', data_point.y, 'z is', data_point.z, 'visibility is', data_point.visibility) for data_point in pose_landmarks.landmark] - -# For live camera input -# (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.) ->>> pose_tracker.run_live() - -# Close the tracker. ->>> pose_tracker.close() -``` - -Tip: Use command `deactivate` to exit the Python virtual environment. - -### Web - -Please refer to [these instructions](../index.md#mediapipe-on-the-web). - ## Resources * Google AI Blog: @@ -189,8 +395,6 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web). * Paper: [BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204) ([presentation](https://youtu.be/YPpUOTRn5tA)) -* Pose detection model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite) -* Upper-body pose landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite) -* [Model card](https://mediapipe.page.link/blazepose-mc) +* [Models and model cards](./models.md#pose) +* [Web demo](https://code.mediapipe.dev/codepen/pose) +* [Python Colab](https://mediapipe.page.link/pose_py_colab) diff --git a/docs/solutions/pose_classification.md b/docs/solutions/pose_classification.md new file mode 100644 index 000000000..9595dc7d1 --- /dev/null +++ b/docs/solutions/pose_classification.md @@ -0,0 +1,142 @@ +--- +layout: default +title: Pose Classification +parent: Pose +grand_parent: Solutions +nav_order: 1 +--- + +# Pose Classification +{: .no_toc } + +
+ + Table of contents + + {: .text-delta } +1. TOC +{:toc} +
+--- + +## Overview + +One of the applications +[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) +can enable is fitness. More specifically - pose classification and repetition +counting. In this section we'll provide basic guidance on building a custom pose +classifier with the help of [Colabs](#colabs) and wrap it in a simple +[fitness app](https://mediapipe.page.link/mlkit-pose-classification-demo-app) +powered by [ML Kit](https://developers.google.com/ml-kit). Push-ups and squats +are used for demonstration purposes as the most common exercises. + +![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) | +:--------------------------------------------------------------------------------------------------------: | +*Fig 1. Pose classification and repetition counting with MediaPipe Pose.* | + +We picked the +[k-nearest neighbors algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) +(k-NN) as the classifier. It's simple and easy to start with. The algorithm +determines the object's class based on the closest samples in the training set. + +**To build it, one needs to:** + +1. Collect image samples of the target exercises and run pose prediction on + them, +2. Convert obtained pose landmarks to a representation suitable for the k-NN + classifier and form a training set using these [Colabs](#colabs), +3. Perform the classification itself followed by repetition counting (e.g., in + the + [ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app)). + +## Training Set + +To build a good classifier appropriate samples should be collected for the +training set: about a few hundred samples for each terminal state of each +exercise (e.g., "up" and "down" positions for push-ups). It's important that +collected samples cover different camera angles, environment conditions, body +shapes, and exercise variations. + +![pose_classification_pushups_un_and_down_samples.jpg](../images/mobile/pose_classification_pushups_un_and_down_samples.jpg) | +:--------------------------------------------------------------------------------------------------------------------------: | +*Fig 2. Two terminal states of push-ups.* | + +To transform samples into a k-NN classifier training set, both +[`Pose Classification Colab (Basic)`] and +[`Pose Classification Colab (Extended)`] could be used. They use the +[Python Solution API](./pose.md#python-solution-api) to run the BlazePose models +on given images and dump predicted pose landmarks to a CSV file. Additionally, +the [`Pose Classification Colab (Extended)`] provides useful tools to find +outliers (e.g., wrongly predicted poses) and underrepresented classes (e.g., not +covering all camera angles) by classifying each sample against the entire +training set. After that, you'll be able to test the classifier on an arbitrary +video right in the Colab. + +## Classification + +Code of the classifier is available both in the +[`Pose Classification Colab (Extended)`] and in the +[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app). +Please refer to them for details of the approach described below. + +The k-NN algorithm used for pose classification requires a feature vector +representation of each sample and a metric to compute the distance between two +such vectors to find the nearest pose samples to a target one. + +To convert pose landmarks to a feature vector, we use pairwise distances between +predefined lists of pose joints, such as distances between wrist and shoulder, +ankle and hip, and two wrists. Since the algorithm relies on distances, all +poses are normalized to have the same torso size and vertical torso orientation +before the conversion. + +![pose_classification_pairwise_distances.png](../images/mobile/pose_classification_pairwise_distances.png) | +:--------------------------------------------------------------------------------------------------------: | +*Fig 3. Main pairwise distances used for the pose feature vector.* | + +To get a better classification result, k-NN search is invoked twice with +different distance metrics: + +* First, to filter out samples that are almost the same as the target one but + have only a few different values in the feature vector (which means + differently bent joints and thus other pose class), minimum per-coordinate + distance is used as distance metric, +* Then average per-coordinate distance is used to find the nearest pose + cluster among those from the first search. + +Finally, we apply +[exponential moving average](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) +(EMA) smoothing to level any noise from pose prediction or classification. To do +that, we search not only for the nearest pose cluster, but we calculate a +probability for each of them and use it for smoothing over time. + +## Repetition Counting + +To count the repetitions, the algorithm monitors the probability of a target +pose class. Let's take push-ups with its "up" and "down" terminal states: + +* When the probability of the "down" pose class passes a certain threshold for + the first time, the algorithm marks that the "down" pose class is entered. +* Once the probability drops below the threshold, the algorithm marks that the + "down" pose class has been exited and increases the counter. + +To avoid cases when the probability fluctuates around the threshold (e.g., when +the user pauses between "up" and "down" states) causing phantom counts, the +threshold used to detect when the state is exited is actually slightly lower +than the one used to detect when the state is entered. It creates an interval +where the pose class and the counter can't be changed. + +## Future Work + +We are actively working on improving BlazePose GHUM 3D's Z prediction. It will +allow us to use joint angles in the feature vectors, which are more natural and +easier to configure (although distances can still be useful to detect touches +between body parts) and to perform rotation normalization of poses and reduce +the number of camera angles required for accurate k-NN classification. + +## Colabs + +* [`Pose Classification Colab (Basic)`] +* [`Pose Classification Colab (Extended)`] + +[`Pose Classification Colab (Basic)`]: https://mediapipe.page.link/pose_classification_basic +[`Pose Classification Colab (Extended)`]: https://mediapipe.page.link/pose_classification_extended diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index 6a852b751..c78dffea0 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -16,19 +16,24 @@ has_toc: false -[]() | Android | iOS | Desktop | Python | Web | Coral -:---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ -[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | -[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ | -[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | -[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | -[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ -[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | -[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | -[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | -[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | -[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | -[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | +[]() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) +:---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ +[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | +[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | +[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ +[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | +[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | +[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | +[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | +[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | + +See also +[MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) +for ML models released in MediaPipe. diff --git a/docs/solutions/youtube_8m.md b/docs/solutions/youtube_8m.md index f6d05bbca..abef6f1b6 100644 --- a/docs/solutions/youtube_8m.md +++ b/docs/solutions/youtube_8m.md @@ -2,14 +2,20 @@ layout: default title: YouTube-8M Feature Extraction and Model Inference parent: Solutions -nav_order: 14 +nav_order: 15 --- # YouTube-8M Feature Extraction and Model Inference {: .no_toc } +
+ + Table of contents + + {: .text-delta } 1. TOC {:toc} +
--- MediaPipe is a useful and general framework for media processing that can assist diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index 055993349..ed58eb61b 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -26,9 +26,10 @@ To enable tracing and profiling of a mediapipe graph: 1. The profiling library must be linked to the framework. 2. Tracing and profiling must be enabled in the graph configuration. -The profiling library is linked to the framework by default. If needed, -the profiling library can be omitted from the framework using the bazel -command line option: `--define MEDIAPIPE_PROFILING=0`. +The profiling library is linked to the framework by default for Desktop. +If needed, it can be omitted from the framework using the bazel command line +option: `--define MEDIAPIPE_PROFILING=0`. For other platforms, you can use the +bazel command line option `--define MEDIAPIPE_PROFILING=1` to link it. To enable tracing and profiling, the `CalculatorGraphConfig` (in [calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto)) @@ -38,6 +39,7 @@ is a simple setup that turns on tracing and keeps 100 seconds of timing events: ``` profiler_config { trace_enabled: true + enable_profiler: true trace_log_interval_count: 200 } ``` @@ -71,6 +73,9 @@ MediaPipe will emit data into a pre-specified directory: You can open the Download Container. Logs will be located in `application container/.xcappdata/AppData/Documents/` + If XCode shows empty content for the downloaded container file, you can + right click and select 'Show Package Contents' in Finder. Logs + will be located in 'AppData/Documents/' ![Windows Download Container](../images/visualizer/ios_download_container.png) @@ -144,6 +149,7 @@ we record ten intervals of half a second each. This can be overridden by adding ```bash profiler_config { trace_enabled: true + enable_profiler: true trace_log_path: "/sdcard/profiles/" } ``` diff --git a/docs/tools/visualizer.md b/docs/tools/visualizer.md index ecd4487a8..9324576a2 100644 --- a/docs/tools/visualizer.md +++ b/docs/tools/visualizer.md @@ -37,7 +37,7 @@ The graph can be modified by adding and editing code in the Editor view. ![New Button](../images/upload_button.png) * Pressing the "Upload" button will prompt the user to select a local PBTXT - file, which will everwrite the current code within the editor. + file, which will overwrite the current code within the editor. * Alternatively, code can be pasted directly into the editor window. diff --git a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen index 7d501c803..d3cd4971a 100644 --- a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen +++ b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen @@ -2,34 +2,40 @@ "additionalFilePaths" : [ "/BUILD", "mediapipe/BUILD", - "mediapipe/objc/BUILD", - "mediapipe/framework/BUILD", - "mediapipe/gpu/BUILD", - "mediapipe/objc/testing/app/BUILD", "mediapipe/examples/ios/common/BUILD", - "mediapipe/examples/ios/helloworld/BUILD", "mediapipe/examples/ios/facedetectioncpu/BUILD", "mediapipe/examples/ios/facedetectiongpu/BUILD", + "mediapipe/examples/ios/faceeffect/BUILD", "mediapipe/examples/ios/facemeshgpu/BUILD", "mediapipe/examples/ios/handdetectiongpu/BUILD", "mediapipe/examples/ios/handtrackinggpu/BUILD", + "mediapipe/examples/ios/helloworld/BUILD", + "mediapipe/examples/ios/holistictrackinggpu/BUILD", "mediapipe/examples/ios/iristrackinggpu/BUILD", - "mediapipe/examples/ios/multihandtrackinggpu/BUILD", "mediapipe/examples/ios/objectdetectioncpu/BUILD", "mediapipe/examples/ios/objectdetectiongpu/BUILD", - "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD" + "mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD", + "mediapipe/examples/ios/posetrackinggpu/BUILD", + "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD", + "mediapipe/framework/BUILD", + "mediapipe/gpu/BUILD", + "mediapipe/objc/BUILD", + "mediapipe/objc/testing/app/BUILD" ], "buildTargets" : [ - "//mediapipe/examples/ios/helloworld:HelloWorldApp", "//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp", "//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp", + "//mediapipe/examples/ios/faceeffect:FaceEffectApp", "//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp", "//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp", "//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp", + "//mediapipe/examples/ios/helloworld:HelloWorldApp", + "//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp", "//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp", - "//mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp", "//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp", "//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp", + "//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp", + "//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp", "//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp", "//mediapipe/objc:mediapipe_framework_ios" ], @@ -87,15 +93,18 @@ "mediapipe/examples/ios", "mediapipe/examples/ios/common", "mediapipe/examples/ios/common/Base.lproj", - "mediapipe/examples/ios/helloworld", "mediapipe/examples/ios/facedetectioncpu", "mediapipe/examples/ios/facedetectiongpu", + "mediapipe/examples/ios/faceeffect", + "mediapipe/examples/ios/faceeffect/Base.lproj", "mediapipe/examples/ios/handdetectiongpu", "mediapipe/examples/ios/handtrackinggpu", + "mediapipe/examples/ios/helloworld", + "mediapipe/examples/ios/holistictrackinggpu", "mediapipe/examples/ios/iristrackinggpu", - "mediapipe/examples/ios/multihandtrackinggpu", "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", + "mediapipe/examples/ios/posetrackinggpu", "mediapipe/examples/ios/upperbodyposetrackinggpu", "mediapipe/framework", "mediapipe/framework/deps", @@ -110,6 +119,7 @@ "mediapipe/graphs", "mediapipe/graphs/edge_detection", "mediapipe/graphs/face_detection", + "mediapipe/graphs/face_geometry", "mediapipe/graphs/hand_tracking", "mediapipe/graphs/object_detection", "mediapipe/graphs/pose_tracking", diff --git a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf index 432316521..7303828ad 100644 --- a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf +++ b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf @@ -9,18 +9,21 @@ "packages" : [ "", "mediapipe", - "mediapipe/objc", "mediapipe/examples/ios", "mediapipe/examples/ios/facedetectioncpu", "mediapipe/examples/ios/facedetectiongpu", + "mediapipe/examples/ios/faceeffect", "mediapipe/examples/ios/facemeshgpu", "mediapipe/examples/ios/handdetectiongpu", "mediapipe/examples/ios/handtrackinggpu", + "mediapipe/examples/ios/holistictrackinggpu", "mediapipe/examples/ios/iristrackinggpu", - "mediapipe/examples/ios/multihandtrackinggpu", "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", - "mediapipe/examples/ios/upperbodyposetrackinggpu" + "mediapipe/examples/ios/objectdetectiontrackinggpu", + "mediapipe/examples/ios/posetrackinggpu", + "mediapipe/examples/ios/upperbodyposetrackinggpu", + "mediapipe/objc" ], "projectName" : "Mediapipe", "workspaceRoot" : "../.." diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index b32529b79..9667e11d5 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2019, 2021 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -167,7 +167,7 @@ cc_library( "//mediapipe/util:time_series_util", "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp:resampler", - "@com_google_audio_tools//audio/dsp:resampler_rational_factor", + "@com_google_audio_tools//audio/dsp:resampler_q", "@eigen_archive//:eigen", ], alwayslink = 1, @@ -242,6 +242,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], diff --git a/mediapipe/calculators/audio/audio_decoder_calculator.cc b/mediapipe/calculators/audio/audio_decoder_calculator.cc index b80b64bae..49c201b37 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator.cc @@ -48,18 +48,17 @@ namespace mediapipe { // TODO: support decoding multiple streams. class AudioDecoderCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: std::unique_ptr decoder_; }; -::mediapipe::Status AudioDecoderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); if (cc->InputSidePackets().HasTag("OPTIONS")) { cc->InputSidePackets().Tag("OPTIONS").Set(); @@ -68,10 +67,10 @@ class AudioDecoderCalculator : public CalculatorBase { if (cc->Outputs().HasTag("AUDIO_HEADER")) { cc->Outputs().Tag("AUDIO_HEADER").SetNone(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { const std::string& input_file_path = cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); const auto& decoder_options = @@ -88,10 +87,10 @@ class AudioDecoderCalculator : public CalculatorBase { cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release())); } cc->Outputs().Tag("AUDIO_HEADER").Close(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { Packet data; int options_index = -1; auto status = decoder_->GetData(&options_index, &data); @@ -101,7 +100,7 @@ class AudioDecoderCalculator : public CalculatorBase { return status; } -::mediapipe::Status AudioDecoderCalculator::Close(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Close(CalculatorContext* cc) { return decoder_->Close(); } diff --git a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc index be0cd1836..33ab9e04f 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.cc b/mediapipe/calculators/audio/basic_time_series_calculators.cc index 4d966f47f..f7b24f6f6 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.cc +++ b/mediapipe/calculators/audio/basic_time_series_calculators.cc @@ -38,7 +38,7 @@ static bool SafeMultiply(int x, int y, int* result) { } } // namespace -::mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract( +absl::Status BasicTimeSeriesCalculatorBase::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. @@ -46,10 +46,10 @@ static bool SafeMultiply(int x, int y, int* result) { cc->Outputs().Index(0).Set( // Output stream with TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { +absl::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { TimeSeriesHeader input_header; MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), &input_header)); @@ -57,11 +57,13 @@ static bool SafeMultiply(int x, int y, int* result) { auto output_header = new TimeSeriesHeader(input_header); MP_RETURN_IF_ERROR(MutateHeader(output_header)); cc->Outputs().Index(0).SetHeader(Adopt(output_header)); - return ::mediapipe::OkStatus(); + + cc->SetOffset(0); + + return absl::OkStatus(); } -::mediapipe::Status BasicTimeSeriesCalculatorBase::Process( - CalculatorContext* cc) { +absl::Status BasicTimeSeriesCalculatorBase::Process(CalculatorContext* cc) { const Matrix& input = cc->Inputs().Index(0).Get(); MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader( input, cc->Inputs().Index(0).Header().Get())); @@ -71,12 +73,12 @@ static bool SafeMultiply(int x, int y, int* result) { *output, cc->Outputs().Index(0).Header().Get())); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader( +absl::Status BasicTimeSeriesCalculatorBase::MutateHeader( TimeSeriesHeader* output_header) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Calculator to sum an input time series across channels. This is @@ -86,9 +88,9 @@ static bool SafeMultiply(int x, int y, int* result) { class SumTimeSeriesAcrossChannelsCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -104,9 +106,9 @@ REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator); class AverageTimeSeriesAcrossChannelsCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -122,7 +124,7 @@ REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator); // Options proto: None. class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { if (output_header->num_channels() != 1) { return tool::StatusInvalid( absl::StrCat("Expected single-channel input, got ", @@ -131,7 +133,7 @@ class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase { output_header->set_num_channels(output_header->num_samples()); output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -160,7 +162,7 @@ REGISTER_CALCULATOR(ReverseChannelOrderCalculator); // Options proto: None. class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { const int num_input_channels = output_header->num_channels(); const int num_input_samples = output_header->num_samples(); RET_CHECK(num_input_channels >= 0) @@ -174,7 +176,7 @@ class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase { output_header->set_num_channels(output_num_channels); output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -253,10 +255,10 @@ REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator); // Options proto: None. class MeanCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -272,10 +274,10 @@ REGISTER_CALCULATOR(MeanCalculator); // Options proto: None. class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -293,9 +295,9 @@ REGISTER_CALCULATOR(StandardDeviationCalculator); // Options proto: None. class CovarianceCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(output_header->num_channels()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -313,9 +315,9 @@ REGISTER_CALCULATOR(CovarianceCalculator); // Options proto: None. class L2NormCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -385,12 +387,12 @@ REGISTER_CALCULATOR(ElementwiseSquareCalculator); // Options proto: None. class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase { protected: - ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { const int num_input_samples = output_header->num_samples(); RET_CHECK(num_input_samples >= 0) << "FirstHalfSlicerCalculator: num_input_samples < 0"; output_header->set_num_samples(num_input_samples / 2); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.h b/mediapipe/calculators/audio/basic_time_series_calculators.h index 3727d66b0..ef31f3448 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.h +++ b/mediapipe/calculators/audio/basic_time_series_calculators.h @@ -28,16 +28,16 @@ namespace mediapipe { class BasicTimeSeriesCalculatorBase : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; protected: // Open() calls this method to mutate the output stream header. The input // to this function will contain a copy of the input stream header, so // subclasses that do not need to mutate the header do not need to override // it. - virtual ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header); + virtual absl::Status MutateHeader(TimeSeriesHeader* output_header); // Process() calls this method on each packet to compute the output matrix. virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0; diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators.cc b/mediapipe/calculators/audio/mfcc_mel_calculators.cc index 93c44b1fb..a63b9d6ea 100644 --- a/mediapipe/calculators/audio/mfcc_mel_calculators.cc +++ b/mediapipe/calculators/audio/mfcc_mel_calculators.cc @@ -66,7 +66,7 @@ std::string PortableDebugString(const TimeSeriesHeader& header) { // rows corresponding to the new feature space). class FramewiseTransformCalculatorBase : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Sequence of Matrices, each column describing a particular time frame, // each row a feature dimension, with TimeSeriesHeader. @@ -75,11 +75,11 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { // Sequence of Matrices, each column describing a particular time frame, // each row a feature dimension, with TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; int num_output_channels(void) { return num_output_channels_; } @@ -90,8 +90,8 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { private: // Takes header and options, and sets up state including calling // set_num_output_channels() on the base object. - virtual ::mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) = 0; + virtual absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) = 0; // Takes a vector corresponding to an input frame, and // perform the specific transformation to produce an output frame. @@ -102,23 +102,23 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { int num_output_channels_; }; -::mediapipe::Status FramewiseTransformCalculatorBase::Open( - CalculatorContext* cc) { +absl::Status FramewiseTransformCalculatorBase::Open(CalculatorContext* cc) { TimeSeriesHeader input_header; MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), &input_header)); - ::mediapipe::Status status = ConfigureTransform(input_header, cc); + absl::Status status = ConfigureTransform(input_header, cc); auto output_header = new TimeSeriesHeader(input_header); output_header->set_num_channels(num_output_channels_); cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + cc->SetOffset(0); + return status; } -::mediapipe::Status FramewiseTransformCalculatorBase::Process( - CalculatorContext* cc) { +absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) { const Matrix& input = cc->Inputs().Index(0).Get(); const int num_frames = input.cols(); std::unique_ptr output(new Matrix(num_output_channels_, num_frames)); @@ -145,7 +145,7 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Calculator wrapper around the dsp/mfcc/mfcc.cc routine. @@ -170,13 +170,13 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { // } class MfccCalculator : public FramewiseTransformCalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return FramewiseTransformCalculatorBase::GetContract(cc); } private: - ::mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) override { + absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) override { MfccCalculatorOptions mfcc_options = cc->Options(); mfcc_.reset(new audio_dsp::Mfcc()); int input_length = header.num_channels(); @@ -194,7 +194,7 @@ class MfccCalculator : public FramewiseTransformCalculatorBase { // audio_dsp::MelFilterBank needs to know this to // correctly interpret the spectrogram bins. if (!header.has_audio_sample_rate()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", PortableDebugString(header))); } @@ -203,10 +203,10 @@ class MfccCalculator : public FramewiseTransformCalculatorBase { mfcc_->Initialize(input_length, header.audio_sample_rate()); if (initialized) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { - return ::mediapipe::Status(mediapipe::StatusCode::kInternal, - "Mfcc::Initialize returned uninitialized"); + return absl::Status(absl::StatusCode::kInternal, + "Mfcc::Initialize returned uninitialized"); } } @@ -228,13 +228,13 @@ REGISTER_CALCULATOR(MfccCalculator); // if you ask for too many channels. class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return FramewiseTransformCalculatorBase::GetContract(cc); } private: - ::mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) override { + absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) override { MelSpectrumCalculatorOptions mel_spectrum_options = cc->Options(); mel_filterbank_.reset(new audio_dsp::MelFilterbank()); @@ -245,7 +245,7 @@ class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { // audio_dsp::MelFilterBank needs to know this to // correctly interpret the spectrogram bins. if (!header.has_audio_sample_rate()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", PortableDebugString(header))); } @@ -255,10 +255,10 @@ class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { mel_spectrum_options.max_frequency_hertz()); if (initialized) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { - return ::mediapipe::Status(mediapipe::StatusCode::kInternal, - "mfcc::Initialize returned uninitialized"); + return absl::Status(absl::StatusCode::kInternal, + "mfcc::Initialize returned uninitialized"); } } diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc index b2ceacf00..e7e312db9 100644 --- a/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc +++ b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc @@ -84,7 +84,7 @@ class FramewiseTransformCalculatorTest num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets); } - ::mediapipe::Status Run() { return this->RunGraph(); } + absl::Status Run() { return this->RunGraph(); } void CheckResults(int expected_num_channels) { const auto& output_header = diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc index 3ed67bd88..1a4210c30 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,22 +16,18 @@ #include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" -#include "audio/dsp/resampler_rational_factor.h" +#include "audio/dsp/resampler_q.h" -using audio_dsp::DefaultResamplingKernel; -using audio_dsp::RationalFactorResampler; using audio_dsp::Resampler; namespace mediapipe { -::mediapipe::Status RationalFactorResampleCalculator::Process( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Process(CalculatorContext* cc) { return ProcessInternal(cc->Inputs().Index(0).Get(), false, cc); } -::mediapipe::Status RationalFactorResampleCalculator::Close( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Close(CalculatorContext* cc) { if (initial_timestamp_ == Timestamp::Unstarted()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix empty_input_frame(num_channels_, 0); return ProcessInternal(empty_input_frame, true, cc); @@ -40,11 +36,8 @@ namespace mediapipe { namespace { void CopyChannelToVector(const Matrix& matrix, int channel, std::vector* vec) { - vec->clear(); - vec->reserve(matrix.cols()); - for (int sample = 0; sample < matrix.cols(); ++sample) { - vec->push_back(matrix(channel, sample)); - } + vec->resize(matrix.cols()); + Eigen::Map(vec->data(), vec->size()) = matrix.row(channel); } void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, @@ -53,17 +46,14 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, matrix->resize(matrix->rows(), vec.size()); } else { CHECK_EQ(vec.size(), matrix->cols()); - CHECK_LT(channel, matrix->rows()); - } - for (int sample = 0; sample < matrix->cols(); ++sample) { - (*matrix)(channel, sample) = vec[sample]; } + CHECK_LT(channel, matrix->rows()); + matrix->row(channel) = + Eigen::Map(vec.data(), vec.size()); } - } // namespace -::mediapipe::Status RationalFactorResampleCalculator::Open( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) { RationalFactorResampleCalculatorOptions resample_options = cc->Options(); @@ -88,7 +78,7 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, resample_options); if (!r) { LOG(ERROR) << "Failed to initialize resampler."; - return ::mediapipe::UnknownError("Failed to initialize resampler."); + return absl::UnknownError("Failed to initialize resampler."); } } } @@ -106,10 +96,10 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, initial_timestamp_ = Timestamp::Unstarted(); check_inconsistent_timestamps_ = resample_options.check_inconsistent_timestamps(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RationalFactorResampleCalculator::ProcessInternal( +absl::Status RationalFactorResampleCalculator::ProcessInternal( const Matrix& input_frame, bool should_flush, CalculatorContext* cc) { if (initial_timestamp_ == Timestamp::Unstarted()) { initial_timestamp_ = cc->InputTimestamp(); @@ -131,7 +121,7 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, *output_frame = input_frame; } else { if (!Resample(input_frame, output_frame.get(), should_flush)) { - return ::mediapipe::UnknownError("Resample() failed."); + return absl::UnknownError("Resample() failed."); } } cumulative_output_samples_ += output_frame->cols(); @@ -139,7 +129,7 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, if (output_frame->cols() > 0) { cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame, @@ -167,25 +157,28 @@ RationalFactorResampleCalculator::ResamplerFromOptions( std::unique_ptr> resampler; const auto& rational_factor_options = options.resampler_rational_factor_options(); - std::unique_ptr kernel; + audio_dsp::QResamplerParams params; if (rational_factor_options.has_radius() && rational_factor_options.has_cutoff() && rational_factor_options.has_kaiser_beta()) { - kernel = absl::make_unique( - source_sample_rate, target_sample_rate, - rational_factor_options.radius(), rational_factor_options.cutoff(), - rational_factor_options.kaiser_beta()); - } else { - kernel = absl::make_unique(source_sample_rate, - target_sample_rate); + // Convert RationalFactorResampler kernel parameters to QResampler + // settings. + params.filter_radius_factor = + rational_factor_options.radius() * + std::min(1.0, target_sample_rate / source_sample_rate); + params.cutoff_proportion = 2 * rational_factor_options.cutoff() / + std::min(source_sample_rate, target_sample_rate); + params.kaiser_beta = rational_factor_options.kaiser_beta(); } - // Set large enough so that the resampling factor between common sample // rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and // that any factor is represented with error less than 0.025%. - const int kMaxDenominator = 2000; - resampler = absl::make_unique>( - *kernel, kMaxDenominator); + params.max_denominator = 2000; + + // NOTE: QResampler supports multichannel resampling, so the code might be + // simplified using a single instance rather than one per channel. + resampler = absl::make_unique>( + source_sample_rate, target_sample_rate, /*num_channels=*/1, params); if (resampler != nullptr && !resampler->Valid()) { resampler = std::unique_ptr>(); } diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.h b/mediapipe/calculators/audio/rational_factor_resample_calculator.h index 745ac8f0d..325886dc7 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.h +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,28 +36,31 @@ namespace mediapipe { // stream's sampling rate is specified by target_sample_rate in the // RationalFactorResampleCalculatorOptions. The output time series may have // a varying number of samples per frame. +// +// NOTE: This calculator uses QResampler, despite the name, which supersedes +// RationalFactorResampler. class RationalFactorResampleCalculator : public CalculatorBase { public: struct TestAccess; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Single input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Resampled stream with TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid or if the // resampler cannot be initialized. - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Resamples a packet of TimeSeries data. Returns FAIL if the // resampler state becomes inconsistent. - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Flushes any remaining state. Returns FAIL if the resampler state // becomes inconsistent. - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; protected: typedef audio_dsp::Resampler ResamplerType; @@ -72,8 +75,8 @@ class RationalFactorResampleCalculator : public CalculatorBase { // Does Timestamp bookkeeping and resampling common to Process() and // Close(). Returns FAIL if the resampler state becomes // inconsistent. - ::mediapipe::Status ProcessInternal(const Matrix& input_frame, - bool should_flush, CalculatorContext* cc); + absl::Status ProcessInternal(const Matrix& input_frame, bool should_flush, + CalculatorContext* cc); // Uses the internal resampler_ objects to actually resample each // row of the input TimeSeries. Returns false if the resampler diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.proto b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto index 6eb36e672..97d7f202c 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.proto +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +// NOTE: This calculator uses QResampler, despite the name, which supersedes +// RationalFactorResampler. message RationalFactorResampleCalculatorOptions { extend CalculatorOptions { optional RationalFactorResampleCalculatorOptions ext = 259760074; @@ -27,8 +29,7 @@ message RationalFactorResampleCalculatorOptions { // stream. Required. Must be greater than 0. optional double target_sample_rate = 1; - // Parameters for initializing the RationalFactorResampler. See - // RationalFactorResampler for more details. + // Parameters for initializing QResampler. See QResampler for more details. message ResamplerRationalFactorOptions { // Kernel radius in units of input samples. optional double radius = 1; diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc index f21cff516..6ae360303 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc @@ -80,7 +80,7 @@ class RationalFactorResampleCalculatorTest } // Initializes and runs the test graph. - ::mediapipe::Status Run(double output_sample_rate) { + absl::Status Run(double output_sample_rate) { options_.set_target_sample_rate(output_sample_rate); InitializeGraph(); @@ -120,7 +120,6 @@ class RationalFactorResampleCalculatorTest // The exact number of expected samples may vary based on the implementation // of the resampler since the exact value is not an integer. - // TODO: Reduce this offset to + 1 once cl/185829520 is submitted. const double expected_num_output_samples = num_input_samples_ * factor; EXPECT_LE(ceil(expected_num_output_samples), num_output_samples); EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples); diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index 7bac73ff7..bd2234f86 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -66,7 +66,7 @@ namespace mediapipe { // analysis frame will advance from its predecessor by the same time step. class SpectrogramCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); @@ -96,26 +96,34 @@ class SpectrogramCalculator : public CalculatorBase { ); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid. - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Outputs at most one packet consisting of a single Matrix with one or // more columns containing the spectral values from as many input frames // as are completed by the input samples. Always returns OK. - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Performs zero-padding and processing of any remaining samples // if pad_final_packet is set. // Returns OK. - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: Timestamp CurrentOutputTimestamp(CalculatorContext* cc) { if (use_local_timestamp_) { - return cc->InputTimestamp(); + const Timestamp now = cc->InputTimestamp(); + if (now == Timestamp::Done()) { + // During Close the timestamp is not available, send an estimate. + return last_local_output_timestamp_ + + round(last_completed_frames_ * frame_step_samples() * + Timestamp::kTimestampUnitsPerSecond / input_sample_rate_); + } + last_local_output_timestamp_ = now; + return now; } return CumulativeOutputTimestamp(); } @@ -138,17 +146,20 @@ class SpectrogramCalculator : public CalculatorBase { // Convert the output of the spectrogram object into a Matrix (or an // Eigen::MatrixXcf if complex-valued output is requested) and pass to // MediaPipe output. - ::mediapipe::Status ProcessVector(const Matrix& input_stream, - CalculatorContext* cc); + absl::Status ProcessVector(const Matrix& input_stream, CalculatorContext* cc); // Templated function to process either real- or complex-output spectrogram. template - ::mediapipe::Status ProcessVectorToOutput( + absl::Status ProcessVectorToOutput( const Matrix& input_stream, const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), CalculatorContext* cc); + // Use the MediaPipe timestamp instead of the estimated one. Useful when the + // data is intermittent. bool use_local_timestamp_; + Timestamp last_local_output_timestamp_; + double input_sample_rate_; bool pad_final_packet_; int frame_duration_samples_; @@ -157,6 +168,9 @@ class SpectrogramCalculator : public CalculatorBase { int64 cumulative_input_samples_; // How many frames we've emitted, used for calculating output time stamps. int64 cumulative_completed_frames_; + // How many frames were emitted last, used for estimating the timestamp on + // Close when use_local_timestamp_ is true; + int64 last_completed_frames_; Timestamp initial_input_timestamp_; int num_input_channels_; // How many frequency bins we emit (=N_FFT/2 + 1). @@ -177,7 +191,7 @@ REGISTER_CALCULATOR(SpectrogramCalculator); // Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0). const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; -::mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { SpectrogramCalculatorOptions spectrogram_options = cc->Options(); @@ -271,11 +285,20 @@ const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; Adopt(multichannel_output_header.release())); } cumulative_completed_frames_ = 0; + last_completed_frames_ = 0; initial_input_timestamp_ = Timestamp::Unstarted(); - return ::mediapipe::OkStatus(); + if (use_local_timestamp_) { + // Inform the framework that the calculator will output packets at the same + // timestamps as input packets to enable packet queueing optimizations. The + // final packet (emitted from Close()) does not follow this rule but it's + // sufficient that its timestamp is strictly greater than the timestamp of + // the previous packet. + cc->SetOffset(0); + } + return absl::OkStatus(); } -::mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Process(CalculatorContext* cc) { if (initial_input_timestamp_ == Timestamp::Unstarted()) { initial_input_timestamp_ = cc->InputTimestamp(); } @@ -291,7 +314,7 @@ const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; } template -::mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput( +absl::Status SpectrogramCalculator::ProcessVectorToOutput( const Matrix& input_stream, const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), CalculatorContext* cc) { @@ -311,8 +334,8 @@ template if (!spectrogram_generators_[channel]->ComputeSpectrogram( input_vector, &output_vectors)) { - return ::mediapipe::Status(mediapipe::StatusCode::kInternal, - "Spectrogram returned failure"); + return absl::Status(absl::StatusCode::kInternal, + "Spectrogram returned failure"); } if (channel == 0) { // Record the number of time frames we expect from each channel. @@ -354,12 +377,19 @@ template CurrentOutputTimestamp(cc)); } cumulative_completed_frames_ += output_vectors.size(); + last_completed_frames_ = output_vectors.size(); + if (!use_local_timestamp_) { + // In non-local timestamp mode the timestamp of the next packet will be + // equal to CumulativeOutputTimestamp(). Inform the framework about this + // fact to enable packet queueing optimizations. + cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); + } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SpectrogramCalculator::ProcessVector( - const Matrix& input_stream, CalculatorContext* cc) { +absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream, + CalculatorContext* cc) { switch (output_type_) { // These blocks deliberately ignore clang-format to preserve the // "silhouette" of the different cases. @@ -394,13 +424,13 @@ template } // clang-format on default: { - return ::mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, - "Unrecognized spectrogram output type."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unrecognized spectrogram output type."); } } } -::mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Close(CalculatorContext* cc) { if (cumulative_input_samples_ > 0 && pad_final_packet_) { // We can flush any remaining samples by sending frame_step_samples - 1 // zeros to the Process method, and letting it do its thing, @@ -416,7 +446,7 @@ template Matrix::Zero(num_input_channels_, required_padding_samples), cc); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc index 200bdee11..3c2b8435d 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator_test.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -50,7 +50,7 @@ class SpectrogramCalculatorTest } // Initializes and runs the test graph. - ::mediapipe::Status Run() { + absl::Status Run() { // Now that options are set, we can set up some internal constants. frame_duration_samples_ = round(options_.frame_duration_seconds() * input_sample_rate_); diff --git a/mediapipe/calculators/audio/stabilized_log_calculator.cc b/mediapipe/calculators/audio/stabilized_log_calculator.cc index b5623ee0f..0c697a196 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator.cc @@ -41,17 +41,17 @@ namespace mediapipe { // } class StabilizedLogCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Output stabilized log stream with TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { StabilizedLogCalculatorOptions stabilized_log_calculator_options = cc->Options(); @@ -70,23 +70,23 @@ class StabilizedLogCalculator : public CalculatorBase { cc->Outputs().Index(0).SetHeader( Adopt(new TimeSeriesHeader(input_header))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto input_matrix = cc->Inputs().Index(0).Get(); if (input_matrix.array().isNaN().any()) { - return ::mediapipe::InvalidArgumentError("NaN input to log operation."); + return absl::InvalidArgumentError("NaN input to log operation."); } if (check_nonnegativity_) { if (input_matrix.minCoeff() < 0.0) { - return ::mediapipe::OutOfRangeError("Negative input to log operation."); + return absl::OutOfRangeError("Negative input to log operation."); } } std::unique_ptr output_frame(new Matrix( output_scale_ * (input_matrix.array() + stabilizer_).log().matrix())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index 04f593bca..fbbf34226 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -66,26 +66,26 @@ namespace mediapipe { // cumulative_completed_samples / sample_rate_. class TimeSeriesFramerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Fixed length time series Packets with TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid. - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Outputs as many framed packets as possible given the accumulated // input. Always returns OK. - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Flushes any remaining samples in a zero-padded packet. Always // returns OK. - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: // Adds input data to the internal buffer. @@ -134,7 +134,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emulate_fractional_frame_overlap is true. double average_frame_step_samples_; int samples_still_to_drop_; - int64 cumulative_input_samples_; int64 cumulative_output_frames_; // "Completed" samples are samples that are no longer needed because // the framer has completely stepped past them (taking into account @@ -163,8 +162,6 @@ void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { sample_buffer_.emplace_back(std::make_pair( input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i))); } - - cumulative_input_samples_ += input_frame.cols(); } void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { @@ -203,9 +200,15 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { ++cumulative_output_frames_; cumulative_completed_samples_ += frame_step_samples; } + if (!use_local_timestamp_) { + // In non-local timestamp mode the timestamp of the next packet will be + // equal to CumulativeOutputTimestamp(). Inform the framework about this + // fact to enable packet queueing optimizations. + cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); + } } -::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { if (initial_input_timestamp_ == Timestamp::Unstarted()) { initial_input_timestamp_ = cc->InputTimestamp(); current_timestamp_ = initial_input_timestamp_; @@ -214,10 +217,10 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { EnqueueInput(cc); FrameOutput(cc); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) { sample_buffer_.pop_front(); --samples_still_to_drop_; @@ -234,10 +237,10 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { CurrentOutputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { TimeSeriesFramerCalculatorOptions framer_options = cc->Options(); @@ -286,7 +289,6 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { } cc->Outputs().Index(0).SetHeader(Adopt(output_header)); cumulative_completed_samples_ = 0; - cumulative_input_samples_ = 0; cumulative_output_frames_ = 0; samples_still_to_drop_ = 0; initial_input_timestamp_ = Timestamp::Unstarted(); @@ -317,7 +319,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { } use_local_timestamp_ = framer_options.use_local_timestamp(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index cd0c38e13..ca88cebb5 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -69,7 +69,7 @@ class TimeSeriesFramerCalculatorTest } // Initializes and runs the test graph. - ::mediapipe::Status Run() { + absl::Status Run() { InitializeGraph(); FillInputHeader(); @@ -441,7 +441,7 @@ class TimeSeriesFramerCalculatorTimestampingTest } } - ::mediapipe::Status RunTimestampTest() { + absl::Status RunTimestampTest() { InitializeGraph(); InitializeInputForTimeStampingTest(); FillInputHeader(); diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 7f9ffd7f8..61d402f74 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -13,181 +13,131 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -proto_library( +mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/framework/formats:classification_proto", + ], +) + +mediapipe_proto_library( + name = "clip_vector_size_calculator_proto", + srcs = ["clip_vector_size_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( - name = "clip_vector_size_calculator_proto", - srcs = ["clip_vector_size_calculator.proto"], +mediapipe_proto_library( + name = "flow_limiter_calculator_proto", + srcs = ["flow_limiter_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "packet_cloner_calculator_cc_proto", - srcs = ["packet_cloner_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":packet_cloner_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "packet_resampler_calculator_cc_proto", - srcs = ["packet_resampler_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":packet_resampler_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "packet_thinner_calculator_cc_proto", - srcs = ["packet_thinner_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":packet_thinner_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "split_vector_calculator_cc_proto", - srcs = ["split_vector_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":split_vector_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "concatenate_vector_calculator_cc_proto", - srcs = ["concatenate_vector_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":concatenate_vector_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "clip_vector_size_calculator_cc_proto", - srcs = ["clip_vector_size_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":clip_vector_size_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "dequantize_byte_array_calculator_cc_proto", - srcs = ["dequantize_byte_array_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":dequantize_byte_array_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "quantize_float_vector_calculator_cc_proto", - srcs = ["quantize_float_vector_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":quantize_float_vector_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "sequence_shift_calculator_cc_proto", - srcs = ["sequence_shift_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":sequence_shift_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "gate_calculator_cc_proto", - srcs = ["gate_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":gate_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "constant_side_packet_calculator_cc_proto", - srcs = ["constant_side_packet_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":constant_side_packet_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( @@ -196,6 +146,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", ], @@ -286,15 +237,28 @@ cc_library( name = "concatenate_vector_calculator", srcs = ["concatenate_vector_calculator.cc"], hdrs = ["concatenate_vector_calculator.h"], + copts = select({ + # Needed for "//mediapipe/framework/formats:tensor" compatibility on Apple + # platforms for Metal pulled in via the tensor.h header. + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework:calculator_framework", + "//mediapipe/util:render_data_cc_proto", "@org_tensorflow//tensorflow/lite:framework", ] + select({ "//mediapipe/gpu:disable_gpu": [], @@ -325,6 +289,7 @@ cc_library( deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -432,6 +397,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", ], alwayslink = 1, @@ -445,7 +411,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", "@eigen_archive//:eigen", @@ -461,7 +427,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", "@eigen_archive//:eigen", @@ -477,6 +443,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/stream_handler:mux_input_stream_handler", ], @@ -589,6 +556,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", ], alwayslink = 1, @@ -645,6 +613,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", @@ -657,6 +626,7 @@ cc_library( srcs = ["flow_limiter_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", @@ -832,6 +802,7 @@ cc_test( srcs = ["flow_limiter_calculator_test.cc"], deps = [ ":flow_limiter_calculator", + ":flow_limiter_calculator_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", @@ -843,6 +814,8 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:simulation_clock", + "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:sink", "@com_google_absl//absl/time", ], @@ -852,14 +825,23 @@ cc_library( name = "split_vector_calculator", srcs = ["split_vector_calculator.cc"], hdrs = ["split_vector_calculator.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/util:resource_util", @@ -984,6 +966,7 @@ cc_library( deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", ], alwayslink = 1, @@ -1033,6 +1016,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1068,6 +1052,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], @@ -1121,6 +1106,7 @@ cc_library( ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/add_header_calculator.cc b/mediapipe/calculators/core/add_header_calculator.cc index 393c12225..1c636afd0 100644 --- a/mediapipe/calculators/core/add_header_calculator.cc +++ b/mediapipe/calculators/core/add_header_calculator.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { +namespace api2 { // Attach the header from a stream or side input to another stream. // @@ -42,49 +44,41 @@ namespace mediapipe { // output_stream: "audio_with_header" // } // -class AddHeaderCalculator : public CalculatorBase { +class AddHeaderCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - bool has_side_input = false; - bool has_header_stream = false; - if (cc->InputSidePackets().HasTag("HEADER")) { - cc->InputSidePackets().Tag("HEADER").SetAny(); - has_side_input = true; - } - if (cc->Inputs().HasTag("HEADER")) { - cc->Inputs().Tag("HEADER").SetNone(); - has_header_stream = true; - } - if (has_side_input == has_header_stream) { - return mediapipe::InvalidArgumentError( + static constexpr Input::Optional kHeader{"HEADER"}; + static constexpr SideInput::Optional kHeaderSide{"HEADER"}; + static constexpr Input kData{"DATA"}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) { + return absl::InvalidArgumentError( "Header must be provided via exactly one of side input and input " "stream"); } - cc->Inputs().Tag("DATA").SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA")); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - Packet header; - if (cc->InputSidePackets().HasTag("HEADER")) { - header = cc->InputSidePackets().Tag("HEADER"); - } - if (cc->Inputs().HasTag("HEADER")) { - header = cc->Inputs().Tag("HEADER").Header(); - } + absl::Status Open(CalculatorContext* cc) override { + const PacketBase& header = + kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc); if (!header.IsEmpty()) { - cc->Outputs().Index(0).SetHeader(header); + kOut(cc).SetHeader(header); } - cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + cc->SetOffset(0); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value()); - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + kOut(cc).Send(kData(cc).packet()); + return absl::OkStatus(); } }; -REGISTER_CALCULATOR(AddHeaderCalculator); +MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator); + +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/add_header_calculator_test.cc b/mediapipe/calculators/core/add_header_calculator_test.cc index 01ea986f1..4e197918d 100644 --- a/mediapipe/calculators/core/add_header_calculator_test.cc +++ b/mediapipe/calculators/core/add_header_calculator_test.cc @@ -153,7 +153,7 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) { } // Run should fail because header can only be provided one way. - EXPECT_EQ(runner.Run().code(), ::mediapipe::InvalidArgumentError("").code()); + EXPECT_EQ(runner.Run().code(), absl::InvalidArgumentError("").code()); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc index 716151b69..b627e5b23 100644 --- a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc +++ b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc @@ -42,22 +42,22 @@ REGISTER_CALCULATOR(BeginLoopIntegerCalculator); class IncrementCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const int& input_int = cc->Inputs().Index(0).Get(); auto output_int = absl::make_unique(input_int + 1); cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -166,19 +166,19 @@ TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) { // bound update. class PassThroughOrEmptyVectorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->SetProcessTimestampBounds(true); cc->Inputs().Index(0).Set>(); cc->Outputs().Index(0).Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!cc->Inputs().Index(0).IsEmpty()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } else { @@ -186,7 +186,7 @@ class PassThroughOrEmptyVectorCalculator : public CalculatorBase { MakePacket>(std::vector()) .At(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -311,24 +311,24 @@ TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) { class MultiplierCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Inputs().Index(1).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const int& input_int = cc->Inputs().Index(0).Get(); const int& multiplier_int = cc->Inputs().Index(1).Get(); auto output_int = absl::make_unique(input_int * multiplier_int); cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index ec59e1012..a9d29e687 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -61,7 +61,7 @@ class BeginLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { // The below enables processing of timestamp bound updates, and that enables // correct timestamp propagation by the companion EndLoopCalculator. // @@ -106,10 +106,10 @@ class BeginLoopCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { Timestamp last_timestamp = loop_internal_timestamp_; if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) { const IterableT& collection = @@ -139,7 +139,7 @@ class BeginLoopCalculator : public CalculatorBase { .AddPacket(MakePacket(cc->InputTimestamp()) .At(Timestamp(loop_internal_timestamp_ - 1))); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/clip_vector_size_calculator.h b/mediapipe/calculators/core/clip_vector_size_calculator.h index 89a4945d4..00de9be7f 100644 --- a/mediapipe/calculators/core/clip_vector_size_calculator.h +++ b/mediapipe/calculators/core/clip_vector_size_calculator.h @@ -33,7 +33,7 @@ namespace mediapipe { // input_stream: "input_vector" // output_stream: "output_vector" // options { -// [mediapipe.ClipIntVectorSizeCalculatorOptions.ext] { +// [mediapipe.ClipVectorSizeCalculatorOptions.ext] { // max_vec_size: 5 // } // } @@ -43,13 +43,13 @@ namespace mediapipe { template class ClipVectorSizeCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() == 1); if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>() .max_vec_size() < 1) { - return ::mediapipe::InternalError( + return absl::InternalError( "max_vec_size should be greater than or equal to 1."); } @@ -60,10 +60,10 @@ class ClipVectorSizeCalculator : public CalculatorBase { cc->InputSidePackets().Index(0).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>() .max_vec_size(); @@ -72,23 +72,23 @@ class ClipVectorSizeCalculator : public CalculatorBase { !cc->InputSidePackets().Index(0).IsEmpty()) { max_vec_size_ = cc->InputSidePackets().Index(0).Get(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (max_vec_size_ < 1) { - return ::mediapipe::InternalError( + return absl::InternalError( "max_vec_size should be greater than or equal to 1."); } if (cc->Inputs().Index(0).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } return ClipVectorSize(std::is_copy_constructible(), cc); } template - ::mediapipe::Status ClipVectorSize(std::true_type, CalculatorContext* cc) { + absl::Status ClipVectorSize(std::true_type, CalculatorContext* cc) { auto output = absl::make_unique>(); const std::vector& input_vector = cc->Inputs().Index(0).Get>(); @@ -100,24 +100,23 @@ class ClipVectorSizeCalculator : public CalculatorBase { } } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template - ::mediapipe::Status ClipVectorSize(std::false_type, CalculatorContext* cc) { + absl::Status ClipVectorSize(std::false_type, CalculatorContext* cc) { return ConsumeAndClipVectorSize(std::is_move_constructible(), cc); } template - ::mediapipe::Status ConsumeAndClipVectorSize(std::true_type, - CalculatorContext* cc) { + absl::Status ConsumeAndClipVectorSize(std::true_type, CalculatorContext* cc) { auto output = absl::make_unique>(); - ::mediapipe::StatusOr>> input_status = + absl::StatusOr>> input_status = cc->Inputs().Index(0).Value().Consume>(); if (input_status.ok()) { std::unique_ptr> input_vector = - std::move(input_status).ValueOrDie(); + std::move(input_status).value(); auto begin_it = input_vector->begin(); auto end_it = input_vector->end(); if (max_vec_size_ < input_vector->size()) { @@ -129,13 +128,13 @@ class ClipVectorSizeCalculator : public CalculatorBase { return input_status.status(); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template - ::mediapipe::Status ConsumeAndClipVectorSize(std::false_type, - CalculatorContext* cc) { - return ::mediapipe::InternalError( + absl::Status ConsumeAndClipVectorSize(std::false_type, + CalculatorContext* cc) { + return absl::InternalError( "Cannot copy or move input vectors and clip their size."); } diff --git a/mediapipe/calculators/core/clip_vector_size_calculator.proto b/mediapipe/calculators/core/clip_vector_size_calculator.proto index 5dea660d6..6044f77c8 100644 --- a/mediapipe/calculators/core/clip_vector_size_calculator.proto +++ b/mediapipe/calculators/core/clip_vector_size_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message ClipVectorSizeCalculatorOptions { extend CalculatorOptions { optional ClipVectorSizeCalculatorOptions ext = 274674998; diff --git a/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc index 161a323cf..fd7d324b2 100644 --- a/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019-2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,14 +20,16 @@ namespace mediapipe { // Example config: +// // node { // calculator: "ConcatenateDetectionVectorCalculator" // input_stream: "detection_vector_1" // input_stream: "detection_vector_2" // output_stream: "concatenated_detection_vector" // } +// typedef ConcatenateVectorCalculator<::mediapipe::Detection> ConcatenateDetectionVectorCalculator; -REGISTER_CALCULATOR(ConcatenateDetectionVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateDetectionVectorCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc index 54c3e05b9..f0a4043a7 100644 --- a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc @@ -16,6 +16,7 @@ #define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -23,61 +24,55 @@ #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Concatenates several NormalizedLandmarkList protos following stream index // order. This class assumes that every input stream contains a // NormalizedLandmarkList proto object. -class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase { +class ConcatenateNormalizedLandmarkListCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().NumEntries() != 0); - RET_CHECK(cc->Outputs().NumEntries() == 1); + static constexpr Input::Multiple kIn{""}; + static constexpr Output kOut{""}; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - cc->Inputs().Index(i).Set(); - } + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - cc->Outputs().Index(0).Set(); - - return ::mediapipe::OkStatus(); + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GE(kIn(cc).Count(), 1); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); + absl::Status Open(CalculatorContext* cc) override { only_emit_if_all_present_ = cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() .only_emit_if_all_present(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (only_emit_if_all_present_) { - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) return absl::OkStatus(); } } NormalizedLandmarkList output; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) continue; - const NormalizedLandmarkList& input = - cc->Inputs().Index(i).Get(); - for (int j = 0; j < input.landmark_size(); ++j) { - const NormalizedLandmark& input_landmark = input.landmark(j); - *output.add_landmark() = input_landmark; + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) continue; + const NormalizedLandmarkList& list = *input; + for (int j = 0; j < list.landmark_size(); ++j) { + *output.add_landmark() = list.landmark(j); } } - cc->Outputs().Index(0).AddPacket( - MakePacket(output).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } private: bool only_emit_if_all_present_; }; +MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator); -REGISTER_CALCULATOR(ConcatenateNormalizedLandmarkListCalculator); - +} // namespace api2 } // namespace mediapipe // NOLINTNEXTLINE diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index c57f84f1e..20d6a3286 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -18,12 +18,14 @@ #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/util/render_data.pb.h" #include "tensorflow/lite/interpreter.h" #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) namespace mediapipe { @@ -35,7 +37,7 @@ namespace mediapipe { // output_stream: "concatenated_float_vector" // } typedef ConcatenateVectorCalculator ConcatenateFloatVectorCalculator; -REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateFloatVectorCalculator); // Example config: // node { @@ -45,10 +47,13 @@ REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator); // output_stream: "concatenated_int32_vector" // } typedef ConcatenateVectorCalculator ConcatenateInt32VectorCalculator; -REGISTER_CALCULATOR(ConcatenateInt32VectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateInt32VectorCalculator); typedef ConcatenateVectorCalculator ConcatenateUInt64VectorCalculator; -REGISTER_CALCULATOR(ConcatenateUInt64VectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator); + +typedef ConcatenateVectorCalculator ConcatenateBoolVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator); // Example config: // node { @@ -59,24 +64,31 @@ REGISTER_CALCULATOR(ConcatenateUInt64VectorCalculator); // } typedef ConcatenateVectorCalculator ConcatenateTfLiteTensorVectorCalculator; -REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateTfLiteTensorVectorCalculator); + +typedef ConcatenateVectorCalculator ConcatenateTensorVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateTensorVectorCalculator); typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark> ConcatenateLandmarkVectorCalculator; -REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator); typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList> ConcatenateLandmarListVectorCalculator; -REGISTER_CALCULATOR(ConcatenateLandmarListVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator); typedef ConcatenateVectorCalculator ConcatenateClassificationListVectorCalculator; -REGISTER_CALCULATOR(ConcatenateClassificationListVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListVectorCalculator); #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer> ConcatenateGlBufferVectorCalculator; -REGISTER_CALCULATOR(ConcatenateGlBufferVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateGlBufferVectorCalculator); #endif +typedef ConcatenateVectorCalculator + ConcatenateRenderDataVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.h b/mediapipe/calculators/core/concatenate_vector_calculator.h index ef72cb0dc..c6687814c 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.h +++ b/mediapipe/calculators/core/concatenate_vector_calculator.h @@ -20,122 +20,96 @@ #include #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +// Note: since this is a calculator template that can be included by other +// source files, we do not place this in namespace api2 directly, but qualify +// the api2 names below, to avoid changing the visible name of the class. +// We cannot simply write "using mediapipe::api2" since it's a header file. +// This distinction will go away once api2 is finalized. // Concatenates several objects of type T or std::vector following stream // index order. This class assumes that every input stream contains either T or // vector type. To use this class for a particular type T, regisiter a // calculator using ConcatenateVectorCalculator. template -class ConcatenateVectorCalculator : public CalculatorBase { +class ConcatenateVectorCalculator : public api2::Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().NumEntries() != 0); - RET_CHECK(cc->Outputs().NumEntries() == 1); + static constexpr + typename api2::Input>>::Multiple kIn{""}; + static constexpr api2::Output> kOut{""}; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - // Actual type T or vector will be validated in Process(). - cc->Inputs().Index(i).SetAny(); - } + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - cc->Outputs().Index(0).Set>(); - - return ::mediapipe::OkStatus(); + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GE(kIn(cc).Count(), 1); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); + absl::Status Open(CalculatorContext* cc) override { only_emit_if_all_present_ = cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() .only_emit_if_all_present(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (only_emit_if_all_present_) { - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) return ::absl::OkStatus(); } } - return ConcatenateVectors(std::is_copy_constructible(), cc); } template - ::mediapipe::Status ConcatenateVectors(std::true_type, - CalculatorContext* cc) { - auto output = absl::make_unique>(); - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - auto& input = cc->Inputs().Index(i); - + absl::Status ConcatenateVectors(std::true_type, CalculatorContext* cc) { + auto output = std::vector(); + for (const auto& input : kIn(cc)) { if (input.IsEmpty()) continue; - - if (input.Value().ValidateAsType().ok()) { - const U& value = input.Get(); - output->push_back(value); - } else if (input.Value().ValidateAsType>().ok()) { - const std::vector& value = input.Get>(); - output->insert(output->end(), value.begin(), value.end()); - } else { - return ::mediapipe::InvalidArgumentError("Invalid input stream type."); - } + input.Visit([&output](const U& value) { output.push_back(value); }, + [&output](const std::vector& value) { + output.insert(output.end(), value.begin(), value.end()); + }); } - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } template - ::mediapipe::Status ConcatenateVectors(std::false_type, - CalculatorContext* cc) { + absl::Status ConcatenateVectors(std::false_type, CalculatorContext* cc) { return ConsumeAndConcatenateVectors(std::is_move_constructible(), cc); } template - ::mediapipe::Status ConsumeAndConcatenateVectors(std::true_type, - CalculatorContext* cc) { - auto output = absl::make_unique>(); - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - auto& input = cc->Inputs().Index(i); - + absl::Status ConsumeAndConcatenateVectors(std::true_type, + CalculatorContext* cc) { + auto output = std::vector(); + for (auto input : kIn(cc)) { if (input.IsEmpty()) continue; - - if (input.Value().ValidateAsType().ok()) { - ::mediapipe::StatusOr> value_status = - input.Value().Consume(); - if (value_status.ok()) { - std::unique_ptr value = std::move(value_status).ValueOrDie(); - output->push_back(std::move(*value)); - } else { - return value_status.status(); - } - } else if (input.Value().ValidateAsType>().ok()) { - ::mediapipe::StatusOr>> value_status = - input.Value().Consume>(); - if (value_status.ok()) { - std::unique_ptr> value = - std::move(value_status).ValueOrDie(); - output->insert(output->end(), std::make_move_iterator(value->begin()), - std::make_move_iterator(value->end())); - } else { - return value_status.status(); - } - } else { - return ::mediapipe::InvalidArgumentError("Invalid input stream type."); - } + MP_RETURN_IF_ERROR(input.ConsumeAndVisit( + [&output](std::unique_ptr value) { + output.push_back(std::move(*value)); + }, + [&output](std::unique_ptr> value) { + output.insert(output.end(), std::make_move_iterator(value->begin()), + std::make_move_iterator(value->end())); + })); } - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } template - ::mediapipe::Status ConsumeAndConcatenateVectors(std::false_type, - CalculatorContext* cc) { - return ::mediapipe::InternalError( + absl::Status ConsumeAndConcatenateVectors(std::false_type, + CalculatorContext* cc) { + return absl::InternalError( "Cannot copy or move inputs to concatenate them"); } diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.proto b/mediapipe/calculators/core/concatenate_vector_calculator.proto index bddb8af95..3753ffb5d 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.proto +++ b/mediapipe/calculators/core/concatenate_vector_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message ConcatenateVectorCalculatorOptions { extend CalculatorOptions { optional ConcatenateVectorCalculatorOptions ext = 259397839; diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc index eaf23700c..83f058086 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -28,7 +28,7 @@ namespace mediapipe { typedef ConcatenateVectorCalculator TestConcatenateIntVectorCalculator; -REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator); +MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator); void AddInputVector(int index, const std::vector& input, int64 timestamp, CalculatorRunner* runner) { @@ -384,7 +384,7 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { typedef ConcatenateVectorCalculator> TestConcatenateUniqueIntPtrCalculator; -REGISTER_CALCULATOR(TestConcatenateUniqueIntPtrCalculator); +MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator); TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) { /* Note: We don't use CalculatorRunner for this test because it keeps copies diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index 7541ccd66..ff328377e 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -17,6 +17,7 @@ #include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" @@ -24,6 +25,8 @@ namespace mediapipe { +namespace {} // namespace + // Generates an output side packet or multiple output side packets according to // the specified options. // @@ -51,7 +54,7 @@ namespace mediapipe { // } class ConstantSidePacketCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>(); RET_CHECK_EQ(cc->OutputSidePackets().NumEntries(kPacketTag), @@ -74,15 +77,17 @@ class ConstantSidePacketCalculator : public CalculatorBase { packet.Set(); } else if (packet_options.has_uint64_value()) { packet.Set(); + } else if (packet_options.has_classification_list_value()) { + packet.Set(); } else { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "None of supported values were specified in options."); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>(); int index = 0; @@ -100,16 +105,19 @@ class ConstantSidePacketCalculator : public CalculatorBase { packet.Set(MakePacket(packet_options.string_value())); } else if (packet_options.has_uint64_value()) { packet.Set(MakePacket(packet_options.uint64_value())); + } else if (packet_options.has_classification_list_value()) { + packet.Set(MakePacket( + packet_options.classification_list_value())); } else { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "None of supported values were specified in options."); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.proto b/mediapipe/calculators/core/constant_side_packet_calculator.proto index 6b3feebde..57f5dc545 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.proto +++ b/mediapipe/calculators/core/constant_side_packet_calculator.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/formats/classification.proto"; + +option objc_class_prefix = "MediaPipe"; message ConstantSidePacketCalculatorOptions { extend CalculatorOptions { @@ -30,6 +33,7 @@ message ConstantSidePacketCalculatorOptions { bool bool_value = 3; string string_value = 4; uint64 uint64_value = 5; + ClassificationList classification_list_value = 6; } } diff --git a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc index dee0a219e..1357a99a5 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc @@ -40,7 +40,7 @@ void DoTestSingleSidePacket(absl::string_view packet_spec, } )"; CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(graph_config_template, packet_spec)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(graph_config)); @@ -49,7 +49,7 @@ void DoTestSingleSidePacket(absl::string_view packet_spec, MP_ASSERT_OK(graph.GetOutputSidePacket("packet")); auto actual_value = - graph.GetOutputSidePacket("packet").ValueOrDie().template Get(); + graph.GetOutputSidePacket("packet").value().template Get(); EXPECT_EQ(actual_value, expected_value); } @@ -62,7 +62,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) { TEST(ConstantSidePacketCalculatorTest, MultiplePackets) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( node { calculator: "ConstantSidePacketCalculator" output_side_packet: "PACKET:0:int_packet" @@ -89,33 +89,29 @@ TEST(ConstantSidePacketCalculatorTest, MultiplePackets) { MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("int_packet").ValueOrDie().Get(), - 256); + EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get(), 256); MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("float_packet").ValueOrDie().Get(), + EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get(), 0.5f); MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet")); - EXPECT_FALSE( - graph.GetOutputSidePacket("bool_packet").ValueOrDie().Get()); + EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get()); MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("string_packet") - .ValueOrDie() - .Get(), - "string"); + EXPECT_EQ( + graph.GetOutputSidePacket("string_packet").value().Get(), + "string"); MP_ASSERT_OK(graph.GetOutputSidePacket("another_string_packet")); EXPECT_EQ(graph.GetOutputSidePacket("another_string_packet") - .ValueOrDie() + .value() .Get(), "another string"); MP_ASSERT_OK(graph.GetOutputSidePacket("another_int_packet")); - EXPECT_EQ( - graph.GetOutputSidePacket("another_int_packet").ValueOrDie().Get(), - 128); + EXPECT_EQ(graph.GetOutputSidePacket("another_int_packet").value().Get(), + 128); } TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( node { calculator: "ConstantSidePacketCalculator" output_side_packet: "PACKET:0:int_packet" @@ -142,24 +138,21 @@ TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) { MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("int_packet").ValueOrDie().Get(), - 256); + EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get(), 256); MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("float_packet").ValueOrDie().Get(), + EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get(), 0.5f); MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet")); - EXPECT_FALSE( - graph.GetOutputSidePacket("bool_packet").ValueOrDie().Get()); + EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get()); MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("string_packet") - .ValueOrDie() - .Get(), - "string"); + EXPECT_EQ( + graph.GetOutputSidePacket("string_packet").value().Get(), + "string"); } TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MoreOptionsThanPackets) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( node { calculator: "ConstantSidePacketCalculator" output_side_packet: "PACKET:int_packet" @@ -177,7 +170,7 @@ TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MoreOptionsThanPackets) { TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MorePacketsThanOptions) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( node { calculator: "ConstantSidePacketCalculator" output_side_packet: "PACKET:0:int_packet" diff --git a/mediapipe/calculators/core/counting_source_calculator.cc b/mediapipe/calculators/core/counting_source_calculator.cc index 7b2f79a0c..0b731d9ce 100644 --- a/mediapipe/calculators/core/counting_source_calculator.cc +++ b/mediapipe/calculators/core/counting_source_calculator.cc @@ -30,7 +30,7 @@ namespace mediapipe { // provided, then batches are of size 1. class CountingSourceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { @@ -55,13 +55,13 @@ class CountingSourceCalculator : public CalculatorBase { if (cc->InputSidePackets().HasTag("INCREMENT")) { cc->InputSidePackets().Tag("INCREMENT").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { - return ::mediapipe::NotFoundError("expected error"); + return absl::NotFoundError("expected error"); } if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get(); @@ -83,12 +83,12 @@ class CountingSourceCalculator : public CalculatorBase { RET_CHECK_LT(0, increment_); } RET_CHECK(error_count_ >= 0 || max_count_ >= 0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (error_count_ >= 0 && batch_counter_ >= error_count_) { - return ::mediapipe::InternalError("expected error"); + return absl::InternalError("expected error"); } if (max_count_ >= 0 && batch_counter_ >= max_count_) { return tool::StatusStop(); @@ -98,7 +98,7 @@ class CountingSourceCalculator : public CalculatorBase { counter_ += increment_; } ++batch_counter_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc index 4f1a3ed86..04a7e55a0 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc @@ -37,34 +37,34 @@ namespace mediapipe { class DequantizeByteArrayCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("ENCODED").Set(); cc->Outputs().Tag("FLOAT_VECTOR").Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { const auto options = cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>(); if (!options.has_max_quantized_value() || !options.has_min_quantized_value()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both max_quantized_value and min_quantized_value must be provided " "in DequantizeByteArrayCalculatorOptions."); } float max_quantized_value = options.max_quantized_value(); float min_quantized_value = options.min_quantized_value(); if (max_quantized_value < min_quantized_value + FLT_EPSILON) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "max_quantized_value must be greater than min_quantized_value."); } float range = max_quantized_value - min_quantized_value; scalar_ = range / 255.0; bias_ = (range / 512.0) + min_quantized_value; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const std::string& encoded = cc->Inputs().Tag("ENCODED").Value().Get(); std::vector float_vector; @@ -77,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { .Tag("FLOAT_VECTOR") .AddPacket(MakePacket>(float_vector) .At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator.proto b/mediapipe/calculators/core/dequantize_byte_array_calculator.proto index 3032dbf48..3af8e11ef 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator.proto +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message DequantizeByteArrayCalculatorOptions { extend CalculatorOptions { optional DequantizeByteArrayCalculatorOptions ext = 272316343; diff --git a/mediapipe/calculators/core/end_loop_calculator.h b/mediapipe/calculators/core/end_loop_calculator.h index 869cc29a2..e40301e81 100644 --- a/mediapipe/calculators/core/end_loop_calculator.h +++ b/mediapipe/calculators/core/end_loop_calculator.h @@ -57,7 +57,7 @@ class EndLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("BATCH_END")) << "Missing BATCH_END tagged input_stream."; cc->Inputs().Tag("BATCH_END").Set(); @@ -67,10 +67,10 @@ class EndLoopCalculator : public CalculatorBase { RET_CHECK(cc->Outputs().HasTag("ITERABLE")); cc->Outputs().Tag("ITERABLE").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!cc->Inputs().Tag("ITEM").IsEmpty()) { if (!input_stream_collection_) { input_stream_collection_.reset(new IterableT); @@ -94,7 +94,7 @@ class EndLoopCalculator : public CalculatorBase { .SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1)); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index 6d595e6cd..4fbfced96 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -23,41 +24,23 @@ namespace mediapipe { -// FlowLimiterCalculator is used to limit the number of pipelined processing -// operations in a section of the graph. +// FlowLimiterCalculator is used to limit the number of frames in flight +// by dropping input frames when necessary. // -// Typical topology: +// The input stream "FINISH" is used to signal the FlowLimiterCalculator +// when a frame is finished processing. Either a non-empty "FINISH" packet +// or a timestamp bound should be received for each processed frame. // -// in ->-[FLC]-[foo]-...-[bar]-+->- out -// ^_____________________| -// FINISHED +// The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives +// best throughput/latency balance. Throughput is nearly optimal as the +// graph is never idle as there is always something in the queue. Latency is +// nearly optimal latency as the queue always stores the latest available frame. // -// By connecting the output of the graph section to this calculator's FINISHED -// input with a backwards edge, this allows FLC to keep track of how many -// timestamps are currently being processed. -// -// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side -// packet. -// -// As long as the number of timestamps being processed ("in flight") is below -// the limit, FLC allows input to pass through. When the limit is reached, -// FLC starts dropping input packets, keeping only the most recent. When the -// processing count decreases again, as signaled by the receipt of a packet on -// FINISHED, FLC allows packets to flow again, releasing the most recently -// queued packet, if any. -// -// If there are multiple input streams, packet dropping is synchronized. -// -// IMPORTANT: for each timestamp where FLC forwards a packet (or a set of -// packets, if using multiple data streams), a packet must eventually arrive on -// the FINISHED stream. Dropping packets in the section between FLC and -// FINISHED will make the in-flight count incorrect. -// -// TODO: Remove this comment when graph-level ISH has been removed. -// NOTE: this calculator should always use the ImmediateInputStreamHandler and -// uses it by default. However, if the graph specifies a graph-level -// InputStreamHandler, to override that setting, the InputStreamHandler must -// be explicitly specified as shown below. +// Increasing `max_in_flight` to 2 or more can yield the better throughput +// when the graph exhibits a high degree of pipeline parallelism. Decreasing +// `max_in_flight` to 0 can yield a better average latency, but at the cost of +// lower throughput (lower framerate) due to the time during which the graph +// is idle awaiting the next input frame. // // Example config: // node { @@ -68,131 +51,178 @@ namespace mediapipe { // tag_index: 'FINISHED' // back_edge: true // } -// input_stream_handler { -// input_stream_handler: 'ImmediateInputStreamHandler' -// } -// output_stream: "gated_frames" +// output_stream: "sampled_frames" +// output_stream: "ALLOW:allowed_timestamps" // } +// +// The "ALLOW" stream indicates the transition between accepting frames and +// dropping frames. "ALLOW = true" indicates the start of accepting frames +// including the current timestamp, and "ALLOW = true" indicates the start of +// dropping frames including the current timestamp. +// +// FlowLimiterCalculator provides limited support for multiple input streams. +// The first input stream is treated as the main input stream and successive +// input streams are treated as auxiliary input streams. The auxiliary input +// streams are limited to timestamps passed on the main input stream. +// class FlowLimiterCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - int num_data_streams = cc->Inputs().NumEntries(""); - RET_CHECK_GE(num_data_streams, 1); - RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) - << "Output streams must correspond input streams except for the " - "finish indicator input stream."; - for (int i = 0; i < num_data_streams; ++i) { + static absl::Status GetContract(CalculatorContract* cc) { + auto& side_inputs = cc->InputSidePackets(); + side_inputs.Tag("OPTIONS").Set().Optional(); + cc->Inputs().Tag("OPTIONS").Set().Optional(); + RET_CHECK_GE(cc->Inputs().NumEntries(""), 1); + for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); } cc->Inputs().Get("FINISHED", 0).SetAny(); - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set(); - } - if (cc->Outputs().HasTag("ALLOW")) { - cc->Outputs().Tag("ALLOW").Set(); - } - + cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set().Optional(); + cc->Outputs().Tag("ALLOW").Set().Optional(); cc->SetInputStreamHandler("ImmediateInputStreamHandler"); - - return ::mediapipe::OkStatus(); + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { - finished_id_ = cc->Inputs().GetId("FINISHED", 0); - max_in_flight_ = 1; + absl::Status Open(CalculatorContext* cc) final { + options_ = cc->Options(); + options_ = tool::RetrieveOptions(options_, cc->InputSidePackets()); if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get(); + options_.set_max_in_flight( + cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get()); } - RET_CHECK_GE(max_in_flight_, 1); - num_in_flight_ = 0; - - allowed_id_ = cc->Outputs().GetId("ALLOW", 0); - allow_ctr_ts_ = Timestamp(0); - - num_data_streams_ = cc->Inputs().NumEntries(""); - data_stream_bound_ts_.resize(num_data_streams_); + input_queues_.resize(cc->Inputs().NumEntries("")); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - bool Allow() { return num_in_flight_ < max_in_flight_; } + // Returns true if an additional frame can be released for processing. + // The "ALLOW" output stream indicates this condition at each input frame. + bool ProcessingAllowed() { + return frames_in_flight_.size() < options_.max_in_flight(); + } - ::mediapipe::Status Process(CalculatorContext* cc) final { - bool old_allow = Allow(); - Timestamp lowest_incomplete_ts = Timestamp::Done(); - - // Process FINISHED stream. - if (!cc->Inputs().Get(finished_id_).Value().IsEmpty()) { - RET_CHECK_GT(num_in_flight_, 0) - << "Received a FINISHED packet, but we had none in flight."; - --num_in_flight_; + // Outputs a packet indicating whether a frame was sent or dropped. + void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { + if (cc->Outputs().HasTag("ALLOW")) { + cc->Outputs().Tag("ALLOW").AddPacket(MakePacket(allow).At(ts)); } + } - // Process data streams. - for (int i = 0; i < num_data_streams_; ++i) { - auto& stream = cc->Inputs().Get("", i); - auto& out = cc->Outputs().Get("", i); - Packet& packet = stream.Value(); - auto ts = packet.Timestamp(); - if (ts.IsRangeValue() && data_stream_bound_ts_[i] <= ts) { - data_stream_bound_ts_[i] = ts + 1; - // Note: it's ok to update the output bound here, before sending the - // packet, because updates are batched during the Process function. - out.SetNextTimestampBound(data_stream_bound_ts_[i]); - } - lowest_incomplete_ts = - std::min(lowest_incomplete_ts, data_stream_bound_ts_[i]); + // Sets the timestamp bound or closes an output stream. + void SetNextTimestampBound(Timestamp bound, OutputStream* stream) { + if (bound > Timestamp::Max()) { + stream->Close(); + } else { + stream->SetNextTimestampBound(bound); + } + } - if (packet.IsEmpty()) { - // If the input stream is closed, close the corresponding output. - if (stream.IsDone() && !out.IsClosed()) { - out.Close(); + // Returns true if a certain timestamp is being processed. + bool IsInFlight(Timestamp timestamp) { + return std::find(frames_in_flight_.begin(), frames_in_flight_.end(), + timestamp) != frames_in_flight_.end(); + } + + // Releases input packets up to the latest settled input timestamp. + void ProcessAuxiliaryInputs(CalculatorContext* cc) { + Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound(); + for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) { + // Release settled frames from each input queue. + while (!input_queues_[i].empty() && + input_queues_[i].front().Timestamp() < settled_bound) { + Packet packet = input_queues_[i].front(); + input_queues_[i].pop_front(); + if (IsInFlight(packet.Timestamp())) { + cc->Outputs().Get("", i).AddPacket(packet); } - // TODO: if the packet is empty, the ts is unset, and we - // cannot read the timestamp bound, even though we'd like to propagate - // it. - } else if (mediapipe::ContainsKey(pending_ts_, ts)) { - // If we have already sent this timestamp (on another stream), send it - // on this stream too. - out.AddPacket(std::move(packet)); - } else if (Allow() && (ts > last_dropped_ts_)) { - // If the in-flight is under the limit, and if we have not already - // dropped this or a later timestamp on another stream, then send - // the packet and add an in-flight timestamp. - out.AddPacket(std::move(packet)); - pending_ts_.insert(ts); - ++num_in_flight_; + } + + // Propagate each input timestamp bound. + if (!input_queues_[i].empty()) { + Timestamp bound = input_queues_[i].front().Timestamp(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); } else { - // Otherwise, we'll drop the packet. - last_dropped_ts_ = std::max(last_dropped_ts_, ts); + Timestamp bound = + cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); + } + } + } + + // Releases input packets allowed by the max_in_flight constraint. + absl::Status Process(CalculatorContext* cc) final { + options_ = tool::RetrieveOptions(options_, cc->Inputs()); + + // Process the FINISHED input stream. + Packet finished_packet = cc->Inputs().Tag("FINISHED").Value(); + if (finished_packet.Timestamp() == cc->InputTimestamp()) { + while (!frames_in_flight_.empty() && + frames_in_flight_.front() <= finished_packet.Timestamp()) { + frames_in_flight_.pop_front(); } } - // Remove old pending_ts_ entries. - auto it = std::lower_bound(pending_ts_.begin(), pending_ts_.end(), - lowest_incomplete_ts); - pending_ts_.erase(pending_ts_.begin(), it); - - // Update ALLOW signal. - if ((old_allow != Allow()) && allowed_id_.IsValid()) { - cc->Outputs() - .Get(allowed_id_) - .AddPacket(MakePacket(Allow()).At(++allow_ctr_ts_)); + // Process the frame input streams. + for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { + Packet packet = cc->Inputs().Get("", i).Value(); + if (!packet.IsEmpty()) { + input_queues_[i].push_back(packet); + } } - return ::mediapipe::OkStatus(); + + // Abandon expired frames in flight. Note that old frames are abandoned + // when much newer frame timestamps arrive regardless of elapsed time. + TimestampDiff timeout = options_.in_flight_timeout(); + Timestamp latest_ts = cc->Inputs().Get("", 0).Value().Timestamp(); + if (timeout > 0 && latest_ts == cc->InputTimestamp() && + latest_ts < Timestamp::Max()) { + while (!frames_in_flight_.empty() && + (latest_ts - frames_in_flight_.front()) > timeout) { + frames_in_flight_.pop_front(); + } + } + + // Release allowed frames from the main input queue. + auto& input_queue = input_queues_[0]; + while (ProcessingAllowed() && !input_queue.empty()) { + Packet packet = input_queue.front(); + input_queue.pop_front(); + cc->Outputs().Get("", 0).AddPacket(packet); + SendAllow(true, packet.Timestamp(), cc); + frames_in_flight_.push_back(packet.Timestamp()); + } + + // Limit the number of queued frames. + // Note that frames can be dropped after frames are released because + // frame-packets and FINISH-packets never arrive in the same Process call. + while (input_queue.size() > options_.max_in_queue()) { + Packet packet = input_queue.front(); + input_queue.pop_front(); + SendAllow(false, packet.Timestamp(), cc); + } + + // Propagate the input timestamp bound. + if (!input_queue.empty()) { + Timestamp bound = input_queue.front().Timestamp(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", 0)); + } else { + Timestamp bound = + cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", 0)); + if (cc->Outputs().HasTag("ALLOW")) { + SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW")); + } + } + + ProcessAuxiliaryInputs(cc); + return absl::OkStatus(); } private: - std::set pending_ts_; - Timestamp last_dropped_ts_; - int num_data_streams_; - int num_in_flight_; - int max_in_flight_; - CollectionItemId finished_id_; - CollectionItemId allowed_id_; - Timestamp allow_ctr_ts_; - std::vector data_stream_bound_ts_; + FlowLimiterCalculatorOptions options_; + std::vector> input_queues_; + std::deque frames_in_flight_; }; REGISTER_CALCULATOR(FlowLimiterCalculator); diff --git a/mediapipe/calculators/core/flow_limiter_calculator.proto b/mediapipe/calculators/core/flow_limiter_calculator.proto new file mode 100644 index 000000000..0f7c925ae --- /dev/null +++ b/mediapipe/calculators/core/flow_limiter_calculator.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option objc_class_prefix = "MediaPipe"; + +message FlowLimiterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FlowLimiterCalculatorOptions ext = 326963320; + } + + // The maximum number of frames released for processing at one time. + // The default value limits to 1 frame processing at a time. + optional int32 max_in_flight = 1 [default = 1]; + + // The maximum number of frames queued waiting for processing. + // The default value limits to 1 frame awaiting processing. + optional int32 max_in_queue = 2 [default = 0]; + + // The maximum time in microseconds to wait for a frame to finish processing. + // The default value stops waiting after 1 sec. + // The value 0 specifies no timeout. + optional int64 in_flight_timeout = 3 [default = 1000000]; +} diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 895c88e6d..303c1a053 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -19,6 +19,7 @@ #include "absl/time/clock.h" #include "absl/time/time.h" +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/image_frame.h" @@ -28,6 +29,8 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/simulation_clock.h" +#include "mediapipe/framework/tool/simulation_clock_executor.h" #include "mediapipe/framework/tool/sink.h" namespace mediapipe { @@ -67,144 +70,49 @@ std::vector PacketValues(const std::vector& packets) { return result; } -constexpr int kNumImageFrames = 5; -constexpr int kNumFinished = 3; -CalculatorGraphConfig::Node GetDefaultNode() { - return ParseTextProtoOrDie(R"( - calculator: "FlowLimiterCalculator" - input_stream: "raw_frames" - input_stream: "FINISHED:finished" - input_stream_info: { tag_index: "FINISHED" back_edge: true } - output_stream: "gated_frames" - )"); -} - -// Simple test to make sure that the FlowLimiterCalculator outputs just one -// packet when MAX_IN_FLIGHT is 1. -TEST(FlowLimiterCalculator, OneOutputTest) { - // Setup the calculator runner and add only ImageFrame packets. - CalculatorRunner runner(GetDefaultNode()); - for (int i = 0; i < kNumImageFrames; ++i) { - Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); - runner.MutableInputs()->Index(0).packets.push_back( - MakePacket().At(timestamp)); - } - - // Run the calculator. - MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& frame_output_packets = - runner.Outputs().Index(0).packets; - - EXPECT_EQ(frame_output_packets.size(), 1); -} - -// Simple test to make sure that the FlowLimiterCalculator waits for all -// input streams to have at least one packet available before publishing. -TEST(FlowLimiterCalculator, BasicTest) { - // Setup the calculator runner and add both ImageFrame and finish packets. - CalculatorRunner runner(GetDefaultNode()); - for (int i = 0; i < kNumImageFrames; ++i) { - Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); - runner.MutableInputs()->Index(0).packets.push_back( - MakePacket().At(timestamp)); - } - for (int i = 0; i < kNumFinished; ++i) { - Timestamp timestamp = - Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); - runner.MutableInputs() - ->Tag("FINISHED") - .packets.push_back(MakePacket(true).At(timestamp)); - } - - // Run the calculator. - MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& frame_output_packets = - runner.Outputs().Index(0).packets; - - // Only outputs packets if both input streams are available. - int expected_num_packets = std::min(kNumImageFrames, kNumFinished + 1); - EXPECT_EQ(frame_output_packets.size(), expected_num_packets); -} - // A Calculator::Process callback function. -typedef std::function<::mediapipe::Status(const InputStreamShardSet&, - OutputStreamShardSet*)> +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -::mediapipe::Status PassthroughFunction(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassthroughFunction(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -// A Calculator that runs a testing callback function in Close. -class CloseCallbackCalculator : public CalculatorBase { +// Tests demonstrating an FlowLimiterCalculator operating in a cyclic graph. +class FlowLimiterCalculatorSemaphoreTest : public testing::Test { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - for (CollectionItemId id = cc->Inputs().BeginId(); - id < cc->Inputs().EndId(); ++id) { - cc->Inputs().Get(id).SetAny(); - } - for (CollectionItemId id = cc->Outputs().BeginId(); - id < cc->Outputs().EndId(); ++id) { - cc->Outputs().Get(id).SetAny(); - } - cc->InputSidePackets().Index(0).Set>(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) override { - return PassthroughFunction(cc->Inputs(), &(cc->Outputs())); - } - - ::mediapipe::Status Close(CalculatorContext* cc) override { - const auto& callback = cc->InputSidePackets() - .Index(0) - .Get>(); - return callback(); - } -}; -REGISTER_CALCULATOR(CloseCallbackCalculator); - -// Tests demostrating an FlowLimiterCalculator operating in a cyclic graph. -// TODO: clean up these tests. -class FlowLimiterCalculatorTest : public testing::Test { - public: - FlowLimiterCalculatorTest() : enter_semaphore_(0), exit_semaphore_(0) {} + FlowLimiterCalculatorSemaphoreTest() : exit_semaphore_(0) {} void SetUp() override { graph_config_ = InflightGraphConfig(); tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_); - tool::AddVectorSink("out_2", &graph_config_, &out_2_packets_); } void InitializeGraph(int max_in_flight) { - ProcessFunction semaphore_0_func = [&](const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { - enter_semaphore_.Release(1); - return PassthroughFunction(inputs, outputs); - }; ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { exit_semaphore_.Acquire(1); return PassthroughFunction(inputs, outputs); }; - std::function<::mediapipe::Status()> close_func = [this]() { - close_count_++; - return ::mediapipe::OkStatus(); - }; + FlowLimiterCalculatorOptions options; + options.set_max_in_flight(max_in_flight); + options.set_max_in_queue(1); MP_ASSERT_OK(graph_.Initialize( graph_config_, { - {"max_in_flight", MakePacket(max_in_flight)}, - {"callback_0", Adopt(new auto(semaphore_0_func))}, + {"limiter_options", Adopt(new auto(options))}, {"callback_1", Adopt(new auto(semaphore_1_func))}, - {"callback_2", Adopt(new auto(close_func))}, })); + + allow_poller_.reset( + new OutputStreamPoller(graph_.AddOutputStreamPoller("allow").value())); } // Adds a packet to a graph input stream. @@ -216,44 +124,24 @@ class FlowLimiterCalculatorTest : public testing::Test { // A calculator graph starting with an FlowLimiterCalculator and // ending with a InFlightFinishCalculator. // Back-edge "finished" limits processing to one frame in-flight. - // The two LambdaCalculators are used to keep certain packet sets in flight. + // The LambdaCalculator is used to keep certain frames in flight. CalculatorGraphConfig InflightGraphConfig() { return ParseTextProtoOrDie(R"( input_stream: 'in_1' - input_stream: 'in_2' node { calculator: 'FlowLimiterCalculator' - input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_side_packet: 'OPTIONS:limiter_options' input_stream: 'in_1' - input_stream: 'in_2' input_stream: 'FINISHED:out_1' input_stream_info: { tag_index: 'FINISHED' back_edge: true } output_stream: 'in_1_sampled' - output_stream: 'in_2_sampled' - } - node { - calculator: 'LambdaCalculator' - input_side_packet: 'callback_0' - input_stream: 'in_1_sampled' - input_stream: 'in_2_sampled' - output_stream: 'queue_1' - output_stream: 'queue_2' + output_stream: 'ALLOW:allow' } node { calculator: 'LambdaCalculator' input_side_packet: 'callback_1' - input_stream: 'queue_1' - input_stream: 'queue_2' - output_stream: 'close_1' - output_stream: 'close_2' - } - node { - calculator: 'CloseCallbackCalculator' - input_side_packet: 'callback_2' - input_stream: 'close_1' - input_stream: 'close_2' + input_stream: 'in_1_sampled' output_stream: 'out_1' - output_stream: 'out_2' } )"); } @@ -261,21 +149,19 @@ class FlowLimiterCalculatorTest : public testing::Test { protected: CalculatorGraphConfig graph_config_; CalculatorGraph graph_; - AtomicSemaphore enter_semaphore_; AtomicSemaphore exit_semaphore_; std::vector out_1_packets_; - std::vector out_2_packets_; - int close_count_ = 0; + std::unique_ptr allow_poller_; }; // A test demonstrating an FlowLimiterCalculator operating in a cyclic // graph. This test shows that: // -// (1) Timestamps are passed through unaltered. -// (2) All output streams including the back_edge stream are closed when -// the first input stream is closed. +// (1) Frames exceeding the queue size are dropped. +// (2) The "ALLOW" signal is produced. +// (3) Timestamps are passed through unaltered. // -TEST_F(FlowLimiterCalculatorTest, BackEdgeCloses) { +TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) { InitializeGraph(1); MP_ASSERT_OK(graph_.StartRun({})); @@ -284,210 +170,590 @@ TEST_F(FlowLimiterCalculatorTest, BackEdgeCloses) { input_name, MakePacket(n).At(Timestamp(n)))); }; - for (int i = 0; i < 10; i++) { - send_packet("in_1", i * 10); - // This next input should be dropped. + Packet allow_packet; + send_packet("in_1", 0); + for (int i = 0; i < 9; i++) { + EXPECT_TRUE(allow_poller_->Next(&allow_packet)); + EXPECT_TRUE(allow_packet.Get()); + // This input should wait in the limiter input queue. send_packet("in_1", i * 10 + 5); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - send_packet("in_2", i * 10); + // This input should drop the previous input. + send_packet("in_1", i * 10 + 10); + EXPECT_TRUE(allow_poller_->Next(&allow_packet)); + EXPECT_FALSE(allow_packet.Get()); exit_semaphore_.Release(1); - MP_EXPECT_OK(graph_.WaitUntilIdle()); } + exit_semaphore_.Release(1); MP_EXPECT_OK(graph_.CloseInputStream("in_1")); - MP_EXPECT_OK(graph_.CloseInputStream("in_2")); MP_EXPECT_OK(graph_.WaitUntilIdle()); // All output streams are closed and all output packets are delivered, - // with stream "in_1" and stream "in_2" closed. + // with stream "in_1" closed. EXPECT_EQ(10, out_1_packets_.size()); - EXPECT_EQ(10, out_2_packets_.size()); - // Timestamps have not been messed with. + // Timestamps have not been altered. EXPECT_EQ(PacketValues(out_1_packets_), TimestampValues(out_1_packets_)); - EXPECT_EQ(PacketValues(out_2_packets_), - TimestampValues(out_2_packets_)); - // Extra inputs on in_1 have been dropped + // Extra inputs on in_1 have been dropped. EXPECT_EQ(TimestampValues(out_1_packets_), (std::vector{0, 10, 20, 30, 40, 50, 60, 70, 80, 90})); - EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); - - // The closing of the stream has been propagated. - EXPECT_EQ(1, close_count_); } -// A test demonstrating that all output streams are closed when all -// input streams are closed after the last input packet has been processed. -TEST_F(FlowLimiterCalculatorTest, AllStreamsClose) { - InitializeGraph(1); - MP_ASSERT_OK(graph_.StartRun({})); - - exit_semaphore_.Release(10); - for (int i = 0; i < 10; i++) { - AddPacket("in_1", i); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - AddPacket("in_2", i); - MP_EXPECT_OK(graph_.WaitUntilIdle()); +// A calculator that sleeps during Process. +class SleepCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("PACKET").SetAny(); + cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); + cc->InputSidePackets().Tag("SLEEP_TIME").Set(); + cc->InputSidePackets().Tag("WARMUP_TIME").Set(); + cc->InputSidePackets().Tag("CLOCK").Set(); + cc->SetTimestampOffset(0); + return absl::OkStatus(); } - MP_EXPECT_OK(graph_.CloseAllInputStreams()); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); - EXPECT_EQ(TimestampValues(out_1_packets_), - (std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); - EXPECT_EQ(1, close_count_); + absl::Status Open(CalculatorContext* cc) final { + clock_ = cc->InputSidePackets().Tag("CLOCK").Get(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + ++packet_count; + absl::Duration sleep_time = absl::Microseconds( + packet_count == 1 + ? cc->InputSidePackets().Tag("WARMUP_TIME").Get() + : cc->InputSidePackets().Tag("SLEEP_TIME").Get()); + clock_->Sleep(sleep_time); + cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + return absl::OkStatus(); + } + + private: + ::mediapipe::Clock* clock_ = nullptr; + int packet_count = 0; +}; +REGISTER_CALCULATOR(SleepCalculator); + +// A calculator that drops a packet occasionally. +// Drops the 3rd packet, and optionally the corresponding timestamp bound. +class DropCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("PACKET").SetAny(); + cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); + cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set(); + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) { + ++packet_count; + } + bool drop = (packet_count == 3); + if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) { + cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + } + if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get()) { + cc->Outputs().Tag("PACKET").SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + return absl::OkStatus(); + } + + private: + int packet_count = 0; +}; +REGISTER_CALCULATOR(DropCalculator); + +// Tests demonstrating an FlowLimiterCalculator processing FINISHED timestamps. +class FlowLimiterCalculatorTest : public testing::Test { + protected: + CalculatorGraphConfig InflightGraphConfig() { + return ParseTextProtoOrDie(R"( + input_stream: 'in_1' + node { + calculator: 'FlowLimiterCalculator' + input_side_packet: 'OPTIONS:limiter_options' + input_stream: 'in_1' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'ALLOW:allow' + } + node { + calculator: 'SleepCalculator' + input_side_packet: 'WARMUP_TIME:warmup_time' + input_side_packet: 'SLEEP_TIME:sleep_time' + input_side_packet: 'CLOCK:clock' + input_stream: 'PACKET:in_1_sampled' + output_stream: 'PACKET:out_1_sampled' + } + node { + calculator: 'DropCalculator' + input_side_packet: "DROP_TIMESTAMPS:drop_timesamps" + input_stream: 'PACKET:out_1_sampled' + output_stream: 'PACKET:out_1' + } + )"); + } + + // Parse an absl::Time from RFC3339 format. + absl::Time ParseTime(const std::string& date_time_str) { + absl::Time result; + absl::ParseTime(absl::RFC3339_sec, date_time_str, &result, nullptr); + return result; + } + + // The point in simulated time when the test starts. + absl::Time StartTime() { return ParseTime("2020-11-03T20:00:00Z"); } + + // Initialize the test clock to follow simulated time. + void SetUpSimulationClock() { + auto executor = std::make_shared(8); + simulation_clock_ = executor->GetClock(); + clock_ = simulation_clock_.get(); + simulation_clock_->ThreadStart(); + clock_->SleepUntil(StartTime()); + simulation_clock_->ThreadFinish(); + MP_ASSERT_OK(graph_.SetExecutor("", executor)); + } + + // Initialize the test clock to follow wall time. + void SetUpRealClock() { clock_ = mediapipe::Clock::RealClock(); } + + // Create a few mediapipe input Packets holding ints. + void SetUpInputData() { + for (int i = 0; i < 100; ++i) { + input_packets_.push_back(MakePacket(i).At(Timestamp(i * 10000))); + } + } + + protected: + CalculatorGraph graph_; + mediapipe::Clock* clock_; + std::shared_ptr simulation_clock_; + std::vector input_packets_; + std::vector out_1_packets_; + std::vector allow_packets_; +}; + +// Shows that "FINISHED" can be indicated with either a packet or a timestamp +// bound. DropCalculator periodically drops one packet but always propagates +// the timestamp bound. Input packets are released or dropped promptly after +// each "FINISH" packet or a timestamp bound arrives. +TEST_F(FlowLimiterCalculatorTest, FinishedTimestamps) { + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = InflightGraphConfig(); + auto limiter_options = ParseTextProtoOrDie(R"( + max_in_flight: 1 + max_in_queue: 1 + )"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(22000)}, + {"sleep_time", MakePacket(22000)}, + {"drop_timesamps", MakePacket(false)}, + {"clock", MakePacket(clock_)}, + }; + + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); + + // Add 9 input packets. + // 1. packet-0 is released, + // 2. packet-1 is queued, + // 3. packet-2 is queued and packet-1 is dropped, + // 4. packet-2 is released, and so forth. + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0])); + clock_->Sleep(absl::Microseconds(1)); + EXPECT_EQ(allow_packets_.size(), 1); + EXPECT_EQ(allow_packets_.back().Get(), true); + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 8; i += 2) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + EXPECT_EQ(allow_packets_.size(), i); + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i + 1])); + clock_->Sleep(absl::Microseconds(1)); + EXPECT_EQ(allow_packets_.size(), i + 1); + EXPECT_EQ(allow_packets_.back().Get(), false); + clock_->Sleep(absl::Microseconds(10000)); + EXPECT_EQ(allow_packets_.size(), i + 2); + EXPECT_EQ(allow_packets_.back().Get(), true); + } + + // Finish the graph. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); + MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // input_packets_[4] is dropped by the DropCalculator. + std::vector expected_output = {input_packets_[0], input_packets_[2], + input_packets_[6], input_packets_[8]}; + EXPECT_EQ(out_1_packets_, expected_output); } -TEST(FlowLimiterCalculator, TwoStreams) { - std::vector a_passed; - std::vector b_passed; - CalculatorGraphConfig graph_config_ = - ParseTextProtoOrDie(R"( - input_stream: 'in_a' - input_stream: 'in_b' - input_stream: 'finished' - node { - name: 'input_dropper' - calculator: 'FlowLimiterCalculator' - input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' - input_stream: 'in_a' - input_stream: 'in_b' - input_stream: 'FINISHED:finished' - input_stream_info: { tag_index: 'FINISHED' back_edge: true } - output_stream: 'in_a_sampled' - output_stream: 'in_b_sampled' - output_stream: 'ALLOW:allow' - } - )"); - std::string allow_cb_name; - tool::AddVectorSink("in_a_sampled", &graph_config_, &a_passed); - tool::AddVectorSink("in_b_sampled", &graph_config_, &b_passed); - tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); - - bool allow = true; - auto allow_cb = [&allow](const Packet& packet) { - allow = packet.Get(); +// Shows that an output packet can be lost completely, and the +// FlowLimiterCalculator will stop waiting for it after in_flight_timeout. +// DropCalculator completely loses one packet including its timestamp bound. +// FlowLimiterCalculator waits 100 ms, and then starts releasing packets again. +TEST_F(FlowLimiterCalculatorTest, FinishedLost) { + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = InflightGraphConfig(); + auto limiter_options = ParseTextProtoOrDie(R"( + max_in_flight: 1 + max_in_queue: 1 + in_flight_timeout: 100000 # 100 ms + )"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(22000)}, + {"sleep_time", MakePacket(22000)}, + {"drop_timesamps", MakePacket(true)}, + {"clock", MakePacket(clock_)}, }; - CalculatorGraph graph_; - MP_EXPECT_OK(graph_.Initialize( - graph_config_, - { - {"max_in_flight", MakePacket(1)}, - {allow_cb_name, - MakePacket>(allow_cb)}, - })); + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); - MP_EXPECT_OK(graph_.StartRun({})); + // Add 21 input packets. + // 1. packet-0 is released, packet-1 queued and dropped, and so forth. + // 2. packet-4 is lost by DropCalculator. + // 3. packet-5 through 13 are dropped while waiting for packet-4. + // 4. packet-4 expires and queued packet-14 is released. + // 5. packet-17, 19, and 20 are released on time. + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0])); + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 21; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } - auto send_packet = [&graph_](const std::string& input_name, int n) { - MP_EXPECT_OK(graph_.AddPacketToInputStream( - input_name, MakePacket(n).At(Timestamp(n)))); - }; - send_packet("in_a", 1); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(allow, false); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{})); - - send_packet("in_a", 2); - send_packet("in_b", 1); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); - EXPECT_EQ(allow, false); - - send_packet("finished", 1); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); - EXPECT_EQ(allow, true); - - send_packet("in_b", 2); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); - EXPECT_EQ(allow, true); - - send_packet("in_b", 3); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); - EXPECT_EQ(allow, false); - - send_packet("in_b", 4); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); - EXPECT_EQ(allow, false); - - send_packet("in_a", 3); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); - EXPECT_EQ(allow, false); - - send_packet("finished", 3); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); - EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); - EXPECT_EQ(allow, true); - - MP_EXPECT_OK(graph_.CloseAllInputStreams()); + // Finish the graph. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // input_packets_[4] is lost by the DropCalculator. + std::vector expected_output = { + input_packets_[0], input_packets_[2], input_packets_[14], + input_packets_[17], input_packets_[19], input_packets_[20], + }; + EXPECT_EQ(out_1_packets_, expected_output); } -TEST(FlowLimiterCalculator, CanConsume) { - std::vector in_sampled_packets_; - CalculatorGraphConfig graph_config_ = +// Shows what happens when a finish packet is delayed beyond in_flight_timeout. +// After in_flight_timeout, FlowLimiterCalculator continues releasing packets. +// Temporarily, more than max_in_flight frames are in flight. +// Eventually, the number of frames in flight returns to max_in_flight. +TEST_F(FlowLimiterCalculatorTest, FinishedDelayed) { + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = InflightGraphConfig(); + auto limiter_options = ParseTextProtoOrDie(R"( + max_in_flight: 1 + max_in_queue: 1 + in_flight_timeout: 100000 # 100 ms + )"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(500000)}, + {"sleep_time", MakePacket(22000)}, + {"drop_timesamps", MakePacket(false)}, + {"clock", MakePacket(clock_)}, + }; + + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); + + // Add 71 input packets. + // 1. During the 500 ms WARMUP_TIME, the in_flight_timeout releases + // packets 0, 10, 20, 30, 40, 50, which are queued at the SleepCalculator. + // 2. During the next 120 ms, these 6 packets are processed. + // 3. After the graph is finally finished with warmup and the backlog packets, + // packets 60 through 70 are released and processed on time. + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0])); + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 71; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } + + // Finish the graph. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); + MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // The graph is warming up or backlogged until packet 60. + std::vector expected_output = { + input_packets_[0], input_packets_[10], input_packets_[30], + input_packets_[40], input_packets_[50], input_packets_[60], + input_packets_[63], input_packets_[65], input_packets_[67], + input_packets_[69], input_packets_[70], + }; + EXPECT_EQ(out_1_packets_, expected_output); +} + +// Shows that packets on auxiliary input streams are relesed for the same +// timestamps as the main input stream, whether the auxiliary packets arrive +// early or late. +TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = ParseTextProtoOrDie(R"( - input_stream: 'in' - input_stream: 'finished' + input_stream: 'in_1' + input_stream: 'in_2' node { - name: 'input_dropper' calculator: 'FlowLimiterCalculator' - input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' - input_stream: 'in' - input_stream: 'FINISHED:finished' + input_side_packet: 'OPTIONS:limiter_options' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' input_stream_info: { tag_index: 'FINISHED' back_edge: true } - output_stream: 'in_sampled' + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' output_stream: 'ALLOW:allow' } + node { + calculator: 'SleepCalculator' + input_side_packet: 'WARMUP_TIME:warmup_time' + input_side_packet: 'SLEEP_TIME:sleep_time' + input_side_packet: 'CLOCK:clock' + input_stream: 'PACKET:in_1_sampled' + output_stream: 'PACKET:out_1_sampled' + } + node { + calculator: 'DropCalculator' + input_side_packet: "DROP_TIMESTAMPS:drop_timesamps" + input_stream: 'PACKET:out_1_sampled' + output_stream: 'PACKET:out_1' + } )"); - std::string allow_cb_name; - tool::AddVectorSink("in_sampled", &graph_config_, &in_sampled_packets_); - tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); - bool allow = true; - auto allow_cb = [&allow](const Packet& packet) { - allow = packet.Get(); + auto limiter_options = ParseTextProtoOrDie(R"( + max_in_flight: 1 + max_in_queue: 1 + in_flight_timeout: 100000 # 100 ms + )"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(22000)}, + {"sleep_time", MakePacket(22000)}, + {"drop_timesamps", MakePacket(true)}, + {"clock", MakePacket(clock_)}, }; - CalculatorGraph graph_; - MP_EXPECT_OK(graph_.Initialize( - graph_config_, - { - {"max_in_flight", MakePacket(1)}, - {allow_cb_name, - MakePacket>(allow_cb)}, - })); + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + std::vector out_2_packets; + MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { + out_2_packets.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); - MP_EXPECT_OK(graph_.StartRun({})); + // Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2. + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0])); + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 10; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1])); + clock_->Sleep(absl::Microseconds(10000)); + } - auto send_packet = [&graph_](const std::string& input_name, int n) { - MP_EXPECT_OK(graph_.AddPacketToInputStream( - input_name, MakePacket(n).At(Timestamp(n)))); - }; - send_packet("in", 1); - MP_EXPECT_OK(graph_.WaitUntilIdle()); - EXPECT_EQ(allow, false); - EXPECT_EQ(TimestampValues(in_sampled_packets_), (std::vector{1})); + // Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2. + for (int i = 10; i < 21; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1])); + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } - MP_EXPECT_OK(in_sampled_packets_[0].Consume()); - - MP_EXPECT_OK(graph_.CloseAllInputStreams()); + // Finish the graph run. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // Packet input_packets_[4] is lost by the DropCalculator. + std::vector expected_output = { + input_packets_[0], input_packets_[2], input_packets_[14], + input_packets_[17], input_packets_[19], input_packets_[20], + }; + EXPECT_EQ(out_1_packets_, expected_output); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. + std::vector expected_output_2 = { + input_packets_[0], input_packets_[2], input_packets_[4], + input_packets_[14], input_packets_[17], input_packets_[19], + input_packets_[20], + }; + EXPECT_EQ(out_2_packets, expected_output_2); +} + +// Shows how FlowLimiterCalculator releases packets with max_in_queue 0. +// Shows how auxiliary input streams still work with max_in_queue 0. +// The processing time "sleep_time" is reduced from 22ms to 12ms to create +// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. +TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: 'in_1' + input_stream: 'in_2' + node { + calculator: 'FlowLimiterCalculator' + input_side_packet: 'OPTIONS:limiter_options' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' + output_stream: 'ALLOW:allow' + } + node { + calculator: 'SleepCalculator' + input_side_packet: 'WARMUP_TIME:warmup_time' + input_side_packet: 'SLEEP_TIME:sleep_time' + input_side_packet: 'CLOCK:clock' + input_stream: 'PACKET:in_1_sampled' + output_stream: 'PACKET:out_1_sampled' + } + node { + calculator: 'DropCalculator' + input_side_packet: "DROP_TIMESTAMPS:drop_timesamps" + input_stream: 'PACKET:out_1_sampled' + output_stream: 'PACKET:out_1' + } + )"); + + auto limiter_options = ParseTextProtoOrDie(R"( + max_in_flight: 1 + max_in_queue: 0 + in_flight_timeout: 100000 # 100 ms + )"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(12000)}, + {"sleep_time", MakePacket(12000)}, + {"drop_timesamps", MakePacket(true)}, + {"clock", MakePacket(clock_)}, + }; + + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + std::vector out_2_packets; + MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { + out_2_packets.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); + + // Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2. + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0])); + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 10; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1])); + clock_->Sleep(absl::Microseconds(10000)); + } + + // Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2. + for (int i = 10; i < 21; ++i) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1])); + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } + + // Finish the graph run. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); + MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // Packet input_packets_[4] is lost by the DropCalculator. + std::vector expected_output = { + input_packets_[0], input_packets_[2], input_packets_[15], + input_packets_[17], input_packets_[19], + }; + EXPECT_EQ(out_1_packets_, expected_output); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. + std::vector expected_output_2 = { + input_packets_[0], input_packets_[2], input_packets_[4], + input_packets_[15], input_packets_[17], input_packets_[19], + }; + EXPECT_EQ(out_2_packets, expected_output_2); } } // anonymous namespace diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index ea0f7b81b..189671860 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019-2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -82,8 +82,7 @@ class GateCalculator : public CalculatorBase { public: GateCalculator() {} - static ::mediapipe::Status CheckAndInitAllowDisallowInputs( - CalculatorContract* cc) { + static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) { bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || cc->InputSidePackets().HasTag("DISALLOW"); bool input_via_stream = @@ -110,10 +109,10 @@ class GateCalculator : public CalculatorBase { cc->Inputs().Tag("DISALLOW").Set(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); const int num_data_streams = cc->Inputs().NumEntries(""); @@ -130,10 +129,10 @@ class GateCalculator : public CalculatorBase { cc->Outputs().Tag("STATE_CHANGE").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { use_side_packet_for_allow_disallow_ = false; if (cc->InputSidePackets().HasTag("ALLOW")) { use_side_packet_for_allow_disallow_ = true; @@ -153,10 +152,10 @@ class GateCalculator : public CalculatorBase { const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); empty_packets_as_allow_ = options.empty_packets_as_allow(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { bool allow = empty_packets_as_allow_; if (use_side_packet_for_allow_disallow_) { allow = allow_by_side_packet_decision_; @@ -187,7 +186,15 @@ class GateCalculator : public CalculatorBase { last_gate_state_ = new_gate_state; if (!allow) { - return ::mediapipe::OkStatus(); + // Close the output streams if the gate will be permanently closed. + // Prevents buffering in calculators whose parents do no use SetOffset. + for (int i = 0; i < num_data_streams_; ++i) { + if (!cc->Outputs().Get("", i).IsClosed() && + use_side_packet_for_allow_disallow_) { + cc->Outputs().Get("", i).Close(); + } + } + return absl::OkStatus(); } // Process data streams. @@ -197,7 +204,7 @@ class GateCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto index 0ef2c3e1c..76bacc74e 100644 --- a/mediapipe/calculators/core/gate_calculator.proto +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019-2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message GateCalculatorOptions { extend mediapipe.CalculatorOptions { optional GateCalculatorOptions ext = 261754847; diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index fc34f6e97..0b78b9b75 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019-2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ namespace { class GateCalculatorTest : public ::testing::Test { protected: // Helper to run a graph and return status. - static ::mediapipe::Status RunGraph(const std::string& proto) { + static absl::Status RunGraph(const std::string& proto) { auto runner = absl::make_unique( ParseTextProtoOrDie(proto)); return runner->Run(); diff --git a/mediapipe/calculators/core/immediate_mux_calculator.cc b/mediapipe/calculators/core/immediate_mux_calculator.cc index 007fbf73e..0e51cda5e 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator.cc @@ -29,9 +29,7 @@ namespace mediapipe { // received. // // This Calculator can be used with an ImmediateInputStreamHandler or with the -// default ISH. Note that currently ImmediateInputStreamHandler seems to -// interfere with timestamp bound propagation, so it is better to use the -// default unless the immediate one is needed. (b/118387598) +// default ISH. // // This Calculator is designed to work with a Demux calculator such as // the RoundRobinDemuxCalculator. Therefore, packets from different @@ -45,17 +43,16 @@ class ImmediateMuxCalculator : public CalculatorBase { public: // This calculator combines any set of input streams into a single // output stream. All input stream types must match the output stream type. - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Passes any input packet to the output stream immediately, unless the // packet timestamp is lower than a previously passed packet. - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(ImmediateMuxCalculator); -::mediapipe::Status ImmediateMuxCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2) << "This calculator produces only one or two output streams."; cc->Outputs().Index(0).SetAny(); @@ -65,15 +62,15 @@ REGISTER_CALCULATOR(ImmediateMuxCalculator); for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) { +absl::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { +absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { // Pass along the first packet, unless it has been superseded. for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { const Packet& packet = cc->Inputs().Index(i).Value(); @@ -91,7 +88,7 @@ REGISTER_CALCULATOR(ImmediateMuxCalculator); } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator_test.cc b/mediapipe/calculators/core/immediate_mux_calculator_test.cc index 4afe358f2..6913fd000 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator_test.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator_test.cc @@ -289,19 +289,19 @@ TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) { } // A Calculator::Process callback function. -typedef std::function<::mediapipe::Status(const InputStreamShardSet&, - OutputStreamShardSet*)> +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -::mediapipe::Status PassThrough(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassThrough(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } TEST_F(ImmediateMuxCalculatorTest, Demux) { @@ -325,7 +325,7 @@ TEST_F(ImmediateMuxCalculatorTest, Demux) { auto out_cb = [&](const Packet& p) { absl::MutexLock lock(&out_mutex); out_packets.push_back(p); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }; auto wait_for = [&](std::function cond) { absl::MutexLock lock(&out_mutex); diff --git a/mediapipe/calculators/core/make_pair_calculator.cc b/mediapipe/calculators/core/make_pair_calculator.cc index 8eb4cb67b..561656861 100644 --- a/mediapipe/calculators/core/make_pair_calculator.cc +++ b/mediapipe/calculators/core/make_pair_calculator.cc @@ -15,10 +15,12 @@ #include #include +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Given two input streams (A, B), output a single stream containing a pair. @@ -30,32 +32,27 @@ namespace mediapipe { // input_stream: "packet_b" // output_stream: "output_pair_a_b" // } -class MakePairCalculator : public CalculatorBase { +class MakePairCalculator : public Node { public: - MakePairCalculator() {} - ~MakePairCalculator() override {} + static constexpr Input::Multiple kIn{""}; + // Note that currently api2::Packet is a different type from mediapipe::Packet + static constexpr Output> + kPair{""}; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Inputs().Index(1).SetAny(); - cc->Outputs().Index(0).Set>(); - return ::mediapipe::OkStatus(); + MEDIAPIPE_NODE_CONTRACT(kIn, kPair); + + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_EQ(kIn(cc).Count(), 2); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).Add( - new std::pair(cc->Inputs().Index(0).Value(), - cc->Inputs().Index(1).Value()), - cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()}); + return absl::OkStatus(); } }; -REGISTER_CALCULATOR(MakePairCalculator); +MEDIAPIPE_REGISTER_NODE(MakePairCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_multiply_calculator.cc b/mediapipe/calculators/core/matrix_multiply_calculator.cc index 8dc60b763..d52e0c2fa 100644 --- a/mediapipe/calculators/core/matrix_multiply_calculator.cc +++ b/mediapipe/calculators/core/matrix_multiply_calculator.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "Eigen/Core" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Perform a (left) matrix multiply. Meaning (output = A * input) // where A is the matrix which is provided as an input side packet. // @@ -28,39 +30,22 @@ namespace mediapipe { // output_stream: "multiplied_samples" // input_side_packet: "multiplication_matrix" // } -class MatrixMultiplyCalculator : public CalculatorBase { +class MatrixMultiplyCalculator : public Node { public: - MatrixMultiplyCalculator() {} - ~MatrixMultiplyCalculator() override {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; + static constexpr SideInput kSide{""}; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(MatrixMultiplyCalculator); +MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator); -// static -::mediapipe::Status MatrixMultiplyCalculator::GetContract( - CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); - cc->InputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); -} - -::mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) { - // The output is at the same timestamp as the input. - cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); -} - -::mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { - Matrix* multiplied = new Matrix(); - *multiplied = cc->InputSidePackets().Index(0).Get() * - cc->Inputs().Index(0).Get(); - cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp()); - return ::mediapipe::OkStatus(); +absl::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { + kOut(cc).Send(*kSide(cc) * *kIn(cc)); + return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator.cc b/mediapipe/calculators/core/matrix_subtract_calculator.cc index af13a0d38..09471a5ee 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "Eigen/Core" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Subtract input matrix from the side input matrix and vice versa. The matrices // must have the same dimension. @@ -41,83 +43,39 @@ namespace mediapipe { // input_side_packet: "MINUEND:side_matrix" // output_stream: "output_matrix" // } -class MatrixSubtractCalculator : public CalculatorBase { +class MatrixSubtractCalculator : public Node { public: - MatrixSubtractCalculator() {} - ~MatrixSubtractCalculator() override {} + static constexpr Input::SideFallback kMinuend{"MINUEND"}; + static constexpr Input::SideFallback kSubtrahend{"SUBTRAHEND"}; + static constexpr Output kOut{""}; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut); + static absl::Status UpdateContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - - private: - bool subtract_from_input_ = false; + absl::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(MatrixSubtractCalculator); +MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator); // static -::mediapipe::Status MatrixSubtractCalculator::GetContract( - CalculatorContract* cc) { - if (cc->Inputs().NumEntries() != 1 || - cc->InputSidePackets().NumEntries() != 1) { - return ::mediapipe::InvalidArgumentError( - "MatrixSubtractCalculator only accepts exactly one input stream and " - "one " - "input side packet"); - } - if (cc->Inputs().HasTag("MINUEND") && - cc->InputSidePackets().HasTag("SUBTRAHEND")) { - cc->Inputs().Tag("MINUEND").Set(); - cc->InputSidePackets().Tag("SUBTRAHEND").Set(); - } else if (cc->Inputs().HasTag("SUBTRAHEND") && - cc->InputSidePackets().HasTag("MINUEND")) { - cc->Inputs().Tag("SUBTRAHEND").Set(); - cc->InputSidePackets().Tag("MINUEND").Set(); - } else { - return ::mediapipe::InvalidArgumentError( - "Must specify exactly one minuend and one subtrahend."); - } - cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); +absl::Status MatrixSubtractCalculator::UpdateContract(CalculatorContract* cc) { + // TODO: the next restriction could be relaxed. + RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream()) + << "MatrixSubtractCalculator only accepts exactly one input stream and " + "one input side packet"; + return absl::OkStatus(); } -::mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) { - // The output is at the same timestamp as the input. - cc->SetOffset(TimestampDiff(0)); - if (cc->Inputs().HasTag("MINUEND")) { - subtract_from_input_ = true; +absl::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { + const Matrix& minuend = *kMinuend(cc); + const Matrix& subtrahend = *kSubtrahend(cc); + if (minuend.rows() != subtrahend.rows() || + minuend.cols() != subtrahend.cols()) { + return absl::InvalidArgumentError( + "Minuend and subtrahend must have the same dimensions."); } - return ::mediapipe::OkStatus(); -} - -::mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { - Matrix* subtracted = new Matrix(); - if (subtract_from_input_) { - const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get(); - const Matrix& side_input_matrix = - cc->InputSidePackets().Tag("SUBTRAHEND").Get(); - if (input_matrix.rows() != side_input_matrix.rows() || - input_matrix.cols() != side_input_matrix.cols()) { - return ::mediapipe::InvalidArgumentError( - "Input matrix and the input side matrix must have the same " - "dimension."); - } - *subtracted = input_matrix - side_input_matrix; - } else { - const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get(); - const Matrix& side_input_matrix = - cc->InputSidePackets().Tag("MINUEND").Get(); - if (input_matrix.rows() != side_input_matrix.rows() || - input_matrix.cols() != side_input_matrix.cols()) { - return ::mediapipe::InvalidArgumentError( - "Input matrix and the input side matrix must have the same " - "dimension."); - } - *subtracted = side_input_matrix - input_matrix; - } - cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + kOut(cc).Send(minuend - subtrahend); + return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc index 162d10e0c..92291050d 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc @@ -89,9 +89,8 @@ TEST(MatrixSubtractCalculatorTest, WrongConfig2) { )"); CalculatorRunner runner(node_config); auto status = runner.Run(); - EXPECT_THAT( - status.message(), - testing::HasSubstr("specify exactly one minuend and one subtrahend.")); + EXPECT_THAT(status.message(), testing::HasSubstr("must be connected")); + EXPECT_THAT(status.message(), testing::HasSubstr("not both")); } TEST(MatrixSubtractCalculatorTest, SubtractFromInput) { diff --git a/mediapipe/calculators/core/matrix_to_vector_calculator.cc b/mediapipe/calculators/core/matrix_to_vector_calculator.cc index b02fda77c..90a36053b 100644 --- a/mediapipe/calculators/core/matrix_to_vector_calculator.cc +++ b/mediapipe/calculators/core/matrix_to_vector_calculator.cc @@ -21,6 +21,7 @@ #include "Eigen/Core" #include "absl/memory/memory.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/integral_types.h" @@ -30,6 +31,7 @@ #include "mediapipe/util/time_series_util.h" namespace mediapipe { +namespace api2 { // A calculator that converts a Matrix M to a vector containing all the // entries of M in column-major order. @@ -40,33 +42,27 @@ namespace mediapipe { // input_stream: "input_matrix" // output_stream: "column_major_vector" // } -class MatrixToVectorCalculator : public CalculatorBase { +class MatrixToVectorCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set( - // Input Packet containing a Matrix. - ); - cc->Outputs().Index(0).Set>( - // Output Packet containing a vector, one for each input Packet. - ); - return ::mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr Output> kOut{""}; - ::mediapipe::Status Open(CalculatorContext* cc) override; + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(CalculatorContext* cc) override; // Outputs a packet containing a vector for each input packet. - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(MatrixToVectorCalculator); +MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator); -::mediapipe::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we don't alter timestamps. - cc->SetOffset(mediapipe::TimestampDiff(0)); - return ::mediapipe::OkStatus(); +absl::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(0); + return mediapipe::OkStatus(); } -::mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { - const Matrix& input = cc->Inputs().Index(0).Get(); +absl::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { + const Matrix& input = *kIn(cc); auto output = absl::make_unique>(); // The following lines work to convert the Matrix to a vector because Matrix @@ -76,8 +72,9 @@ REGISTER_CALCULATOR(MatrixToVectorCalculator); Eigen::Map(output->data(), input.rows(), input.cols()); output_as_matrix = input; - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_calculator.cc b/mediapipe/calculators/core/merge_calculator.cc index e85ae0c12..a283842ae 100644 --- a/mediapipe/calculators/core/merge_calculator.cc +++ b/mediapipe/calculators/core/merge_calculator.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // This calculator takes a set of input streams and combines them into a single // output stream. The packets from different streams do not need to contain the @@ -41,51 +43,43 @@ namespace mediapipe { // output_stream: "merged_shot_infos" // } // -class MergeCalculator : public CalculatorBase { +class MergeCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_GT(cc->Inputs().NumEntries(), 0) - << "Needs at least one input stream"; - RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - if (cc->Inputs().NumEntries() == 1) { + static constexpr Input::Multiple kIn{""}; + static constexpr Output kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream"; + if (kIn(cc).Count() == 1) { LOG(WARNING) << "MergeCalculator expects multiple input streams to merge but is " "receiving only one. Make sure the calculator is configured " "correctly or consider removing this calculator to reduce " "unnecessary overhead."; } - - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - cc->Inputs().Index(i).SetAny(); - } - cc->Outputs().Index(0).SetAny(); - - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->SetOffset(TimestampDiff(0)); - - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Output the packet from the first input stream with a packet ready at this // timestamp. - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (!cc->Inputs().Index(i).IsEmpty()) { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value()); - return ::mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (!input.IsEmpty()) { + kOut(cc).Send(input.packet()); + return absl::OkStatus(); } } LOG(WARNING) << "Empty input packets at timestamp " << cc->InputTimestamp().Value(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; -REGISTER_CALCULATOR(MergeCalculator); +MEDIAPIPE_REGISTER_NODE(MergeCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index 8ca25bdd0..a0ce2ae34 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -12,97 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" namespace mediapipe { - -namespace { -constexpr char kSelectTag[] = "SELECT"; -constexpr char kInputTag[] = "INPUT"; -} // namespace +namespace api2 { // A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., -// using the integer value (0, 1, ...) in the packet on the kSelectTag input +// using the integer value (0, 1, ...) in the packet on the "SELECT" input // stream, and passes the packet on the selected input stream to the "OUTPUT" // output stream. -// The kSelectTag input can also be passed in as an input side packet, instead -// of as an input stream. Either of input stream or input side packet must be -// specified but not both. // // Note that this calculator defaults to use MuxInputStreamHandler, which is // required for this calculator. However, it can be overridden to work with // other InputStreamHandlers. Check out the unit tests on for an example usage // with DefaultInputStreamHandler. -class MuxCalculator : public CalculatorBase { +// TODO: why would you need to use DefaultISH? Perhaps b/167596925? +class MuxCalculator : public Node { public: - static ::mediapipe::Status CheckAndInitAllowDisallowInputs( - CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^ - cc->InputSidePackets().HasTag(kSelectTag)); - if (cc->Inputs().HasTag(kSelectTag)) { - cc->Inputs().Tag(kSelectTag).Set(); - } else { - cc->InputSidePackets().Tag(kSelectTag).Set(); + static constexpr Input::SideFallback kSelect{"SELECT"}; + // TODO: this currently sets them all to Any independently, instead + // of the first being Any and the others being SameAs. + static constexpr Input::Multiple kIn{"INPUT"}; + static constexpr Output> kOut{"OUTPUT"}; + + MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut, + StreamHandler("MuxInputStreamHandler")); + + absl::Status Process(CalculatorContext* cc) final { + int select = *kSelect(cc); + RET_CHECK(0 <= select && select < kIn(cc).Count()); + if (!kIn(cc)[select].IsEmpty()) { + kOut(cc).Send(kIn(cc)[select].packet()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); - CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag); - PacketType* data_input0 = &cc->Inputs().Get(data_input_id); - data_input0->SetAny(); - ++data_input_id; - for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) { - cc->Inputs().Get(data_input_id).SetSameAs(data_input0); - } - RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); - - cc->SetInputStreamHandler("MuxInputStreamHandler"); - MediaPipeOptions options; - cc->SetInputStreamHandlerOptions(options); - - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - use_side_packet_select_ = false; - if (cc->InputSidePackets().HasTag(kSelectTag)) { - use_side_packet_select_ = true; - selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get(); - } else { - select_input_ = cc->Inputs().GetId(kSelectTag, 0); - } - data_input_base_ = cc->Inputs().GetId(kInputTag, 0); - num_data_inputs_ = cc->Inputs().NumEntries(kInputTag); - output_ = cc->Outputs().GetId("OUTPUT", 0); - cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - int select = use_side_packet_select_ - ? selected_index_ - : cc->Inputs().Get(select_input_).Get(); - RET_CHECK(0 <= select && select < num_data_inputs_); - if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) { - cc->Outputs().Get(output_).AddPacket( - cc->Inputs().Get(data_input_base_ + select).Value()); - } - return ::mediapipe::OkStatus(); - } - - private: - CollectionItemId select_input_; - CollectionItemId data_input_base_; - int num_data_inputs_ = 0; - CollectionItemId output_; - bool use_side_packet_select_; - int selected_index_; }; -REGISTER_CALCULATOR(MuxCalculator); +MEDIAPIPE_REGISTER_NODE(MuxCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index ac6f7d6ee..46cce74d0 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "mediapipe/calculators/core/split_vector_calculator.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -132,10 +134,9 @@ void RunGraph(const std::string& graph_config_proto, const std::string& input_stream_name, int num_input_packets, std::function input_fn, const std::string& output_stream_name, - std::function<::mediapipe::Status(const Packet&)> output_fn) { + std::function output_fn) { CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie( - graph_config_proto); + mediapipe::ParseTextProtoOrDie(graph_config_proto); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream(output_stream_name, output_fn)); @@ -164,9 +165,9 @@ TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn, @@ -190,9 +191,9 @@ TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket(0)}}, @@ -224,14 +225,80 @@ TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn, kOutputName, output_fn); EXPECT_EQ(output, input_packets); } + +constexpr char kDualInputGraphConfig[] = R"proto( + input_stream: "input_0" + input_stream: "input_1" + input_stream: "input_select" + output_stream: "test_output" + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:input_0" + input_stream: "INPUT:1:input_1" + input_stream: "SELECT:input_select" + output_stream: "OUTPUT:test_output" + } +)proto"; + +TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + kDualInputGraphConfig); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + std::shared_ptr output; + MP_ASSERT_OK( + graph.ObserveOutputStream("test_output", [&output](const Packet& p) { + output = p.Get>(); + return absl::OkStatus(); + })); + + MP_ASSERT_OK(graph.StartRun({})); + + auto one = std::make_shared(1); + auto two = std::make_shared(2); + auto three = std::make_shared(3); + std::weak_ptr one_weak = one; + std::weak_ptr two_weak = two; + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_0", + MakePacket>(std::move(one)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_1", + MakePacket>(std::move(two)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_1", + MakePacket>(std::move(three)).At(Timestamp(1)))); + EXPECT_EQ(one, nullptr); + EXPECT_EQ(two, nullptr); + EXPECT_EQ(three, nullptr); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_select", MakePacket(0).At(Timestamp(0)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(*output, 1); + EXPECT_NE(one_weak.lock(), nullptr); + EXPECT_EQ(two_weak.lock(), nullptr); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_select", MakePacket(1).At(Timestamp(1)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(*output, 3); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc index 26044fc2c..41bddbfa7 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.cc +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -45,17 +45,17 @@ namespace mediapipe { // packet_inner_join_calculator.cc: Don't output unless all inputs are new. class PacketClonerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const int tick_signal_index = cc->Inputs().NumEntries() - 1; for (int i = 0; i < tick_signal_index; ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } cc->Inputs().Index(tick_signal_index).SetAny(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { // Load options. const auto calculator_options = cc->Options(); @@ -71,10 +71,10 @@ class PacketClonerCalculator : public CalculatorBase { cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Store input signals. for (int i = 0; i < tick_signal_index_; ++i) { if (!cc->Inputs().Index(i).Value().IsEmpty()) { @@ -88,7 +88,7 @@ class PacketClonerCalculator : public CalculatorBase { // Return if one of the input is null. for (int i = 0; i < tick_signal_index_; ++i) { if (current_[i].IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } } @@ -103,7 +103,7 @@ class PacketClonerCalculator : public CalculatorBase { } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/packet_cloner_calculator.proto b/mediapipe/calculators/core/packet_cloner_calculator.proto index 7abb16163..e30672fab 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.proto +++ b/mediapipe/calculators/core/packet_cloner_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message PacketClonerCalculatorOptions { extend CalculatorOptions { optional PacketClonerCalculatorOptions ext = 258872085; diff --git a/mediapipe/calculators/core/packet_inner_join_calculator.cc b/mediapipe/calculators/core/packet_inner_join_calculator.cc index 2b93df3cf..6ffffb58b 100644 --- a/mediapipe/calculators/core/packet_inner_join_calculator.cc +++ b/mediapipe/calculators/core/packet_inner_join_calculator.cc @@ -34,10 +34,10 @@ namespace mediapipe { // packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs. class PacketInnerJoinCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int num_streams_; @@ -45,8 +45,7 @@ class PacketInnerJoinCalculator : public CalculatorBase { REGISTER_CALCULATOR(PacketInnerJoinCalculator); -::mediapipe::Status PacketInnerJoinCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketInnerJoinCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries()) << "The number of input and output streams must match."; const int num_streams = cc->Inputs().NumEntries(); @@ -54,25 +53,25 @@ REGISTER_CALCULATOR(PacketInnerJoinCalculator); cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) { +absl::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) { num_streams_ = cc->Inputs().NumEntries(); cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) { +absl::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < num_streams_; ++i) { if (cc->Inputs().Index(i).Value().IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } for (int i = 0; i < num_streams_; ++i) { cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_presence_calculator.cc b/mediapipe/calculators/core/packet_presence_calculator.cc index 468d31718..cb119a76d 100644 --- a/mediapipe/calculators/core/packet_presence_calculator.cc +++ b/mediapipe/calculators/core/packet_presence_calculator.cc @@ -57,26 +57,26 @@ namespace mediapipe { // } class PacketPresenceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("PACKET").SetAny(); cc->Outputs().Tag("PRESENCE").Set(); // Process() function is invoked in response to input stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs() .Tag("PRESENCE") .AddPacket(MakePacket(!cc->Inputs().Tag("PACKET").IsEmpty()) .At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(PacketPresenceCalculator); diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index da9c26c15..32b1c850a 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -47,8 +47,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { } } // namespace -::mediapipe::Status PacketResamplerCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { const auto& resampler_options = cc->Options(); if (cc->InputSidePackets().HasTag("OPTIONS")) { @@ -78,10 +77,10 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { RET_CHECK(cc->InputSidePackets().HasTag("SEED")); cc->InputSidePackets().Tag("SEED").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { const auto resampler_options = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), "OPTIONS"); @@ -156,8 +155,8 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { - return ::mediapipe::Status( - ::mediapipe::StatusCode::kInvalidArgument, + return absl::Status( + absl::StatusCode::kInvalidArgument, "SecureRandom is not available. With \"jitter\" specified, " "PacketResamplerCalculator processing cannot proceed."); } @@ -165,17 +164,17 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { } packet_reservoir_ = std::make_unique(packet_reservoir_random_.get()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream() && cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); video_header_.frame_rate = frame_rate_; if (cc->Inputs().Get(input_data_id_).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } if (jitter_ != 0.0 && random_ != nullptr) { @@ -192,7 +191,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { MP_RETURN_IF_ERROR(ProcessWithoutJitter(cc)); } last_packet_ = cc->Inputs().Get(input_data_id_).Value(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() { @@ -229,7 +228,7 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { ((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat()); } -::mediapipe::Status PacketResamplerCalculator::ProcessWithJitter( +absl::Status PacketResamplerCalculator::ProcessWithJitter( CalculatorContext* cc) { RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); RET_CHECK_NE(jitter_, 0.0); @@ -243,7 +242,7 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_)); UpdateNextOutputTimestampWithJitter(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (frame_time_usec_ < @@ -266,11 +265,21 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { : cc->Inputs().Get(input_data_id_).Value()) .At(next_output_timestamp_)); UpdateNextOutputTimestampWithJitter(); + // From now on every time a packet is emitted the timestamp of the next + // packet becomes known; that timestamp is stored in next_output_timestamp_. + // The only exception to this rule is the packet emitted from Close() which + // can only happen when jitter_with_reflection is enabled but in this case + // next_output_timestamp_min_ is a non-decreasing lower bound of any + // subsequent packet. + const Timestamp timestamp_bound = jitter_with_reflection_ + ? next_output_timestamp_min_ + : next_output_timestamp_; + cc->Outputs().Get(output_data_id_).SetNextTimestampBound(timestamp_bound); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketResamplerCalculator::ProcessWithoutJitter( +absl::Status PacketResamplerCalculator::ProcessWithoutJitter( CalculatorContext* cc) { RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); RET_CHECK_EQ(jitter_, 0.0); @@ -333,12 +342,12 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { .Get(output_data_id_) .SetNextTimestampBound(PeriodIndexToTimestamp(period_count_)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { if (!cc->GraphStatus().ok()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Emit the last packet received if we have at least one packet, but // haven't sent anything for its period. @@ -350,7 +359,7 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { if (!packet_reservoir_->IsEmpty()) { OutputWithinLimits(cc, packet_reservoir_->GetSample()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const { diff --git a/mediapipe/calculators/core/packet_resampler_calculator.h b/mediapipe/calculators/core/packet_resampler_calculator.h index 95ef24cc2..4a1a3ffaa 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.h +++ b/mediapipe/calculators/core/packet_resampler_calculator.h @@ -99,11 +99,11 @@ class PacketReservoir { // packet_downsampler_calculator.cc: skips packets regardless of timestamps. class PacketResamplerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Calculates the first sampled timestamp that incorporates a jittering @@ -113,10 +113,10 @@ class PacketResamplerCalculator : public CalculatorBase { void UpdateNextOutputTimestampWithJitter(); // Logic for Process() when jitter_ != 0.0. - ::mediapipe::Status ProcessWithJitter(CalculatorContext* cc); + absl::Status ProcessWithJitter(CalculatorContext* cc); // Logic for Process() when jitter_ == 0.0. - ::mediapipe::Status ProcessWithoutJitter(CalculatorContext* cc); + absl::Status ProcessWithoutJitter(CalculatorContext* cc); // Given the current count of periods that have passed, this returns // the next valid timestamp of the middle point of the next period: diff --git a/mediapipe/calculators/core/packet_resampler_calculator.proto b/mediapipe/calculators/core/packet_resampler_calculator.proto index f23ce1fdc..d037ee9de 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.proto +++ b/mediapipe/calculators/core/packet_resampler_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message PacketResamplerCalculatorOptions { extend CalculatorOptions { optional PacketResamplerCalculatorOptions ext = 95743844; diff --git a/mediapipe/calculators/core/packet_thinner_calculator.cc b/mediapipe/calculators/core/packet_thinner_calculator.cc index 417fafa31..d3d391b61 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator.cc @@ -90,7 +90,7 @@ class PacketThinnerCalculator : public CalculatorBase { PacketThinnerCalculator() {} ~PacketThinnerCalculator() override {} - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kOptionsTag)) { cc->InputSidePackets().Tag(kOptionsTag).Set(); } @@ -99,21 +99,21 @@ class PacketThinnerCalculator : public CalculatorBase { if (cc->InputSidePackets().HasTag(kPeriodTag)) { cc->InputSidePackets().Tag(kPeriodTag).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override { if (cc->InputTimestamp() < start_time_) { - return ::mediapipe::OkStatus(); // Drop packets before start_time_. + return absl::OkStatus(); // Drop packets before start_time_. } else if (cc->InputTimestamp() >= end_time_) { if (!cc->Outputs().Index(0).IsClosed()) { cc->Outputs() .Index(0) .Close(); // No more Packets will be output after end_time_. } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ? AsyncThinnerProcess(cc) @@ -123,8 +123,8 @@ class PacketThinnerCalculator : public CalculatorBase { private: // Implementation of ASYNC and SYNC versions of thinner algorithm. - ::mediapipe::Status AsyncThinnerProcess(CalculatorContext* cc); - ::mediapipe::Status SyncThinnerProcess(CalculatorContext* cc); + absl::Status AsyncThinnerProcess(CalculatorContext* cc); + absl::Status SyncThinnerProcess(CalculatorContext* cc); // Cached option. PacketThinnerCalculatorOptions::ThinnerType thinner_type_; @@ -153,7 +153,7 @@ namespace { TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } } // namespace -::mediapipe::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { +absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { PacketThinnerCalculatorOptions options = mediapipe::tool::RetrieveOptions( cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -224,10 +224,10 @@ TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { +absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { // Emit any saved packets before quitting. if (!saved_packet_.IsEmpty()) { // Only sync thinner should have saved packets. @@ -239,10 +239,10 @@ TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } cc->Outputs().Index(0).AddPacket(saved_packet_); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketThinnerCalculator::AsyncThinnerProcess( +absl::Status PacketThinnerCalculator::AsyncThinnerProcess( CalculatorContext* cc) { if (cc->InputTimestamp() >= next_valid_timestamp_) { cc->Outputs().Index(0).AddPacket( @@ -251,10 +251,10 @@ TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } // Guaranteed not to emit packets seen during refractory period. cc->Outputs().Index(0).SetNextTimestampBound(next_valid_timestamp_); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketThinnerCalculator::SyncThinnerProcess( +absl::Status PacketThinnerCalculator::SyncThinnerProcess( CalculatorContext* cc) { if (saved_packet_.IsEmpty()) { // If no packet has been saved, store the current packet. @@ -290,7 +290,7 @@ TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } saved_packet_ = cc->Inputs().Index(0).Value(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const { diff --git a/mediapipe/calculators/core/packet_thinner_calculator.proto b/mediapipe/calculators/core/packet_thinner_calculator.proto index 34fd9bc32..6c69f3afd 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.proto +++ b/mediapipe/calculators/core/packet_thinner_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message PacketThinnerCalculatorOptions { extend CalculatorOptions { optional PacketThinnerCalculatorOptions ext = 288533508; diff --git a/mediapipe/calculators/core/pass_through_calculator.cc b/mediapipe/calculators/core/pass_through_calculator.cc index d4e648037..197e1331a 100644 --- a/mediapipe/calculators/core/pass_through_calculator.cc +++ b/mediapipe/calculators/core/pass_through_calculator.cc @@ -28,9 +28,9 @@ namespace mediapipe { // ignored. class PassThroughCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output streams to PassThroughCalculator must use " "matching tags and indexes."); } @@ -46,7 +46,7 @@ class PassThroughCalculator : public CalculatorBase { if (cc->OutputSidePackets().NumEntries() != 0) { if (!cc->InputSidePackets().TagMap()->SameAs( *cc->OutputSidePackets().TagMap())) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output side packets to PassThroughCalculator must use " "matching tags and indexes."); } @@ -56,10 +56,10 @@ class PassThroughCalculator : public CalculatorBase { &cc->InputSidePackets().Get(id)); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (!cc->Inputs().Get(id).Header().IsEmpty()) { @@ -73,10 +73,10 @@ class PassThroughCalculator : public CalculatorBase { } } cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->GetCounter("PassThrough")->Increment(); if (cc->Inputs().NumEntries() == 0) { return tool::StatusStop(); @@ -90,7 +90,7 @@ class PassThroughCalculator : public CalculatorBase { cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(PassThroughCalculator); diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index 8cbf04410..d67e6c061 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -14,12 +14,14 @@ #include +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" namespace mediapipe { +namespace api2 { // PreviousLoopbackCalculator is useful when a graph needs to process an input // together with some previous output. @@ -51,79 +53,77 @@ namespace mediapipe { // input_stream: "PREV_TRACK:prev_output" // output_stream: "TRACK:output" // } -class PreviousLoopbackCalculator : public CalculatorBase { +class PreviousLoopbackCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Get("MAIN", 0).SetAny(); - cc->Inputs().Get("LOOP", 0).SetAny(); - cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0))); - // TODO: an optional PREV_TIMESTAMP output could be added to - // carry the original timestamp of the packet on PREV_LOOP. - cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + static constexpr Input kMain{"MAIN"}; + static constexpr Input kLoop{"LOOP"}; + static constexpr Output> kPrevLoop{"PREV_LOOP"}; + // TODO: an optional PREV_TIMESTAMP output could be added to + // carry the original timestamp of the packet on PREV_LOOP. + + MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop, + StreamHandler("ImmediateInputStreamHandler"), + TimestampChange::Arbitrary()); + + static absl::Status UpdateContract(CalculatorContract* cc) { // Process() function is invoked in response to MAIN/LOOP stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { - main_id_ = cc->Inputs().GetId("MAIN", 0); - loop_id_ = cc->Inputs().GetId("LOOP", 0); - prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0); - cc->Outputs() - .Get(prev_loop_id_) - .SetHeader(cc->Inputs().Get(loop_id_).Header()); - return ::mediapipe::OkStatus(); + absl::Status Open(CalculatorContext* cc) final { + kPrevLoop(cc).SetHeader(kLoop(cc).Header()); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Non-empty packets and empty packets indicating timestamp bound updates // are guaranteed to have timestamps greater than timestamps of previous // packets within the same stream. Calculator tracks and operates on such // packets. - const Packet& main_packet = cc->Inputs().Get(main_id_).Value(); - if (prev_main_ts_ < main_packet.Timestamp()) { + const PacketBase& main_packet = kMain(cc).packet(); + if (prev_main_ts_ < main_packet.timestamp()) { Timestamp loop_timestamp; if (!main_packet.IsEmpty()) { loop_timestamp = prev_non_empty_main_ts_; - prev_non_empty_main_ts_ = main_packet.Timestamp(); + prev_non_empty_main_ts_ = main_packet.timestamp(); } else { // Calculator advances PREV_LOOP timestamp bound in response to empty // MAIN packet, hence not caring about corresponding loop packet. loop_timestamp = Timestamp::Unset(); } - main_packet_specs_.push_back({main_packet.Timestamp(), loop_timestamp}); - prev_main_ts_ = main_packet.Timestamp(); + main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp}); + prev_main_ts_ = main_packet.timestamp(); } - const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value(); - if (prev_loop_ts_ < loop_packet.Timestamp()) { + const PacketBase& loop_packet = kLoop(cc).packet(); + if (prev_loop_ts_ < loop_packet.timestamp()) { loop_packets_.push_back(loop_packet); - prev_loop_ts_ = loop_packet.Timestamp(); + prev_loop_ts_ = loop_packet.timestamp(); } - auto& prev_loop = cc->Outputs().Get(prev_loop_id_); while (!main_packet_specs_.empty() && !loop_packets_.empty()) { // The earliest MAIN packet. - const MainPacketSpec& main_spec = main_packet_specs_.front(); + MainPacketSpec main_spec = main_packet_specs_.front(); // The earliest LOOP packet. - const Packet& loop_candidate = loop_packets_.front(); + const PacketBase& loop_candidate = loop_packets_.front(); // Match LOOP and MAIN packets. - if (main_spec.loop_timestamp < loop_candidate.Timestamp()) { + if (main_spec.loop_timestamp < loop_candidate.timestamp()) { // No LOOP packet can match the MAIN packet under review. - prev_loop.SetNextTimestampBound(main_spec.timestamp + 1); + kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); main_packet_specs_.pop_front(); - } else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) { + } else if (main_spec.loop_timestamp > loop_candidate.timestamp()) { // No MAIN packet can match the LOOP packet under review. loop_packets_.pop_front(); } else { // Exact match found. if (loop_candidate.IsEmpty()) { // However, LOOP packet is empty. - prev_loop.SetNextTimestampBound(main_spec.timestamp + 1); + kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); } else { - prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp)); + kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); } loop_packets_.pop_front(); main_packet_specs_.pop_front(); @@ -135,11 +135,11 @@ class PreviousLoopbackCalculator : public CalculatorBase { // b) Empty MAIN packet has been received with Timestamp::Max() indicating // MAIN is done. if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) { - prev_loop.Close(); + kPrevLoop(cc).Close(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -150,10 +150,6 @@ class PreviousLoopbackCalculator : public CalculatorBase { Timestamp loop_timestamp; }; - CollectionItemId main_id_; - CollectionItemId loop_id_; - CollectionItemId prev_loop_id_; - // Contains specs for MAIN packets which only can be: // - non-empty packets // - empty packets indicating timestamp bound updates @@ -169,12 +165,13 @@ class PreviousLoopbackCalculator : public CalculatorBase { // - empty packets indicating timestamp bound updates // // Sorted according to packet timestamps. - std::deque loop_packets_; + std::deque loop_packets_; // Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to // allow addition of the very first empty packet (which doesn't indicate // timestamp bound change necessarily). Timestamp prev_loop_ts_ = Timestamp::Unset(); }; -REGISTER_CALCULATOR(PreviousLoopbackCalculator); +MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index ef469b43a..54959edae 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -136,27 +136,27 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) { // A Calculator that outputs a summary packet in CalculatorBase::Close(). class PacketOnCloseCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { sum_ += cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket( MakePacket(sum_).At(Timestamp::Max())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -700,19 +700,19 @@ TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest, // Similar to GateCalculator, but it doesn't propagate timestamp bound updates. class DroppingGateCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Inputs().Tag("DISALLOW").Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (!cc->Inputs().Index(0).IsEmpty() && !cc->Inputs().Tag("DISALLOW").Get()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DroppingGateCalculator); diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc index 76e635e5b..e95509298 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -43,32 +43,32 @@ namespace mediapipe { class QuantizeFloatVectorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("FLOAT_VECTOR").Set>(); cc->Outputs().Tag("ENCODED").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { const auto options = cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>(); if (!options.has_max_quantized_value() || !options.has_min_quantized_value()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both max_quantized_value and min_quantized_value must be provided " "in QuantizeFloatVectorCalculatorOptions."); } max_quantized_value_ = options.max_quantized_value(); min_quantized_value_ = options.min_quantized_value(); if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "max_quantized_value must be greater than min_quantized_value."); } range_ = max_quantized_value_ - min_quantized_value_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const std::vector& float_vector = cc->Inputs().Tag("FLOAT_VECTOR").Value().Get>(); int feature_size = float_vector.size(); @@ -88,7 +88,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { } cc->Outputs().Tag("ENCODED").AddPacket( MakePacket(encoded_features).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.proto b/mediapipe/calculators/core/quantize_float_vector_calculator.proto index 3f6cfda21..0ccc3c0d9 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.proto +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message QuantizeFloatVectorCalculatorOptions { extend CalculatorOptions { optional QuantizeFloatVectorCalculatorOptions ext = 259848061; diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc new file mode 100644 index 000000000..277f83fe2 --- /dev/null +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -0,0 +1,199 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/header_util.h" + +namespace mediapipe { + +// RealTimeFlowLimiterCalculator is used to limit the number of pipelined +// processing operations in a section of the graph. +// +// Typical topology: +// +// in ->-[FLC]-[foo]-...-[bar]-+->- out +// ^_____________________| +// FINISHED +// +// By connecting the output of the graph section to this calculator's FINISHED +// input with a backwards edge, this allows FLC to keep track of how many +// timestamps are currently being processed. +// +// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side +// packet. +// +// As long as the number of timestamps being processed ("in flight") is below +// the limit, FLC allows input to pass through. When the limit is reached, +// FLC starts dropping input packets, keeping only the most recent. When the +// processing count decreases again, as signaled by the receipt of a packet on +// FINISHED, FLC allows packets to flow again, releasing the most recently +// queued packet, if any. +// +// If there are multiple input streams, packet dropping is synchronized. +// +// IMPORTANT: for each timestamp where FLC forwards a packet (or a set of +// packets, if using multiple data streams), a packet must eventually arrive on +// the FINISHED stream. Dropping packets in the section between FLC and +// FINISHED will make the in-flight count incorrect. +// +// TODO: Remove this comment when graph-level ISH has been removed. +// NOTE: this calculator should always use the ImmediateInputStreamHandler and +// uses it by default. However, if the graph specifies a graph-level +// InputStreamHandler, to override that setting, the InputStreamHandler must +// be explicitly specified as shown below. +// +// Example config: +// node { +// calculator: "RealTimeFlowLimiterCalculator" +// input_stream: "raw_frames" +// input_stream: "FINISHED:finished" +// input_stream_info: { +// tag_index: 'FINISHED' +// back_edge: true +// } +// input_stream_handler { +// input_stream_handler: 'ImmediateInputStreamHandler' +// } +// output_stream: "gated_frames" +// } +class RealTimeFlowLimiterCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + int num_data_streams = cc->Inputs().NumEntries(""); + RET_CHECK_GE(num_data_streams, 1); + RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) + << "Output streams must correspond input streams except for the " + "finish indicator input stream."; + for (int i = 0; i < num_data_streams; ++i) { + cc->Inputs().Get("", i).SetAny(); + cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); + } + cc->Inputs().Get("FINISHED", 0).SetAny(); + if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set(); + } + if (cc->Outputs().HasTag("ALLOW")) { + cc->Outputs().Tag("ALLOW").Set(); + } + + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) final { + finished_id_ = cc->Inputs().GetId("FINISHED", 0); + max_in_flight_ = 1; + if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get(); + } + RET_CHECK_GE(max_in_flight_, 1); + num_in_flight_ = 0; + + allowed_id_ = cc->Outputs().GetId("ALLOW", 0); + allow_ctr_ts_ = Timestamp(0); + + num_data_streams_ = cc->Inputs().NumEntries(""); + data_stream_bound_ts_.resize(num_data_streams_); + RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); + return absl::OkStatus(); + } + + bool Allow() { return num_in_flight_ < max_in_flight_; } + + absl::Status Process(CalculatorContext* cc) final { + bool old_allow = Allow(); + Timestamp lowest_incomplete_ts = Timestamp::Done(); + + // Process FINISHED stream. + if (!cc->Inputs().Get(finished_id_).Value().IsEmpty()) { + RET_CHECK_GT(num_in_flight_, 0) + << "Received a FINISHED packet, but we had none in flight."; + --num_in_flight_; + } + + // Process data streams. + for (int i = 0; i < num_data_streams_; ++i) { + auto& stream = cc->Inputs().Get("", i); + auto& out = cc->Outputs().Get("", i); + Packet& packet = stream.Value(); + auto ts = packet.Timestamp(); + if (ts.IsRangeValue() && data_stream_bound_ts_[i] <= ts) { + data_stream_bound_ts_[i] = ts + 1; + // Note: it's ok to update the output bound here, before sending the + // packet, because updates are batched during the Process function. + out.SetNextTimestampBound(data_stream_bound_ts_[i]); + } + lowest_incomplete_ts = + std::min(lowest_incomplete_ts, data_stream_bound_ts_[i]); + + if (packet.IsEmpty()) { + // If the input stream is closed, close the corresponding output. + if (stream.IsDone() && !out.IsClosed()) { + out.Close(); + } + // TODO: if the packet is empty, the ts is unset, and we + // cannot read the timestamp bound, even though we'd like to propagate + // it. + } else if (mediapipe::ContainsKey(pending_ts_, ts)) { + // If we have already sent this timestamp (on another stream), send it + // on this stream too. + out.AddPacket(std::move(packet)); + } else if (Allow() && (ts > last_dropped_ts_)) { + // If the in-flight is under the limit, and if we have not already + // dropped this or a later timestamp on another stream, then send + // the packet and add an in-flight timestamp. + out.AddPacket(std::move(packet)); + pending_ts_.insert(ts); + ++num_in_flight_; + } else { + // Otherwise, we'll drop the packet. + last_dropped_ts_ = std::max(last_dropped_ts_, ts); + } + } + + // Remove old pending_ts_ entries. + auto it = std::lower_bound(pending_ts_.begin(), pending_ts_.end(), + lowest_incomplete_ts); + pending_ts_.erase(pending_ts_.begin(), it); + + // Update ALLOW signal. + if ((old_allow != Allow()) && allowed_id_.IsValid()) { + cc->Outputs() + .Get(allowed_id_) + .AddPacket(MakePacket(Allow()).At(++allow_ctr_ts_)); + } + return absl::OkStatus(); + } + + private: + std::set pending_ts_; + Timestamp last_dropped_ts_; + int num_data_streams_; + int num_in_flight_; + int max_in_flight_; + CollectionItemId finished_id_; + CollectionItemId allowed_id_; + Timestamp allow_ctr_ts_; + std::vector data_stream_bound_ts_; +}; +REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc new file mode 100644 index 000000000..73c50e56d --- /dev/null +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc @@ -0,0 +1,495 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" + +namespace mediapipe { + +namespace { +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_.fetch_add(amount); } + + private: + std::atomic supply_; +}; + +// Returns the timestamp values for a vector of Packets. +std::vector TimestampValues(const std::vector& packets) { + std::vector result; + for (const Packet& packet : packets) { + result.push_back(packet.Timestamp().Value()); + } + return result; +} + +// Returns the packet values for a vector of Packets. +template +std::vector PacketValues(const std::vector& packets) { + std::vector result; + for (const Packet& packet : packets) { + result.push_back(packet.Get()); + } + return result; +} + +constexpr int kNumImageFrames = 5; +constexpr int kNumFinished = 3; +CalculatorGraphConfig::Node GetDefaultNode() { + return ParseTextProtoOrDie(R"( + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "raw_frames" + input_stream: "FINISHED:finished" + input_stream_info: { tag_index: "FINISHED" back_edge: true } + output_stream: "gated_frames" + )"); +} + +// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs just +// one packet when MAX_IN_FLIGHT is 1. +TEST(RealTimeFlowLimiterCalculator, OneOutputTest) { + // Setup the calculator runner and add only ImageFrame packets. + CalculatorRunner runner(GetDefaultNode()); + for (int i = 0; i < kNumImageFrames; ++i) { + Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs()->Index(0).packets.push_back( + MakePacket().At(timestamp)); + } + + // Run the calculator. + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& frame_output_packets = + runner.Outputs().Index(0).packets; + + EXPECT_EQ(frame_output_packets.size(), 1); +} + +// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all +// input streams to have at least one packet available before publishing. +TEST(RealTimeFlowLimiterCalculator, BasicTest) { + // Setup the calculator runner and add both ImageFrame and finish packets. + CalculatorRunner runner(GetDefaultNode()); + for (int i = 0; i < kNumImageFrames; ++i) { + Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs()->Index(0).packets.push_back( + MakePacket().At(timestamp)); + } + for (int i = 0; i < kNumFinished; ++i) { + Timestamp timestamp = + Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs() + ->Tag("FINISHED") + .packets.push_back(MakePacket(true).At(timestamp)); + } + + // Run the calculator. + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& frame_output_packets = + runner.Outputs().Index(0).packets; + + // Only outputs packets if both input streams are available. + int expected_num_packets = std::min(kNumImageFrames, kNumFinished + 1); + EXPECT_EQ(frame_output_packets.size(), expected_num_packets); +} + +// A Calculator::Process callback function. +typedef std::function + ProcessFunction; + +// A testing callback function that passes through all packets. +absl::Status PassthroughFunction(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + for (int i = 0; i < inputs.NumEntries(); ++i) { + if (!inputs.Index(i).Value().IsEmpty()) { + outputs->Index(i).AddPacket(inputs.Index(i).Value()); + } + } + return absl::OkStatus(); +} + +// A Calculator that runs a testing callback function in Close. +class CloseCallbackCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + for (CollectionItemId id = cc->Inputs().BeginId(); + id < cc->Inputs().EndId(); ++id) { + cc->Inputs().Get(id).SetAny(); + } + for (CollectionItemId id = cc->Outputs().BeginId(); + id < cc->Outputs().EndId(); ++id) { + cc->Outputs().Get(id).SetAny(); + } + cc->InputSidePackets().Index(0).Set>(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + return PassthroughFunction(cc->Inputs(), &(cc->Outputs())); + } + + absl::Status Close(CalculatorContext* cc) override { + const auto& callback = + cc->InputSidePackets().Index(0).Get>(); + return callback(); + } +}; +REGISTER_CALCULATOR(CloseCallbackCalculator); + +// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic +// graph. +// TODO: clean up these tests. +class RealTimeFlowLimiterCalculatorTest : public testing::Test { + public: + RealTimeFlowLimiterCalculatorTest() + : enter_semaphore_(0), exit_semaphore_(0) {} + + void SetUp() override { + graph_config_ = InflightGraphConfig(); + tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_); + tool::AddVectorSink("out_2", &graph_config_, &out_2_packets_); + } + + void InitializeGraph(int max_in_flight) { + ProcessFunction semaphore_0_func = [&](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + enter_semaphore_.Release(1); + return PassthroughFunction(inputs, outputs); + }; + ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + exit_semaphore_.Acquire(1); + return PassthroughFunction(inputs, outputs); + }; + std::function close_func = [this]() { + close_count_++; + return absl::OkStatus(); + }; + MP_ASSERT_OK(graph_.Initialize( + graph_config_, { + {"max_in_flight", MakePacket(max_in_flight)}, + {"callback_0", Adopt(new auto(semaphore_0_func))}, + {"callback_1", Adopt(new auto(semaphore_1_func))}, + {"callback_2", Adopt(new auto(close_func))}, + })); + } + + // Adds a packet to a graph input stream. + void AddPacket(const std::string& input_name, int value) { + MP_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(value).At(Timestamp(value)))); + } + + // A calculator graph starting with an RealTimeFlowLimiterCalculator and + // ending with a InFlightFinishCalculator. + // Back-edge "finished" limits processing to one frame in-flight. + // The two LambdaCalculators are used to keep certain packet sets in flight. + CalculatorGraphConfig InflightGraphConfig() { + return ParseTextProtoOrDie(R"( + input_stream: 'in_1' + input_stream: 'in_2' + node { + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' + } + node { + calculator: 'LambdaCalculator' + input_side_packet: 'callback_0' + input_stream: 'in_1_sampled' + input_stream: 'in_2_sampled' + output_stream: 'queue_1' + output_stream: 'queue_2' + } + node { + calculator: 'LambdaCalculator' + input_side_packet: 'callback_1' + input_stream: 'queue_1' + input_stream: 'queue_2' + output_stream: 'close_1' + output_stream: 'close_2' + } + node { + calculator: 'CloseCallbackCalculator' + input_side_packet: 'callback_2' + input_stream: 'close_1' + input_stream: 'close_2' + output_stream: 'out_1' + output_stream: 'out_2' + } + )"); + } + + protected: + CalculatorGraphConfig graph_config_; + CalculatorGraph graph_; + AtomicSemaphore enter_semaphore_; + AtomicSemaphore exit_semaphore_; + std::vector out_1_packets_; + std::vector out_2_packets_; + int close_count_ = 0; +}; + +// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic +// graph. This test shows that: +// +// (1) Timestamps are passed through unaltered. +// (2) All output streams including the back_edge stream are closed when +// the first input stream is closed. +// +TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) { + InitializeGraph(1); + MP_ASSERT_OK(graph_.StartRun({})); + + auto send_packet = [this](const std::string& input_name, int64 n) { + MP_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + + for (int i = 0; i < 10; i++) { + send_packet("in_1", i * 10); + // This next input should be dropped. + send_packet("in_1", i * 10 + 5); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + send_packet("in_2", i * 10); + exit_semaphore_.Release(1); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + } + MP_EXPECT_OK(graph_.CloseInputStream("in_1")); + MP_EXPECT_OK(graph_.CloseInputStream("in_2")); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + + // All output streams are closed and all output packets are delivered, + // with stream "in_1" and stream "in_2" closed. + EXPECT_EQ(10, out_1_packets_.size()); + EXPECT_EQ(10, out_2_packets_.size()); + + // Timestamps have not been messed with. + EXPECT_EQ(PacketValues(out_1_packets_), + TimestampValues(out_1_packets_)); + EXPECT_EQ(PacketValues(out_2_packets_), + TimestampValues(out_2_packets_)); + + // Extra inputs on in_1 have been dropped + EXPECT_EQ(TimestampValues(out_1_packets_), + (std::vector{0, 10, 20, 30, 40, 50, 60, 70, 80, 90})); + EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); + + // The closing of the stream has been propagated. + EXPECT_EQ(1, close_count_); +} + +// A test demonstrating that all output streams are closed when all +// input streams are closed after the last input packet has been processed. +TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) { + InitializeGraph(1); + MP_ASSERT_OK(graph_.StartRun({})); + + exit_semaphore_.Release(10); + for (int i = 0; i < 10; i++) { + AddPacket("in_1", i); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + AddPacket("in_2", i); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + } + MP_EXPECT_OK(graph_.CloseAllInputStreams()); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + + EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); + EXPECT_EQ(TimestampValues(out_1_packets_), + (std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); + EXPECT_EQ(1, close_count_); +} + +TEST(RealTimeFlowLimiterCalculator, TwoStreams) { + std::vector a_passed; + std::vector b_passed; + CalculatorGraphConfig graph_config_ = + ParseTextProtoOrDie(R"( + input_stream: 'in_a' + input_stream: 'in_b' + input_stream: 'finished' + node { + name: 'input_dropper' + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in_a' + input_stream: 'in_b' + input_stream: 'FINISHED:finished' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_a_sampled' + output_stream: 'in_b_sampled' + output_stream: 'ALLOW:allow' + } + )"); + std::string allow_cb_name; + tool::AddVectorSink("in_a_sampled", &graph_config_, &a_passed); + tool::AddVectorSink("in_b_sampled", &graph_config_, &b_passed); + tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); + + bool allow = true; + auto allow_cb = [&allow](const Packet& packet) { + allow = packet.Get(); + }; + + CalculatorGraph graph_; + MP_EXPECT_OK(graph_.Initialize( + graph_config_, + { + {"max_in_flight", MakePacket(1)}, + {allow_cb_name, + MakePacket>(allow_cb)}, + })); + + MP_EXPECT_OK(graph_.StartRun({})); + + auto send_packet = [&graph_](const std::string& input_name, int n) { + MP_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + send_packet("in_a", 1); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(allow, false); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{})); + + send_packet("in_a", 2); + send_packet("in_b", 1); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, false); + + send_packet("finished", 1); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, true); + + send_packet("in_b", 2); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, true); + + send_packet("in_b", 3); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("in_b", 4); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("in_a", 3); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("finished", 3); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, true); + + MP_EXPECT_OK(graph_.CloseAllInputStreams()); + MP_EXPECT_OK(graph_.WaitUntilDone()); +} + +TEST(RealTimeFlowLimiterCalculator, CanConsume) { + std::vector in_sampled_packets_; + CalculatorGraphConfig graph_config_ = + ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'finished' + node { + name: 'input_dropper' + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in' + input_stream: 'FINISHED:finished' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_sampled' + output_stream: 'ALLOW:allow' + } + )"); + std::string allow_cb_name; + tool::AddVectorSink("in_sampled", &graph_config_, &in_sampled_packets_); + tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); + + bool allow = true; + auto allow_cb = [&allow](const Packet& packet) { + allow = packet.Get(); + }; + + CalculatorGraph graph_; + MP_EXPECT_OK(graph_.Initialize( + graph_config_, + { + {"max_in_flight", MakePacket(1)}, + {allow_cb_name, + MakePacket>(allow_cb)}, + })); + + MP_EXPECT_OK(graph_.StartRun({})); + + auto send_packet = [&graph_](const std::string& input_name, int n) { + MP_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + send_packet("in", 1); + MP_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(allow, false); + EXPECT_EQ(TimestampValues(in_sampled_packets_), (std::vector{1})); + + MP_EXPECT_OK(in_sampled_packets_[0].Consume()); + + MP_EXPECT_OK(graph_.CloseAllInputStreams()); + MP_EXPECT_OK(graph_.WaitUntilDone()); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/round_robin_demux_calculator.cc b/mediapipe/calculators/core/round_robin_demux_calculator.cc index c84e08884..8c93bba71 100644 --- a/mediapipe/calculators/core/round_robin_demux_calculator.cc +++ b/mediapipe/calculators/core/round_robin_demux_calculator.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" namespace mediapipe { +namespace api2 { // Forwards the input packet to one of the n output streams "OUTPUT:0", // "OUTPUT:1", ..., in round robin fashion. The index of the selected output @@ -71,50 +73,34 @@ namespace mediapipe { // output with MakePairCalculator, MakeVectorCalculator, or a similar variant to // use it with MuxCalculator and later unpack, or can create new variants of // MuxCalculator/MuxInputStreamHandler. -class RoundRobinDemuxCalculator : public CalculatorBase { +class RoundRobinDemuxCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_EQ(cc->Inputs().NumEntries(), 1); - cc->Inputs().Index(0).SetAny(); - if (cc->Outputs().HasTag("SELECT")) { - cc->Outputs().Tag("SELECT").Set(); - } - for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); - id < cc->Outputs().EndId("OUTPUT"); ++id) { - cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Index(0)); - } - return ::mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr Output::Optional kSelect{"SELECT"}; + static constexpr Output>::Multiple kOut{"OUTPUT"}; - ::mediapipe::Status Open(CalculatorContext* cc) override { - select_output_ = cc->Outputs().GetId("SELECT", 0); + MEDIAPIPE_NODE_CONTRACT(kIn, kSelect, kOut); + + absl::Status Open(CalculatorContext* cc) override { output_data_stream_index_ = 0; - output_data_stream_base_ = cc->Outputs().GetId("OUTPUT", 0); - num_output_data_streams_ = cc->Outputs().NumEntries("OUTPUT"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs() - .Get(output_data_stream_base_ + output_data_stream_index_) - .AddPacket(cc->Inputs().Index(0).Value()); - if (select_output_.IsValid()) { - cc->Outputs() - .Get(select_output_) - .Add(new int(output_data_stream_index_), cc->InputTimestamp()); + absl::Status Process(CalculatorContext* cc) override { + kOut(cc)[output_data_stream_index_].Send(kIn(cc).packet()); + if (kSelect(cc).IsConnected()) { + kSelect(cc).Send(output_data_stream_index_); } output_data_stream_index_ = - (output_data_stream_index_ + 1) % num_output_data_streams_; - return ::mediapipe::OkStatus(); + (output_data_stream_index_ + 1) % kOut(cc).Count(); + return absl::OkStatus(); } private: - CollectionItemId select_output_; - CollectionItemId output_data_stream_base_; - int num_output_data_streams_; int output_data_stream_index_; }; -REGISTER_CALCULATOR(RoundRobinDemuxCalculator); +MEDIAPIPE_REGISTER_NODE(RoundRobinDemuxCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index f2ab11025..66dbdef2e 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -15,9 +15,11 @@ #include #include "mediapipe/calculators/core/sequence_shift_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" namespace mediapipe { +namespace api2 { // A Calculator that shifts the timestamps of packets along a stream. Packets on // the input stream are output with a timestamp of the packet given by packet @@ -28,17 +30,17 @@ namespace mediapipe { // of -1, the first packet on the stream will be dropped, the second will be // output with the timestamp of the first, the third with the timestamp of the // second, and so on. -class SequenceShiftCalculator : public CalculatorBase { +class SequenceShiftCalculator : public Node { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return ::mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr SideInput::Optional kOffset{"PACKET_OFFSET"}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOffset, kOut, TimestampChange::Arbitrary()); // Reads from options to set cache_size_ and packet_offset_. - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // A positive offset means we want a packet to be output with the timestamp of @@ -53,7 +55,7 @@ class SequenceShiftCalculator : public CalculatorBase { // Storage for packets waiting to be output when packet_offset > 0. When cache // is full, oldest packet is output with current timestamp. - std::deque packet_cache_; + std::deque packet_cache_; // Storage for previous timestamps used when packet_offset < 0. When cache is // full, oldest timestamp is used for current packet. @@ -65,50 +67,49 @@ class SequenceShiftCalculator : public CalculatorBase { // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; }; -REGISTER_CALCULATOR(SequenceShiftCalculator); +MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); -::mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { - packet_offset_ = - cc->Options().packet_offset(); +absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { + packet_offset_ = kOffset(cc).GetOr( + cc->Options().packet_offset()); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { cc->Outputs().Index(0).SetOffset(0); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { +absl::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { if (packet_offset_ > 0) { ProcessPositiveOffset(cc); } else if (packet_offset_ < 0) { ProcessNegativeOffset(cc); } else { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + kOut(cc).Send(kIn(cc).packet()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { if (packet_cache_.size() >= cache_size_) { // Ready to output oldest packet with current timestamp. - cc->Outputs().Index(0).AddPacket( - packet_cache_.front().At(cc->InputTimestamp())); + kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); } // Store current packet for later output. - packet_cache_.push_back(cc->Inputs().Index(0).Value()); + packet_cache_.push_back(kIn(cc).packet()); } void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) { if (timestamp_cache_.size() >= cache_size_) { // Ready to output current packet with oldest timestamp. - cc->Outputs().Index(0).AddPacket( - cc->Inputs().Index(0).Value().At(timestamp_cache_.front())); + kOut(cc).Send(kIn(cc).packet().At(timestamp_cache_.front())); timestamp_cache_.pop_front(); } // Store current timestamp for use by a future packet. timestamp_cache_.push_back(cc->InputTimestamp()); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto index 15b111d71..cdcd284ca 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.proto +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + message SequenceShiftCalculatorOptions { extend CalculatorOptions { optional SequenceShiftCalculatorOptions ext = 107633927; diff --git a/mediapipe/calculators/core/sequence_shift_calculator_test.cc b/mediapipe/calculators/core/sequence_shift_calculator_test.cc index 1fee61daa..23ad57225 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator_test.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator_test.cc @@ -99,6 +99,35 @@ TEST(SequenceShiftCalculatorTest, NegativeShift) { } } +// Tests using a side packet to specify the offset. Shifting by -2, i.e., +// output input[i] with timestamp[i - 2]. The first two packets should be +// dropped. +TEST(SequenceShiftCalculatorTest, SidePacketOffset) { + CalculatorGraphConfig::Node node; + node.set_calculator("SequenceShiftCalculator"); + node.add_input_stream("input"); + node.add_output_stream("output"); + node.add_input_side_packet("PACKET_OFFSET:packet_offset"); + + CalculatorRunner runner(node); + AddPackets(&runner); + runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2)); + MP_ASSERT_OK(runner.Run()); + const std::vector& input_packets = + runner.MutableInputs()->Index(0).packets; + const std::vector& output_packets = runner.Outputs().Index(0).packets; + ASSERT_EQ(10, input_packets.size()); + // Input packet[i] should be output with the timestamp of input packet[i - 2]. + // The first two packets are dropped. This means timestamps match between + // input and output packets, but the data in the output packets come from + // input_packets[i + 2]. + ASSERT_EQ(8, output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp()); + EXPECT_EQ(input_packets[i + 2].Get(), output_packets[i].Get()); + } +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index 47c3f624b..ed89889df 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -89,10 +89,10 @@ class SidePacketToStreamCalculator : public CalculatorBase { SidePacketToStreamCalculator() = default; ~SidePacketToStreamCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: bool is_tick_processing_ = false; @@ -100,8 +100,7 @@ class SidePacketToStreamCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(SidePacketToStreamCalculator); -::mediapipe::Status SidePacketToStreamCalculator::GetContract( - CalculatorContract* cc) { +absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { const auto& tags = cc->Outputs().GetTags(); RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " @@ -138,10 +137,10 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); cc->Inputs().Tag(kTagTick).SetAny(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { output_tag_ = GetOutputTag(*cc); if (cc->Inputs().HasTag(kTagTick)) { is_tick_processing_ = true; @@ -149,11 +148,10 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); // timestamp bound update. cc->SetOffset(TimestampDiff(0)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SidePacketToStreamCalculator::Process( - CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { if (is_tick_processing_) { // TICK input is guaranteed to be non-empty, as it's the only input stream // for this calculator. @@ -164,13 +162,13 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - return ::mediapipe::tool::StatusStop(); + return mediapipe::tool::StatusStop(); } -::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { if (!cc->Outputs().HasTag(kTagAtTick) && !cc->Outputs().HasTag(kTagAtTimestamp)) { const auto& timestamp = kTimestampMap->at(output_tag_); @@ -188,7 +186,7 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); .AddPacket(cc->InputSidePackets().Index(i).At(Timestamp(timestamp))); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc index e7195e03b..b6b3d4e5c 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc @@ -20,6 +20,7 @@ #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -30,6 +31,8 @@ namespace mediapipe { namespace { +using testing::HasSubstr; + TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( @@ -46,9 +49,10 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2( - absl::StrContains, status.message(), - "Either both of TICK and AT_TICK should be used or none of them."); + EXPECT_THAT( + status.message(), + HasSubstr( + "Either both of TICK and AT_TICK should be used or none of them.")); } TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { @@ -67,9 +71,9 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2( - absl::StrContains, status.message(), - "Either both TIMESTAMP and AT_TIMESTAMP should be used or none of them."); + EXPECT_THAT(status.message(), + HasSubstr("Either both TIMESTAMP and AT_TIMESTAMP should be used " + "or none of them.")); } TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { @@ -88,10 +92,11 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2(absl::StrContains, status.message(), - "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s)."); + EXPECT_THAT( + status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " + "AT_TIMESTAMP tags is allowed and required to specify output " + "stream(s).")); } TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { @@ -112,10 +117,11 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2(absl::StrContains, status.message(), - "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s)."); + EXPECT_THAT( + status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " + "AT_TIMESTAMP tags is allowed and required to specify output " + "stream(s).")); } TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { @@ -134,9 +140,10 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2( - absl::StrContains, status.message(), - "Same number of input side packets and output streams is required."); + EXPECT_THAT( + status.message(), + HasSubstr( + "Same number of input side packets and output streams is required.")); } TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { @@ -155,9 +162,10 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_PRED2( - absl::StrContains, status.message(), - "Same number of input side packets and output streams is required."); + EXPECT_THAT( + status.message(), + HasSubstr( + "Same number of input side packets and output streams is required.")); } void DoTestNonAtTickOutputTag(absl::string_view tag, @@ -181,7 +189,7 @@ void DoTestNonAtTickOutputTag(absl::string_view tag, MP_ASSERT_OK(graph.ObserveOutputStream( "packet", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK( graph.StartRun({{"side_packet", MakePacket(expected_value)}})); diff --git a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc index 85bac0e9b..d57cebe9c 100644 --- a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc @@ -35,7 +35,7 @@ namespace mediapipe { // NormalizedLandmarkList. class SplitNormalizedLandmarkListCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); @@ -55,7 +55,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { range_0.begin() < range_1.end()) || (range_1.begin() >= range_0.begin() && range_1.begin() < range_0.end())) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Ranges must be non-overlapping when using combine_outputs " "option."); } @@ -63,7 +63,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } else { if (cc->Outputs().NumEntries() != options.ranges_size()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The number of output streams should match the number of ranges " "specified in the CalculatorOptions."); } @@ -72,13 +72,13 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 || options.ranges(i).begin() >= options.ranges(i).end()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Indices should be non-negative and begin index should be less " "than the end index."); } if (options.element_only()) { if (options.ranges(i).end() - options.ranges(i).begin() != 1) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } cc->Outputs().Index(i).Set(); @@ -88,10 +88,10 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -106,10 +106,10 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { total_elements_ += range.end() - range.begin(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const NormalizedLandmarkList& input = cc->Inputs().Index(0).Get(); RET_CHECK_GE(input.landmark_size(), max_range_end_) @@ -148,7 +148,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc b/mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc index ce02dcd8a..202287208 100644 --- a/mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc +++ b/mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc @@ -121,7 +121,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, SmokeTest) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -170,7 +170,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, SmokeTest) { TEST_F(SplitNormalizedLandmarkListCalculatorTest, InvalidRangeTest) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -195,7 +195,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, InvalidOutputStreamCountTest) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -222,7 +222,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, InvalidCombineOutputsMultipleOutputsTest) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -251,7 +251,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, InvalidOverlappingRangesTest) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -280,7 +280,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, SmokeTestElementOnly) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -333,7 +333,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, SmokeTestCombiningOutputs) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { @@ -376,7 +376,7 @@ TEST_F(SplitNormalizedLandmarkListCalculatorTest, ElementOnlyDisablesVectorOutputs) { // Prepare a graph to use the SplitNormalizedLandmarkListCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "landmarks_in" node { diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index 100507c99..c8f1177d5 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -16,15 +16,17 @@ #include +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" #include "tensorflow/lite/interpreter.h" #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) namespace mediapipe { @@ -46,15 +48,18 @@ typedef SplitVectorCalculator SplitTfLiteTensorVectorCalculator; REGISTER_CALCULATOR(SplitTfLiteTensorVectorCalculator); -typedef SplitVectorCalculator<::mediapipe::NormalizedLandmark, false> +typedef SplitVectorCalculator SplitTensorVectorCalculator; +REGISTER_CALCULATOR(SplitTensorVectorCalculator); + +typedef SplitVectorCalculator SplitLandmarkVectorCalculator; REGISTER_CALCULATOR(SplitLandmarkVectorCalculator); -typedef SplitVectorCalculator<::mediapipe::NormalizedLandmarkList, false> +typedef SplitVectorCalculator SplitNormalizedLandmarkListVectorCalculator; REGISTER_CALCULATOR(SplitNormalizedLandmarkListVectorCalculator); -typedef SplitVectorCalculator<::mediapipe::NormalizedRect, false> +typedef SplitVectorCalculator SplitNormalizedRectVectorCalculator; REGISTER_CALCULATOR(SplitNormalizedRectVectorCalculator); @@ -67,8 +72,12 @@ typedef SplitVectorCalculator<::tflite::gpu::gl::GlBuffer, true> REGISTER_CALCULATOR(MovableSplitGlBufferVectorCalculator); #endif -typedef SplitVectorCalculator<::mediapipe::Detection, false> +typedef SplitVectorCalculator SplitDetectionVectorCalculator; REGISTER_CALCULATOR(SplitDetectionVectorCalculator); +typedef SplitVectorCalculator + SplitClassificationListVectorCalculator; +REGISTER_CALCULATOR(SplitClassificationListVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/split_vector_calculator.h b/mediapipe/calculators/core/split_vector_calculator.h index 4e257e3df..c77c6a40d 100644 --- a/mediapipe/calculators/core/split_vector_calculator.h +++ b/mediapipe/calculators/core/split_vector_calculator.h @@ -58,7 +58,7 @@ using IsNotMovable = template class SplitVectorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); @@ -79,7 +79,7 @@ class SplitVectorCalculator : public CalculatorBase { RET_CHECK_OK(checkRangesDontOverlap(options)); } else { if (cc->Outputs().NumEntries() != options.ranges_size()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The number of output streams should match the number of ranges " "specified in the CalculatorOptions."); } @@ -88,13 +88,13 @@ class SplitVectorCalculator : public CalculatorBase { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 || options.ranges(i).begin() >= options.ranges(i).end()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Indices should be non-negative and begin index should be less " "than the end index."); } if (options.element_only()) { if (options.ranges(i).end() - options.ranges(i).begin() != 1) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } cc->Outputs().Index(i).Set(); @@ -104,10 +104,10 @@ class SplitVectorCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -122,11 +122,11 @@ class SplitVectorCalculator : public CalculatorBase { total_elements_ += range.end() - range.begin(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - if (cc->Inputs().Index(0).IsEmpty()) return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + if (cc->Inputs().Index(0).IsEmpty()) return absl::OkStatus(); if (move_elements) { return ProcessMovableElements(cc); @@ -136,7 +136,7 @@ class SplitVectorCalculator : public CalculatorBase { } template = true> - ::mediapipe::Status ProcessCopyableElements(CalculatorContext* cc) { + absl::Status ProcessCopyableElements(CalculatorContext* cc) { // static_assert(std::is_copy_constructible::value, // "Cannot copy non-copyable elements"); const auto& input = cc->Inputs().Index(0).Get>(); @@ -167,21 +167,21 @@ class SplitVectorCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template = true> - ::mediapipe::Status ProcessCopyableElements(CalculatorContext* cc) { - return ::mediapipe::InternalError("Cannot copy non-copyable elements."); + absl::Status ProcessCopyableElements(CalculatorContext* cc) { + return absl::InternalError("Cannot copy non-copyable elements."); } template = true> - ::mediapipe::Status ProcessMovableElements(CalculatorContext* cc) { - ::mediapipe::StatusOr>> input_status = + absl::Status ProcessMovableElements(CalculatorContext* cc) { + absl::StatusOr>> input_status = cc->Inputs().Index(0).Value().Consume>(); if (!input_status.ok()) return input_status.status(); std::unique_ptr> input_vector = - std::move(input_status).ValueOrDie(); + std::move(input_status).value(); RET_CHECK_GE(input_vector->size(), max_range_end_); if (combine_outputs_) { @@ -214,16 +214,16 @@ class SplitVectorCalculator : public CalculatorBase { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template = true> - ::mediapipe::Status ProcessMovableElements(CalculatorContext* cc) { - return ::mediapipe::InternalError("Cannot move non-movable elements."); + absl::Status ProcessMovableElements(CalculatorContext* cc) { + return absl::InternalError("Cannot move non-movable elements."); } private: - static ::mediapipe::Status checkRangesDontOverlap( + static absl::Status checkRangesDontOverlap( const ::mediapipe::SplitVectorCalculatorOptions& options) { for (int i = 0; i < options.ranges_size() - 1; ++i) { for (int j = i + 1; j < options.ranges_size(); ++j) { @@ -233,13 +233,13 @@ class SplitVectorCalculator : public CalculatorBase { range_0.begin() < range_1.end()) || (range_1.begin() >= range_0.begin() && range_1.begin() < range_0.end())) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Ranges must be non-overlapping when using combine_outputs " "option."); } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector> ranges_; diff --git a/mediapipe/calculators/core/split_vector_calculator.proto b/mediapipe/calculators/core/split_vector_calculator.proto index 53acbb7bf..40301f88b 100644 --- a/mediapipe/calculators/core/split_vector_calculator.proto +++ b/mediapipe/calculators/core/split_vector_calculator.proto @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option objc_class_prefix = "MediaPipe"; + // A Range {begin, end} specifies beginning ane ending indices to splice a // vector. A vector v is spliced to have elements v[begin:(end-1)], i.e., with // begin index inclusive and end index exclusive. diff --git a/mediapipe/calculators/core/split_vector_calculator_test.cc b/mediapipe/calculators/core/split_vector_calculator_test.cc index 5d1ea2a04..0b98940fe 100644 --- a/mediapipe/calculators/core/split_vector_calculator_test.cc +++ b/mediapipe/calculators/core/split_vector_calculator_test.cc @@ -162,7 +162,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTest) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -213,7 +213,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidRangeTest) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -239,7 +239,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -268,7 +268,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -298,7 +298,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOverlappingRangesTest) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -329,7 +329,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -384,7 +384,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestCombiningOutputs) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -427,7 +427,7 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, ElementOnlyDisablesVectorOutputs) { // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "tensor_in" node { @@ -510,7 +510,7 @@ class MovableSplitUniqueIntPtrCalculatorTest : public ::testing::Test { TEST_F(MovableSplitUniqueIntPtrCalculatorTest, InvalidOverlappingRangesTest) { // Prepare a graph to use the TestMovableSplitUniqueIntPtrVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "input_vector" node { @@ -535,7 +535,7 @@ TEST_F(MovableSplitUniqueIntPtrCalculatorTest, InvalidOverlappingRangesTest) { TEST_F(MovableSplitUniqueIntPtrCalculatorTest, SmokeTest) { // Prepare a graph to use the TestMovableSplitUniqueIntPtrVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "input_vector" node { @@ -591,7 +591,7 @@ TEST_F(MovableSplitUniqueIntPtrCalculatorTest, SmokeTest) { TEST_F(MovableSplitUniqueIntPtrCalculatorTest, SmokeTestElementOnly) { // Prepare a graph to use the TestMovableSplitUniqueIntPtrVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "input_vector" node { @@ -645,7 +645,7 @@ TEST_F(MovableSplitUniqueIntPtrCalculatorTest, SmokeTestElementOnly) { TEST_F(MovableSplitUniqueIntPtrCalculatorTest, SmokeTestCombiningOutputs) { // Prepare a graph to use the TestMovableSplitUniqueIntPtrVectorCalculator. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( R"( input_stream: "input_vector" node { diff --git a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc index 07bb8c852..9dc25142a 100644 --- a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc +++ b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc @@ -30,17 +30,17 @@ namespace mediapipe { // } class StreamToSidePacketCalculator : public mediapipe::CalculatorBase { public: - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) { + static absl::Status GetContract(mediapipe::CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override { + absl::Status Process(mediapipe::CalculatorContext* cc) override { mediapipe::Packet& packet = cc->Inputs().Index(0).Value(); cc->OutputSidePackets().Index(0).Set( packet.At(mediapipe::Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(StreamToSidePacketCalculator); diff --git a/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc b/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc index 12f417c58..606f0e352 100644 --- a/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc +++ b/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc @@ -44,7 +44,7 @@ class StreamToSidePacketCalculatorTest : public Test { TEST_F(StreamToSidePacketCalculatorTest, StreamToSidePacketCalculatorWithEmptyStreamFails) { - EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(runner_->Run().code(), absl::StatusCode::kUnavailable); } TEST_F(StreamToSidePacketCalculatorTest, @@ -61,7 +61,7 @@ TEST_F(StreamToSidePacketCalculatorTest, Adopt(new std::string("test1")).At(Timestamp(1))); runner_->MutableInputs()->Index(0).packets.push_back( Adopt(new std::string("test2")).At(Timestamp(2))); - EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kAlreadyExists); + EXPECT_EQ(runner_->Run().code(), absl::StatusCode::kAlreadyExists); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index 64600cde3..13a9a29e0 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -36,32 +36,32 @@ namespace mediapipe { template class StringToIntCalculatorTemplate : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).Set(); cc->OutputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { IntType number; if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get(), &number)) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The std::string could not be parsed as an integer."); } cc->OutputSidePackets().Index(0).Set(MakePacket(number)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; using StringToIntCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToIntCalculator); -using StringToUintCalculator = StringToIntCalculatorTemplate; +using StringToUintCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUintCalculator); using StringToInt32Calculator = StringToIntCalculatorTemplate; diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index a14bd31d1..e94fb7ec7 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -12,154 +12,78 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -exports_files(["LICENSE"]) - -proto_library( +mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", ], ) -proto_library( +mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", ], ) -mediapipe_cc_proto_library( - name = "opencv_image_encoder_calculator_cc_proto", - srcs = ["opencv_image_encoder_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = [ - "//visibility:public", - ], - deps = [":opencv_image_encoder_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "opencv_encoded_image_to_image_frame_calculator_cc_proto", - srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":opencv_encoded_image_to_image_frame_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "mask_overlay_calculator_cc_proto", - srcs = ["mask_overlay_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":mask_overlay_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "scale_image_calculator_cc_proto", - srcs = ["scale_image_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/formats:image_format_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":scale_image_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "set_alpha_calculator_cc_proto", - srcs = ["set_alpha_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":set_alpha_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "image_cropping_calculator_cc_proto", - srcs = ["image_cropping_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":image_cropping_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "bilateral_filter_calculator_cc_proto", - srcs = ["bilateral_filter_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = [ - "//visibility:public", - ], - deps = [":bilateral_filter_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "recolor_calculator_cc_proto", - srcs = ["recolor_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":recolor_calculator_proto"], -) - cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -193,9 +117,7 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -251,9 +173,7 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -279,27 +199,17 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/gpu:scale_mode_proto", ], ) -mediapipe_cc_proto_library( - name = "image_transformation_calculator_cc_proto", - srcs = ["image_transformation_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/gpu:scale_mode_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":image_transformation_calculator_proto"], -) - cc_library( name = "image_transformation_calculator", srcs = ["image_transformation_calculator.cc"], @@ -528,7 +438,6 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) @@ -547,7 +456,6 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) @@ -560,32 +468,34 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "feature_detector_calculator_cc_proto", - srcs = ["feature_detector_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = ["//visibility:public"], - deps = [":feature_detector_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( @@ -607,7 +517,7 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//mediapipe:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -658,6 +568,5 @@ cc_test( "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index b366caf7a..3d878bffc 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -28,11 +28,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -82,18 +82,18 @@ class BilateralFilterCalculator : public CalculatorBase { BilateralFilterCalculator() = default; ~BilateralFilterCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status RenderGpu(CalculatorContext* cc); - ::mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); - ::mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlSetup(CalculatorContext* cc); void GlRender(CalculatorContext* cc); mediapipe::BilateralFilterCalculatorOptions options_; @@ -102,80 +102,79 @@ class BilateralFilterCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint vao_; GLuint vbo_[2]; // vertex storage -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(BilateralFilterCalculator); -::mediapipe::Status BilateralFilterCalculator::GetContract( - CalculatorContract* cc) { +absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { - return ::mediapipe::InternalError("Cannot have multiple input images."); + return absl::InternalError("Cannot have multiple input images."); } if (cc->Inputs().HasTag(kInputFrameTagGpu) != cc->Outputs().HasTag(kOutputFrameTagGpu)) { - return ::mediapipe::InternalError("GPU output must have GPU input."); + return absl::InternalError("GPU output must have GPU input."); } bool use_gpu = false; // Input image to filter. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input guide image mask (optional) -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputGuideTagGpu)) { cc->Inputs().Tag(kInputGuideTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputGuideTag)) { cc->Inputs().Tag(kInputGuideTag).Set(); } // Output image. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; #else RET_CHECK_FAIL() << "GPU processing not enabled."; @@ -189,36 +188,35 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); if (!use_gpu_) sigma_color_ *= 255.0; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(GlSetup(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return ::mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); if (vao_) glDeleteVertexArrays(1, &vao_); @@ -228,15 +226,14 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); vbo_[0] = 0; vbo_[1] = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BilateralFilterCalculator::RenderCpu( - CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); @@ -244,7 +241,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); // Only 1 or 3 channel images supported by OpenCV. if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { - return ::mediapipe::InternalError( + return absl::InternalError( "CPU filtering supports only 1 or 3 channel input images."); } @@ -255,7 +252,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); if (has_guide_image) { // cv::jointBilateralFilter() is in contrib module 'ximgproc'. - return ::mediapipe::UnimplementedError( + return absl::UnimplementedError( "CPU joint filtering support is not implemented yet."); } else { auto output_mat = mediapipe::formats::MatView(output_frame.get()); @@ -267,15 +264,14 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); cc->Outputs() .Tag(kOutputFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BilateralFilterCalculator::RenderGpu( - CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); @@ -285,8 +281,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); // Setup textures and Update image in GPU shader. if (has_guide_image) { - if (cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty()) - return mediapipe::OkStatus(); + if (cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty()) return absl::OkStatus(); // joint bilateral filter glUseProgram(program_); const auto& guide_image = @@ -332,13 +327,13 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // bring back vao and vbo glBindVertexArray(vao_); @@ -347,11 +342,11 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { // cleanup glBindVertexArray(0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -515,9 +510,9 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { glBindBuffer(GL_ARRAY_BUFFER, 0); glBindVertexArray(0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index f31586d9d..bdac932bb 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -78,12 +78,12 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; class ColorConvertCalculator : public CalculatorBase { public: ~ColorConvertCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -91,17 +91,16 @@ class ColorConvertCalculator : public CalculatorBase { // conversion. The ImageFrame on input_tag is converted using the // open_cv_convert_code provided and then output on the output_tag stream. // Note that the output_format must match the destination conversion code. - ::mediapipe::Status ConvertAndOutput(const std::string& input_tag, - const std::string& output_tag, - ImageFormat::Format output_format, - int open_cv_convert_code, - CalculatorContext* cc); + absl::Status ConvertAndOutput(const std::string& input_tag, + const std::string& output_tag, + ImageFormat::Format output_format, + int open_cv_convert_code, + CalculatorContext* cc); }; REGISTER_CALCULATOR(ColorConvertCalculator); -::mediapipe::Status ColorConvertCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is allowed."; RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) @@ -139,10 +138,10 @@ REGISTER_CALCULATOR(ColorConvertCalculator); cc->Outputs().Tag(kBgraOutTag).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ColorConvertCalculator::ConvertAndOutput( +absl::Status ColorConvertCalculator::ConvertAndOutput( const std::string& input_tag, const std::string& output_tag, ImageFormat::Format output_format, int open_cv_convert_code, CalculatorContext* cc) { @@ -161,10 +160,10 @@ REGISTER_CALCULATOR(ColorConvertCalculator); cc->Outputs() .Tag(output_tag) .Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ColorConvertCalculator::Process(CalculatorContext* cc) { +absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { // RGBA -> RGB if (cc->Inputs().HasTag(kRgbaInTag) && cc->Outputs().HasTag(kRgbOutTag)) { return ConvertAndOutput(kRgbaInTag, kRgbOutTag, ImageFormat::SRGB, @@ -196,7 +195,7 @@ REGISTER_CALCULATOR(ColorConvertCalculator); cv::COLOR_RGBA2BGRA, cc); } - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format conversion."; } diff --git a/mediapipe/calculators/image/feature_detector_calculator.cc b/mediapipe/calculators/image/feature_detector_calculator.cc index 9f873740e..389a33696 100644 --- a/mediapipe/calculators/image/feature_detector_calculator.cc +++ b/mediapipe/calculators/image/feature_detector_calculator.cc @@ -50,15 +50,15 @@ class FeatureDetectorCalculator : public CalculatorBase { public: ~FeatureDetectorCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: FeatureDetectorCalculatorOptions options_; cv::Ptr feature_detector_; - std::unique_ptr<::mediapipe::ThreadPool> pool_; + std::unique_ptr pool_; // Create image pyramid based on input image. void ComputeImagePyramid(const cv::Mat& input_image, @@ -71,8 +71,7 @@ class FeatureDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(FeatureDetectorCalculator); -::mediapipe::Status FeatureDetectorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status FeatureDetectorCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("IMAGE")) { cc->Inputs().Tag("IMAGE").Set(); } @@ -85,26 +84,26 @@ REGISTER_CALCULATOR(FeatureDetectorCalculator); if (cc->Outputs().HasTag("PATCHES")) { cc->Outputs().Tag("PATCHES").Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FeatureDetectorCalculator::Open(CalculatorContext* cc) { +absl::Status FeatureDetectorCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag) .GetExtension(FeatureDetectorCalculatorOptions::ext); feature_detector_ = cv::ORB::create( options_.max_features(), options_.scale_factor(), options_.pyramid_level(), kPatchSize - 1, 0, 2, cv::ORB::FAST_SCORE); - pool_ = absl::make_unique<::mediapipe::ThreadPool>("ThreadPool", kNumThreads); + pool_ = absl::make_unique("ThreadPool", kNumThreads); pool_->StartWorkers(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FeatureDetectorCalculator::Process(CalculatorContext* cc) { +absl::Status FeatureDetectorCalculator::Process(CalculatorContext* cc) { const Timestamp& timestamp = cc->InputTimestamp(); if (timestamp == Timestamp::PreStream()) { // Indicator packet. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* input_frame = &(cc->Inputs().Tag("IMAGE")); cv::Mat input_view = formats::MatView(&input_frame->Get()); @@ -176,7 +175,7 @@ REGISTER_CALCULATOR(FeatureDetectorCalculator); cc->Outputs().Tag("PATCHES").Add(patches.release(), timestamp); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void FeatureDetectorCalculator::ComputeImagePyramid( diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index b008a9e1e..e4b0b7218 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -24,11 +24,11 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -38,9 +38,9 @@ namespace mediapipe { namespace { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU constexpr char kRectTag[] = "RECT"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -53,8 +53,7 @@ constexpr char kWidthTag[] = "WIDTH"; REGISTER_CALCULATOR(ImageCroppingCalculator); -::mediapipe::Status ImageCroppingCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ImageCroppingCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kImageTag) ^ cc->Inputs().HasTag(kImageGpuTag)); RET_CHECK(cc->Outputs().HasTag(kImageTag) ^ cc->Outputs().HasTag(kImageGpuTag)); @@ -66,14 +65,14 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); cc->Inputs().Tag(kImageTag).Set(); cc->Outputs().Tag(kImageTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageGpuTag)) { RET_CHECK(cc->Outputs().HasTag(kImageGpuTag)); cc->Inputs().Tag(kImageGpuTag).Set(); cc->Outputs().Tag(kImageGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU int flags = 0; if (cc->Inputs().HasTag(kRectTag)) { @@ -111,15 +110,15 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kImageGpuTag)) { @@ -133,11 +132,11 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); options_.has_output_max_height() ? options_.output_max_height() : FLT_MAX; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #else RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } // Validate border mode. @@ -147,56 +146,55 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); MP_RETURN_IF_ERROR(ValidateBorderModeForCPU(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kRectTag) && cc->Inputs().Tag(kRectTag).IsEmpty()) { VLOG(1) << "RECT is empty for timestamp: " << cc->InputTimestamp(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kNormRectTag) && cc->Inputs().Tag(kNormRectTag).IsEmpty()) { VLOG(1) << "NORM_RECT is empty for timestamp: " << cc->InputTimestamp(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(InitGpu(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return ::mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(InitGpu(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); gpu_initialized_ = false; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::ValidateBorderModeForCPU( +absl::Status ImageCroppingCalculator::ValidateBorderModeForCPU( CalculatorContext* cc) { int border_mode; return GetBorderModeForOpenCV(cc, &border_mode); } -::mediapipe::Status ImageCroppingCalculator::ValidateBorderModeForGPU( +absl::Status ImageCroppingCalculator::ValidateBorderModeForGPU( CalculatorContext* cc) { mediapipe::ImageCroppingCalculatorOptions options = cc->Options(); @@ -213,12 +211,12 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); << options.border_mode(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kImageTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_img = cc->Inputs().Tag(kImageTag).Get(); cv::Mat input_mat = formats::MatView(&input_img); @@ -268,14 +266,14 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); cropped_image.copyTo(output_mat); cc->Outputs().Tag(kImageTag).Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU const Packet& input_packet = cc->Inputs().Tag(kImageGpuTag).Value(); const auto& input_buffer = input_packet.Get(); auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer); @@ -306,13 +304,13 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); // Cleanup src_tex.Release(); dst_tex.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void ImageCroppingCalculator::GlRender() { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -356,11 +354,11 @@ void ImageCroppingCalculator::GlRender() { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -406,9 +404,9 @@ void ImageCroppingCalculator::GlRender() { // Parameters glUseProgram(program_); glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // For GPU only. @@ -534,7 +532,7 @@ RectSpec ImageCroppingCalculator::GetCropSpecs(const CalculatorContext* cc, return {crop_width, crop_height, x_center, y_center, rotation}; } -::mediapipe::Status ImageCroppingCalculator::GetBorderModeForOpenCV( +absl::Status ImageCroppingCalculator::GetBorderModeForOpenCV( CalculatorContext* cc, int* border_mode) { mediapipe::ImageCroppingCalculatorOptions options = cc->Options(); @@ -551,7 +549,7 @@ RectSpec ImageCroppingCalculator::GetCropSpecs(const CalculatorContext* cc, << options.border_mode(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator.h b/mediapipe/calculators/image/image_cropping_calculator.h index 5d50b6647..39d99cc55 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.h +++ b/mediapipe/calculators/image/image_cropping_calculator.h @@ -6,9 +6,9 @@ #include "mediapipe/calculators/image/image_cropping_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Crops the input texture to the given rectangle region. The rectangle can // be at arbitrary location on the image with rotation. If there's rotation, the @@ -58,24 +58,23 @@ class ImageCroppingCalculator : public CalculatorBase { ImageCroppingCalculator() = default; ~ImageCroppingCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; static RectSpec GetCropSpecs(const CalculatorContext* cc, int src_width, int src_height); private: - ::mediapipe::Status ValidateBorderModeForCPU(CalculatorContext* cc); - ::mediapipe::Status ValidateBorderModeForGPU(CalculatorContext* cc); - ::mediapipe::Status RenderCpu(CalculatorContext* cc); - ::mediapipe::Status RenderGpu(CalculatorContext* cc); - ::mediapipe::Status InitGpu(CalculatorContext* cc); + absl::Status ValidateBorderModeForCPU(CalculatorContext* cc); + absl::Status ValidateBorderModeForGPU(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); void GlRender(); void GetOutputDimensions(CalculatorContext* cc, int src_width, int src_height, int* dst_width, int* dst_height); - ::mediapipe::Status GetBorderModeForOpenCV(CalculatorContext* cc, - int* border_mode); + absl::Status GetBorderModeForOpenCV(CalculatorContext* cc, int* border_mode); mediapipe::ImageCroppingCalculatorOptions options_; @@ -84,11 +83,11 @@ class ImageCroppingCalculator : public CalculatorBase { float transformed_points_[8]; float output_max_width_ = FLT_MAX; float output_max_height_ = FLT_MAX; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU bool gpu_initialized_ = false; mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index c511014aa..bb75826a1 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -59,8 +59,8 @@ TEST(ImageCroppingCalculatorTest, GetCroppingDimensionsNormal) { auto calculator_state = absl::make_unique( "Node", 0, "Calculator", calculator_node, nullptr); auto cc = absl::make_unique( - calculator_state.get(), tool::CreateTagMap({}).ValueOrDie(), - tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); RectSpec expectRect = { .width = 60, @@ -99,8 +99,8 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecInOptions) { auto calculator_state = absl::make_unique( "Node", 0, "Calculator", calculator_node, nullptr); auto cc = absl::make_unique( - calculator_state.get(), tool::CreateTagMap({}).ValueOrDie(), - tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); RectSpec expectRect = { .width = 50, .height = 50, @@ -144,9 +144,9 @@ TEST(ImageCroppingCalculatorTest, RedundantSpectWithInputStream) { "HEIGHT:0:crop_height", "WIDTH:0:crop_width", }) - .ValueOrDie(); + .value(); auto cc = absl::make_unique( - calculator_state.get(), inputTags, tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); inputs.Tag(kHeightTag).Value() = MakePacket(1); inputs.Tag(kWidthTag).Value() = MakePacket(1); @@ -191,9 +191,9 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto inputTags = tool::CreateTagMap({ "RECT:0:rect", }) - .ValueOrDie(); + .value(); auto cc = absl::make_unique( - calculator_state.get(), inputTags, tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); mediapipe::Rect rect = ParseTextProtoOrDie( R"( diff --git a/mediapipe/calculators/image/image_file_properties_calculator.cc b/mediapipe/calculators/image/image_file_properties_calculator.cc index 82af9ef8a..9c6d8caca 100644 --- a/mediapipe/calculators/image/image_file_properties_calculator.cc +++ b/mediapipe/calculators/image/image_file_properties_calculator.cc @@ -28,23 +28,24 @@ namespace { // sqrt(36^2 + 24^2). static const double SENSOR_DIAGONAL_35MM = std::sqrt(1872.0); -::mediapipe::StatusOr ComputeFocalLengthInPixels( - int image_width, int image_height, double focal_length_35mm, - double focal_length_mm) { +absl::StatusOr ComputeFocalLengthInPixels(int image_width, + int image_height, + double focal_length_35mm, + double focal_length_mm) { // TODO: Allow returning image file properties even when focal length // computation is not possible. if (image_width == 0 || image_height == 0) { - return ::mediapipe::InternalError( + return absl::InternalError( "Image dimensions should be non-zero to compute focal length in " "pixels."); } if (focal_length_mm == 0) { - return ::mediapipe::InternalError( + return absl::InternalError( "Focal length in mm should be non-zero to compute focal length in " "pixels."); } if (focal_length_35mm == 0) { - return ::mediapipe::InternalError( + return absl::InternalError( "Focal length in 35 mm should be non-zero to compute focal length in " "pixels."); } @@ -76,13 +77,13 @@ static const double SENSOR_DIAGONAL_35MM = std::sqrt(1872.0); return focal_length_pixels; } -::mediapipe::StatusOr GetImageFileProperites( +absl::StatusOr GetImageFileProperites( const std::string& image_bytes) { easyexif::EXIFInfo result; int code = result.parseFrom(image_bytes); if (code) { - return ::mediapipe::InternalError("Error parsing EXIF, code: " + - std::to_string(code)); + return absl::InternalError("Error parsing EXIF, code: " + + std::to_string(code)); } ImageFileProperties properties; @@ -125,7 +126,7 @@ static const double SENSOR_DIAGONAL_35MM = std::sqrt(1872.0); // } class ImageFilePropertiesCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (cc->Inputs().NumEntries() != 0) { RET_CHECK(cc->Inputs().NumEntries() == 1); cc->Inputs().Index(0).Set(); @@ -141,10 +142,10 @@ class ImageFilePropertiesCalculator : public CalculatorBase { cc->OutputSidePackets().Index(0).Set<::mediapipe::ImageFileProperties>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); if (cc->InputSidePackets().NumEntries() == 1) { @@ -159,13 +160,13 @@ class ImageFilePropertiesCalculator : public CalculatorBase { MakePacket(properties_)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().NumEntries() == 1) { if (cc->Inputs().Index(0).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const std::string& image_bytes = cc->Inputs().Index(0).Get(); ASSIGN_OR_RETURN(properties_, GetImageFileProperites(image_bytes)); @@ -179,11 +180,11 @@ class ImageFilePropertiesCalculator : public CalculatorBase { } else { cc->OutputSidePackets().Index(0).Set( MakePacket(properties_) - .At(::mediapipe::Timestamp::Unset())); + .At(mediapipe::Timestamp::Unset())); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/image/image_properties_calculator.cc b/mediapipe/calculators/image/image_properties_calculator.cc index be0a65e0d..5fbd64012 100644 --- a/mediapipe/calculators/image/image_properties_calculator.cc +++ b/mediapipe/calculators/image/image_properties_calculator.cc @@ -15,9 +15,9 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr char kImageFrameTag[] = "IMAGE"; @@ -44,31 +44,31 @@ namespace mediapipe { // } class ImagePropertiesCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ cc->Inputs().HasTag(kGpuBufferTag)); if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set<::mediapipe::GpuBuffer>(); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("SIZE")) { cc->Outputs().Tag("SIZE").Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { int width; int height; @@ -78,7 +78,7 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.Width(); height = image.Height(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag) && !cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { const auto& image = @@ -86,13 +86,13 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.width(); height = image.height(); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU cc->Outputs().Tag("SIZE").AddPacket( MakePacket>(width, height) .At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ImagePropertiesCalculator); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 37539d814..bb98f14e0 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -22,12 +22,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/scale_mode.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU #if defined(__ANDROID__) // The size of Java arrays is dynamic, which makes it difficult to @@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2]; namespace mediapipe { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr char kImageFrameTag[] = "IMAGE"; @@ -163,16 +163,16 @@ class ImageTransformationCalculator : public CalculatorBase { ImageTransformationCalculator() = default; ~ImageTransformationCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status RenderCpu(CalculatorContext* cc); - ::mediapipe::Status RenderGpu(CalculatorContext* cc); - ::mediapipe::Status GlSetup(); + absl::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status GlSetup(); void ComputeOutputDimensions(int input_width, int input_height, int* output_width, int* output_height); @@ -189,17 +189,17 @@ class ImageTransformationCalculator : public CalculatorBase { bool flip_vertically_ = false; bool use_gpu_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU GlCalculatorHelper gpu_helper_; std::unique_ptr rgb_renderer_; std::unique_ptr yuv_renderer_; std::unique_ptr ext_rgb_renderer_; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(ImageTransformationCalculator); // static -::mediapipe::Status ImageTransformationCalculator::GetContract( +absl::Status ImageTransformationCalculator::GetContract( CalculatorContract* cc) { // Only one input can be set, and the output type must match. RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ @@ -212,14 +212,14 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); cc->Inputs().Tag(kImageFrameTag).Set(); cc->Outputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag)); cc->Inputs().Tag(kGpuBufferTag).Set(); cc->Outputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); @@ -249,15 +249,15 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); @@ -303,19 +303,18 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Let the helper access the GL context information. MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageTransformationCalculator::Process( - CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { // Override values if specified so. if (cc->Inputs().HasTag("ROTATION_DEGREES") && !cc->Inputs().Tag("ROTATION_DEGREES").IsEmpty()) { @@ -332,26 +331,25 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } return gpu_helper_.RunInGlContext( - [this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); }); -#endif // !MEDIAPIPE_DISABLE_GPU + [this, cc]() -> absl::Status { return RenderGpu(cc); }); +#endif // !MEDIAPIPE_DISABLE_GPU } else { if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } return RenderCpu(cc); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageTransformationCalculator::Close( - CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Close(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU QuadRenderer* rgb_renderer = rgb_renderer_.release(); QuadRenderer* yuv_renderer = yuv_renderer_.release(); QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); @@ -369,14 +367,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); delete yuv_renderer; } }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageTransformationCalculator::RenderCpu( - CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) { cv::Mat input_mat; mediapipe::ImageFormat::Format format; @@ -480,12 +477,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); .Tag(kImageFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageTransformationCalculator::RenderGpu( - CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); const int input_width = input.width(); const int input_height = input.height(); @@ -519,7 +515,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); renderer = yuv_renderer_.get(); src1 = gpu_helper_.CreateSourceTexture(input, 0); } else // NOLINT(readability/braces) -#endif // iOS +#endif // iOS { src1 = gpu_helper_.CreateSourceTexture(input); #if defined(TEXTURE_EXTERNAL_OES) @@ -531,7 +527,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } renderer = ext_rgb_renderer_.get(); } else // NOLINT(readability/braces) -#endif // TEXTURE_EXTERNAL_OES +#endif // TEXTURE_EXTERNAL_OES { if (!rgb_renderer_) { rgb_renderer_ = absl::make_unique(); @@ -568,9 +564,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); auto output = dst.template GetFrame(); cc->Outputs().Tag(kGpuBufferTag).Add(output.release(), cc->InputTimestamp()); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void ImageTransformationCalculator::ComputeOutputDimensions( diff --git a/mediapipe/calculators/image/luminance_calculator.cc b/mediapipe/calculators/image/luminance_calculator.cc index 325745d99..d5122c7a4 100644 --- a/mediapipe/calculators/image/luminance_calculator.cc +++ b/mediapipe/calculators/image/luminance_calculator.cc @@ -26,10 +26,9 @@ namespace mediapipe { // See GlSimpleCalculatorBase for inputs, outputs and input side packets. class LuminanceCalculator : public GlSimpleCalculator { public: - ::mediapipe::Status GlSetup() override; - ::mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - ::mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -37,7 +36,7 @@ class LuminanceCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(LuminanceCalculator); -::mediapipe::Status LuminanceCalculator::GlSetup() { +absl::Status LuminanceCalculator::GlSetup() { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, @@ -83,11 +82,11 @@ REGISTER_CALCULATOR(LuminanceCalculator); (const GLchar**)&attr_name[0], attr_location, &program_); RET_CHECK(program_) << "Problem initializing the program."; frame_ = glGetUniformLocation(program_, "video_frame"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -137,15 +136,15 @@ REGISTER_CALCULATOR(LuminanceCalculator); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LuminanceCalculator::GlTeardown() { +absl::Status LuminanceCalculator::GlTeardown() { if (program_) { glDeleteProgram(program_); program_ = 0; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/mask_overlay_calculator.cc b/mediapipe/calculators/image/mask_overlay_calculator.cc index eda3f4f91..5fbf9e4f4 100644 --- a/mediapipe/calculators/image/mask_overlay_calculator.cc +++ b/mediapipe/calculators/image/mask_overlay_calculator.cc @@ -52,14 +52,14 @@ class MaskOverlayCalculator : public CalculatorBase { MaskOverlayCalculator() {} ~MaskOverlayCalculator(); - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status GlSetup( + absl::Status GlSetup( const MaskOverlayCalculatorOptions::MaskChannel mask_channel); - ::mediapipe::Status GlRender(const float mask_const); + absl::Status GlRender(const float mask_const); private: GlCalculatorHelper helper_; @@ -73,7 +73,7 @@ class MaskOverlayCalculator : public CalculatorBase { REGISTER_CALCULATOR(MaskOverlayCalculator); // static -::mediapipe::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { +absl::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); cc->Inputs().Get("VIDEO", 0).Set(); cc->Inputs().Get("VIDEO", 1).Set(); @@ -82,14 +82,13 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); else if (cc->Inputs().HasTag("CONST_MASK")) cc->Inputs().Tag("CONST_MASK").Set(); else - return ::mediapipe::Status( - ::mediapipe::StatusCode::kNotFound, - "At least one mask input stream must be present."); + return absl::Status(absl::StatusCode::kNotFound, + "At least one mask input stream must be present."); cc->Outputs().Tag("OUTPUT").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { +absl::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag("MASK")) { use_mask_tex_ = true; @@ -97,8 +96,8 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); return helper_.Open(cc); } -::mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { - return helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { +absl::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { + return helper_.RunInGlContext([this, &cc]() -> absl::Status { if (!initialized_) { const auto& options = cc->Options(); const auto mask_channel = options.mask_channel(); @@ -116,7 +115,7 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); if (mask_packet.IsEmpty()) { cc->Outputs().Tag("OUTPUT").AddPacket(input1_packet); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input0_buffer = cc->Inputs().Get("VIDEO", 0).Get(); @@ -173,11 +172,11 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); dst.Release(); cc->Outputs().Tag("OUTPUT").Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } -::mediapipe::Status MaskOverlayCalculator::GlSetup( +absl::Status MaskOverlayCalculator::GlSetup( const MaskOverlayCalculatorOptions::MaskChannel mask_channel) { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { @@ -248,10 +247,10 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); unif_frame1_ = glGetUniformLocation(program_, "frame1"); unif_frame2_ = glGetUniformLocation(program_, "frame2"); unif_mask_ = glGetUniformLocation(program_, "mask"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MaskOverlayCalculator::GlRender(const float mask_const) { +absl::Status MaskOverlayCalculator::GlRender(const float mask_const) { glUseProgram(program_); glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, kBasicSquareVertices); glEnableVertexAttribArray(ATTRIB_VERTEX); @@ -267,7 +266,7 @@ REGISTER_CALCULATOR(MaskOverlayCalculator); glUniform1f(unif_mask_, mask_const); glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } MaskOverlayCalculator::~MaskOverlayCalculator() { diff --git a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc index ae87d7511..21bc587f3 100644 --- a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc +++ b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc @@ -34,29 +34,29 @@ namespace mediapipe { // } class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: mediapipe::OpenCvEncodedImageToImageFrameCalculatorOptions options_; }; -::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::GetContract( +absl::Status OpenCvEncodedImageToImageFrameCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Open( +absl::Status OpenCvEncodedImageToImageFrameCalculator::Open( CalculatorContext* cc) { options_ = cc->Options(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process( +absl::Status OpenCvEncodedImageToImageFrameCalculator::Process( CalculatorContext* cc) { const std::string& contents = cc->Inputs().Index(0).Get(); const std::vector contents_vector(contents.begin(), contents.end()); @@ -84,10 +84,11 @@ class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase { cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGB); break; case 4: - return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) - << "4-channel image isn't supported yet"; + image_format = ImageFormat::SRGBA; + cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGBA); + break; default: - return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Unsupported number of channels: " << decoded_mat.channels(); } std::unique_ptr output_frame = absl::make_unique( @@ -95,7 +96,7 @@ class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase { ImageFrame::kGlDefaultAlignmentBoundary); output_mat.copyTo(formats::MatView(output_frame.get())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvEncodedImageToImageFrameCalculator); diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc index efe79d99c..93ec9435f 100644 --- a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc @@ -38,30 +38,28 @@ namespace mediapipe { // } class OpenCvImageEncoderCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: int encoding_quality_; }; -::mediapipe::Status OpenCvImageEncoderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status OpenCvImageEncoderCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { +absl::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { auto options = cc->Options(); encoding_quality_ = options.quality(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvImageEncoderCalculator::Process( - CalculatorContext* cc) { +absl::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { const ImageFrame& image_frame = cc->Inputs().Index(0).Get(); CHECK_EQ(1, image_frame.ByteDepth()); @@ -85,10 +83,10 @@ class OpenCvImageEncoderCalculator : public CalculatorBase { encoded_result->set_colorspace(OpenCvImageEncoderCalculatorResults::RGB); break; case 4: - return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) << "4-channel image isn't supported yet"; default: - return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Unsupported number of channels: " << original_mat.channels(); } @@ -101,19 +99,18 @@ class OpenCvImageEncoderCalculator : public CalculatorBase { // Check its JpegEncoder::write() in "imgcodecs/src/grfmt_jpeg.cpp" for more // info. if (!cv::imencode(".jpg", input_mat, encode_buffer, parameters)) { - return ::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) << "Fail to encode the image to be jpeg format."; } - encoded_result->set_encoded_image(std::string(absl::string_view( - reinterpret_cast(&encode_buffer[0]), encode_buffer.size()))); + encoded_result->set_encoded_image(&encode_buffer[0], encode_buffer.size()); cc->Outputs().Index(0).Add(encoded_result.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) { - return ::mediapipe::OkStatus(); +absl::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvImageEncoderCalculator); diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.proto b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto index 43172b319..0564fb270 100644 --- a/mediapipe/calculators/image/opencv_image_encoder_calculator.proto +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto @@ -29,11 +29,13 @@ message OpenCvImageEncoderCalculatorOptions { // TODO: Consider renaming it to EncodedImage. message OpenCvImageEncoderCalculatorResults { - // Encoded image - optional string encoded_image = 1; + // Pixel data encoded as JPEG. + optional bytes encoded_image = 1; - // Dimensions of the encoded image + // Height of the image data under #1 once decoded. optional int32 height = 2; + + // Width of the image data under #1 once decoded. optional int32 width = 3; enum ColorSpace { diff --git a/mediapipe/calculators/image/opencv_put_text_calculator.cc b/mediapipe/calculators/image/opencv_put_text_calculator.cc index 07f6f0dbf..82a4b3a53 100644 --- a/mediapipe/calculators/image/opencv_put_text_calculator.cc +++ b/mediapipe/calculators/image/opencv_put_text_calculator.cc @@ -32,18 +32,17 @@ namespace mediapipe { // TODO: Generalize the calculator for other text use cases. class OpenCvPutTextCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Process(CalculatorContext* cc) override; }; -::mediapipe::Status OpenCvPutTextCalculator::GetContract( - CalculatorContract* cc) { +absl::Status OpenCvPutTextCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { +absl::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { const std::string& text_content = cc->Inputs().Index(0).Get(); cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC4); cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3, @@ -52,7 +51,7 @@ class OpenCvPutTextCalculator : public CalculatorBase { ImageFormat::SRGBA, mat.size().width, mat.size().height); mat.copyTo(formats::MatView(output_frame.get())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvPutTextCalculator); diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index c8d3d1725..6a12025f6 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -24,11 +24,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/util/color.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -84,17 +84,17 @@ class RecolorCalculator : public CalculatorBase { RecolorCalculator() = default; ~RecolorCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status LoadOptions(CalculatorContext* cc); - ::mediapipe::Status InitGpu(CalculatorContext* cc); - ::mediapipe::Status RenderGpu(CalculatorContext* cc); - ::mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); void GlRender(); bool initialized_ = false; @@ -102,46 +102,46 @@ class RecolorCalculator : public CalculatorBase { mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; bool use_gpu_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(RecolorCalculator); // static -::mediapipe::Status RecolorCalculator::GetContract(CalculatorContract* cc) { +absl::Status RecolorCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); bool use_gpu = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kMaskGpuTag)) { cc->Inputs().Tag(kMaskGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kMaskCpuTag)) { cc->Inputs().Tag(kMaskCpuTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kGpuBufferTag)) { cc->Outputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs().Tag(kImageFrameTag).Set(); } @@ -154,62 +154,62 @@ REGISTER_CALCULATOR(RecolorCalculator); cc->Outputs().HasTag(kGpuBufferTag)); if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) { +absl::Status RecolorCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kGpuBufferTag)) { use_gpu_ = true; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { +absl::Status RecolorCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + gpu_helper_.RunInGlContext([this, &cc]() -> absl::Status { if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); initialized_ = true; } MP_RETURN_IF_ERROR(RenderGpu(cc)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status RecolorCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Get inputs and setup output. const auto& input_img = cc->Inputs().Tag(kImageFrameTag).Get(); @@ -265,14 +265,14 @@ REGISTER_CALCULATOR(RecolorCalculator); .Tag(kImageFrameTag) .Add(output_img.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Get inputs and setup output. const Packet& input_packet = cc->Inputs().Tag(kGpuBufferTag).Value(); const Packet& mask_packet = cc->Inputs().Tag(kMaskGpuTag).Value(); @@ -311,13 +311,13 @@ REGISTER_CALCULATOR(RecolorCalculator); img_tex.Release(); mask_tex.Release(); dst_tex.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void RecolorCalculator::GlRender() { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -365,10 +365,10 @@ void RecolorCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { +absl::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { const auto& options = cc->Options(); mask_channel_ = options.mask_channel(); @@ -379,11 +379,11 @@ void RecolorCalculator::GlRender() { color_.push_back(options.color().g()); color_.push_back(options.color().b()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -452,9 +452,9 @@ void RecolorCalculator::GlRender() { glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0] / 255.0, color_[1] / 255.0, color_[2] / 255.0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index ac441689e..575268da5 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -44,7 +44,7 @@ namespace { // Given an upscaling algorithm, determine which OpenCV interpolation algorithm // to use. -::mediapipe::Status FindInterpolationAlgorithm( +absl::Status FindInterpolationAlgorithm( ScaleImageCalculatorOptions::ScaleAlgorithm upscaling_algorithm, int* interpolation_algorithm) { switch (upscaling_algorithm) { @@ -70,7 +70,7 @@ namespace { RET_CHECK_FAIL() << absl::Substitute("Unknown upscaling algorithm: $0", upscaling_algorithm); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void CropImageFrame(const ImageFrame& original, int col_start, int row_start, @@ -147,7 +147,7 @@ class ScaleImageCalculator : public CalculatorBase { ScaleImageCalculator(); ~ScaleImageCalculator() override; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { ScaleImageCalculatorOptions options = cc->Options(); @@ -184,35 +184,35 @@ class ScaleImageCalculator : public CalculatorBase { if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { cc->Inputs().Tag("OVERRIDE_OPTIONS").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // From Calculator. - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Initialize some data members from options_. This can be called either from // Open or Process depending on whether OVERRIDE_OPTIONS is used. - ::mediapipe::Status InitializeFromOptions(); + absl::Status InitializeFromOptions(); // Initialize crop and output parameters based on set member variable // values. This function will also send the header information on // the VIDEO_HEADER stream if it hasn't been done yet. - ::mediapipe::Status InitializeFrameInfo(CalculatorContext* cc); + absl::Status InitializeFrameInfo(CalculatorContext* cc); // Validate that input_format_ and output_format_ are supported image // formats. - ::mediapipe::Status ValidateImageFormats() const; + absl::Status ValidateImageFormats() const; // Validate that the image frame has the proper format and dimensions. // If the dimensions and format weren't initialized by the header, // then the first frame on which this function is called is used // to initialize. - ::mediapipe::Status ValidateImageFrame(CalculatorContext* cc, - const ImageFrame& image_frame); + absl::Status ValidateImageFrame(CalculatorContext* cc, + const ImageFrame& image_frame); // Validate that the YUV image has the proper dimensions. If the // dimensions weren't initialized by the header, then the first image // on which this function is called is used to initialize. - ::mediapipe::Status ValidateYUVImage(CalculatorContext* cc, - const YUVImage& yuv_image); + absl::Status ValidateYUVImage(CalculatorContext* cc, + const YUVImage& yuv_image); bool has_header_; // True if the input stream has a header. int input_width_; @@ -251,8 +251,7 @@ ScaleImageCalculator::ScaleImageCalculator() {} ScaleImageCalculator::~ScaleImageCalculator() {} -::mediapipe::Status ScaleImageCalculator::InitializeFrameInfo( - CalculatorContext* cc) { +absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) { MP_RETURN_IF_ERROR( scale_image::FindCropDimensions(input_width_, input_height_, // options_.min_aspect_ratio(), // @@ -299,10 +298,10 @@ ScaleImageCalculator::~ScaleImageCalculator() {} .Add(header.release(), Timestamp::PreStream()); cc->Outputs().Tag("VIDEO_HEADER").Close(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) { +absl::Status ScaleImageCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); input_data_id_ = cc->Inputs().GetId("FRAMES", 0); @@ -339,7 +338,7 @@ ScaleImageCalculator::~ScaleImageCalculator() {} // has a header. At this point in the code, the ScaleImageCalculator // config may be changed by the new options at PreStream, so the output // header can't be determined. - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "OVERRIDE_OPTIONS stream can't be used when the main input stream " "has a header."); } @@ -406,10 +405,10 @@ ScaleImageCalculator::~ScaleImageCalculator() {} } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::InitializeFromOptions() { +absl::Status ScaleImageCalculator::InitializeFromOptions() { if (options_.has_input_format()) { input_format_ = options_.input_format(); } else { @@ -423,10 +422,10 @@ ScaleImageCalculator::~ScaleImageCalculator() {} downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::ValidateImageFormats() const { +absl::Status ScaleImageCalculator::ValidateImageFormats() const { RET_CHECK_NE(input_format_, ImageFormat::UNKNOWN) << "The input image format was UNKNOWN."; RET_CHECK_NE(output_format_, ImageFormat::UNKNOWN) @@ -440,10 +439,10 @@ ScaleImageCalculator::~ScaleImageCalculator() {} input_format_ == ImageFormat::YCBCR420P) << "Conversion of the color space (except from " "YCbCr420P to SRGB) is not yet supported."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::ValidateImageFrame( +absl::Status ScaleImageCalculator::ValidateImageFrame( CalculatorContext* cc, const ImageFrame& image_frame) { if (!has_header_) { if (input_width_ != image_frame.Width() || @@ -494,11 +493,11 @@ ScaleImageCalculator::~ScaleImageCalculator() {} image_frame_format_desc, " but expected ", input_format_desc)); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::ValidateYUVImage( - CalculatorContext* cc, const YUVImage& yuv_image) { +absl::Status ScaleImageCalculator::ValidateYUVImage(CalculatorContext* cc, + const YUVImage& yuv_image) { CHECK_EQ(input_format_, ImageFormat::YCBCR420P); if (!has_header_) { if (input_width_ != yuv_image.width() || @@ -528,14 +527,14 @@ ScaleImageCalculator::~ScaleImageCalculator() {} input_width_, "x", input_height_)); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { +absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { if (cc->Inputs().Tag("OVERRIDE_OPTIONS").IsEmpty()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The OVERRIDE_OPTIONS input stream must be non-empty at PreStream " "time if used."); } @@ -549,7 +548,7 @@ ScaleImageCalculator::~ScaleImageCalculator() {} input_video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); } if (cc->Inputs().Get(input_data_id_).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -603,7 +602,7 @@ ScaleImageCalculator::~ScaleImageCalculator() {} cc->Outputs() .Get(output_data_id_) .Add(output_image.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } else { image_frame = &cc->Inputs().Get(input_data_id_).Get(); @@ -664,7 +663,7 @@ ScaleImageCalculator::~ScaleImageCalculator() {} .Add(output_frame.release(), cc->InputTimestamp()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Rescale the image frame. @@ -698,7 +697,7 @@ ScaleImageCalculator::~ScaleImageCalculator() {} cc->Outputs() .Get(output_data_id_) .Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 3225521a5..738e83da0 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -35,11 +35,11 @@ double ParseRational(const std::string& rational) { } } // namespace -::mediapipe::Status FindCropDimensions(int input_width, int input_height, // - const std::string& min_aspect_ratio, // - const std::string& max_aspect_ratio, // - int* crop_width, int* crop_height, // - int* col_start, int* row_start) { +absl::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start) { CHECK(crop_width); CHECK(crop_height); CHECK(col_start); @@ -85,17 +85,16 @@ double ParseRational(const std::string& rational) { CHECK_LE(*crop_width, input_width); CHECK_LE(*crop_height, input_height); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FindOutputDimensions(int input_width, // - int input_height, // - int target_width, // - int target_height, // - bool preserve_aspect_ratio, // - int scale_to_multiple_of, // - int* output_width, - int* output_height) { +absl::Status FindOutputDimensions(int input_width, // + int input_height, // + int target_width, // + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height) { CHECK(output_width); CHECK(output_height); @@ -123,7 +122,7 @@ double ParseRational(const std::string& rational) { *output_width = target_width; *output_height = target_height; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (target_width > 0) { @@ -140,7 +139,7 @@ double ParseRational(const std::string& rational) { // was within the image, so use these dimensions. *output_width = try_width; *output_height = try_height; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -158,7 +157,7 @@ double ParseRational(const std::string& rational) { // was within the image, so use these dimensions. *output_width = try_width; *output_height = try_height; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } RET_CHECK_FAIL() diff --git a/mediapipe/calculators/image/scale_image_utils.h b/mediapipe/calculators/image/scale_image_utils.h index ea9dd3f0f..c2c0b8f7c 100644 --- a/mediapipe/calculators/image/scale_image_utils.h +++ b/mediapipe/calculators/image/scale_image_utils.h @@ -28,11 +28,11 @@ namespace scale_image { // is a centered, cropped portion of the image that falls within the min // and max aspect ratio. If either the min or max aspect ratio argument // is empty or has a 0 in the numerator or denominator then it is ignored. -::mediapipe::Status FindCropDimensions(int input_width, int input_height, // - const std::string& min_aspect_ratio, // - const std::string& max_aspect_ratio, // - int* crop_width, int* crop_height, // - int* col_start, int* row_start); +absl::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start); // Given an input width and height, a target width and height, whether to // preserve the aspect ratio, and whether to round-down to the multiple of a @@ -43,12 +43,12 @@ namespace scale_image { // output_height will be reduced as necessary to preserve_aspect_ratio if the // option is specified. If preserving the aspect ratio is desired, you must set // scale_to_multiple_of to 2. -::mediapipe::Status FindOutputDimensions(int input_width, int input_height, // - int target_width, - int target_height, // - bool preserve_aspect_ratio, // - int scale_to_multiple_of, // - int* output_width, int* output_height); +absl::Status FindOutputDimensions(int input_width, int input_height, // + int target_width, + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height); } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 31de1e21a..08c150d21 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -25,11 +25,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -87,18 +87,18 @@ class SetAlphaCalculator : public CalculatorBase { SetAlphaCalculator() = default; ~SetAlphaCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status RenderGpu(CalculatorContext* cc); - ::mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); - ::mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlSetup(CalculatorContext* cc); void GlRender(CalculatorContext* cc); mediapipe::SetAlphaCalculatorOptions options_; @@ -106,81 +106,81 @@ class SetAlphaCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(SetAlphaCalculator); -::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { +absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { - return ::mediapipe::InternalError("Cannot have multiple input images."); + return absl::InternalError("Cannot have multiple input images."); } if (cc->Inputs().HasTag(kInputFrameTagGpu) != cc->Outputs().HasTag(kOutputFrameTagGpu)) { - return ::mediapipe::InternalError("GPU output must have GPU input."); + return absl::InternalError("GPU output must have GPU input."); } // Input image to add/edit alpha channel. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input alpha image mask (optional) -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { cc->Inputs().Tag(kInputAlphaTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputAlphaTag)) { cc->Inputs().Tag(kInputAlphaTag).Set(); } // RGBA output image. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } // Get global value from options (-1 if not set). @@ -193,48 +193,47 @@ REGISTER_CALCULATOR(SetAlphaCalculator); RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #endif } // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(GlSetup(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return ::mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status SetAlphaCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Setup source image @@ -294,14 +293,14 @@ REGISTER_CALCULATOR(SetAlphaCalculator); .Tag(kOutputFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Setup source texture. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -354,13 +353,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void SetAlphaCalculator::GlRender(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -409,11 +408,11 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -466,9 +465,9 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/sobel_edges_calculator.cc b/mediapipe/calculators/image/sobel_edges_calculator.cc index e710a99f5..6154a246b 100644 --- a/mediapipe/calculators/image/sobel_edges_calculator.cc +++ b/mediapipe/calculators/image/sobel_edges_calculator.cc @@ -27,10 +27,9 @@ namespace mediapipe { // See GlSimpleCalculatorBase for inputs, outputs and input side packets. class SobelEdgesCalculator : public GlSimpleCalculator { public: - ::mediapipe::Status GlSetup() override; - ::mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - ::mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -40,7 +39,7 @@ class SobelEdgesCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(SobelEdgesCalculator); -::mediapipe::Status SobelEdgesCalculator::GlSetup() { +absl::Status SobelEdgesCalculator::GlSetup() { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, @@ -166,11 +165,11 @@ REGISTER_CALCULATOR(SobelEdgesCalculator); frame_ = glGetUniformLocation(program_, "inputImage"); pixel_w_ = glGetUniformLocation(program_, "pixelW"); pixel_h_ = glGetUniformLocation(program_, "pixelH"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SobelEdgesCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status SobelEdgesCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -225,15 +224,15 @@ REGISTER_CALCULATOR(SobelEdgesCalculator); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SobelEdgesCalculator::GlTeardown() { +absl::Status SobelEdgesCalculator::GlTeardown() { if (program_) { glDeleteProgram(program_); program_ = 0; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/internal/callback_packet_calculator.cc b/mediapipe/calculators/internal/callback_packet_calculator.cc index e9f85ee83..cc153483e 100644 --- a/mediapipe/calculators/internal/callback_packet_calculator.cc +++ b/mediapipe/calculators/internal/callback_packet_calculator.cc @@ -50,7 +50,7 @@ void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) { // while that pointer is still alive. class CallbackPacketCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); switch (options.type()) { case CallbackPacketCalculatorOptions::VECTOR_PACKET: @@ -60,17 +60,17 @@ class CallbackPacketCalculator : public CalculatorBase { .Set>(); break; default: - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Invalid type of callback to produce."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options(); void* ptr; if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Stored pointer value in options is invalid."; } switch (options.type()) { @@ -87,14 +87,14 @@ class CallbackPacketCalculator : public CalculatorBase { std::placeholders::_1))); break; default: - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Invalid type to dump into."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD new file mode 100644 index 000000000..5a0631007 --- /dev/null +++ b/mediapipe/calculators/tensor/BUILD @@ -0,0 +1,752 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@bazel_skylib//lib:selects.bzl", "selects") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_binary_graph", +) +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") +load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +selects.config_setting_group( + name = "compute_shader_unavailable", + match_any = [ + "//mediapipe/gpu:disable_gpu", + ], +) + +mediapipe_proto_library( + name = "inference_calculator_proto", + srcs = ["inference_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "inference_calculator_interface", + srcs = ["inference_calculator.cc"], + hdrs = ["inference_calculator.h"], + copts = select({ + # TODO: fix tensor.h not to require this, if possible + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + ":inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/framework/tool:subgraph_expansion", + "//mediapipe/util/tflite:config", + "//mediapipe/util/tflite:tflite_model_loader", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_gl", + srcs = ["inference_calculator_gl.cc"], + tags = ["nomac"], # config problem with cpuinfo via TF + deps = [ + "inference_calculator_interface", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/util/tflite:tflite_gpu_runner", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_metal", + srcs = ["inference_calculator_metal.cc"], + copts = [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + linkopts = [ + "-framework CoreVideo", + "-framework MetalKit", + ], + tags = ["ios"], + deps = [ + "inference_calculator_interface", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_cpu", + srcs = [ + "inference_calculator_cpu.cc", + ], + copts = select({ + # TODO: fix tensor.h not to require this, if possible + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + ":inference_calculator_interface", + "@com_google_absl//absl/memory", + "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ] + select({ + "//conditions:default": [ + "//mediapipe/util:cpu_util", + ], + }) + select({ + "//conditions:default": [], + "//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"], + }), + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_gl_if_compute_shader_available", + deps = select({ + ":compute_shader_unavailable": [], + "//conditions:default": [":inference_calculator_gl"], + }), +) + +cc_library( + name = "inference_calculator", + visibility = ["//visibility:public"], + deps = [ + ":inference_calculator_interface", + ":inference_calculator_cpu", + ] + select({ + "//conditions:default": [":inference_calculator_gl_if_compute_shader_available"], + "//mediapipe:ios": [":inference_calculator_metal"], + }), + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tensor_converter_calculator_proto", + srcs = ["tensor_converter_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "tensor_converter_calculator", + srcs = ["tensor_converter_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + features = ["-layering_check"], # allow depending on tensor_converter_calculator_gpu_deps + linkopts = select({ + "//mediapipe:apple": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensor_converter_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework:port", + "//mediapipe/util:resource_util", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": ["tensor_converter_calculator_gpu_deps"], + }), + alwayslink = 1, +) + +cc_library( + name = "tensor_converter_calculator_gpu_deps", + deps = select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + ], + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + ], + "//mediapipe:macos": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:shader_util", + "//mediapipe/gpu:gpu_buffer", + ], + }), +) + +cc_test( + name = "tensor_converter_calculator_test", + srcs = ["tensor_converter_calculator_test.cc"], + deps = [ + ":tensor_converter_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +mediapipe_proto_library( + name = "tensors_to_detections_calculator_proto", + srcs = ["tensors_to_detections_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "tensors_to_detections_calculator", + srcs = ["tensors_to_detections_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + features = ["-layering_check"], # allow depending on tensors_to_detections_calculator_gpu_deps + linkopts = select({ + "//mediapipe:apple": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensors_to_detections_calculator_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/api2:node", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:port", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:ret_check", + ] + select({ + ":compute_shader_unavailable": [], + "//conditions:default": [":tensors_to_detections_calculator_gpu_deps"], + }), + alwayslink = 1, +) + +cc_library( + name = "tensors_to_detections_calculator_gpu_deps", + deps = select({ + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:MPPMetalHelper", + ], + "//mediapipe:macos": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + ], + }), +) + +mediapipe_proto_library( + name = "tensors_to_landmarks_calculator_proto", + srcs = ["tensors_to_landmarks_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "tensors_to_landmarks_calculator", + srcs = ["tensors_to_landmarks_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensors_to_landmarks_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tensors_to_floats_calculator_proto", + srcs = ["tensors_to_floats_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "tensors_to_floats_calculator", + srcs = ["tensors_to_floats_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensors_to_floats_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_test( + name = "tensors_to_floats_calculator_test", + srcs = ["tensors_to_floats_calculator_test.cc"], + deps = [ + ":tensors_to_floats_calculator", + ":tensors_to_floats_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tensors_to_classification_calculator", + srcs = ["tensors_to_classification_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensors_to_classification_calculator_cc_proto", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/formats:tensor", + "//mediapipe/util:resource_util", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:ios": [ + "//mediapipe/util/android/file/base", + ], + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tensors_to_classification_calculator_proto", + srcs = ["tensors_to_classification_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_test( + name = "tensors_to_classification_calculator_test", + srcs = ["tensors_to_classification_calculator_test.cc"], + data = ["testdata/labelmap.txt"], + deps = [ + ":tensors_to_classification_calculator", + ":tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "image_to_tensor_calculator", + srcs = ["image_to_tensor_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps + visibility = ["//visibility:public"], + deps = [ + ":image_to_tensor_calculator_cc_proto", + ":image_to_tensor_converter", + ":image_to_tensor_converter_opencv", + ":image_to_tensor_utils", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:port", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [":image_to_tensor_calculator_gpu_deps"], + }), + alwayslink = 1, +) + +cc_library( + name = "image_to_tensor_calculator_gpu_deps", + deps = select({ + "//mediapipe:android": [ + ":image_to_tensor_converter_gl_buffer", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + ], + "//mediapipe:apple": [ + ":image_to_tensor_converter_metal", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/gpu:gpu_buffer", + ], + "//conditions:default": [ + ":image_to_tensor_converter_gl_buffer", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + ], + }), +) + +mediapipe_proto_library( + name = "image_to_tensor_calculator_proto", + srcs = ["image_to_tensor_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_test( + name = "image_to_tensor_calculator_test", + srcs = ["image_to_tensor_calculator_test.cc"], + data = [ + "testdata/image_to_tensor/input.jpg", + "testdata/image_to_tensor/large_sub_rect.png", + "testdata/image_to_tensor/large_sub_rect_border_zero.png", + "testdata/image_to_tensor/large_sub_rect_keep_aspect.png", + "testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png", + "testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png", + "testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png", + "testdata/image_to_tensor/medium_sub_rect_keep_aspect.png", + "testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png", + "testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png", + "testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", + "testdata/image_to_tensor/medium_sub_rect_with_rotation.png", + "testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", + "testdata/image_to_tensor/noop_except_range.png", + ], + deps = [ + ":image_to_tensor_calculator", + ":image_to_tensor_converter", + ":image_to_tensor_utils", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "image_to_tensor_converter", + hdrs = ["image_to_tensor_converter.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + ":image_to_tensor_utils", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:statusor", + ], +) + +cc_library( + name = "image_to_tensor_converter_opencv", + srcs = ["image_to_tensor_converter_opencv.cc"], + hdrs = ["image_to_tensor_converter_opencv.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + ":image_to_tensor_converter", + ":image_to_tensor_utils", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + ], +) + +cc_library( + name = "image_to_tensor_converter_gl_buffer", + srcs = ["image_to_tensor_converter_gl_buffer.cc"], + hdrs = ["image_to_tensor_converter_gl_buffer.h"], + deps = ["//mediapipe/framework:port"] + select({ + "//mediapipe:apple": [], + "//conditions:default": [ + ":image_to_tensor_converter", + ":image_to_tensor_converter_gl_utils", + ":image_to_tensor_utils", + "@com_google_absl//absl/strings", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/framework/formats:image", + "//mediapipe/gpu:gpu_buffer_format", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:command_queue", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_call", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:request_gpu_info", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:variable", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util", + ], + }), +) + +cc_library( + name = "image_to_tensor_converter_gl_texture", + srcs = ["image_to_tensor_converter_gl_texture.cc"], + hdrs = ["image_to_tensor_converter_gl_texture.h"], + deps = ["//mediapipe/framework:port"] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + ":image_to_tensor_converter", + ":image_to_tensor_converter_gl_utils", + ":image_to_tensor_utils", + "@com_google_absl//absl/strings", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/framework/formats:image", + "//mediapipe/gpu:shader_util", + ], + }), +) + +cc_library( + name = "image_to_tensor_converter_gl_utils", + srcs = ["image_to_tensor_converter_gl_utils.cc"], + hdrs = ["image_to_tensor_converter_gl_utils.h"], + deps = ["//mediapipe/framework:port"] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_base", + "//mediapipe/gpu:gl_context", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + ], + }), +) + +cc_library( + name = "image_to_tensor_converter_metal", + srcs = ["image_to_tensor_converter_metal.cc"], + hdrs = ["image_to_tensor_converter_metal.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:apple": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), + deps = ["//mediapipe/framework:port"] + select({ + "//mediapipe:apple": [ + ":image_to_tensor_converter", + ":image_to_tensor_utils", + "//mediapipe/gpu:MPPMetalHelper", + "@com_google_absl//absl/strings", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/formats:image", + "//mediapipe/gpu:gpu_buffer_format", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", + ], + "//conditions:default": [], + }), +) + +cc_library( + name = "image_to_tensor_utils", + srcs = ["image_to_tensor_utils.cc"], + hdrs = ["image_to_tensor_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "image_to_tensor_utils_test", + srcs = ["image_to_tensor_utils_test.cc"], + deps = [ + ":image_to_tensor_utils", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc new file mode 100644 index 000000000..91eba2de5 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -0,0 +1,322 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_buffer.h" + +#if MEDIAPIPE_METAL_ENABLED +#include "mediapipe/calculators/tensor/image_to_tensor_converter_metal.h" +#include "mediapipe/gpu/MPPMetalHelper.h" +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#else +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // MEDIAPIPE_METAL_ENABLED + +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { +namespace api2 { + +#if MEDIAPIPE_DISABLE_GPU +// Just a placeholder to not have to depend on mediapipe::GpuBuffer. +using GpuBuffer = AnyType; +#else +using GpuBuffer = mediapipe::GpuBuffer; +#endif // MEDIAPIPE_DISABLE_GPU + +// Converts image into Tensor, possibly with cropping, resizing and +// normalization, according to specified inputs and options. +// +// Inputs: +// IMAGE - Image[ImageFormat::SRGB / SRGBA, GpuBufferFormat::kBGRA32] or +// ImageFrame [ImageFormat::SRGB/SRGBA] (for backward compatibility +// with existing graphs that use IMAGE for ImageFrame input) +// IMAGE_GPU - GpuBuffer [GpuBufferFormat::kBGRA32] +// Image to extract from. +// +// Note: +// - One and only one of IMAGE and IMAGE_GPU should be specified. +// - IMAGE input of type Image is processed on GPU if the data is already on +// GPU (i.e., Image::UsesGpu() returns true), or otherwise processed on CPU. +// - IMAGE input of type ImageFrame is always processed on CPU. +// - IMAGE_GPU input (of type GpuBuffer) is always processed on GPU. +// +// NORM_RECT - NormalizedRect @Optional +// Describes region of image to extract. +// @Optional: rect covering the whole image is used if not specified. +// +// Outputs: +// TENSORS - std::vector +// Vector containing a single Tensor populated with an extrated RGB image. +// MATRIX - std::array @Optional +// An std::array representing a 4x4 row-major-order matrix which +// can be used to map a point on the output tensor to a point on the input +// image. +// LETTERBOX_PADDING - std::array @Optional +// An std::array representing the letterbox padding from the 4 +// sides ([left, top, right, bottom]) of the output image, normalized to +// [0.f, 1.f] by the output dimensions. The padding values are non-zero only +// when the "keep_aspect_ratio" is true. +// +// For instance, when the input image is 10x10 (width x height) and the +// output dimensions specified in the calculator option are 20x40 and +// "keep_aspect_ratio" is true, the calculator scales the input image to +// 20x20 and places it in the middle of the output image with an equal +// padding of 10 pixels at the top and the bottom. The resulting array is +// therefore [0.f, 0.25f, 0.f, 0.25f] (10/40 = 0.25f). +// +// Example: +// node { +// calculator: "ImageToTensorCalculator" +// input_stream: "IMAGE:image" # or "IMAGE_GPU:image" +// input_stream: "NORM_RECT:roi" +// output_stream: "TENSORS:tensors" +// output_stream: "MATRIX:matrix" +// options { +// [mediapipe.ImageToTensorCalculatorOptions.ext] { +// output_tensor_width: 256 +// output_tensor_height: 256 +// keep_aspect_ratio: false +// output_tensor_float_range { +// min: 0.0 +// max: 1.0 +// } +// # gpu_origin: CONVENTIONAL # or TOP_LEFT +// } +// } +// } +class ImageToTensorCalculator : public Node { + public: + static constexpr Input< + OneOf>::Optional kIn{"IMAGE"}; + static constexpr Input::Optional kInGpu{"IMAGE_GPU"}; + static constexpr Input::Optional kInNormRect{ + "NORM_RECT"}; + static constexpr Output> kOutTensors{"TENSORS"}; + static constexpr Output>::Optional kOutLetterboxPadding{ + "LETTERBOX_PADDING"}; + static constexpr Output>::Optional kOutMatrix{"MATRIX"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kInGpu, kInNormRect, kOutTensors, + kOutLetterboxPadding, kOutMatrix); + + static absl::Status UpdateContract(CalculatorContract* cc) { + const auto& options = + cc->Options(); + + RET_CHECK(options.has_output_tensor_float_range()) + << "Output tensor range is required."; + RET_CHECK_LT(options.output_tensor_float_range().min(), + options.output_tensor_float_range().max()) + << "Valid output tensor range is required."; + RET_CHECK_GT(options.output_tensor_width(), 0) + << "Valid output tensor width is required."; + RET_CHECK_GT(options.output_tensor_height(), 0) + << "Valid output tensor height is required."; + + RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected()) + << "One and only one of IMAGE and IMAGE_GPU input is expected."; + +#if MEDIAPIPE_DISABLE_GPU + if (kInGpu(cc).IsConnected()) { + return absl::UnimplementedError( + "GPU processing is disabled in build flags"); + } +#else // !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#else + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) { + options_ = cc->Options(); + output_width_ = options_.output_tensor_width(); + output_height_ = options_.output_tensor_height(); + range_min_ = options_.output_tensor_float_range().min(); + range_max_ = options_.output_tensor_float_range().max(); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) { + if ((kIn(cc).IsConnected() && kIn(cc).IsEmpty()) || + (kInGpu(cc).IsConnected() && kInGpu(cc).IsEmpty())) { + // Timestamp bound update happens automatically. + return absl::OkStatus(); + } + + absl::optional norm_rect; + if (kInNormRect(cc).IsConnected()) { + if (kInNormRect(cc).IsEmpty()) { + // Timestamp bound update happens automatically. (See Open().) + return absl::OkStatus(); + } + norm_rect = *kInNormRect(cc); + if (norm_rect->width() == 0 && norm_rect->height() == 0) { + // WORKAROUND: some existing graphs may use sentinel rects {width=0, + // height=0, ...} quite often and calculator has to handle them + // gracefully by updating timestamp bound instead of returning failure. + // Timestamp bound update happens automatically. (See Open().) + // NOTE: usage of sentinel rects should be avoided. + DLOG(WARNING) + << "Updating timestamp bound in response to a sentinel rect"; + return absl::OkStatus(); + } + } + + ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); + const Size size{image->width(), image->height()}; + RotatedRect roi = GetRoi(size.width, size.height, norm_rect); + ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), + options_.output_tensor_height(), + options_.keep_aspect_ratio(), &roi)); + if (kOutLetterboxPadding(cc).IsConnected()) { + kOutLetterboxPadding(cc).Send(padding); + } + if (kOutMatrix(cc).IsConnected()) { + std::array matrix; + GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height, + /*flip_horizontaly=*/false, + &matrix); + kOutMatrix(cc).Send(std::move(matrix)); + } + + // Lazy initialization of the GPU or CPU converter. + MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, image->UsesGpu())); + + ASSIGN_OR_RETURN(Tensor tensor, + (image->UsesGpu() ? gpu_converter_ : cpu_converter_) + ->Convert(*image, roi, {output_width_, output_height_}, + range_min_, range_max_)); + + auto result = std::make_unique>(); + result->push_back(std::move(tensor)); + kOutTensors(cc).Send(std::move(result)); + + return absl::OkStatus(); + } + + private: + bool DoesInputStartAtBottom() { + return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; + } + + BorderMode GetBorderMode() { + switch (options_.border_mode()) { + case mediapipe:: + ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED: + return BorderMode::kReplicate; + case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO: + return BorderMode::kZero; + case mediapipe:: + ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE: + return BorderMode::kReplicate; + } + } + + absl::StatusOr> GetInputImage( + CalculatorContext* cc) { + if (kIn(cc).IsConnected()) { + const auto& packet = kIn(cc).packet(); + return kIn(cc).Visit( + [&packet](const mediapipe::Image&) { + return SharedPtrWithPacket(packet); + }, + [&packet](const mediapipe::ImageFrame&) { + return std::make_shared( + std::const_pointer_cast( + SharedPtrWithPacket(packet))); + }); + } else { // if (kInGpu(cc).IsConnected()) +#if !MEDIAPIPE_DISABLE_GPU + const GpuBuffer& input = *kInGpu(cc); + // A shallow copy is okay since the resulting 'image' object is local in + // Process(), and thus never outlives 'input'. + return std::make_shared(input); +#else + return absl::UnimplementedError( + "GPU processing is disabled in build flags"); +#endif // !MEDIAPIPE_DISABLE_GPU + } + } + + absl::Status InitConverterIfNecessary(CalculatorContext* cc, bool use_gpu) { + // Lazy initialization of the GPU or CPU converter. + if (use_gpu) { + if (!gpu_converter_) { +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_METAL_ENABLED + ASSIGN_OR_RETURN(gpu_converter_, + CreateMetalConverter(cc, GetBorderMode())); +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + ASSIGN_OR_RETURN(gpu_converter_, + CreateImageToGlBufferTensorConverter( + cc, DoesInputStartAtBottom(), GetBorderMode())); +#else + ASSIGN_OR_RETURN(gpu_converter_, + CreateImageToGlTextureTensorConverter( + cc, DoesInputStartAtBottom(), GetBorderMode())); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + } else { + if (!cpu_converter_) { + ASSIGN_OR_RETURN(cpu_converter_, + CreateOpenCvConverter(cc, GetBorderMode())); + } + } + return absl::OkStatus(); + } + + std::unique_ptr gpu_converter_; + std::unique_ptr cpu_converter_; + mediapipe::ImageToTensorCalculatorOptions options_; + int output_width_ = 0; + int output_height_ = 0; + float range_min_ = 0.0f; + float range_max_ = 1.0f; +}; + +MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto new file mode 100644 index 000000000..77fb1eb46 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto @@ -0,0 +1,79 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message GpuOrigin { + enum Mode { + DEFAULT = 0; + + // OpenGL: bottom-left origin + // Metal : top-left origin + CONVENTIONAL = 1; + + // OpenGL: top-left origin + // Metal : top-left origin + TOP_LEFT = 2; + } +} + +message ImageToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional ImageToTensorCalculatorOptions ext = 334361939; + } + + // Range of float values [min, max]. + // min, must be strictly less than max. + message FloatRange { + optional float min = 1; + optional float max = 2; + } + + // Pixel extrapolation methods. See @border_mode. + enum BorderMode { + BORDER_UNSPECIFIED = 0; + BORDER_ZERO = 1; + BORDER_REPLICATE = 2; + } + + optional int32 output_tensor_width = 1; + optional int32 output_tensor_height = 2; + + // If true, image region will be extracted and copied into tensor keeping + // region aspect ratio, which usually results in letterbox padding. Otherwise, + // if false, image region is stretched to fill output tensor fully. + optional bool keep_aspect_ratio = 3; + + // Output tensor element range/type image pixels are converted to. + oneof range { + FloatRange output_tensor_float_range = 4; + } + + // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs + // to be flipped vertically as tensors are expected to start at top. + // (DEFAULT or unset interpreted as CONVENTIONAL.) + optional GpuOrigin.Mode gpu_origin = 5; + + // Pixel extrapolation method. + // When converting image to tensor it may happen that tensor needs to read + // pixels outside image boundaries. Border mode helps to specify how such + // pixels will be calculated. + // + // BORDER_REPLICATE is used by default. + optional BorderMode border_mode = 6; +} diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc new file mode 100644 index 000000000..233424720 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -0,0 +1,437 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +cv::Mat GetRgb(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +// Image to tensor test template. +// No processing/assertions should be done after the function is invoked. +void RunTestWithInputImagePacket(const Packet& input_image_packet, + cv::Mat expected_result, float range_min, + float range_max, int tensor_width, + int tensor_height, bool keep_aspect, + absl::optional border_mode, + const mediapipe::NormalizedRect& roi) { + std::string border_mode_str; + if (border_mode) { + switch (*border_mode) { + case BorderMode::kReplicate: + border_mode_str = "border_mode: BORDER_REPLICATE"; + break; + case BorderMode::kZero: + border_mode_str = "border_mode: BORDER_ZERO"; + break; + } + } + auto graph_config = mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "input_image" + input_stream: "roi" + node { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:input_image" + input_stream: "NORM_RECT:roi" + output_stream: "TENSORS:tensor" + options { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: $0 + output_tensor_height: $1 + keep_aspect_ratio: $4 + output_tensor_float_range { + min: $2 + max: $3 + } + $5 # border mode + } + } + } + )", + /*$0=*/tensor_width, + /*$1=*/tensor_height, + /*$2=*/range_min, + /*$3=*/range_max, + /*$4=*/keep_aspect ? "true" : "false", + /*$5=*/border_mode_str)); + + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream("input_image", input_image_packet)); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "roi", + MakePacket(std::move(roi)).At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_packets, testing::SizeIs(1)); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + ASSERT_THAT(tensor_vec, testing::SizeIs(1)); + + const Tensor& tensor = tensor_vec[0]; + EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); + + auto view = tensor.GetCpuReadView(); + cv::Mat tensor_mat(tensor_height, tensor_width, CV_32FC3, + const_cast(view.buffer())); + cv::Mat result_rgb; + auto transformation = + GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); + tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, + transformation.offset); + + cv::Mat diff; + cv::absdiff(result_rgb, expected_result, diff); + double max_val; + cv::minMaxLoc(diff, nullptr, &max_val); + // Expects the maximum absolute pixel-by-pixel difference is less than 5. + EXPECT_LE(max_val, 5); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.CloseInputStream("roi")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +Packet MakeImageFramePacket(cv::Mat input) { + ImageFrame input_image( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input) { + mediapipe::Image input_image(std::make_shared( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +enum class InputType { kImageFrame, kImage }; + +const std::vector kInputTypesToTest = {InputType::kImageFrame, + InputType::kImage}; + +void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, + float range_max, int tensor_width, int tensor_height, + bool keep_aspect, absl::optional border_mode, + const mediapipe::NormalizedRect& roi) { + for (auto input_type : kInputTypesToTest) { + RunTestWithInputImagePacket( + input_type == InputType::kImageFrame ? MakeImageFramePacket(input) + : MakeImagePacket(input), + expected_result, range_min, range_max, tensor_width, tensor_height, + keep_aspect, border_mode, roi); + } +} + +TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + RunTest( + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); +} + +TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_border_zero.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation.png"), + /*range_min=*/0.0f, /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); +} + +TEST(ImageToTensorCalculatorTest, + MediumSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + /*range_min=*/0.0f, /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + RunTest( + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), + /*range_min=*/-1.0f, + /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); +} + +TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_with_rotation_border_zero.png"), + /*range_min=*/-1.0f, + /*range_max=*/1.0f, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + RunTest( + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + RunTest( + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + RunTest(GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_border_zero.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetRgba("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetRgba("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + RunTest(GetRgba("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); +} + +TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + RunTest(GetRgba("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetRgb("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"), + /*range_min=*/0.0f, + /*range_max=*/1.0f, + /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kZero, roi); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter.h b/mediapipe/calculators/tensor/image_to_tensor_converter.h new file mode 100644 index 000000000..39fd1ee0d --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter.h @@ -0,0 +1,56 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_H_ + +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +struct Size { + int width; + int height; +}; + +// Pixel extrapolation method. +// When converting image to tensor it may happen that tensor needs to read +// pixels outside image boundaries. Border mode helps to specify how such pixels +// will be calculated. +enum class BorderMode { kZero, kReplicate }; + +// Converts image to tensor. +class ImageToTensorConverter { + public: + virtual ~ImageToTensorConverter() = default; + + // Converts image to tensor. + // @image contains image to extract from. + // @roi describes region of interest within the image to extract (absolute + // values). + // @output_dims dimensions of output tensor. + // @range_min/max describes output tensor range image pixels should converted + // to. + virtual absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, + float range_min, float range_max) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc new file mode 100644 index 000000000..c6c9a19f4 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -0,0 +1,347 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" +#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/variable.h" + +namespace mediapipe { + +namespace { + +// Implements a common pattern of extracting a subrect from RGBA input texture +// and resizing it into a buffer. +class SubRectExtractorGl { + public: + // Extracts a region defined by @sub_rect, removes A channel, transforms input + // pixels as alpha * x + beta and resizes result into destination. + absl::Status ExtractSubRectToBuffer( + const tflite::gpu::gl::GlTexture& texture, + const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect, + bool flip_horizontaly, float alpha, float beta, + const tflite::gpu::HW& destination_size, + tflite::gpu::gl::CommandQueue* command_queue, + tflite::gpu::gl::GlBuffer* destination); + + static absl::StatusOr Create( + const mediapipe::GlContext& gl_context, bool input_starts_at_bottom, + BorderMode border_mode); + + private: + explicit SubRectExtractorGl(tflite::gpu::gl::GlProgram program, + tflite::gpu::uint3 workgroup_size, + bool use_custom_zero_border, + BorderMode border_mode) + : program_(std::move(program)), + workgroup_size_(workgroup_size), + use_custom_zero_border_(use_custom_zero_border), + border_mode_(border_mode) {} + + tflite::gpu::gl::GlProgram program_; + tflite::gpu::uint3 workgroup_size_; + bool use_custom_zero_border_ = false; + BorderMode border_mode_ = BorderMode::kReplicate; +}; + +absl::Status SetMat4x4(const tflite::gpu::gl::GlProgram& program, + const std::string& name, float* data) { + GLint uniform_id; + MP_RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_id, + program.id(), name.c_str())); + return TFLITE_GPU_CALL_GL(glProgramUniformMatrix4fv, program.id(), uniform_id, + 1, GL_TRUE, data); +} + +constexpr char kShaderCode[] = R"( +layout(std430) buffer; + +precision highp float; + +// It is possible to use "vec3 elements[];" here, however due to alignment +// requirements it works only when "packed" layout is used. "packed" layout is +// determined by implementation and it's expected that OpenGL API is used to +// query the layout. Favoring float array over vec3, considering performance is +// comparable, layout is the same and no need for layout querying (even though +// it's not quite needed here as there's only one member). +layout(binding = 0) writeonly buffer B0 { + float elements[]; +} output_data; + +uniform ivec2 out_size; +uniform float alpha; +uniform float beta; +uniform mat4 transform_matrix; +uniform mediump sampler2D input_data; + +void main() { + int out_width = out_size.x; + int out_height = out_size.y; + + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= out_width || gid.y >= out_height) { + return; + } + + // transform from image.width, image.height range to [0, 1] + float normal_x = (float(gid.x) + 0.5f) / float(out_width); + float normal_y = (float(gid.y) + 0.5f) / float(out_height); + vec4 tc = vec4(normal_x, normal_y, 0.0, 1.0); + + // Apply transformation from roi coordinates to original image coordinates. + tc = transform_matrix * tc; +#ifdef INPUT_STARTS_AT_BOTTOM + // Opengl texture sampler has origin in lower left corner, + // so we invert y coordinate. + tc.y = 1.0f - tc.y; +#endif // INPUT_STARTS_AT_BOTTOM + vec4 src_value = alpha * texture(input_data, tc.xy) + beta; + +#ifdef CUSTOM_ZERO_BORDER_MODE + float out_of_bounds = + float(tc.x < 0.0 || tc.x > 1.0 || tc.y < 0.0 || tc.y > 1.0); + src_value = mix(src_value, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); +#endif + + int linear_index = gid.y * out_width + gid.x; + + // output_data.elements is populated as though it contains vec3 elements. + int first_component_index = 3 * linear_index; + output_data.elements[first_component_index] = src_value.r; + output_data.elements[first_component_index + 1] = src_value.g; + output_data.elements[first_component_index + 2] = src_value.b; +} +)"; + +absl::Status SubRectExtractorGl::ExtractSubRectToBuffer( + const tflite::gpu::gl::GlTexture& texture, + const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect, + bool flip_horizontaly, float alpha, float beta, + const tflite::gpu::HW& destination_size, + tflite::gpu::gl::CommandQueue* command_queue, + tflite::gpu::gl::GlBuffer* destination) { + std::array transform_mat; + GetRotatedSubRectToRectTransformMatrix(texture_sub_rect, texture_size.w, + texture_size.h, flip_horizontaly, + &transform_mat); + MP_RETURN_IF_ERROR(texture.BindAsSampler2D(0)); + + // a) Filtering. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + + // b) Clamping. + switch (border_mode_) { + case BorderMode::kReplicate: { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + break; + } + case BorderMode::kZero: { + if (!use_custom_zero_border_) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_BORDER); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_BORDER); + glTexParameterfv(GL_TEXTURE_2D, GL_TEXTURE_BORDER_COLOR, + std::array{0.0f, 0.0f, 0.0f, 0.0f}.data()); + } + break; + } + } + + MP_RETURN_IF_ERROR(destination->BindToIndex(0)); + MP_RETURN_IF_ERROR(program_.SetParameter({"input_data", 0})); + MP_RETURN_IF_ERROR( + SetMat4x4(program_, "transform_matrix", transform_mat.data())); + MP_RETURN_IF_ERROR(program_.SetParameter( + {"out_size", tflite::gpu::int2(destination_size.w, destination_size.h)})); + MP_RETURN_IF_ERROR(program_.SetParameter({"alpha", alpha})); + MP_RETURN_IF_ERROR(program_.SetParameter({"beta", beta})); + tflite::gpu::uint3 num_workgroups = tflite::gpu::DivideRoundUp( + tflite::gpu::uint3{destination_size.w, destination_size.h, 1}, + workgroup_size_); + MP_RETURN_IF_ERROR(command_queue->Dispatch(program_, num_workgroups)); + + // Resetting to MediaPipe texture param defaults. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + + return absl::OkStatus(); +} + +absl::StatusOr SubRectExtractorGl::Create( + const mediapipe::GlContext& gl_context, bool input_starts_at_bottom, + BorderMode border_mode) { + bool use_custom_zero_border = border_mode == BorderMode::kZero && + !IsGlClampToBorderSupported(gl_context); + + const tflite::gpu::uint3 workgroup_size = {8, 8, 1}; + std::string starts_at_bottom_def; + if (input_starts_at_bottom) { + starts_at_bottom_def = R"( + #define INPUT_STARTS_AT_BOTTOM; + )"; + } + std::string custom_zero_border_mode_def; + if (use_custom_zero_border) { + custom_zero_border_mode_def = R"( + #define CUSTOM_ZERO_BORDER_MODE + )"; + } + const std::string full_shader_source = absl::StrCat( + tflite::gpu::gl::GetShaderHeader(workgroup_size), starts_at_bottom_def, + custom_zero_border_mode_def, kShaderCode); + + tflite::gpu::gl::GlShader shader; + MP_RETURN_IF_ERROR(tflite::gpu::gl::GlShader::CompileShader( + GL_COMPUTE_SHADER, full_shader_source, &shader)); + tflite::gpu::gl::GlProgram program; + MP_RETURN_IF_ERROR( + tflite::gpu::gl::GlProgram::CreateWithShader(shader, &program)); + + return SubRectExtractorGl(std::move(program), workgroup_size, + use_custom_zero_border, border_mode); +} + +class GlProcessor : public ImageToTensorConverter { + public: + absl::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, + BorderMode border_mode) { + MP_RETURN_IF_ERROR(gl_helper_.Open(cc)); + return gl_helper_.RunInGlContext([this, input_starts_at_bottom, + border_mode]() -> absl::Status { + tflite::gpu::GpuInfo gpu_info; + MP_RETURN_IF_ERROR(tflite::gpu::gl::RequestGpuInfo(&gpu_info)); + RET_CHECK(gpu_info.IsApiOpenGl31OrAbove()) + << "OpenGL ES 3.1 is required."; + command_queue_ = tflite::gpu::gl::NewCommandQueue(gpu_info); + + ASSIGN_OR_RETURN( + auto extractor, + SubRectExtractorGl::Create(gl_helper_.GetGlContext(), + input_starts_at_bottom, border_mode)); + extractor_ = absl::make_unique(std::move(extractor)); + return absl::OkStatus(); + }); + } + + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { + return InvalidArgumentError( + absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", + static_cast(input.format()))); + } + + constexpr int kNumChannels = 3; + Tensor tensor(Tensor::ElementType::kFloat32, + {1, output_dims.height, output_dims.width, kNumChannels}); + + MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi, + &output_dims, range_min, + range_max]() -> absl::Status { + constexpr int kRgbaNumChannels = 4; + auto source_texture = gl_helper_.CreateSourceTexture(input); + tflite::gpu::gl::GlTexture input_texture( + GL_TEXTURE_2D, source_texture.name(), GL_RGBA, + source_texture.width() * source_texture.height() * kRgbaNumChannels * + sizeof(uint8_t), + /*layer=*/0, + /*owned=*/false); + + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 1.0f; + ASSIGN_OR_RETURN( + auto transform, + GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, + range_min, range_max)); + + auto buffer_view = tensor.GetOpenGlBufferWriteView(); + tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, + buffer_view.name(), tensor.bytes(), + /*offset=*/0, + /*has_ownership=*/false); + MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( + input_texture, + tflite::gpu::HW(source_texture.height(), source_texture.width()), roi, + /*flip_horizontaly=*/false, transform.scale, transform.offset, + tflite::gpu::HW(output_dims.height, output_dims.width), + command_queue_.get(), &output)); + + return absl::OkStatus(); + })); + + return tensor; + } + + ~GlProcessor() override { + gl_helper_.RunInGlContext([this]() { + // Release OpenGL resources. + extractor_ = nullptr; + command_queue_ = nullptr; + }); + } + + private: + std::unique_ptr command_queue_; + std::unique_ptr extractor_; + mediapipe::GlCalculatorHelper gl_helper_; +}; + +} // namespace + +absl::StatusOr> +CreateImageToGlBufferTensorConverter(CalculatorContext* cc, + bool input_starts_at_bottom, + BorderMode border_mode) { + auto result = absl::make_unique(); + MP_RETURN_IF_ERROR(result->Init(cc, input_starts_at_bottom, border_mode)); + + // Simply "return std::move(result)" failed to build on macOS with bazel. + return std::unique_ptr(std::move(result)); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h new file mode 100644 index 000000000..437b16b70 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h @@ -0,0 +1,42 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_BUFFER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_BUFFER_H_ + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +// Creates image to tensor (represented as OpenGL buffer) converter. +// NOTE: mediapipe::GlCalculatorHelper::UpdateContract invocation must precede +// converter creation. +absl::StatusOr> +CreateImageToGlBufferTensorConverter(CalculatorContext* cc, + bool input_starts_at_bottom, + BorderMode border_mode); + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_BUFFER_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc new file mode 100644 index 000000000..26c31eaf5 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -0,0 +1,344 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" + +namespace mediapipe { + +namespace { + +constexpr int kAttribVertex = 0; +constexpr int kAttribTexturePosition = 1; +constexpr int kNumAttributes = 2; + +class GlProcessor : public ImageToTensorConverter { + public: + absl::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, + BorderMode border_mode) { + MP_RETURN_IF_ERROR(gl_helper_.Open(cc)); + return gl_helper_.RunInGlContext([this, input_starts_at_bottom, + border_mode]() -> absl::Status { + use_custom_zero_border_ = + border_mode == BorderMode::kZero && + !IsGlClampToBorderSupported(gl_helper_.GetGlContext()); + border_mode_ = border_mode; + + const GLint attr_location[kNumAttributes] = { + kAttribVertex, + kAttribTexturePosition, + }; + const GLchar* attr_name[kNumAttributes] = { + "position", + "texture_coordinate", + }; + + constexpr GLchar kExtractSubRectVertexShader[] = R"( + in vec4 position; + in mediump vec4 texture_coordinate; + out mediump vec2 sample_coordinate; + uniform mat4 transform_matrix; + + void main() { + gl_Position = position; + // Apply transformation from roi coordinates to original image coordinates. + vec4 tc = transform_matrix * texture_coordinate; + #ifdef INPUT_STARTS_AT_BOTTOM + // Opengl texture sampler has origin in lower left corner, + // so we invert y coordinate. + tc.y = 1.0 - tc.y; + #endif // defined(INPUT_STARTS_AT_BOTTOM) + sample_coordinate = tc.xy; + } + )"; + + constexpr GLchar kExtractSubRectFragBody[] = R"( + DEFAULT_PRECISION(mediump, float) + + // Provided by kExtractSubRectVertexShader. + in vec2 sample_coordinate; + + uniform sampler2D input_texture; + uniform float alpha; + uniform float beta; + + #ifdef GL_ES + #define fragColor gl_FragColor + #else + out vec4 fragColor; + #endif // defined(GL_ES); + + void main() { + vec4 color = texture2D(input_texture, sample_coordinate); + #ifdef CUSTOM_ZERO_BORDER_MODE + float out_of_bounds = + float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || + sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); + color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); + #endif // defined(CUSTOM_ZERO_BORDER_MODE) + fragColor = alpha * color + beta; + } + )"; + + std::string starts_at_bottom_def; + if (input_starts_at_bottom) { + starts_at_bottom_def = R"( + #define INPUT_STARTS_AT_BOTTOM + )"; + } + + // Create program and set parameters. + const std::string extract_sub_rect_vertex_src = + absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, + starts_at_bottom_def, kExtractSubRectVertexShader); + + std::string custom_zero_border_mode_def; + if (use_custom_zero_border_) { + custom_zero_border_mode_def = R"( + #define CUSTOM_ZERO_BORDER_MODE + )"; + } + const std::string extract_sub_rect_frag_src = + absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, + custom_zero_border_mode_def, kExtractSubRectFragBody); + mediapipe::GlhCreateProgram(extract_sub_rect_vertex_src.c_str(), + extract_sub_rect_frag_src.c_str(), + kNumAttributes, &attr_name[0], attr_location, + &program_); + + RET_CHECK(program_) << "Problem initializing image to tensor program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "input_texture"), 1); + alpha_id_ = glGetUniformLocation(program_, "alpha"); + beta_id_ = glGetUniformLocation(program_, "beta"); + matrix_id_ = glGetUniformLocation(program_, "transform_matrix"); + + glGenFramebuffers(1, &framebuffer_); + + // vertex storage + glGenBuffers(2, vbo_); + glGenVertexArrays(1, &vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicSquareVertices), + mediapipe::kBasicSquareVertices, GL_STATIC_DRAW); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicTextureVertices), + mediapipe::kBasicTextureVertices, GL_STATIC_DRAW); + + glBindBuffer(GL_ARRAY_BUFFER, 0); + + return absl::OkStatus(); + }); + } + + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { + return InvalidArgumentError( + absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", + static_cast(input.format()))); + } + + constexpr int kNumChannels = 3; + Tensor tensor( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels}); + + MP_RETURN_IF_ERROR( + gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims, + range_min, range_max]() -> absl::Status { + auto input_texture = gl_helper_.CreateSourceTexture(input); + + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 1.0f; + ASSIGN_OR_RETURN(auto transform, + GetValueRangeTransformation(kInputImageRangeMin, + kInputImageRangeMax, + range_min, range_max)); + auto tensor_view = tensor.GetOpenGlTexture2dWriteView(); + MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, + /*flip_horizontaly=*/false, + transform.scale, transform.offset, + output_dims, &tensor_view)); + return absl::OkStatus(); + })); + + return tensor; + } + + absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, + const RotatedRect& sub_rect, + bool flip_horizontaly, float alpha, float beta, + const Size& output_dims, + Tensor::OpenGlTexture2dView* output) { + std::array transform_mat; + + glDisable(GL_DEPTH_TEST); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, output_dims.width, output_dims.height); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, output->name()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + output->name(), 0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(texture.target(), texture.name()); + + // a) Filtering. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + + // b) Clamping. + switch (border_mode_) { + case BorderMode::kReplicate: { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + break; + } + case BorderMode::kZero: { + if (!use_custom_zero_border_) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_BORDER); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_BORDER); + glTexParameterfv(GL_TEXTURE_2D, GL_TEXTURE_BORDER_COLOR, + std::array{0.0f, 0.0f, 0.0f, 0.0f}.data()); + } + break; + } + } + + glUseProgram(program_); + glUniform1f(alpha_id_, alpha); + glUniform1f(beta_id_, beta); + + // If our context is ES2, then we must use GL_FALSE for our 'transpose' + // GLboolean in glUniformMatrix4fv, or else we'll get an INVALID_VALUE + // error. So in that case, we'll grab the transpose of our original matrix + // and send that instead. + const auto gl_context = mediapipe::GlContext::GetCurrent(); + LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; + if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { + GetTransposedRotatedSubRectToRectTransformMatrix( + sub_rect, texture.width(), texture.height(), flip_horizontaly, + &transform_mat); + glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data()); + } else { + GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(), + texture.height(), flip_horizontaly, + &transform_mat); + glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data()); + } + + // vao + glBindVertexArray(vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glEnableVertexAttribArray(kAttribVertex); + glVertexAttribPointer(kAttribVertex, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glEnableVertexAttribArray(kAttribTexturePosition); + glVertexAttribPointer(kAttribTexturePosition, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // Resetting to MediaPipe texture param defaults. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + + glDisableVertexAttribArray(kAttribVertex); + glDisableVertexAttribArray(kAttribTexturePosition); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, 0); + + return absl::OkStatus(); + } + + ~GlProcessor() override { + gl_helper_.RunInGlContext([this]() { + // Release OpenGL resources. + if (framebuffer_ != 0) glDeleteFramebuffers(1, &framebuffer_); + if (program_ != 0) glDeleteProgram(program_); + if (vao_ != 0) glDeleteVertexArrays(1, &vao_); + glDeleteBuffers(2, vbo_); + }); + } + + private: + mediapipe::GlCalculatorHelper gl_helper_; + bool use_custom_zero_border_ = false; + BorderMode border_mode_ = BorderMode::kReplicate; + GLuint vao_ = 0; + GLuint vbo_[2] = {0, 0}; + GLuint program_ = 0; + GLuint framebuffer_ = 0; + GLint alpha_id_ = 0; + GLint beta_id_ = 0; + GLint matrix_id_ = 0; +}; + +} // namespace + +absl::StatusOr> +CreateImageToGlTextureTensorConverter(CalculatorContext* cc, + bool input_starts_at_bottom, + BorderMode border_mode) { + auto result = absl::make_unique(); + MP_RETURN_IF_ERROR(result->Init(cc, input_starts_at_bottom, border_mode)); + + // Simply "return std::move(result)" failed to build on macOS with bazel. + return std::unique_ptr(std::move(result)); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h new file mode 100644 index 000000000..269abf141 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h @@ -0,0 +1,42 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_TEXTURE_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_TEXTURE_H_ + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +// Creates image to tensor (represented as OpenGL texture) converter. +// NOTE: mediapipe::GlCalculatorHelper::UpdateContract invocation must precede +// converter creation. +absl::StatusOr> +CreateImageToGlTextureTensorConverter(CalculatorContext* cc, + bool input_starts_at_bottom, + BorderMode border_mode); + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_TEXTURE_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc new file mode 100644 index 000000000..6fb39e0c3 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc @@ -0,0 +1,88 @@ +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#include +#include +#include + +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_context.h" + +namespace mediapipe { + +namespace { + +class GlNoOpOverride : public GlOverride {}; + +class GlTexParameteriOverride : public GlOverride { + public: + GlTexParameteriOverride(GLenum name, GLint old_value) + : name_(name), old_value_(old_value) {} + + ~GlTexParameteriOverride() override { + glTexParameteri(GL_TEXTURE_2D, name_, old_value_); + } + + private: + GLenum name_; + GLint old_value_; +}; + +template +class GlTexParameterfvOverride : public GlOverride { + public: + GlTexParameterfvOverride(GLenum name, + std::array old_values) + : name_(name), old_values_(std::move(old_values)) {} + + ~GlTexParameterfvOverride() { + glTexParameterfv(GL_TEXTURE_2D, name_, &old_values_[0]); + } + + private: + GLenum name_; + std::array old_values_; +}; + +} // namespace + +std::unique_ptr OverrideGlTexParametri(GLenum name, GLint value) { + GLint old_value; + glGetTexParameteriv(GL_TEXTURE_2D, name, &old_value); + if (value != old_value) { + glTexParameteri(GL_TEXTURE_2D, name, value); + return {absl::make_unique(name, old_value)}; + } + return {absl::make_unique()}; +} + +template +std::unique_ptr OverrideGlTexParameterfv( + GLenum name, std::array values) { + std::array old_values; + glGetTexParameterfv(GL_TEXTURE_2D, name, values.data()); + if (values != old_values) { + glTexParameterfv(GL_TEXTURE_2D, name, values.data()); + return {absl::make_unique>( + name, std::move(old_values))}; + } + return {absl::make_unique()}; +} + +template std::unique_ptr OverrideGlTexParameterfv<4>( + GLenum name, std::array values); + +bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context) { + return gl_context.gl_major_version() > 3 || + (gl_context.gl_major_version() == 3 && + gl_context.gl_minor_version() >= 2); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h new file mode 100644 index 000000000..3105cfef1 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h @@ -0,0 +1,45 @@ +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_UTILS_H_ + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#include +#include +#include + +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_context.h" + +namespace mediapipe { + +// Intended to override and automatically revert various OpenGL attributes. +// (e.g. overriding texture parameters like GL_TEXTURE_MIN_FILTER, +// GL_TEXTURE_MAG_FILTER, etc.) +class GlOverride { + public: + virtual ~GlOverride() = default; +}; + +// Creates an object that overrides attributes using `glTexParameteri` +// function during construction and reverts them during destruction. See +// `glTexParameteri` for details on @name and @value. +ABSL_MUST_USE_RESULT std::unique_ptr OverrideGlTexParametri( + GLenum name, GLint value); + +// Creates an object that overrides attributes using `glTexParameterfv` +// function during construction and reverts them during destruction. See +// `glTexParameterfv` for details on @name and @values. +template +ABSL_MUST_USE_RESULT std::unique_ptr OverrideGlTexParameterfv( + GLenum name, std::array values); + +bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context); + +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_UTILS_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc new file mode 100644 index 000000000..9482cfc2a --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc @@ -0,0 +1,49 @@ +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 + +#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_context.h" + +namespace mediapipe { +namespace { + +TEST(ImageToTensorConverterGlUtilsTest, GlTexParameteriOverrider) { + auto status_or_context = mediapipe::GlContext::Create(nullptr, false); + MP_ASSERT_OK(status_or_context); + auto context = status_or_context.value(); + + std::vector min_filter_changes; + context->Run([&min_filter_changes]() { + GLuint texture = 0; + glGenTextures(1, &texture); + glBindTexture(GL_TEXTURE_2D, texture); + + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + GLint value = 0; + glGetTexParameteriv(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, &value); + min_filter_changes.push_back(value); + + { + auto min_filter_linear = + OverrideGlTexParametri(GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glGetTexParameteriv(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, &value); + min_filter_changes.push_back(value); + + // reverter is destroyed automatically reverting previously set value + } + glGetTexParameteriv(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, &value); + min_filter_changes.push_back(value); + }); + + EXPECT_THAT(min_filter_changes, + testing::ElementsAre(GL_NEAREST, GL_LINEAR, GL_NEAREST)); +} + +} // namespace +} // namespace mediapipe + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc new file mode 100644 index 000000000..565dd85b9 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -0,0 +1,408 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_converter_metal.h" + +#if MEDIAPIPE_METAL_ENABLED + +#import + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/MPPMetalHelper.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace mediapipe { + +namespace { + +// clang-format off +// a square formed by 2 triangles +const float kBasicSquareVertices[] = { + -1, 1, 0, 1, + 1, 1, 0, 1, + 1, -1, 0, 1, + -1, 1, 0, 1, + 1, -1, 0, 1, + -1, -1, 0, 1, +}; + +// maps a texture to kBasicSquareVertices via aspect fill +const float kBasicTextureVertices[] = { + 0, 0, 0, 1, + 1, 0, 0, 1, + 1, 1, 0, 1, + 0, 0, 0, 1, + 1, 1, 0, 1, + 0, 1, 0, 1, +}; +// clang-format on + +constexpr char kShaderLibHeader[] = R"( + #include + + using namespace metal; + + struct TextureVertex + { + float4 position [[position]]; + float2 uv; + }; +)"; + +constexpr char kVertexShader[] = R"( + vertex TextureVertex vertexShader( + constant float4 *position [[buffer(0)]], + device float4* tex_coords [[buffer(1)]], + constant float4x4& transform_matrix [[buffer(2)]], + uint vid [[vertex_id]]) { + TextureVertex vert; + vert.position = position[vid]; + vert.uv = (tex_coords[vid] * transform_matrix).xy; + return vert; + } +)"; + +constexpr char kFragmentShader[] = R"( + #ifdef OUTPUT_F16C4 + #define Type4 half4 + #define Type half + #endif // OUTPUT_F16C4 + + #ifdef OUTPUT_F32C4 + #define Type4 float4 + #define Type float + #endif // OUTPUT_F32C4 + + fragment Type4 fragmentShader(TextureVertex vertex_output [[stage_in]], + texture2d texture [[texture(0)]], + constant float* parameters [[buffer(1)]]) + { + const float alpha = parameters[0]; + const float beta = parameters[1]; + + #ifdef CLAMP_TO_ZERO + constexpr sampler linear_sampler(address::clamp_to_zero, min_filter::linear, + mag_filter::linear); + #endif // CLAMP_TO_ZERO + + #ifdef CLAMP_TO_EDGE + constexpr sampler linear_sampler(address::clamp_to_edge, min_filter::linear, + mag_filter::linear); + #endif // CLAMP_TO_EDGE + + Type4 texture_pixel = texture.sample(linear_sampler, vertex_output.uv); + return Type4(alpha * texture_pixel.rgb + beta, 0); + } +)"; + +enum class OutputFormat { kF16C4, kF32C4 }; + +MTLPixelFormat GetPixelFormat(OutputFormat output_format) { + switch (output_format) { + case OutputFormat::kF16C4: + return MTLPixelFormatRGBA16Float; + case OutputFormat::kF32C4: + return MTLPixelFormatRGBA32Float; + } +} +int GetBytesPerRaw(OutputFormat output_format, const tflite::gpu::HW& size) { + std::size_t type_size; + switch (output_format) { + case OutputFormat::kF16C4: + type_size = sizeof(tflite::gpu::HalfBits); + break; + case OutputFormat::kF32C4: + type_size = sizeof(float); + break; + } + constexpr int kNumChannels = 4; + return size.w * kNumChannels * type_size; +} + +class SubRectExtractorMetal { + public: + static absl::StatusOr> Make( + id device, OutputFormat output_format, + BorderMode border_mode) { + id pipeline_state; + MP_RETURN_IF_ERROR(SubRectExtractorMetal::MakePipelineState( + device, output_format, border_mode, &pipeline_state)); + + return absl::make_unique(device, pipeline_state, + output_format); + } + + SubRectExtractorMetal(id device, + id pipeline_state, + OutputFormat output_format) + : device_(device), + pipeline_state_(pipeline_state), + output_format_(output_format) { + positions_buffer_ = + [device_ newBufferWithBytes:kBasicSquareVertices + length:sizeof(kBasicSquareVertices) + options:MTLResourceOptionCPUCacheModeDefault]; + + tex_coords_buffer_ = + [device_ newBufferWithBytes:kBasicTextureVertices + length:sizeof(kBasicTextureVertices) + options:MTLResourceOptionCPUCacheModeDefault]; + } + + absl::Status Execute(id input_texture, + const RotatedRect& sub_rect, bool flip_horizontaly, + float alpha, float beta, + const tflite::gpu::HW& destination_size, + id command_buffer, + id destination) { + auto output_texture = MTLTextureWithBuffer(destination_size, destination); + return InternalExecute(input_texture, sub_rect, flip_horizontaly, alpha, + beta, destination_size, command_buffer, + output_texture); + } + + private: + id MTLTextureWithBuffer(const tflite::gpu::HW& size, + id buffer) { + MTLTextureDescriptor* texture_desc = [MTLTextureDescriptor + texture2DDescriptorWithPixelFormat:GetPixelFormat(output_format_) + width:size.w + height:size.h + mipmapped:NO]; + texture_desc.usage = MTLTextureUsageRenderTarget; + + NSUInteger output_bytes_per_row = GetBytesPerRaw(output_format_, size); + + id texture = + [buffer newTextureWithDescriptor:texture_desc + offset:0 + bytesPerRow:output_bytes_per_row]; + return texture; + } + + absl::Status InternalExecute(id input_texture, + const RotatedRect& sub_rect, + bool flip_horizontaly, float alpha, float beta, + const tflite::gpu::HW& destination_size, + id command_buffer, + id output_texture) { + RET_CHECK(command_buffer != nil); + RET_CHECK(output_texture != nil); + + // Obtain texture mapping coordinates transformation matrix and copy its + // data to the buffer. + std::array transform_mat; + GetRotatedSubRectToRectTransformMatrix(sub_rect, input_texture.width, + input_texture.height, + flip_horizontaly, &transform_mat); + id transform_mat_buffer = + [device_ newBufferWithBytes:&transform_mat + length:sizeof(transform_mat) + options:MTLResourceOptionCPUCacheModeDefault]; + + // Create parameters wrapper. + float parameters[] = {alpha, beta}; + + // Now everything is ready to go! + // Setup render pass. + MTLRenderPassDescriptor* render_pass_desc = + [MTLRenderPassDescriptor renderPassDescriptor]; + render_pass_desc.colorAttachments[0].texture = output_texture; + render_pass_desc.colorAttachments[0].storeAction = MTLStoreActionStore; + render_pass_desc.colorAttachments[0].loadAction = MTLLoadActionClear; + + // Setup render command encoder. + id command_encoder = + [command_buffer renderCommandEncoderWithDescriptor:render_pass_desc]; + [command_encoder setRenderPipelineState:pipeline_state_]; + [command_encoder setVertexBuffer:positions_buffer_ offset:0 atIndex:0]; + [command_encoder setVertexBuffer:tex_coords_buffer_ offset:0 atIndex:1]; + [command_encoder setVertexBuffer:transform_mat_buffer offset:0 atIndex:2]; + [command_encoder setFragmentTexture:input_texture atIndex:0]; + [command_encoder setFragmentBytes:¶meters + length:sizeof(parameters) + atIndex:1]; + + [command_encoder drawPrimitives:MTLPrimitiveTypeTriangle + vertexStart:0 + vertexCount:6]; + [command_encoder endEncoding]; + + return absl::OkStatus(); + } + + static absl::Status MakePipelineState( + id device, OutputFormat output_format, BorderMode border_mode, + id* pipeline_state) { + RET_CHECK(pipeline_state != nil); + + std::string output_type_def; + MTLPixelFormat pixel_format; + switch (output_format) { + case OutputFormat::kF16C4: + output_type_def = R"( + #define OUTPUT_F16C4 + )"; + break; + case OutputFormat::kF32C4: + output_type_def = R"( + #define OUTPUT_F32C4 + )"; + break; + } + + std::string clamp_def; + switch (border_mode) { + case BorderMode::kReplicate: { + clamp_def = R"( + #define CLAMP_TO_EDGE + )"; + break; + } + case BorderMode::kZero: { + clamp_def = R"( + #define CLAMP_TO_ZERO + )"; + break; + } + } + + std::string shader_lib = + absl::StrCat(kShaderLibHeader, output_type_def, clamp_def, + kVertexShader, kFragmentShader); + NSError* error = nil; + NSString* library_source = + [NSString stringWithUTF8String:shader_lib.c_str()]; + + id library = + [device newLibraryWithSource:library_source options:nil error:&error]; + RET_CHECK(library != nil) << "Couldn't create a shader library" + << [[error localizedDescription] UTF8String]; + + id vertex_function = + [library newFunctionWithName:@"vertexShader"]; + RET_CHECK(vertex_function != nil) + << "Failed creating a new vertex function!"; + + id fragment_function = + [library newFunctionWithName:@"fragmentShader"]; + RET_CHECK(fragment_function != nil) + << "Failed creating a new fragment function!"; + + MTLRenderPipelineDescriptor* pipelineDescriptor = + [MTLRenderPipelineDescriptor new]; + pipelineDescriptor.vertexFunction = vertex_function; + pipelineDescriptor.fragmentFunction = fragment_function; + pipelineDescriptor.colorAttachments[0].pixelFormat = + GetPixelFormat(output_format); + + *pipeline_state = + [device newRenderPipelineStateWithDescriptor:pipelineDescriptor + error:&error]; + RET_CHECK(error == nil) << "Couldn't create a pipeline state" + << [[error localizedDescription] UTF8String]; + + return absl::OkStatus(); + } + + id positions_buffer_; + id tex_coords_buffer_; + id device_; + id pipeline_state_; + OutputFormat output_format_; +}; + +class MetalProcessor : public ImageToTensorConverter { + public: + absl::Status Init(CalculatorContext* cc, BorderMode border_mode) { + metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(metal_helper_); + ASSIGN_OR_RETURN(extractor_, SubRectExtractorMetal::Make( + metal_helper_.mtlDevice, + OutputFormat::kF32C4, border_mode)); + return absl::OkStatus(); + } + + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { + return InvalidArgumentError( + absl::StrCat("Only BGRA/RGBA textures are supported, passed " + "format: ", + static_cast(input.format()))); + } + + @autoreleasepool { + id texture = + [metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()]; + + constexpr int kNumChannels = 4; + Tensor tensor(Tensor::ElementType::kFloat32, + Tensor::Shape{1, output_dims.height, output_dims.width, + kNumChannels}); + + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 1.0f; + ASSIGN_OR_RETURN( + auto transform, + GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, + range_min, range_max)); + + id command_buffer = [metal_helper_ commandBuffer]; + const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer); + MP_RETURN_IF_ERROR(extractor_->Execute( + texture, roi, + /*flip_horizontaly=*/false, transform.scale, transform.offset, + tflite::gpu::HW(output_dims.height, output_dims.width), + command_buffer, buffer_view.buffer())); + [command_buffer commit]; + return tensor; + } + } + + private: + MPPMetalHelper* metal_helper_ = nil; + std::unique_ptr extractor_; +}; + +} // namespace + +absl::StatusOr> CreateMetalConverter( + CalculatorContext* cc, BorderMode border_mode) { + auto result = absl::make_unique(); + MP_RETURN_IF_ERROR(result->Init(cc, border_mode)); + + // Simply "return std::move(result)" failed to build on macOS with bazel. + return std::unique_ptr(std::move(result)); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h new file mode 100644 index 000000000..0fe5a87d0 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h @@ -0,0 +1,40 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_METAL_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_METAL_H_ + +#include "mediapipe/framework/port.h" + +#if MEDIAPIPE_METAL_ENABLED + +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +// Creates Metal image-to-tensor converter. +// NOTE: [MPPMetalHelper updateContract:...] invocation must precede +// converter creation. +absl::StatusOr> CreateMetalConverter( + CalculatorContext* cc, BorderMode border_mode); + +} // namespace mediapipe + +#endif // MEDIAPIPE_METAL_ENABLED + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_METAL_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc new file mode 100644 index 000000000..b8d1b0a8b --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -0,0 +1,123 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h" + +#include +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +namespace { + +class OpenCvProcessor : public ImageToTensorConverter { + public: + OpenCvProcessor(BorderMode border_mode) { + switch (border_mode) { + case BorderMode::kReplicate: + border_mode_ = cv::BORDER_REPLICATE; + break; + case BorderMode::kZero: + border_mode_ = cv::BORDER_CONSTANT; + break; + } + } + + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { + if (input.image_format() != mediapipe::ImageFormat::SRGB && + input.image_format() != mediapipe::ImageFormat::SRGBA) { + return InvalidArgumentError( + absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", + static_cast(input.image_format()))); + } + cv::Mat src = mediapipe::formats::MatView(&input); + + constexpr int kNumChannels = 3; + Tensor tensor( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels}); + auto buffer_view = tensor.GetCpuWriteView(); + cv::Mat dst(output_dims.height, output_dims.width, CV_32FC3, + buffer_view.buffer()); + + const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y), + cv::Size2f(roi.width, roi.height), + roi.rotation * 180.f / M_PI); + cv::Mat src_points; + cv::boxPoints(rotated_rect, src_points); + + const float dst_width = output_dims.width; + const float dst_height = output_dims.height; + /* clang-format off */ + float dst_corners[8] = {0.0f, dst_height, + 0.0f, 0.0f, + dst_width, 0.0f, + dst_width, dst_height}; + /* clang-format on */ + + cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners); + cv::Mat projection_matrix = + cv::getPerspectiveTransform(src_points, dst_points); + cv::Mat transformed; + cv::warpPerspective(src, transformed, projection_matrix, + cv::Size(dst_width, dst_height), + /*flags=*/cv::INTER_LINEAR, + /*borderMode=*/border_mode_); + + if (transformed.channels() > kNumChannels) { + cv::Mat proper_channels_mat; + cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB); + transformed = proper_channels_mat; + } + + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 255.0f; + ASSIGN_OR_RETURN( + auto transform, + GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, + range_min, range_max)); + transformed.convertTo(dst, CV_32FC3, transform.scale, transform.offset); + return tensor; + } + + private: + enum cv::BorderTypes border_mode_; +}; + +} // namespace + +absl::StatusOr> CreateOpenCvConverter( + CalculatorContext* cc, BorderMode border_mode) { + // Simply "return absl::make_unique()" failed to build on + // macOS with bazel. + return std::unique_ptr( + absl::make_unique(border_mode)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h new file mode 100644 index 000000000..3ccecc557 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h @@ -0,0 +1,32 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_OPENCV_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_OPENCV_H_ + +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +// Creates OpenCV image-to-tensor converter. +absl::StatusOr> CreateOpenCvConverter( + CalculatorContext* cc, BorderMode border_mode); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_OPENCV_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc new file mode 100644 index 000000000..6b3bf08cd --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -0,0 +1,217 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" + +#include + +#include "absl/types/optional.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +RotatedRect GetRoi(int input_width, int input_height, + absl::optional norm_rect) { + if (norm_rect) { + return {/*center_x=*/norm_rect->x_center() * input_width, + /*center_y =*/norm_rect->y_center() * input_height, + /*width =*/norm_rect->width() * input_width, + /*height =*/norm_rect->height() * input_height, + /*rotation =*/norm_rect->rotation()}; + } + return {/*center_x=*/0.5f * input_width, + /*center_y =*/0.5f * input_height, + /*width =*/static_cast(input_width), + /*height =*/static_cast(input_height), + /*rotation =*/0}; +} + +absl::StatusOr> PadRoi(int input_tensor_width, + int input_tensor_height, + bool keep_aspect_ratio, + RotatedRect* roi) { + if (!keep_aspect_ratio) { + return std::array{0.0f, 0.0f, 0.0f, 0.0f}; + } + + RET_CHECK(input_tensor_width > 0 && input_tensor_height > 0) + << "Input tensor width and height must be > 0."; + const float tensor_aspect_ratio = + static_cast(input_tensor_height) / input_tensor_width; + + RET_CHECK(roi->width > 0 && roi->height > 0) + << "ROI width and height must be > 0."; + const float roi_aspect_ratio = roi->height / roi->width; + + float vertical_padding = 0.0f; + float horizontal_padding = 0.0f; + float new_width; + float new_height; + if (tensor_aspect_ratio > roi_aspect_ratio) { + new_width = roi->width; + new_height = roi->width * tensor_aspect_ratio; + vertical_padding = (1.0f - roi_aspect_ratio / tensor_aspect_ratio) / 2.0f; + } else { + new_width = roi->height / tensor_aspect_ratio; + new_height = roi->height; + horizontal_padding = (1.0f - tensor_aspect_ratio / roi_aspect_ratio) / 2.0f; + } + + roi->width = new_width; + roi->height = new_height; + + return std::array{horizontal_padding, vertical_padding, + horizontal_padding, vertical_padding}; +} + +absl::StatusOr GetValueRangeTransformation( + float from_range_min, float from_range_max, float to_range_min, + float to_range_max) { + RET_CHECK_LT(from_range_min, from_range_max) + << "Invalid FROM range: min >= max."; + RET_CHECK_LT(to_range_min, to_range_max) << "Invalid TO range: min >= max."; + const float scale = + (to_range_max - to_range_min) / (from_range_max - from_range_min); + const float offset = to_range_min - from_range_min * scale; + return ValueTransformation{scale, offset}; +} + +void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, + int rect_width, int rect_height, + bool flip_horizontaly, + std::array* matrix_ptr) { + std::array& matrix = *matrix_ptr; + // The resulting matrix is multiplication of below commented out matrices: + // post_scale_matrix + // * translate_matrix + // * rotate_matrix + // * flip_matrix + // * scale_matrix + // * initial_translate_matrix + + // Matrix to convert X,Y to [-0.5, 0.5] range "initial_translate_matrix" + // { 1.0f, 0.0f, 0.0f, -0.5f} + // { 0.0f, 1.0f, 0.0f, -0.5f} + // { 0.0f, 0.0f, 1.0f, 0.0f} + // { 0.0f, 0.0f, 0.0f, 1.0f} + + const float a = sub_rect.width; + const float b = sub_rect.height; + // Matrix to scale X,Y,Z to sub rect "scale_matrix" + // Z has the same scale as X. + // { a, 0.0f, 0.0f, 0.0f} + // {0.0f, b, 0.0f, 0.0f} + // {0.0f, 0.0f, a, 0.0f} + // {0.0f, 0.0f, 0.0f, 1.0f} + + const float flip = flip_horizontaly ? -1 : 1; + // Matrix for optional horizontal flip around middle of output image. + // { fl , 0.0f, 0.0f, 0.0f} + // { 0.0f, 1.0f, 0.0f, 0.0f} + // { 0.0f, 0.0f, 1.0f, 0.0f} + // { 0.0f, 0.0f, 0.0f, 1.0f} + + const float c = std::cos(sub_rect.rotation); + const float d = std::sin(sub_rect.rotation); + // Matrix to do rotation around Z axis "rotate_matrix" + // { c, -d, 0.0f, 0.0f} + // { d, c, 0.0f, 0.0f} + // { 0.0f, 0.0f, 1.0f, 0.0f} + // { 0.0f, 0.0f, 0.0f, 1.0f} + + const float e = sub_rect.center_x; + const float f = sub_rect.center_y; + // Matrix to do X,Y translation of sub rect within parent rect + // "translate_matrix" + // {1.0f, 0.0f, 0.0f, e } + // {0.0f, 1.0f, 0.0f, f } + // {0.0f, 0.0f, 1.0f, 0.0f} + // {0.0f, 0.0f, 0.0f, 1.0f} + + const float g = 1.0f / rect_width; + const float h = 1.0f / rect_height; + // Matrix to scale X,Y,Z to [0.0, 1.0] range "post_scale_matrix" + // {g, 0.0f, 0.0f, 0.0f} + // {0.0f, h, 0.0f, 0.0f} + // {0.0f, 0.0f, g, 0.0f} + // {0.0f, 0.0f, 0.0f, 1.0f} + + // row 1 + matrix[0] = a * c * flip * g; + matrix[1] = -b * d * g; + matrix[2] = 0.0f; + matrix[3] = (-0.5f * a * c * flip + 0.5f * b * d + e) * g; + + // row 2 + matrix[4] = a * d * flip * h; + matrix[5] = b * c * h; + matrix[6] = 0.0f; + matrix[7] = (-0.5f * b * c - 0.5f * a * d * flip + f) * h; + + // row 3 + matrix[8] = 0.0f; + matrix[9] = 0.0f; + matrix[10] = a * g; + matrix[11] = 0.0f; + + // row 4 + matrix[12] = 0.0f; + matrix[13] = 0.0f; + matrix[14] = 0.0f; + matrix[15] = 1.0f; +} + +void GetTransposedRotatedSubRectToRectTransformMatrix( + const RotatedRect& sub_rect, int rect_width, int rect_height, + bool flip_horizontaly, std::array* matrix_ptr) { + std::array& matrix = *matrix_ptr; + // See comments in GetRotatedSubRectToRectTransformMatrix for detailed + // calculations. + const float a = sub_rect.width; + const float b = sub_rect.height; + const float flip = flip_horizontaly ? -1 : 1; + const float c = std::cos(sub_rect.rotation); + const float d = std::sin(sub_rect.rotation); + const float e = sub_rect.center_x; + const float f = sub_rect.center_y; + const float g = 1.0f / rect_width; + const float h = 1.0f / rect_height; + + // row 1 (indices 0,4,8,12 from non-transposed fcn) + matrix[0] = a * c * flip * g; + matrix[1] = a * d * flip * h; + matrix[2] = 0.0f; + matrix[3] = 0.0f; + + // row 2 (indices 1,5,9,13 from non-transposed fcn) + matrix[4] = -b * d * g; + matrix[5] = b * c * h; + matrix[6] = 0.0f; + matrix[7] = 0.0f; + + // row 3 (indices 2,6,10,14 from non-transposed fcn) + matrix[8] = 0.0f; + matrix[9] = 0.0f; + matrix[10] = a * g; + matrix[11] = 0.0f; + + // row 4 (indices 3,7,11,15 from non-transposed fcn) + matrix[12] = (-0.5f * a * c * flip + 0.5f * b * d + e) * g; + matrix[13] = (-0.5f * b * c - 0.5f * a * d * flip + f) * h; + matrix[14] = 0.0f; + matrix[15] = 1.0f; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.h b/mediapipe/calculators/tensor/image_to_tensor_utils.h new file mode 100644 index 000000000..f913875e3 --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.h @@ -0,0 +1,100 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_ + +#include + +#include "absl/types/optional.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +struct RotatedRect { + float center_x; + float center_y; + float width; + float height; + float rotation; +}; + +// Generates a new ROI or converts it from normalized rect. +RotatedRect GetRoi(int input_width, int input_height, + absl::optional norm_rect); + +// Pads ROI, so extraction happens correctly if aspect ratio is to be kept. +// Returns letterbox padding applied. +absl::StatusOr> PadRoi(int input_tensor_width, + int input_tensor_height, + bool keep_aspect_ratio, + RotatedRect* roi); + +// Represents a transformation of value which involves scaling and offsetting. +// To apply transformation: +// ValueTransformation transform = ... +// float transformed_value = transform.scale * value + transfrom.offset; +struct ValueTransformation { + float scale; + float offset; +}; + +// Returns value transformation to apply to a value in order to convert it from +// [from_range_min, from_range_max] into [to_range_min, to_range_max] range. +// from_range_min must be less than from_range_max +// to_range_min must be less than to_range_max +absl::StatusOr GetValueRangeTransformation( + float from_range_min, float from_range_max, float to_range_min, + float to_range_max); + +// Populates 4x4 "matrix" with row major order transformation matrix which +// maps (x, y) in range [0, 1] (describing points of @sub_rect) +// to (x', y') in range [0, 1]*** (describing points of a rect: +// [0, @rect_width] x [0, @rect_height] = RECT). +// +// *** (x', y') will go out of the range for points from @sub_rect +// which are not contained by RECT and it's expected behavior +// +// @sub_rect - rotated sub rect in absolute coordinates +// @rect_width - rect width +// @rect_height - rect height +// @flip_horizontaly - we need to flip the output buffer. +// @matrix - 4x4 matrix (array of 16 elements) to populate +void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, + int rect_width, int rect_height, + bool flip_horizontaly, + std::array* matrix); + +// Returns the transpose of the matrix found with +// "GetRotatedSubRectToRectTransformMatrix". That is to say, this populates a +// 4x4 "matrix" with col major order transformation matrix which maps (x, y) in +// range [0, 1] (describing points of @sub_rect) to (x', y') in range [0, 1]*** +// (describing points of a rect: [0, @rect_width] x [0, @rect_height] = RECT). +// +// *** (x', y') will go out of the range for points from @sub_rect +// which are not contained by RECT and it's expected behavior +// +// @sub_rect - rotated sub rect in absolute coordinates +// @rect_width - rect width +// @rect_height - rect height +// @flip_horizontaly - we need to flip the output buffer. +// @matrix - 4x4 matrix (array of 16 elements) to populate +void GetTransposedRotatedSubRectToRectTransformMatrix( + const RotatedRect& sub_rect, int rect_width, int rect_height, + bool flip_horizontaly, std::array* matrix); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc b/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc new file mode 100644 index 000000000..814b4c34f --- /dev/null +++ b/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc @@ -0,0 +1,161 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" + +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +testing::Matcher EqRotatedRect(float width, float height, + float center_x, float center_y, + float rotation) { + return testing::AllOf( + testing::Field(&RotatedRect::width, testing::FloatEq(width)), + testing::Field(&RotatedRect::height, testing::FloatEq(height)), + testing::Field(&RotatedRect::center_x, testing::FloatEq(center_x)), + testing::Field(&RotatedRect::center_y, testing::FloatEq(center_y)), + testing::Field(&RotatedRect::rotation, testing::FloatEq(rotation))); +} + +TEST(GetRoi, NoNormRect) { + EXPECT_THAT(GetRoi(4, 4, {}), EqRotatedRect(4, 4, 2, 2, 0)); + EXPECT_THAT(GetRoi(25, 15, {}), EqRotatedRect(25, 15, 12.5f, 7.5f, 0)); +} + +TEST(GetRoi, WholeImageNormRect) { + mediapipe::NormalizedRect norm_rect; + norm_rect.set_width(1.0f); + norm_rect.set_height(1.0f); + norm_rect.set_x_center(0.5f); + norm_rect.set_y_center(0.5f); + norm_rect.set_rotation(0.0f); + EXPECT_THAT(GetRoi(4, 4, norm_rect), EqRotatedRect(4, 4, 2, 2, 0)); + EXPECT_THAT(GetRoi(25, 15, norm_rect), EqRotatedRect(25, 15, 12.5f, 7.5f, 0)); +} + +TEST(GetRoi, ExpandedNormRect) { + mediapipe::NormalizedRect norm_rect; + norm_rect.set_width(4.0f); + norm_rect.set_height(2.0f); + norm_rect.set_x_center(0.5f); + norm_rect.set_y_center(1.0f); + norm_rect.set_rotation(3.0f); + EXPECT_THAT(GetRoi(4, 4, norm_rect), EqRotatedRect(16, 8, 2, 4, 3)); + EXPECT_THAT(GetRoi(25, 15, norm_rect), EqRotatedRect(100, 30, 12.5f, 15, 3)); +} + +TEST(PadRoi, NoPadding) { + RotatedRect roi{.center_x = 20, + .center_y = 10, + .width = 100, + .height = 200, + .rotation = 5}; + auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/false, &roi); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + ElementsAreArray({0.0f, 0.0f, 0.0f, 0.0f})); + EXPECT_THAT(roi, EqRotatedRect(100, 200, 20, 10, 5)); +} + +TEST(PadRoi, HorizontalPadding) { + RotatedRect roi{.center_x = 20, + .center_y = 10, + .width = 100, + .height = 200, + .rotation = 5}; + auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/true, &roi); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + ElementsAreArray({0.25f, 0.0f, 0.25f, 0.0f})); + EXPECT_THAT(roi, EqRotatedRect(200, 200, 20, 10, 5)); +} + +TEST(PadRoi, VerticalPadding) { + RotatedRect roi{ + .center_x = 1, .center_y = 2, .width = 21, .height = 19, .rotation = 3}; + const float expected_horizontal_padding = (21 - 19) / 2.0f / 21; + auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/true, &roi); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT( + status_or_value.value(), + ElementsAre(testing::FloatEq(0.0f), + testing::FloatNear(expected_horizontal_padding, 1e-6), + testing::FloatEq(0.0f), + testing::FloatNear(expected_horizontal_padding, 1e-6))); + EXPECT_THAT(roi, EqRotatedRect(21, 21, 1, 2, 3)); +} + +testing::Matcher EqValueTransformation(float scale, + float offset) { + return ::testing::AllOf( + testing::Field(&ValueTransformation::scale, testing::FloatEq(scale)), + testing::Field(&ValueTransformation::offset, testing::FloatEq(offset))); +} + +TEST(GetValueRangeTransformation, PixelToFloatZeroCenter) { + auto status_or_value = GetValueRangeTransformation( + /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, + /*to_range_min=*/-1.0f, /*to_range_max=*/1.0f); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + EqValueTransformation(/*scale=*/2 / 255.0f, + /*offset=*/-1.0f)); +} + +TEST(GetValueRangeTransformation, PixelToFloat) { + auto status_or_value = GetValueRangeTransformation( + /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, + /*to_range_min=*/0.0f, /*to_range_max=*/1.0f); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + EqValueTransformation(/*scale=*/1 / 255.0f, + /*offset=*/0.0f)); +} + +TEST(GetValueRangeTransformation, FloatToFloatNoOp) { + auto status_or_value = GetValueRangeTransformation( + /*from_range_min=*/0.0f, /*from_range_max=*/1.0f, + /*to_range_min=*/0.0f, /*to_range_max=*/1.0f); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + EqValueTransformation(/*scale=*/1.0f, /*offset=*/0.0f)); +} + +TEST(GetValueRangeTransformation, PixelToPixelNoOp) { + auto status_or_value = GetValueRangeTransformation( + /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, + /*to_range_min=*/0.0f, /*to_range_max=*/255.0f); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + EqValueTransformation(/*scale=*/1.0f, /*offset=*/0.0f)); +} + +TEST(GetValueRangeTransformation, FloatToPixel) { + auto status_or_value = GetValueRangeTransformation( + /*from_range_min=*/0.0f, /*from_range_max=*/1.0f, + /*to_range_min=*/0.0f, /*to_range_max=*/255.0f); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), + EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc new file mode 100644 index 000000000..89a02b713 --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/inference_calculator.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/tool/subgraph_expansion.h" + +namespace mediapipe { +namespace api2 { + +class InferenceCalculatorSelectorImpl + : public SubgraphImpl { + public: + absl::StatusOr GetConfig( + const CalculatorGraphConfig::Node& subgraph_node) { + const auto& options = + Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>( + subgraph_node); + std::vector impls; + const bool should_use_gpu = + !options.has_delegate() || // Use GPU delegate if not specified + (options.has_delegate() && options.delegate().has_gpu()); + if (should_use_gpu) { + impls.emplace_back("Metal"); + impls.emplace_back("MlDrift"); + impls.emplace_back("Gl"); + } + impls.emplace_back("Cpu"); + for (const auto& suffix : impls) { + const auto impl = absl::StrCat("InferenceCalculator", suffix); + if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue; + CalculatorGraphConfig::Node impl_node = subgraph_node; + impl_node.set_calculator(impl); + return tool::MakeSingleNodeGraph(std::move(impl_node)); + } + return absl::UnimplementedError("no implementation available"); + } +}; + +absl::StatusOr> InferenceCalculator::GetModelAsPacket( + CalculatorContext* cc) { + const auto& options = cc->Options(); + if (!options.model_path().empty()) { + return TfLiteModelLoader::LoadFromPath(options.model_path()); + } + if (!kSideInModel(cc).IsEmpty()) return kSideInModel(cc); + return absl::Status(mediapipe::StatusCode::kNotFound, + "Must specify TFLite model as path or loaded model."); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h new file mode 100644 index 000000000..a746684ff --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -0,0 +1,136 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/tflite/tflite_model_loader.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +namespace mediapipe { +namespace api2 { + +// Runs inference on the provided input Tensors and TFLite model. +// +// Creates an interpreter with given model and calls invoke(). +// Optionally run inference on CPU/GPU. +// +// This calculator can be used with TensorConverterCalculator to get the +// appropriate inputs. +// +// When the input tensors are on CPU, gpu inference is optional and can be +// specified in the calculator options. +// When the input tensors are on GPU, inference is GPU and output can be CPU or +// GPU. +// +// Input: +// TENSORS - Vector of Tensors +// +// Output: +// TENSORS - Vector of Tensors +// +// Input side packet: +// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, +// instead of the builtin one. +// MODEL (optional) - Use to specify TfLite model +// (std::unique_ptr>) +// +// Example use: +// node { +// calculator: "InferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.InferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// } +// } +// } +// +// or +// +// node { +// calculator: "InferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// input_side_packet: "MODEL:model" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.InferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// delegate { gpu {} } +// } +// } +// } +// +// IMPORTANT Notes: +// Tensors are assumed to be ordered correctly (sequentially added to model). +// Input tensors are assumed to be of the correct size and already normalized. + +class InferenceCalculator : public NodeIntf { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr SideInput::Optional + kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput::Optional kSideInModel{"MODEL"}; + static constexpr Output> kOutTensors{"TENSORS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, + kOutTensors); + + protected: + using TfLiteDelegatePtr = + std::unique_ptr>; + + absl::StatusOr> GetModelAsPacket( + CalculatorContext* cc); +}; + +struct InferenceCalculatorSelector : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculator"; +}; + +struct InferenceCalculatorGl : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; +}; + +struct InferenceCalculatorMlDrift : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorMlDrift"; +}; + +struct InferenceCalculatorMetal : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorMetal"; +}; + +struct InferenceCalculatorCpu : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorCpu"; +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto new file mode 100644 index 000000000..0efb61d4a --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -0,0 +1,116 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Full Example: +// +// node { +// calculator: "InferenceCalculator" +// input_stream: "TENSOR_IN:image_tensors" +// output_stream: "TENSOR_OUT:result_tensors" +// options { +// [mediapipe.InferenceCalculatorOptions.ext] { +// model_path: "model.tflite" +// delegate { gpu {} } +// } +// } +// } +// +message InferenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional InferenceCalculatorOptions ext = 336783863; + } + + message Delegate { + // Default inference provided by tflite. + message TfLite {} + // Delegate to run GPU inference depending on the device. + // (Can use OpenGl, OpenCl, Metal depending on the device.) + message Gpu { + // Experimental, Android/Linux only. Use TFLite GPU delegate API2 for + // the NN inference. + // example: + // delegate: { gpu { use_advanced_gpu_api: true } } + optional bool use_advanced_gpu_api = 1 [default = false]; + + // This option is valid for TFLite GPU delegate API2 only, + // Choose any of available APIs to force running inference using it. + enum API { + ANY = 0; + OPENGL = 1; + OPENCL = 2; + } + optional API api = 4 [default = ANY]; + + // This option is valid for TFLite GPU delegate API2 only, + // Set to true to use 16-bit float precision. If max precision is needed, + // set to false for 32-bit float calculations only. + optional bool allow_precision_loss = 3 [default = true]; + + // Load pre-compiled serialized binary cache to accelerate init process. + // Only available for OpenCL delegate on Android. + // Kernel caching will only be enabled if this path is set. + optional string cached_kernel_path = 2; + } + // Android only. + message Nnapi {} + message Xnnpack { + // Number of threads for XNNPACK delegate. (By default, calculator tries + // to choose optimal number of threads depending on the device.) + optional int32 num_threads = 1 [default = -1]; + } + + oneof delegate { + TfLite tflite = 1; + Gpu gpu = 2; + Nnapi nnapi = 3; + Xnnpack xnnpack = 4; + } + } + + // Path to the TF Lite model (ex: /path/to/modelname.tflite). + // On mobile, this is generally just modelname.tflite. + optional string model_path = 1; + + // Whether the TF Lite GPU or CPU backend should be used. Effective only when + // input tensors are on CPU. For input tensors on GPU, GPU backend is always + // used. + // DEPRECATED: configure "delegate" instead. + optional bool use_gpu = 2 [deprecated = true, default = false]; + + // Android only. When true, an NNAPI delegate will be used for inference. + // If NNAPI is not available, then the default CPU delegate will be used + // automatically. + // DEPRECATED: configure "delegate" instead. + optional bool use_nnapi = 3 [deprecated = true, default = false]; + + // The number of threads available to the interpreter. Effective only when + // input tensors are on CPU and 'use_gpu' is false. + optional int32 cpu_num_thread = 4 [default = -1]; + + // TfLite delegate to run inference. + // If not specified, TFLite GPU delegate is used by default (as if "gpu {}" + // is specified) unless GPU support is disabled in the build (i.e., with + // --define MEDIAPIPE_DISABLE_GPU=1), in which case regular TFLite on CPU is + // used (as if "tflite {}" is specified) except when building with emscripten + // where xnnpack is used. + // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes + // precedence over use_* deprecated options.) + optional Delegate delegate = 5; +} diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc new file mode 100644 index 000000000..d931b93fa --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -0,0 +1,205 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" + +#if defined(MEDIAPIPE_ANDROID) +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#endif // ANDROID + +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) +#include "mediapipe/util/cpu_util.h" +#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ + +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +// Returns number of threads to configure XNNPACK delegate with. +// (Equal to user provided value if specified. Otherwise, it returns number of +// high cores (hard-coded to 1 for Emscripten without Threads extension)) +int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { + static constexpr int kDefaultNumThreads = -1; + if (opts.has_delegate() && opts.delegate().has_xnnpack() && + opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { + return opts.delegate().xnnpack().num_threads(); + } +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) + return InferHigherCoreIds().size(); +#else + return 1; +#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ +} + +} // namespace + +class InferenceCalculatorCpuImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; +}; + +absl::Status InferenceCalculatorCpuImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + // Read CPU input into tensors. + for (int i = 0; i < input_tensors.size(); ++i) { + const Tensor* input_tensor = &input_tensors[i]; + auto input_tensor_view = input_tensor->GetCpuReadView(); + auto input_tensor_buffer = input_tensor_view.buffer(); + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes()); + } + + // Run inference. + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + // Output result tensors (CPU). + const auto& tensor_indexes = interpreter_->outputs(); + output_tensors->reserve(tensor_indexes.size()); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + output_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}}); + auto cpu_view = output_tensors->back().GetCpuWriteView(); + std::memcpy(cpu_view.buffer(), tensor->data.f, + output_tensors->back().bytes()); + } + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) { + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + +#if defined(__EMSCRIPTEN__) + interpreter_->SetNumThreads(1); +#else + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); +#endif // __EMSCRIPTEN__ + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { + const auto& calculator_opts = + cc->Options(); + if (calculator_opts.has_delegate() && + calculator_opts.delegate().has_tflite()) { + // Default tflite inference requeqsted - no need to modify graph. + return absl::OkStatus(); + } + +#if defined(MEDIAPIPE_ANDROID) + const bool nnapi_requested = calculator_opts.has_delegate() + ? calculator_opts.delegate().has_nnapi() + : calculator_opts.use_nnapi(); + if (nnapi_requested) { + // Attempt to use NNAPI. + // If not supported, the default CPU delegate will be created and used. + interpreter_->SetAllowFp16PrecisionForFp32(1); + delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { + // No need to free according to tflite::NnApiDelegate() documentation. + }); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + return absl::OkStatus(); + } +#endif // MEDIAPIPE_ANDROID + +#if defined(__EMSCRIPTEN__) + const bool use_xnnpack = true; +#else + const bool use_xnnpack = calculator_opts.has_delegate() && + calculator_opts.delegate().has_xnnpack(); +#endif // defined(__EMSCRIPTEN__) + + if (use_xnnpack) { + TfLiteXNNPackDelegateOptions xnnpack_opts{}; + xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); + delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), + &TfLiteXNNPackDelegateDelete); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + } + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc new file mode 100644 index 000000000..5fb7c974a --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc @@ -0,0 +1,186 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/graph_test_base.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/subgraph_expansion.h" +#include "mediapipe/framework/tool/test_util.h" + +namespace mediapipe { +namespace api2 { +namespace { + +using mediapipe::Detection; +using mediapipe::InferenceCalculatorOptions_Delegate; +using testing::ElementsAre; +using testing::EqualsProto; +using testing::proto::Approximately; + +struct Param { + std::string name; // Appended to the test name. + std::string impl_suffix; // Expected InferenceCalculator backend. + InferenceCalculatorOptions_Delegate delegate; +}; + +const std::vector& GetParams() { + static auto all_params = [] { + static std::vector p; + p.push_back({"TfLite", "Cpu"}); + p.back().delegate.mutable_tflite(); +#if TARGET_OS_IPHONE && !TARGET_IPHONE_SIMULATOR + // Metal is not available on the iOS simulator. + p.push_back({"Metal", "Metal"}); + p.back().delegate.mutable_gpu(); +#endif // TARGET_IPHONE_SIMULATOR +#if __EMSCRIPTEN__ + p.push_back({"MlDrift", "MlDrift"}); + p.back().delegate.mutable_gpu(); +#endif // __EMSCRIPTEN__ +#if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3 + p.push_back({"Gl", "Gl"}); + p.back().delegate.mutable_gpu(); + // This requires API level 27 + p.push_back({"NnApi", "Cpu"}); + p.back().delegate.mutable_nnapi(); +#endif // __ANDROID__ + p.push_back({"XnnPack", "Cpu"}); + p.back().delegate.mutable_xnnpack(); + return p; + }(); + return all_params; +} + +class InferenceCalculatorTest : public testing::TestWithParam { + protected: +#if __EMSCRIPTEN__ + // TODO: fix Tensor locking. + // The MlDrift backend currently fails in debug mode without this, + // because of Tensor locking issues. I am adding this temporarily since + // the calculator is already being used and it's better to have test + // coverage for it. Also, the issue doesn't apply to our Emscripten + // build in practice since it's single-threaded. + void SetUp(void) override { + absl::SetMutexDeadlockDetectionMode(absl::OnDeadlockCycle::kIgnore); + } +#endif // __EMSCRIPTEN__ + + void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) { + *node->mutable_options() + ->MutableExtension(mediapipe::InferenceCalculatorOptions::ext) + ->mutable_delegate() = GetParam().delegate; + } +}; + +TEST_P(InferenceCalculatorTest, TestBackendSelection) { + CalculatorGraphConfig config; + auto node = config.add_node(); + node->set_calculator("InferenceCalculator"); + SetDelegateForParam(node); + MP_ASSERT_OK(tool::ExpandSubgraphs(&config)); + EXPECT_EQ(config.node(0).calculator(), + absl::StrCat("InferenceCalculator", GetParam().impl_suffix)); +} + +TEST_P(InferenceCalculatorTest, TestFaceDetection) { + CalculatorGraphConfig config; + ASSERT_TRUE(LoadTestGraph( + &config, file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/face_detection_test.binarypb"))); + + // Expand subgraphs to find any nested instances of InferenceCalculator. + MP_ASSERT_OK(tool::ExpandSubgraphs(&config)); + int found = 0; + for (auto& node : *config.mutable_node()) { + // The InferenceCalculator subgraph itself will have expanded to a specific + // implementation. Replace it. + // TODO: make it possible to exclude it from expansion above. + if (absl::StartsWith(node.calculator(), "InferenceCalculator")) { + ++found; + node.set_calculator("InferenceCalculator"); + SetDelegateForParam(&node); + } + } + ASSERT_EQ(found, 1); + + std::vector detection_packets; + tool::AddVectorSink("detections", &config, &detection_packets); + std::vector rendering_packets; + tool::AddVectorSink("rendering", &config, &rendering_packets); + + // Load test image. + std::unique_ptr input_image = LoadTestPng( + file::JoinPath(GetTestRootDir(), "mediapipe/objc/testdata/sergey.png")); + ASSERT_THAT(input_image, testing::NotNull()); + + std::unique_ptr expected_image = + LoadTestPng(file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/face_detection_expected.png")); + ASSERT_THAT(expected_image, testing::NotNull()); + + std::string binary; + Detection expected_detection; + MP_ASSERT_OK( + file::GetContents(file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/expected_detection.binarypb"), + &binary)); + expected_detection.ParseFromArray(binary.data(), binary.size()); + + // Prepare test inputs. + std::unordered_map> input_streams; + input_streams.insert(std::make_pair("image", std::move(input_image))); + std::string output_stream = "rendering"; + + // Test graph with relaxed color difference tolerance. + // Compare with CPU generated image. + Timestamp ts0 = Timestamp(0); + TestGraphConfig(config, input_streams, output_stream, expected_image, {}, ts0, + 2.0, 2.0, 1.0); + + ASSERT_EQ(detection_packets.size(), 1); + std::vector dets = + detection_packets[0].Get>(); +#if !defined(MEDIAPIPE_PROTO_LITE) + // Approximately is not available with lite protos (b/178137094). + EXPECT_THAT(dets, + ElementsAre(Approximately(EqualsProto(expected_detection)))); +#endif +} + +INSTANTIATE_TEST_SUITE_P(Implementation, InferenceCalculatorTest, + testing::ValuesIn(GetParams()), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +} // namespace +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc new file mode 100644 index 000000000..081b12d3c --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -0,0 +1,368 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/util/tflite/config.h" + +#if MEDIAPIPE_TFLITE_GL_INFERENCE +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/util/tflite/tflite_gpu_runner.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if defined(MEDIAPIPE_ANDROID) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/filesystem.h" +#include "mediapipe/util/android/file/base/helpers.h" +#endif // ANDROID + +namespace mediapipe { +namespace api2 { + +class InferenceCalculatorGlImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status ReadKernelsFromFile(); + absl::Status WriteKernelsToFile(); + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status InitTFLiteGPURunner(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; + +#if MEDIAPIPE_TFLITE_GL_INFERENCE + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr tflite_gpu_runner_; + bool allow_precision_loss_ = false; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API + tflite_gpu_runner_api_; +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + std::vector output_shapes_; + std::vector> gpu_buffers_in_; + std::vector> gpu_buffers_out_; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED + + bool use_advanced_gpu_api_ = false; + bool use_gpu_delegate_ = false; + + bool use_kernel_caching_ = false; + std::string cached_kernel_filename_; +}; + +absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + use_advanced_gpu_api_ = options.has_delegate() && + options.delegate().has_gpu() && + options.delegate().gpu().use_advanced_gpu_api(); + allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = options.delegate().gpu().api(); + use_kernel_caching_ = use_advanced_gpu_api_ && + options.delegate().gpu().has_cached_kernel_path(); + use_gpu_delegate_ = !use_advanced_gpu_api_; + + if (use_kernel_caching_) { +#ifdef MEDIAPIPE_ANDROID + cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + mediapipe::File::Basename(options.model_path()) + + ".ker"; +#endif // MEDIAPIPE_ANDROID + } + + // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner + // for everything. + if (!use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + } + + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, + &cc]() -> ::mediapipe::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); + })); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + if (use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors, &output_tensors]() -> ::mediapipe::Status { + for (int i = 0; i < input_tensors.size(); ++i) { + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( + input_tensors[i].GetOpenGlBufferReadView().name(), i)); + } + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + output_shapes_[i]); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( + output_tensors->back().GetOpenGlBufferWriteView().name(), i)); + } + return absl::OkStatus(); + })); + } else { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Explicitly copy input. + for (int i = 0; i < input_tensors.size(); ++i) { + glBindBuffer(GL_COPY_READ_BUFFER, + input_tensors[i].GetOpenGlBufferReadView().name()); + glBindBuffer(GL_COPY_WRITE_BUFFER, + gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + input_tensors[i].bytes()); + } + return absl::OkStatus(); + })); + } + + // Run inference. + if (use_advanced_gpu_api_) { + RET_CHECK(tflite_gpu_runner_->Invoke().ok()); + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } + + if (use_gpu_delegate_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors]() -> ::mediapipe::Status { + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + const auto& t = gpu_buffers_out_[i]; + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + gpu_buffers_out_[i]->shape()); + auto read_view = t->GetOpenGlBufferReadView(); + glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); + auto write_view = output_tensors->back().GetOpenGlBufferWriteView(); + glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + t->bytes()); + } + return absl::OkStatus(); + })); + } + // Output tensors are already bound if use_advanced_gpu_api_ is true. + + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_) { + // Save kernel file. + auto kernel_cache = absl::make_unique>( + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + MP_RETURN_IF_ERROR( + mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(WriteKernelsToFile()); + if (use_gpu_delegate_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { + gpu_buffers_in_.clear(); + gpu_buffers_out_.clear(); + return absl::OkStatus(); + })); + } + + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_) { + // Load pre-compiled kernel file. + if (mediapipe::File::Exists(cached_kernel_filename_)) { + std::string cache_str; + MP_RETURN_IF_ERROR( + mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); + std::vector cache_vec(cache_str.begin(), cache_str.end()); + tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); + } + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( + CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + // Create runner + tflite::gpu::InferenceOptions options; + options.priority1 = allow_precision_loss_ + ? tflite::gpu::InferencePriority::MIN_LATENCY + : tflite::gpu::InferencePriority::MAX_PRECISION; + options.priority2 = tflite::gpu::InferencePriority::AUTO; + options.priority3 = tflite::gpu::InferencePriority::AUTO; + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + tflite_gpu_runner_ = std::make_unique(options); + switch (tflite_gpu_runner_api_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { + tflite_gpu_runner_->ForceOpenGL(); + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: { + tflite_gpu_runner_->ForceOpenCL(); + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } + } + MP_RETURN_IF_ERROR( + tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + + // Create and bind OpenGL buffers for outputs. + // The buffers are created once and their ids are passed to calculator outputs + output_shapes_.resize(tflite_gpu_runner_->outputs_size()); + for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { + output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b, + tflite_gpu_runner_->GetOutputShapes()[i].h, + tflite_gpu_runner_->GetOutputShapes()[i].w, + tflite_gpu_runner_->GetOutputShapes()[i].c}; + } + + MP_RETURN_IF_ERROR(ReadKernelsFromFile()); + + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + +#if defined(__EMSCRIPTEN__) + interpreter_->SetNumThreads(1); +#else + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); +#endif // __EMSCRIPTEN__ + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { + // Configure and create the delegate. + TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); + options.compile_options.precision_loss_allowed = 1; + options.compile_options.preferred_gl_object_type = + TFLITE_GL_OBJECT_TYPE_FASTEST; + options.compile_options.dynamic_batch_enabled = 0; + options.compile_options.inline_parameters = 1; + delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), + &TfLiteGpuDelegateDelete); + + // Get input image sizes. + const auto& input_indices = interpreter_->inputs(); + for (int i = 0; i < input_indices.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}})); + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_.get(), + gpu_buffers_in_.back()->GetOpenGlBufferWriteView().name(), + interpreter_->inputs()[i]), + kTfLiteOk); + } + interpreter_->SetAllowBufferHandleOutput(true); + // Get output image sizes. + const auto& output_indices = interpreter_->outputs(); + output_shapes_.resize(output_indices.size()); + // Create and bind output buffers. + for (int i = 0; i < output_shapes_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + gpu_buffers_out_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}})); + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_.get(), + gpu_buffers_out_.back()->GetOpenGlBufferWriteView().name(), + output_indices[i]), + kTfLiteOk); + } + + // Must call this last. + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc new file mode 100644 index 000000000..490189aec --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -0,0 +1,293 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import +#import + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/util/tflite/config.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" + +namespace { + +// Round up n to next multiple of m. +template +T RoundUp(T n, T m) { + return ((n + m - T{1}) / m) * m; +} + +} // namespace + +namespace mediapipe { +namespace api2 { + +#if MEDIAPIPE_TFLITE_METAL_INFERENCE +namespace { +tflite::gpu::BHWC BhwcFromTensorShape(const Tensor::Shape& shape) { + tflite::gpu::BHWC result; + result.b = shape.dims[0]; + switch (shape.dims.size()) { + case 1: + // result.b is already filled. + break; + case 2: + result.h = 1; + result.w = 1; + result.c = shape.dims[1]; + break; + case 3: + result.h = 1; + result.w = shape.dims[1]; + result.c = shape.dims[2]; + break; + case 4: + result.h = shape.dims[1]; + result.w = shape.dims[2]; + result.c = shape.dims[3]; + break; + default: + // Handles 0 and >4. + LOG(FATAL) + << "Dimensions size must be in range [1,4] for GPU inference, but " + << shape.dims.size() << " is provided"; + } + return result; +} +} // namespace +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE + +class InferenceCalculatorMetalImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; + +#if MEDIAPIPE_TFLITE_METAL_INFERENCE + MPPMetalHelper* gpu_helper_ = nullptr; + TFLBufferConvert* converter_to_BPHWC4_ = nil; + TFLBufferConvert* converter_from_BPHWC4_ = nil; +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + std::vector output_shapes_; + std::vector> gpu_buffers_in_; + std::vector> gpu_buffers_out_; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED +}; + +absl::Status InferenceCalculatorMetalImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + id command_buffer; + + command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"InferenceCalculator"; + // Explicit copy input with conversion float 32 bits to 16 bits. + for (int i = 0; i < input_tensors.size(); ++i) { + auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); + // Reshape tensor. + tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); + auto gpu_buffer_view = + gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); + id input_encoder = + [command_buffer computeCommandEncoder]; + [converter_to_BPHWC4_ convertWithEncoder:input_encoder + shape:shape + sourceBuffer:input_view.buffer() + convertedBuffer:gpu_buffer_view.buffer()]; + [input_encoder endEncoding]; + } + + // Run inference. + RET_CHECK(TFLGpuDelegateSetCommandBuffer(delegate_.get(), command_buffer)); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + output_shapes_[i]); + // Reshape tensor. + tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); + auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); + auto write_view = + output_tensors->at(i).GetMtlBufferWriteView(command_buffer); + id output_encoder = + [command_buffer computeCommandEncoder]; + [converter_from_BPHWC4_ convertWithEncoder:output_encoder + shape:shape + sourceBuffer:read_view.buffer() + convertedBuffer:write_view.buffer()]; + [output_encoder endEncoding]; + } + [command_buffer commit]; + + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) { + converter_to_BPHWC4_ = nil; + converter_from_BPHWC4_ = nil; + gpu_buffers_in_.clear(); + gpu_buffers_out_.clear(); + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { + const auto& calculator_opts = + cc->Options(); + + // Configure and create the delegate. + TFLGpuDelegateOptions options; + options.allow_precision_loss = true; + options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; + delegate_ = + TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + id device = gpu_helper_.mtlDevice; + + // Get input image sizes. + const auto& input_indices = interpreter_->inputs(); + for (int i = 0; i < input_indices.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + // Create and bind input buffer. + std::vector dims{tensor->dims->data, + tensor->dims->data + tensor->dims->size}; + dims.back() = RoundUp(dims.back(), 4); + gpu_buffers_in_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + auto buffer_view = + gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_.get(), input_indices[i], buffer_view.buffer()), + true); + } + + interpreter_->SetAllowBufferHandleOutput(true); + // Get output image sizes. + const auto& output_indices = interpreter_->outputs(); + output_shapes_.resize(output_indices.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size <= 4); + // Create and bind output buffers. + // Channels are always padded to multiple of 4. + std::vector dims{tensor->dims->data, + tensor->dims->data + tensor->dims->size}; + output_shapes_[i] = {dims}; + dims.back() = RoundUp(dims.back(), 4); + gpu_buffers_out_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_.get(), output_indices[i], + gpu_buffers_out_[i] + ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) + .buffer()), + true); + } + + // Create converter for GPU input. + converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:true + convertToPBHWC4:true]; + if (converter_to_BPHWC4_ == nil) { + return mediapipe::InternalError( + "Error initializating input buffer converter"); + } + // Create converter for GPU output. + converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:true + convertToPBHWC4:false]; + if (converter_from_BPHWC4_ == nil) { + return absl::InternalError("Error initializating output buffer converter"); + } + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_test.cc b/mediapipe/calculators/tensor/inference_calculator_test.cc new file mode 100644 index 000000000..882a5e81e --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_test.cc @@ -0,0 +1,162 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +#ifdef __APPLE__ +#include +#endif // defined(__APPLE__) + +namespace mediapipe { + +using ::tflite::Interpreter; + +void DoSmokeTest(const std::string& graph_proto) { + const int width = 8; + const int height = 8; + const int channels = 3; + // Prepare input tensor. + auto input_vec = absl::make_unique>(); + input_vec->emplace_back(Tensor::ElementType::kFloat32, + Tensor::Shape{1, height, width, channels}); + { + auto view1 = input_vec->back().GetCpuWriteView(); + auto tensor_buffer = view1.buffer(); + ASSERT_NE(tensor_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + tensor_buffer[i] = 1; + } + } + + // Prepare single calculator graph to and wait for packets. + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(graph_proto); + std::vector output_packets; + tool::AddVectorSink("tensor_out", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + // Push the tensor into the graph. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& result_vec = + output_packets[0].Get>(); + ASSERT_EQ(1, result_vec.size()); + + const Tensor& result = result_vec[0]; + auto view = result.GetCpuReadView(); + auto result_buffer = view.buffer(); + ASSERT_NE(result_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + ASSERT_EQ(3, result_buffer[i]); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Tests a simple add model that adds an input tensor to itself. +TEST(InferenceCalculatorTest, SmokeTest) { + std::string graph_proto = R"( + input_stream: "tensor_in" + node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:tensor_in" + output_stream: "TENSORS:tensor_out" + options { + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/calculators/tensor/testdata/add.bin" + $delegate + } + } + } + )"; + // Test CPU inference only. + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + graph_proto, {{"$delegate", "delegate { tflite {} }"}})); + DoSmokeTest(absl::StrReplaceAll(graph_proto, + {{"$delegate", "delegate { xnnpack {} }"}})); + DoSmokeTest(absl::StrReplaceAll( + graph_proto, + {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); +} + +TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) { + std::string graph_proto = R"( + input_stream: "tensor_in" + + node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:model_path" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "mediapipe/calculators/tensor/testdata/add.bin" } + } + } + } + + node { + calculator: "LocalFileContentsCalculator" + input_side_packet: "FILE_PATH:model_path" + output_side_packet: "CONTENTS:model_blob" + } + + node { + calculator: "TfLiteModelCalculator" + input_side_packet: "MODEL_BLOB:model_blob" + output_side_packet: "MODEL:model" + } + + node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:tensor_in" + output_stream: "TENSORS:tensor_out" + input_side_packet: "MODEL:model" + options { + [mediapipe.InferenceCalculatorOptions.ext] { + delegate { tflite {} } + } + } + } + )"; + DoSmokeTest(graph_proto); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc new file mode 100644 index 000000000..82180fe52 --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -0,0 +1,671 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_buffer.h" +#if MEDIAPIPE_METAL_ENABLED +#import +#import +#import + +#import "mediapipe/gpu/MPPMetalHelper.h" +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#include "mediapipe/gpu/gl_calculator_helper.h" +#if MEDIAPIPE_OPENGL_ES_VERSION < MEDIAPIPE_OPENGL_ES_31 +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // MEDIAPIPE_OPENGL_ES_VERSION < MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace { +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; +} + +typedef Eigen::Matrix + RowMajorMatrixXf; +typedef Eigen::Matrix + ColMajorMatrixXf; + +constexpr char kImageFrameTag[] = "IMAGE"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kMatrixTag[] = "MATRIX"; +} // namespace + +namespace mediapipe { + +// Calculator for normalizing and converting an ImageFrame, GpuBuffer or Matrix +// into a Tensor. +// +// This calculator is designed to be used with the TfLiteInferenceCalcualtor, +// as a pre-processing step for calculator inputs. +// +// IMAGE and IMAGE_GPU inputs are normalized to [-1,1] (default) or [0,1], +// specified by options (unless outputting a quantized tensor). +// +// Input: +// One of the following tags: +// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data). +// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture). +// MATRIX - Matrix. +// +// Output: +// One of the following tags: +// TENSORS - Vector of Tensors of type kFloat32. The resource type used: +// - MTLBuffer if Metal API is available +// - SSBO if Metal is unavailable and OpenGL ES 3.1 is available +// - Texture2D if Metal and GLES 3.1 are not available and GLES 3.0 is. +// +// Example use: +// node { +// calculator: "TensorConverterCalculator" +// input_stream: "IMAGE:input_image" +// output_stream: "TENSORS:image_tensor" +// options: { +// [mediapipe.TensorConverterCalculatorOptions.ext] { +// zero_center: true +// } +// } +// } +// +// IMPORTANT Notes: +// GPU tensors are currently only supported on mobile platforms. + +class TensorConverterCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status InitGpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); + template + absl::Status NormalizeImage(const ImageFrame& image_frame, + bool flip_vertically, float* tensor_ptr); + absl::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); + absl::Status ProcessCPU(CalculatorContext* cc); + absl::Status ProcessGPU(CalculatorContext* cc); + +#if MEDIAPIPE_METAL_ENABLED + MPPMetalHelper* gpu_helper_ = nullptr; + id to_buffer_program_; +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + mediapipe::GlCalculatorHelper gpu_helper_; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + GLuint to_buffer_program_; +#else + enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + GLuint to_tex2d_program_; + GLuint framebuffer_; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_METAL_ENABLED + + bool initialized_ = false; + bool use_gpu_ = false; + absl::optional> output_range_; + bool flip_vertically_ = false; + bool row_major_matrix_ = false; + int max_num_channels_ = 3; +}; +REGISTER_CALCULATOR(TensorConverterCalculator); + +absl::Status TensorConverterCalculator::GetContract(CalculatorContract* cc) { + // Confirm only one of the input streams is present. + RET_CHECK(static_cast(cc->Inputs().HasTag(kImageFrameTag)) + + static_cast(cc->Inputs().HasTag(kGpuBufferTag)) + + static_cast(cc->Inputs().HasTag(kMatrixTag)) == + 1); + + if (cc->Inputs().HasTag(kImageFrameTag)) { + cc->Inputs().Tag(kImageFrameTag).Set(); + } + if (cc->Inputs().HasTag(kMatrixTag)) { + cc->Inputs().Tag(kMatrixTag).Set(); + } + +#if !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag(kGpuBufferTag)) { + cc->Inputs().Tag(kGpuBufferTag).Set(); +#if MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // MEDIAPIPE_METAL_ENABLED + } +#endif // !MEDIAPIPE_DISABLE_GPU + + RET_CHECK(cc->Outputs().HasTag(kTensorsTag)); + cc->Outputs().Tag(kTensorsTag).Set>(); + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + MP_RETURN_IF_ERROR(LoadOptions(cc)); + +#if !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag(kGpuBufferTag)) { + use_gpu_ = true; +#if MEDIAPIPE_METAL_ENABLED + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif // MEDIAPIPE_METAL_ENABLED + } +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::Process(CalculatorContext* cc) { + if (use_gpu_) { + if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { + return absl::OkStatus(); + } + // Convert to GPU tensors type. + MP_RETURN_IF_ERROR(ProcessGPU(cc)); + } else { + // Convert to CPU tensors or Matrix type. + MP_RETURN_IF_ERROR(ProcessCPU(cc)); + } + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) { +#if MEDIAPIPE_METAL_ENABLED + to_buffer_program_ = nil; +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + gpu_helper_.RunInGlContext([this] { +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glDeleteProgram(to_buffer_program_); +#else + glDeleteFramebuffers(1, &framebuffer_); + glDeleteProgram(to_tex2d_program_); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + }); +#endif // MEDIAPIPE_METAL_ENABLED + } +#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) { + auto output_tensors = absl::make_unique>(); + if (cc->Inputs().HasTag(kImageFrameTag)) { + if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { + return absl::OkStatus(); + } + const auto& image_frame = + cc->Inputs().Tag(kImageFrameTag).Get(); + const int height = image_frame.Height(); + const int width = image_frame.Width(); + const int channels = image_frame.NumberOfChannels(); + const int channels_preserved = std::min(channels, max_num_channels_); + const mediapipe::ImageFormat::Format format = image_frame.Format(); + + if (!(format == mediapipe::ImageFormat::SRGBA || + format == mediapipe::ImageFormat::SRGB || + format == mediapipe::ImageFormat::GRAY8 || + format == mediapipe::ImageFormat::VEC32F1)) + RET_CHECK_FAIL() << "Unsupported CPU input format."; + + output_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, height, width, channels_preserved}); + auto cpu_view = output_tensors->back().GetCpuWriteView(); + + // Copy image data into tensor. + if (image_frame.ByteDepth() == 1) { + MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, + cpu_view.buffer())); + } else if (image_frame.ByteDepth() == 4) { + MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, + cpu_view.buffer())); + } else { + return absl::InternalError( + "Only byte-based (8 bit) and float (32 bit) images supported."); + } + } else if (cc->Inputs().HasTag(kMatrixTag)) { + if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) { + return absl::OkStatus(); + } + const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get(); + const int height = matrix.rows(); + const int width = matrix.cols(); + const int channels = 1; + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + Tensor::Shape{1, height, width, channels}); + MP_RETURN_IF_ERROR(CopyMatrixToTensor( + matrix, output_tensors->back().GetCpuWriteView().buffer())); + } else { + return absl::OkStatus(); + } + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors.release(), cc->InputTimestamp()); + + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + if (!initialized_) { + MP_RETURN_IF_ERROR(InitGpu(cc)); + initialized_ = true; + } + const auto& input = + cc->Inputs().Tag(kGpuBufferTag).Get(); + int width = input.width(); + int height = input.height(); + int channels = max_num_channels_; + auto output_tensors = absl::make_unique>(); + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + Tensor::Shape{1, height, width, channels}); +#if MEDIAPIPE_METAL_ENABLED + id device = gpu_helper_.mtlDevice; + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TensorConverterCalculatorConvert"; + id compute_encoder = + [command_buffer computeCommandEncoder]; + [compute_encoder setComputePipelineState:to_buffer_program_]; + id src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input]; + [compute_encoder setTexture:src_texture atIndex:0]; + auto output_view = + output_tensors->at(0).GetMtlBufferWriteView(command_buffer); + [compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1]; + MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); + MTLSize threadgroups = + MTLSizeMake(NumGroups(input.width(), kWorkgroupSize), + NumGroups(input.height(), kWorkgroupSize), 1); + [compute_encoder dispatchThreadgroups:threadgroups + threadsPerThreadgroup:threads_per_group]; + [compute_encoder endEncoding]; + [command_buffer commit]; +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors, &input]() -> absl::Status { + auto src = gpu_helper_.CreateSourceTexture(input); +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + // Convert GL texture into SSBO. + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, src.name()); + auto output_view = output_tensors->back().GetOpenGlBufferWriteView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, output_view.name()); + glUseProgram(to_buffer_program_); + glDispatchCompute(NumGroups(input.width(), kWorkgroupSize), + NumGroups(input.height(), kWorkgroupSize), 1); + glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); + glBindTexture(GL_TEXTURE_2D, 0); +#else + // Texture2D -> Texture2D with OpenGL ES 3.0. + glUseProgram(to_tex2d_program_); + glDisable(GL_DEPTH_TEST); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, src.width(), src.height()); + glActiveTexture(GL_TEXTURE0); + auto output_view = output_tensors->back().GetOpenGlTexture2dWriteView(); + glBindTexture(GL_TEXTURE_2D, output_view.name()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_TEXTURE_2D, output_view.name(), 0); + glActiveTexture(GL_TEXTURE1); + glBindTexture(src.target(), src.name()); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, + mediapipe::kBasicSquareVertices); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, + mediapipe::kBasicTextureVertices); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + src.Release(); + return absl::OkStatus(); + })); +#endif // MEDIAPIPE_METAL_ENABLED + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors.release(), cc->InputTimestamp()); +#else + RET_CHECK_FAIL() << "GPU processing is not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + // Get input image sizes. + const auto& input = + cc->Inputs().Tag(kGpuBufferTag).Get(); + mediapipe::ImageFormat::Format format = + mediapipe::ImageFormatForGpuBufferFormat(input.format()); + const bool include_alpha = (max_num_channels_ == 4); + const bool single_channel = (max_num_channels_ == 1); + if (!(format == mediapipe::ImageFormat::GRAY8 || + format == mediapipe::ImageFormat::SRGB || + format == mediapipe::ImageFormat::SRGBA)) + RET_CHECK_FAIL() << "Unsupported GPU input format."; + if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) + RET_CHECK_FAIL() << "Num input channels is less than desired output."; + +#if MEDIAPIPE_METAL_ENABLED + id device = gpu_helper_.mtlDevice; + // Shader to convert GL Texture to Metal Buffer, + // with normalization to either: [0,1] or [-1,1]. + const std::string shader_source = absl::Substitute( + R"( + #include + + using namespace metal; + + kernel void convertKernel( + texture2d in_tex [[ texture(0) ]], + device float* out_buf [[ buffer(1) ]], + uint2 gid [[ thread_position_in_grid ]]) { + if (gid.x >= in_tex.get_width() || gid.y >= in_tex.get_height()) return; + constexpr sampler texture_sampler(coord::pixel, address::clamp_to_edge); + const float2 coord = float2(gid.x, gid.y); + half4 pixel = in_tex.sample(texture_sampler, coord); + $0 // normalize [-1,1] + const int linear_index = $1 * ($2 * in_tex.get_width() + gid.x); + out_buf[linear_index + 0] = pixel.x; + $3 // g & b channels + $4 // alpha channel + } + )", + /*$0=*/ + output_range_.has_value() + ? absl::Substitute("pixel = pixel * half($0) + half($1);", + (output_range_->second - output_range_->first), + output_range_->first) + : "", + /*$1=*/max_num_channels_, + /*$2=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y", + /*$3=*/ + single_channel ? "" : R"(out_buf[linear_index + 1] = pixel.y; + out_buf[linear_index + 2] = pixel.z;)", + /*$4=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : ""); + + NSString* library_source = + [NSString stringWithUTF8String:shader_source.c_str()]; + NSError* error = nil; + id library = + [device newLibraryWithSource:library_source options:nullptr error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"convertKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + to_buffer_program_ = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(to_buffer_program_ != nil) << "Couldn't create pipeline state " << + [[error localizedDescription] UTF8String]; +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &include_alpha, +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + &input, +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + &single_channel]() + -> absl::Status { +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), + // with normalization to either: [0,1] or [-1,1]. + const std::string shader_source = absl::Substitute( + R"( #version 310 es + layout(local_size_x = $0, local_size_y = $0) in; + layout(binding = 0) uniform sampler2D input_texture; + layout(std430, binding = 1) buffer Output {float elements[];} output_data; + ivec2 width_height = ivec2($1, $2); + void main() { + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= width_height.x || gid.y >= width_height.y) return; + vec4 pixel = texelFetch(input_texture, gid, 0); + $3 // normalize [-1,1] + int linear_index = $7 * ($4 * width_height.x + gid.x); + output_data.elements[linear_index + 0] = pixel.x; // r channel + $5 // g & b channels + $6 // alpha channel + })", + /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), + /*$3=*/ + output_range_.has_value() + ? absl::Substitute("pixel = pixel * float($0) + float($1);", + (output_range_->second - output_range_->first), + output_range_->first) + : "", + /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", + /*$5=*/ + single_channel ? "" + : R"(output_data.elements[linear_index + 1] = pixel.y; + output_data.elements[linear_index + 2] = pixel.z;)", + /*$6=*/ + include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" + : "", + /*$7=*/max_num_channels_); + GLuint shader = glCreateShader(GL_COMPUTE_SHADER); + const GLchar* sources[] = {shader_source.c_str()}; + glShaderSource(shader, 1, sources, NULL); + glCompileShader(shader); + GLint compiled = GL_FALSE; + glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled); + RET_CHECK(compiled == GL_TRUE); + to_buffer_program_ = glCreateProgram(); + glAttachShader(to_buffer_program_, shader); + glDeleteShader(shader); + glLinkProgram(to_buffer_program_); +#else + // OpenGL ES 3.0 fragment shader Texture2d -> Texture2d conversion. + const std::string shader_source = absl::Substitute( + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out $0 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D frame; + + void main() { + $1 // flip + vec4 pixel = texture2D(frame, sample_coordinate); + $2 // normalize [-1,1] + fragColor.r = pixel.r; // r channel + $3 // g & b channels + $4 // alpha channel + })", + /*$0=*/single_channel ? "vec1" : "vec4", + /*$1=*/ + flip_vertically_ ? "sample_coordinate.y = 1.0 - sample_coordinate.y;" + : "", + /*$2=*/output_range_.has_value() + ? absl::Substitute("pixel = pixel * float($0) + float($1);", + (output_range_->second - output_range_->first), + output_range_->first) + : "", + /*$3=*/single_channel ? "" : R"(fragColor.g = pixel.g; + fragColor.b = pixel.b;)", + /*$4=*/ + include_alpha ? "fragColor.a = pixel.a;" + : (single_channel ? "" : "fragColor.a = 1.0;")); + + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + // shader program and params + mediapipe::GlhCreateProgram( + mediapipe::kBasicVertexShader, shader_source.c_str(), NUM_ATTRIBUTES, + &attr_name[0], attr_location, &to_tex2d_program_); + RET_CHECK(to_tex2d_program_) << "Problem initializing the program."; + glUseProgram(to_tex2d_program_); + glUniform1i(glGetUniformLocation(to_tex2d_program_, "frame"), 1); + glGenFramebuffers(1, &framebuffer_); + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + return absl::OkStatus(); + })); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) { + // Get calculator options specified in the graph. + const auto& options = + cc->Options<::mediapipe::TensorConverterCalculatorOptions>(); + + // if zero_center, set output float range to match [-1, 1] as specified in + // calculator proto. + if (options.zero_center()) { + output_range_.emplace(std::pair(-1.0, 1.0)); + } + + // Custom output_tensor_float_range values. + // If the float range is specified in pb text, use the specified values + // instead. + if (options.has_output_tensor_float_range()) { + output_range_.emplace(options.output_tensor_float_range().min(), + options.output_tensor_float_range().max()); + CHECK_GT(output_range_->second, output_range_->first); + } + + // Custom div and sub values. + if (options.use_custom_normalization()) { + output_range_.emplace(std::pair( + -options.custom_sub(), + -options.custom_sub() + 255.0 / options.custom_div())); + } + + // Get y-flip mode. + flip_vertically_ = options.flip_vertically(); + + // Get row_major_matrix mode. + row_major_matrix_ = options.row_major_matrix(); + + // Get desired way to handle input channels. + max_num_channels_ = options.max_num_channels(); + CHECK_GE(max_num_channels_, 1); + CHECK_LE(max_num_channels_, 4); + CHECK_NE(max_num_channels_, 2); + return absl::OkStatus(); +} + +template +absl::Status TensorConverterCalculator::NormalizeImage( + const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) { + const int height = image_frame.Height(); + const int width = image_frame.Width(); + const int channels = image_frame.NumberOfChannels(); + const int channels_preserved = std::min(channels, max_num_channels_); + const int channels_ignored = channels - channels_preserved; + + if (output_range_.has_value()) { + // If the output float range is set and we are not using custom + // normalization, normalize the pixel values from [0, 255] to the specified + // output range. + RET_CHECK_NE(output_range_->first, output_range_->second); + const float scale = (output_range_->second - output_range_->first) / 255.0f; + const float bias = output_range_->first; + + for (int i = 0; i < height; ++i) { + const T* image_ptr = reinterpret_cast( + image_frame.PixelData() + + (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); + for (int j = 0; j < width; ++j) { + for (int c = 0; c < channels_preserved; ++c) { + *tensor_ptr++ = *image_ptr++ * scale + bias; + } + image_ptr += channels_ignored; + } + } + } else { + // [0,1], scale only (bias == 0) + // Verified that there are no precision issues with 1.0f / 255.0f expression + const float scale = 1.0f / 255.0f; + for (int i = 0; i < height; ++i) { + const T* image_ptr = reinterpret_cast( + image_frame.PixelData() + + (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); + for (int j = 0; j < width; ++j) { + for (int c = 0; c < channels_preserved; ++c) { + *tensor_ptr++ = *image_ptr++ * scale; + } + image_ptr += channels_ignored; + } + } + } + + return absl::OkStatus(); +} + +absl::Status TensorConverterCalculator::CopyMatrixToTensor(const Matrix& matrix, + float* tensor_ptr) { + if (row_major_matrix_) { + auto matrix_map = + Eigen::Map(tensor_ptr, matrix.rows(), matrix.cols()); + matrix_map = matrix; + } else { + auto matrix_map = + Eigen::Map(tensor_ptr, matrix.rows(), matrix.cols()); + matrix_map = matrix; + } + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.proto b/mediapipe/calculators/tensor/tensor_converter_calculator.proto new file mode 100644 index 000000000..97c2154a0 --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.proto @@ -0,0 +1,69 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Full Example: +// +// node { +// calculator: "TensorConverterCalculator" +// input_stream: "IMAGE_IN:input_image" +// output_stream: "TENSOR_OUT:image_tensor" +// options { +// [mediapipe.TensorConverterCalculatorOptions.ext] { +// zero_center: true +// } +// } +// } +// +message TensorConverterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorConverterCalculatorOptions ext = 335742637; + } + + // Choose normalization mode for output (not applied for Matrix inputs). + // true = [-1,1] + // false = [0,1] + // Ignored if using quantization. + optional bool zero_center = 1 [default = true]; + + // Custom settings to override the internal scaling factors `div` and `sub`. + // Both values must be set to non-negative values. Will only take effect on + // CPU AND when |use_custom_normalization| is set to true. When these custom + // values take effect, the |zero_center| setting above will be overriden, and + // the normalized_value will be calculated as: + // normalized_value = input / custom_div - custom_sub. + optional bool use_custom_normalization = 6 [default = false]; + optional float custom_div = 7 [default = -1.0]; + optional float custom_sub = 8 [default = -1.0]; + + // Whether the input image should be flipped vertically (along the + // y-direction). This is useful, for example, when the input image is defined + // with a coordinate system where the origin is at the bottom-left corner + // (e.g., in OpenGL) whereas the ML model expects an image with a top-left + // origin. + optional bool flip_vertically = 2 [default = false]; + + // Controls how many channels of the input image get passed through to the + // tensor. Valid values are 1,3,4 only. Ignored for iOS GPU. + optional int32 max_num_channels = 3 [default = 3]; + + // The calculator expects Matrix inputs to be in column-major order. Set + // row_major_matrix to true if the inputs are in row-major order. + optional bool row_major_matrix = 4 [default = false]; + + // Quantization option (CPU only). + // When true, output kUint8 tensor instead of kFloat32. + optional bool use_quantized_tensors = 5 [default = false]; + + // Normalization option. + // Setting normalization_range results in the values normalized to + // the range [output_tensor_float_range.min, output_tensor_float_range.max]. + optional TensorFloatRange output_tensor_float_range = 9; + + message TensorFloatRange { + optional float min = 1; + optional float max = 2; + } +} diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc new file mode 100644 index 000000000..69eb7df77 --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc @@ -0,0 +1,323 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { +namespace { + +constexpr char kTransposeOptionsString[] = + "[mediapipe.TensorConverterCalculatorOptions.ext]: {" + "row_major_matrix: True}"; + +} // namespace + +using RandomEngine = std::mt19937_64; +using testing::Eq; +const uint32 kSeed = 1234; +const int kNumSizes = 8; +const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, + {5, 3}, {7, 13}, {16, 32}, {101, 2}}; + +class TensorConverterCalculatorTest : public ::testing::Test { + protected: + // Adds a packet with a matrix filled with random values in [0,1]. + void AddRandomMatrix(int num_rows, int num_columns, uint32 seed, + bool row_major_matrix = false) { + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + auto matrix = ::absl::make_unique(); + matrix->resize(num_rows, num_columns); + if (row_major_matrix) { + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + float value = uniform_dist(random); + (*matrix)(y, x) = value; + } + } + } else { + for (int x = 0; x < num_columns; ++x) { + for (int y = 0; y < num_rows; ++y) { + float value = uniform_dist(random); + (*matrix)(y, x) = value; + } + } + } + MP_ASSERT_OK(graph_->AddPacketToInputStream( + "matrix", Adopt(matrix.release()).At(Timestamp(0)))); + } + + std::unique_ptr graph_; +}; + +TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(R"( + input_stream: "matrix" + node { + calculator: "TensorConverterCalculator" + input_stream: "MATRIX:matrix" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + row_major_matrix: false + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + graph_ = absl::make_unique(); + MP_ASSERT_OK(graph_->Initialize(graph_config)); + MP_ASSERT_OK(graph_->StartRun({})); + + // Push the tensor into the graph. + AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/false); + + // Wait until the calculator done processing. + MP_ASSERT_OK(graph_->WaitUntilIdle()); + EXPECT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_EQ(1, tensor_vec.size()); + + const Tensor* tensor = &tensor_vec[0]; + EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + auto view = tensor->GetCpuReadView(); + auto tensor_buffer = view.buffer(); + for (int i = 0; i < num_rows * num_columns; ++i) { + const float expected = uniform_dist(random); + EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph_->CloseInputStream("matrix")); + MP_ASSERT_OK(graph_->WaitUntilDone()); + + graph_.reset(); + } +} + +TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(R"( + input_stream: "matrix" + node { + calculator: "TensorConverterCalculator" + input_stream: "MATRIX:matrix" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + row_major_matrix: true + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + graph_ = absl::make_unique(); + MP_ASSERT_OK(graph_->Initialize(graph_config)); + MP_ASSERT_OK(graph_->StartRun({})); + + // Push the tensor into the graph. + AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/true); + + // Wait until the calculator done processing. + MP_ASSERT_OK(graph_->WaitUntilIdle()); + EXPECT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_EQ(1, tensor_vec.size()); + + const Tensor* tensor = &tensor_vec[0]; + EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + auto view = tensor->GetCpuReadView(); + auto tensor_buffer = view.buffer(); + for (int i = 0; i < num_rows * num_columns; ++i) { + const float expected = uniform_dist(random); + EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph_->CloseInputStream("matrix")); + MP_ASSERT_OK(graph_->WaitUntilDone()); + + graph_.reset(); + } +} + +TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) { + CalculatorGraph graph; + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_image" + node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + row_major_matrix: true + use_custom_normalization: true + custom_div: 2.0 + custom_sub: 33.0 + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); + mat.at(0, 0) = 200; + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_EQ(1, tensor_vec.size()); + + const Tensor* tensor = &tensor_vec[0]; + EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); + auto view = tensor->GetCpuReadView(); + EXPECT_FLOAT_EQ(67.0f, *view.buffer()); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(TensorConverterCalculatorTest, SetOutputRange) { + std::vector> range_values = { + std::make_pair(0.0, 1.0), std::make_pair(-1.0, 1.0), + std::make_pair(-0.5, 0.5)}; + for (std::pair range : range_values) { + CalculatorGraph graph; + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "input_image" + node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + output_tensor_float_range { + min: $0 + max: $1 + } + } + } + } + )", + /*$0=*/range.first, + /*$1=*/range.second)); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); + mat.at(0, 0) = 200; + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator finishes processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets.size(), Eq(1)); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_THAT(tensor_vec.size(), Eq(1)); + + const Tensor* tensor = &tensor_vec[0]; + + // Calculate the expected normalized value: + float normalized_value = + range.first + (200 * (range.second - range.first)) / 255.0; + + EXPECT_THAT(tensor->element_type(), Eq(Tensor::ElementType::kFloat32)); + auto view = tensor->GetCpuReadView(); + float dataf = *view.buffer(); + EXPECT_THAT( + normalized_value, + testing::FloatNear(dataf, 2.0f * std::abs(dataf) * + std::numeric_limits::epsilon())); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc new file mode 100644 index 000000000..c3b91de71 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -0,0 +1,178 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" +#if defined(MEDIAPIPE_MOBILE) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/helpers.h" +#else +#include "mediapipe/framework/port/file_helpers.h" +#endif + +namespace mediapipe { +namespace api2 { + +// Convert result tensors from classification models into MediaPipe +// classifications. +// +// Input: +// TENSORS - Vector of Tensors of type kFloat32 containing one +// tensor, the size of which must be (1, * num_classes). +// Output: +// CLASSIFICATIONS - Result MediaPipe ClassificationList. The score and index +// fields of each classification are set, while the label +// field is only set if label_map_path is provided. +// +// Usage example: +// node { +// calculator: "TensorsToClassificationCalculator" +// input_stream: "TENSORS:tensors" +// output_stream: "CLASSIFICATIONS:classifications" +// options: { +// [mediapipe.TensorsToClassificationCalculatorOptions.ext] { +// num_classes: 1024 +// min_score_threshold: 0.1 +// label_map_path: "labelmap.txt" +// } +// } +// } +class TensorsToClassificationCalculator : public Node { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Output kOutClassificationList{ + "CLASSIFICATIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutClassificationList); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::TensorsToClassificationCalculatorOptions options_; + int top_k_ = 0; + absl::node_hash_map label_map_; + bool label_map_loaded_ = false; +}; +MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); + +absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { + options_ = + cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>(); + + top_k_ = options_.top_k(); + if (options_.has_label_map_path()) { + std::string string_path; + ASSIGN_OR_RETURN(string_path, + PathToResourceAsFile(options_.label_map_path())); + std::string label_map_string; + MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + + std::istringstream stream(label_map_string); + std::string line; + int i = 0; + while (std::getline(stream, line)) { + label_map_[i++] = line; + } + label_map_loaded_ = true; + } + + return absl::OkStatus(); +} + +absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { + const auto& input_tensors = *kInTensors(cc); + RET_CHECK_EQ(input_tensors.size(), 1); + + int num_classes = input_tensors[0].shape().num_elements(); + + if (options_.binary_classification()) { + RET_CHECK_EQ(num_classes, 1); + // Number of classes for binary classification. + num_classes = 2; + } + if (label_map_loaded_) { + RET_CHECK_EQ(num_classes, label_map_.size()); + } + auto view = input_tensors[0].GetCpuReadView(); + auto raw_scores = view.buffer(); + + auto classification_list = absl::make_unique(); + if (options_.binary_classification()) { + Classification* class_first = classification_list->add_classification(); + Classification* class_second = classification_list->add_classification(); + class_first->set_index(0); + class_second->set_index(1); + class_first->set_score(raw_scores[0]); + class_second->set_score(1. - raw_scores[0]); + + if (label_map_loaded_) { + class_first->set_label(label_map_[0]); + class_second->set_label(label_map_[1]); + } + } else { + for (int i = 0; i < num_classes; ++i) { + if (options_.has_min_score_threshold() && + raw_scores[i] < options_.min_score_threshold()) { + continue; + } + Classification* classification = + classification_list->add_classification(); + classification->set_index(i); + classification->set_score(raw_scores[i]); + + if (label_map_loaded_) { + classification->set_label(label_map_[i]); + } + } + } + + // Note that partial_sort will raise error when top_k_ > + // classification_list->classification_size(). + CHECK_GE(classification_list->classification_size(), top_k_); + auto raw_classification_list = classification_list->mutable_classification(); + if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { + std::partial_sort(raw_classification_list->begin(), + raw_classification_list->begin() + top_k_, + raw_classification_list->end(), + [](const Classification a, const Classification b) { + return a.score() > b.score(); + }); + + // Resizes the underlying list to have only top_k_ classifications. + raw_classification_list->DeleteSubrange( + top_k_, raw_classification_list->size() - top_k_); + } + kOutClassificationList(cc).Send(std::move(classification_list)); + return absl::OkStatus(); +} + +absl::Status TensorsToClassificationCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto new file mode 100644 index 000000000..51f7f3f90 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -0,0 +1,41 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TensorsToClassificationCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsToClassificationCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TensorsToClassificationCalculatorOptions ext = 335742638; + } + + // Score threshold for perserving the class. + optional float min_score_threshold = 1; + // Number of highest scoring labels to output. If top_k is not positive then + // all labels are used. + optional int32 top_k = 2; + // Path to a label map file for getting the actual name of class ids. + optional string label_map_path = 3; + // Whether the input is a single float for binary classification. + // When true, only a single float is expected in the input tensor and the + // label map, if provided, is expected to have exactly two labels. + // The single score(float) represent the probability of first label, and + // 1 - score is the probabilility of the second label. + optional bool binary_classification = 4; +} diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc new file mode 100644 index 000000000..8f4877dad --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -0,0 +1,174 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +using mediapipe::ParseTextProtoOrDie; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +class TensorsToClassificationCalculatorTest : public ::testing::Test { + protected: + void BuildGraph(mediapipe::CalculatorRunner* runner, + const std::vector& scores) { + auto tensors = absl::make_unique>(); + tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, 1, static_cast(scores.size()), 1}); + auto view = tensors->back().GetCpuWriteView(); + float* tensor_buffer = view.buffer(); + ASSERT_NE(tensor_buffer, nullptr); + for (int i = 0; i < scores.size(); ++i) { + tensor_buffer[i] = scores[i]; + } + + int64 stream_timestamp = 0; + auto& input_stream_packets = + runner->MutableInputs()->Tag("TENSORS").packets; + + input_stream_packets.push_back( + mediapipe::Adopt(tensors.release()) + .At(mediapipe::Timestamp(stream_timestamp++))); + } +}; + +TEST_F(TensorsToClassificationCalculatorTest, CorrectOutput) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] {} + } + )")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + // Verify that the label_id and score fields are set correctly. + for (int i = 0; i < classification_list.classification_size(); ++i) { + EXPECT_EQ(i, classification_list.classification(i).index()); + EXPECT_EQ(i * 0.5, classification_list.classification(i).score()); + ASSERT_FALSE(classification_list.classification(i).has_label()); + } +} + +TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMapPath) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + label_map_path: "mediapipe/calculators/tensor/testdata/labelmap.txt" + } + } + )")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + // Verify that the label field is set. + for (int i = 0; i < classification_list.classification_size(); ++i) { + EXPECT_EQ(i, classification_list.classification(i).index()); + EXPECT_EQ(i * 0.5, classification_list.classification(i).score()); + ASSERT_TRUE(classification_list.classification(i).has_label()); + } +} + +TEST_F(TensorsToClassificationCalculatorTest, + CorrectOutputWithLabelMinScoreThreshold) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + min_score_threshold: 0.6 + } + } + )")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + + // Verify that the low score labels are filtered out. + EXPECT_EQ(1, classification_list.classification_size()); + EXPECT_EQ(1, classification_list.classification(0).score()); +} + +TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithTopK) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { top_k: 2 } + } + )")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + + // Verify that the only top2 labels are left. + EXPECT_EQ(2, classification_list.classification_size()); + for (int i = 0; i < classification_list.classification_size(); ++i) { + EXPECT_EQ((classification_list.classification_size() - i) * 0.5, + classification_list.classification(i).score()); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc new file mode 100644 index 000000000..1a27cafce --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -0,0 +1,1144 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" + +// Note: On Apple platforms MEDIAPIPE_DISABLE_GL_COMPUTE is automatically +// defined in mediapipe/framework/port.h. Therefore, +// "#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE" and "#if MEDIAPIPE_METAL_ENABLED" +// below are mutually exclusive. +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + +#if MEDIAPIPE_METAL_ENABLED +#import +#import +#import + +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#endif // MEDIAPIPE_METAL_ENABLED + +namespace { +constexpr int kNumInputTensorsWithAnchors = 3; +constexpr int kNumCoordsPerBox = 4; + +bool CanUseGpu() { +#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || MEDIAPIPE_METAL_ENABLED + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; +#else + return false; +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || MEDIAPIPE_METAL_ENABLED +} +} // namespace + +namespace mediapipe { +namespace api2 { + +namespace { + +void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, + std::vector* anchors) { + anchors->clear(); + for (int i = 0; i < num_boxes; ++i) { + Anchor new_anchor; + new_anchor.set_y_center(raw_anchors[i * kNumCoordsPerBox + 0]); + new_anchor.set_x_center(raw_anchors[i * kNumCoordsPerBox + 1]); + new_anchor.set_h(raw_anchors[i * kNumCoordsPerBox + 2]); + new_anchor.set_w(raw_anchors[i * kNumCoordsPerBox + 3]); + anchors->push_back(new_anchor); + } +} + +void ConvertAnchorsToRawValues(const std::vector& anchors, + int num_boxes, float* raw_anchors) { + CHECK_EQ(anchors.size(), num_boxes); + int box = 0; + for (const auto& anchor : anchors) { + raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); + raw_anchors[box * kNumCoordsPerBox + 1] = anchor.x_center(); + raw_anchors[box * kNumCoordsPerBox + 2] = anchor.h(); + raw_anchors[box * kNumCoordsPerBox + 3] = anchor.w(); + ++box; + } +} + +} // namespace + +// Convert result Tensors from object detection models into MediaPipe +// Detections. +// +// Input: +// TENSORS - Vector of Tensors of type kFloat32. The vector of tensors can have +// 2 or 3 tensors. First tensor is the predicted raw boxes/keypoints. +// The size of the values must be (num_boxes * num_predicted_values). +// Second tensor is the score tensor. The size of the valuse must be +// (num_boxes * num_classes). It's optional to pass in a third tensor +// for anchors (e.g. for SSD models) depend on the outputs of the +// detection model. The size of anchor tensor must be (num_boxes * +// 4). +// Output: +// DETECTIONS - Result MediaPipe detections. +// +// Usage example: +// node { +// calculator: "TensorsToDetectionsCalculator" +// input_stream: "TENSORS:tensors" +// input_side_packet: "ANCHORS:anchors" +// output_stream: "DETECTIONS:detections" +// options: { +// [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { +// num_classes: 91 +// num_boxes: 1917 +// num_coords: 4 +// ignore_classes: [0, 1, 2] +// x_scale: 10.0 +// y_scale: 10.0 +// h_scale: 5.0 +// w_scale: 5.0 +// } +// } +// } +class TensorsToDetectionsCalculator : public Node { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr SideInput>::Optional kInAnchors{ + "ANCHORS"}; + static constexpr Output> kOutDetections{"DETECTIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections); + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status ProcessCPU(CalculatorContext* cc, + std::vector* output_detections); + absl::Status ProcessGPU(CalculatorContext* cc, + std::vector* output_detections); + + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status GpuInit(CalculatorContext* cc); + absl::Status DecodeBoxes(const float* raw_boxes, + const std::vector& anchors, + std::vector* boxes); + absl::Status ConvertToDetections(const float* detection_boxes, + const float* detection_scores, + const int* detection_classes, + std::vector* output_detections); + Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, + float box_xmax, float score, int class_id, + bool flip_vertically); + + int num_classes_ = 0; + int num_boxes_ = 0; + int num_coords_ = 0; + std::set ignore_classes_; + + ::mediapipe::TensorsToDetectionsCalculatorOptions options_; + std::vector anchors_; + +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint decode_program_; + GLuint score_program_; +#elif MEDIAPIPE_METAL_ENABLED + MPPMetalHelper* gpu_helper_ = nullptr; + id decode_program_; + id score_program_; +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + std::unique_ptr raw_anchors_buffer_; + std::unique_ptr decoded_boxes_buffer_; + std::unique_ptr scored_boxes_buffer_; + + bool gpu_inited_ = false; + bool gpu_input_ = false; + bool anchors_init_ = false; +}; +MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); + +absl::Status TensorsToDetectionsCalculator::UpdateContract( + CalculatorContract* cc) { + if (CanUseGpu()) { +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#elif MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + } + + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadOptions(cc)); + + if (CanUseGpu()) { +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#elif MEDIAPIPE_METAL_ENABLED + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + } + + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { + auto output_detections = absl::make_unique>(); + bool gpu_processing = false; + if (CanUseGpu()) { + // Use GPU processing only if at least one input tensor is already on GPU + // (to avoid CPU->GPU overhead). + for (const auto& tensor : *kInTensors(cc)) { + if (tensor.ready_on_gpu()) { + gpu_processing = true; + break; + } + } + } + + if (gpu_processing) { + if (!gpu_inited_) { + MP_RETURN_IF_ERROR(GpuInit(cc)); + gpu_inited_ = true; + } + MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); + } else { + MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); + } + + kOutDetections(cc).Send(std::move(output_detections)); + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::ProcessCPU( + CalculatorContext* cc, std::vector* output_detections) { + const auto& input_tensors = *kInTensors(cc); + + if (input_tensors.size() == 2 || + input_tensors.size() == kNumInputTensorsWithAnchors) { + // Postprocessing on CPU for model without postprocessing op. E.g. output + // raw score tensor and box tensor. Anchor decoding will be handled below. + // TODO: Add flexible input tensor size handling. + auto raw_box_tensor = &input_tensors[0]; + RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); + RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); + auto raw_score_tensor = &input_tensors[1]; + RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); + RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); + auto raw_box_view = raw_box_tensor->GetCpuReadView(); + auto raw_boxes = raw_box_view.buffer(); + auto raw_scores_view = raw_score_tensor->GetCpuReadView(); + auto raw_scores = raw_scores_view.buffer(); + + // TODO: Support other options to load anchors. + if (!anchors_init_) { + if (input_tensors.size() == kNumInputTensorsWithAnchors) { + auto anchor_tensor = &input_tensors[2]; + RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2); + RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_); + RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox); + auto anchor_view = anchor_tensor->GetCpuReadView(); + auto raw_anchors = anchor_view.buffer(); + ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); + } else if (!kInAnchors(cc).IsEmpty()) { + anchors_ = *kInAnchors(cc); + } else { + return absl::UnavailableError("No anchor data available."); + } + anchors_init_ = true; + } + std::vector boxes(num_boxes_ * num_coords_); + MP_RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes)); + + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); + + // Filter classes by scores. + for (int i = 0; i < num_boxes_; ++i) { + int class_id = -1; + float max_score = -std::numeric_limits::max(); + // Find the top score for box i. + for (int score_idx = 0; score_idx < num_classes_; ++score_idx) { + if (ignore_classes_.find(score_idx) == ignore_classes_.end()) { + auto score = raw_scores[i * num_classes_ + score_idx]; + if (options_.sigmoid_score()) { + if (options_.has_score_clipping_thresh()) { + score = score < -options_.score_clipping_thresh() + ? -options_.score_clipping_thresh() + : score; + score = score > options_.score_clipping_thresh() + ? options_.score_clipping_thresh() + : score; + } + score = 1.0f / (1.0f + std::exp(-score)); + } + if (max_score < score) { + max_score = score; + class_id = score_idx; + } + } + } + detection_scores[i] = max_score; + detection_classes[i] = class_id; + } + + MP_RETURN_IF_ERROR( + ConvertToDetections(boxes.data(), detection_scores.data(), + detection_classes.data(), output_detections)); + } else { + // Postprocessing on CPU with postprocessing op (e.g. anchor decoding and + // non-maximum suppression) within the model. + RET_CHECK_EQ(input_tensors.size(), 4); + + auto num_boxes_tensor = &input_tensors[3]; + RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1); + RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1); + + auto detection_boxes_tensor = &input_tensors[0]; + RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3); + RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1); + const int max_detections = detection_boxes_tensor->shape().dims[1]; + RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_); + + auto detection_classes_tensor = &input_tensors[1]; + RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2); + RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1); + RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections); + + auto detection_scores_tensor = &input_tensors[2]; + RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2); + RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1); + RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections); + + auto num_boxes_view = num_boxes_tensor->GetCpuReadView(); + auto num_boxes = num_boxes_view.buffer(); + num_boxes_ = num_boxes[0]; + + auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView(); + auto detection_boxes = detection_boxes_view.buffer(); + + auto detection_scores_view = detection_scores_tensor->GetCpuReadView(); + auto detection_scores = detection_scores_view.buffer(); + + auto detection_classes_view = detection_classes_tensor->GetCpuReadView(); + auto detection_classes_ptr = detection_classes_view.buffer(); + std::vector detection_classes(num_boxes_); + for (int i = 0; i < num_boxes_; ++i) { + detection_classes[i] = static_cast(detection_classes_ptr[i]); + } + MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, + detection_classes.data(), + output_detections)); + } + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::ProcessGPU( + CalculatorContext* cc, std::vector* output_detections) { + const auto& input_tensors = *kInTensors(cc); + RET_CHECK_GE(input_tensors.size(), 2); +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, + &output_detections]() + -> absl::Status { + if (!anchors_init_) { + if (input_tensors.size() == kNumInputTensorsWithAnchors) { + auto read_view = input_tensors[2].GetOpenGlBufferReadView(); + glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); + auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView(); + glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + input_tensors[2].bytes()); + } else if (!kInAnchors(cc).IsEmpty()) { + const auto& anchors = *kInAnchors(cc); + auto anchors_view = raw_anchors_buffer_->GetCpuWriteView(); + auto raw_anchors = anchors_view.buffer(); + ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors); + } else { + return absl::UnavailableError("No anchor data available."); + } + anchors_init_ = true; + } + // Use the scope to release the writable buffers' views before requesting + // the reading buffers' views. + { + // Decode boxes. + auto scored_boxes_view = scored_boxes_buffer_->GetOpenGlBufferWriteView(); + auto decoded_boxes_view = + decoded_boxes_buffer_->GetOpenGlBufferWriteView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name()); + auto input0_view = input_tensors[0].GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name()); + auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name()); + glUseProgram(decode_program_); + glDispatchCompute(num_boxes_, 1, 1); + + // Score boxes. + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name()); + auto input1_view = input_tensors[1].GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name()); + glUseProgram(score_program_); + glDispatchCompute(num_boxes_, 1, 1); + } + return absl::OkStatus(); + })); + + // TODO: b/138851969. Is it possible to output a float vector + // for score and an int vector for class so that we can avoid copying twice? + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); + // The order of requesting of CpuViews must be the same as the order of + // requesting OpenGlViews above to avoid 'Potential mutex deadlock' message + // when compiled without '-c opt' option. + auto scored_boxes_view = scored_boxes_buffer_->GetCpuReadView(); + auto score_class_id_pairs = scored_boxes_view.buffer(); + for (int i = 0; i < num_boxes_; ++i) { + detection_scores[i] = score_class_id_pairs[i * 2]; + detection_classes[i] = static_cast(score_class_id_pairs[i * 2 + 1]); + } + auto decoded_boxes_view = decoded_boxes_buffer_->GetCpuReadView(); + auto boxes = decoded_boxes_view.buffer(); + MP_RETURN_IF_ERROR(ConvertToDetections(boxes, detection_scores.data(), + detection_classes.data(), + output_detections)); +#elif MEDIAPIPE_METAL_ENABLED + id device = gpu_helper_.mtlDevice; + if (!anchors_init_) { + if (input_tensors.size() == kNumInputTensorsWithAnchors) { + RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); + auto command_buffer = [gpu_helper_ commandBuffer]; + auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer); + auto dest_buffer = + raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); + id blit_command = + [command_buffer blitCommandEncoder]; + [blit_command copyFromBuffer:src_buffer.buffer() + sourceOffset:0 + toBuffer:dest_buffer.buffer() + destinationOffset:0 + size:input_tensors[2].bytes()]; + [blit_command endEncoding]; + [command_buffer commit]; + } else if (!kInAnchors(cc).IsEmpty()) { + const auto& anchors = *kInAnchors(cc); + auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView(); + ConvertAnchorsToRawValues(anchors, num_boxes_, + raw_anchors_view.buffer()); + } else { + return absl::UnavailableError("No anchor data available."); + } + anchors_init_ = true; + } + + // Use the scope to release the writable buffers' views before requesting the + // reading buffers' views. + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"DecodeAndScoreBoxes"; + id command_encoder = + [command_buffer computeCommandEncoder]; + [command_encoder setComputePipelineState:decode_program_]; + { + auto scored_boxes_view = + scored_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + auto decoded_boxes_view = + decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; + auto input0_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; + auto raw_anchors_view = + raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); + [command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2]; + MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1); + MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1); + [command_encoder dispatchThreadgroups:decode_threadgroups + threadsPerThreadgroup:decode_threads_per_group]; + + [command_encoder setComputePipelineState:score_program_]; + [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; + auto input1_view = input_tensors[1].GetMtlBufferReadView(command_buffer); + [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; + MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); + MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); + [command_encoder dispatchThreadgroups:score_threadgroups + threadsPerThreadgroup:score_threads_per_group]; + [command_encoder endEncoding]; + [command_buffer commit]; + } + + // Output detections. + // TODO Adjust shader to avoid copying shader output twice. + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); + { + auto scored_boxes_view = scored_boxes_buffer_->GetCpuReadView(); + auto score_class_id_pairs = scored_boxes_view.buffer(); + for (int i = 0; i < num_boxes_; ++i) { + detection_scores[i] = score_class_id_pairs[i * 2]; + detection_classes[i] = static_cast(score_class_id_pairs[i * 2 + 1]); + } + } + auto decoded_boxes_view = decoded_boxes_buffer_->GetCpuReadView(); + auto boxes = decoded_boxes_view.buffer(); + MP_RETURN_IF_ERROR(ConvertToDetections(boxes, detection_scores.data(), + detection_classes.data(), + output_detections)); + +#else + LOG(ERROR) << "GPU input on non-Android not supported yet."; +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::Close(CalculatorContext* cc) { +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + gpu_helper_.RunInGlContext([this] { + decoded_boxes_buffer_ = nullptr; + scored_boxes_buffer_ = nullptr; + raw_anchors_buffer_ = nullptr; + glDeleteProgram(decode_program_); + glDeleteProgram(score_program_); + }); +#elif MEDIAPIPE_METAL_ENABLED + decoded_boxes_buffer_ = nullptr; + scored_boxes_buffer_ = nullptr; + raw_anchors_buffer_ = nullptr; + decode_program_ = nil; + score_program_ = nil; +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = cc->Options<::mediapipe::TensorsToDetectionsCalculatorOptions>(); + RET_CHECK(options_.has_num_classes()); + RET_CHECK(options_.has_num_boxes()); + RET_CHECK(options_.has_num_coords()); + + num_classes_ = options_.num_classes(); + num_boxes_ = options_.num_boxes(); + num_coords_ = options_.num_coords(); + + // Currently only support 2D when num_values_per_keypoint equals to 2. + CHECK_EQ(options_.num_values_per_keypoint(), 2); + + // Check if the output size is equal to the requested boxes and keypoints. + CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + + kNumCoordsPerBox, + num_coords_); + + for (int i = 0; i < options_.ignore_classes_size(); ++i) { + ignore_classes_.insert(options_.ignore_classes(i)); + } + + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::DecodeBoxes( + const float* raw_boxes, const std::vector& anchors, + std::vector* boxes) { + for (int i = 0; i < num_boxes_; ++i) { + const int box_offset = i * num_coords_ + options_.box_coord_offset(); + + float y_center = raw_boxes[box_offset]; + float x_center = raw_boxes[box_offset + 1]; + float h = raw_boxes[box_offset + 2]; + float w = raw_boxes[box_offset + 3]; + if (options_.reverse_output_order()) { + x_center = raw_boxes[box_offset]; + y_center = raw_boxes[box_offset + 1]; + w = raw_boxes[box_offset + 2]; + h = raw_boxes[box_offset + 3]; + } + + x_center = + x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center(); + y_center = + y_center / options_.y_scale() * anchors[i].h() + anchors[i].y_center(); + + if (options_.apply_exponential_on_box_size()) { + h = std::exp(h / options_.h_scale()) * anchors[i].h(); + w = std::exp(w / options_.w_scale()) * anchors[i].w(); + } else { + h = h / options_.h_scale() * anchors[i].h(); + w = w / options_.w_scale() * anchors[i].w(); + } + + const float ymin = y_center - h / 2.f; + const float xmin = x_center - w / 2.f; + const float ymax = y_center + h / 2.f; + const float xmax = x_center + w / 2.f; + + (*boxes)[i * num_coords_ + 0] = ymin; + (*boxes)[i * num_coords_ + 1] = xmin; + (*boxes)[i * num_coords_ + 2] = ymax; + (*boxes)[i * num_coords_ + 3] = xmax; + + if (options_.num_keypoints()) { + for (int k = 0; k < options_.num_keypoints(); ++k) { + const int offset = i * num_coords_ + options_.keypoint_coord_offset() + + k * options_.num_values_per_keypoint(); + + float keypoint_y = raw_boxes[offset]; + float keypoint_x = raw_boxes[offset + 1]; + if (options_.reverse_output_order()) { + keypoint_x = raw_boxes[offset]; + keypoint_y = raw_boxes[offset + 1]; + } + + (*boxes)[offset] = keypoint_x / options_.x_scale() * anchors[i].w() + + anchors[i].x_center(); + (*boxes)[offset + 1] = + keypoint_y / options_.y_scale() * anchors[i].h() + + anchors[i].y_center(); + } + } + } + + return absl::OkStatus(); +} + +absl::Status TensorsToDetectionsCalculator::ConvertToDetections( + const float* detection_boxes, const float* detection_scores, + const int* detection_classes, std::vector* output_detections) { + for (int i = 0; i < num_boxes_; ++i) { + if (options_.has_min_score_thresh() && + detection_scores[i] < options_.min_score_thresh()) { + continue; + } + const int box_offset = i * num_coords_; + Detection detection = ConvertToDetection( + detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], + detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], + detection_scores[i], detection_classes[i], options_.flip_vertically()); + const auto& bbox = detection.location_data().relative_bounding_box(); + if (bbox.width() < 0 || bbox.height() < 0) { + // Decoded detection boxes could have negative values for width/height due + // to model prediction. Filter out those boxes since some downstream + // calculators may assume non-negative values. (b/171391719) + continue; + } + // Add keypoints. + if (options_.num_keypoints() > 0) { + auto* location_data = detection.mutable_location_data(); + for (int kp_id = 0; kp_id < options_.num_keypoints() * + options_.num_values_per_keypoint(); + kp_id += options_.num_values_per_keypoint()) { + auto keypoint = location_data->add_relative_keypoints(); + const int keypoint_index = + box_offset + options_.keypoint_coord_offset() + kp_id; + keypoint->set_x(detection_boxes[keypoint_index + 0]); + keypoint->set_y(options_.flip_vertically() + ? 1.f - detection_boxes[keypoint_index + 1] + : detection_boxes[keypoint_index + 1]); + } + } + output_detections->emplace_back(detection); + } + return absl::OkStatus(); +} + +Detection TensorsToDetectionsCalculator::ConvertToDetection( + float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, + int class_id, bool flip_vertically) { + Detection detection; + detection.add_score(score); + detection.add_label_id(class_id); + + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + + LocationData::RelativeBoundingBox* relative_bbox = + location_data->mutable_relative_bounding_box(); + + relative_bbox->set_xmin(box_xmin); + relative_bbox->set_ymin(flip_vertically ? 1.f - box_ymax : box_ymin); + relative_bbox->set_width(box_xmax - box_xmin); + relative_bbox->set_height(box_ymax - box_ymin); + return detection; +} + +absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) { +#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { + // A shader to decode detection boxes. + const std::string decode_src = absl::Substitute( + R"( #version 310 es + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(location = 0) uniform vec4 scale; + +layout(std430, binding = 0) writeonly buffer Output { + float data[]; +} boxes; + +layout(std430, binding = 1) readonly buffer Input0 { + float data[]; +} raw_boxes; + +layout(std430, binding = 2) readonly buffer Input1 { + float data[]; +} raw_anchors; + +uint num_coords = uint($0); +int reverse_output_order = int($1); +int apply_exponential = int($2); +int box_coord_offset = int($3); +int num_keypoints = int($4); +int keypt_coord_offset = int($5); +int num_values_per_keypt = int($6); + +void main() { + uint g_idx = gl_GlobalInvocationID.x; // box index + uint box_offset = g_idx * num_coords + uint(box_coord_offset); + uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox + + float y_center, x_center, h, w; + + if (reverse_output_order == int(0)) { + y_center = raw_boxes.data[box_offset + uint(0)]; + x_center = raw_boxes.data[box_offset + uint(1)]; + h = raw_boxes.data[box_offset + uint(2)]; + w = raw_boxes.data[box_offset + uint(3)]; + } else { + x_center = raw_boxes.data[box_offset + uint(0)]; + y_center = raw_boxes.data[box_offset + uint(1)]; + w = raw_boxes.data[box_offset + uint(2)]; + h = raw_boxes.data[box_offset + uint(3)]; + } + + float anchor_yc = raw_anchors.data[anchor_offset + uint(0)]; + float anchor_xc = raw_anchors.data[anchor_offset + uint(1)]; + float anchor_h = raw_anchors.data[anchor_offset + uint(2)]; + float anchor_w = raw_anchors.data[anchor_offset + uint(3)]; + + x_center = x_center / scale.x * anchor_w + anchor_xc; + y_center = y_center / scale.y * anchor_h + anchor_yc; + + if (apply_exponential == int(1)) { + h = exp(h / scale.w) * anchor_h; + w = exp(w / scale.z) * anchor_w; + } else { + h = (h / scale.w) * anchor_h; + w = (w / scale.z) * anchor_w; + } + + float ymin = y_center - h / 2.0; + float xmin = x_center - w / 2.0; + float ymax = y_center + h / 2.0; + float xmax = x_center + w / 2.0; + + boxes.data[box_offset + uint(0)] = ymin; + boxes.data[box_offset + uint(1)] = xmin; + boxes.data[box_offset + uint(2)] = ymax; + boxes.data[box_offset + uint(3)] = xmax; + + if (num_keypoints > int(0)){ + for (int k = 0; k < num_keypoints; ++k) { + int kp_offset = + int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; + float kp_y, kp_x; + if (reverse_output_order == int(0)) { + kp_y = raw_boxes.data[kp_offset + int(0)]; + kp_x = raw_boxes.data[kp_offset + int(1)]; + } else { + kp_x = raw_boxes.data[kp_offset + int(0)]; + kp_y = raw_boxes.data[kp_offset + int(1)]; + } + boxes.data[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc; + boxes.data[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc; + } + } +})", + options_.num_coords(), // box xywh + options_.reverse_output_order() ? 1 : 0, + options_.apply_exponential_on_box_size() ? 1 : 0, + options_.box_coord_offset(), options_.num_keypoints(), + options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); + + // Shader program + GLuint shader = glCreateShader(GL_COMPUTE_SHADER); + const GLchar* sources[] = {decode_src.c_str()}; + glShaderSource(shader, 1, sources, NULL); + glCompileShader(shader); + GLint compiled = GL_FALSE; + glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled); + RET_CHECK(compiled == GL_TRUE) << "Shader compilation error: " << [shader] { + GLint length; + glGetShaderiv(shader, GL_INFO_LOG_LENGTH, &length); + std::string str; + str.reserve(length); + glGetShaderInfoLog(shader, length, nullptr, str.data()); + return str; + }(); + decode_program_ = glCreateProgram(); + glAttachShader(decode_program_, shader); + glDeleteShader(shader); + glLinkProgram(decode_program_); + + // Outputs + decoded_boxes_buffer_ = + absl::make_unique(Tensor::ElementType::kFloat32, + Tensor::Shape{1, num_boxes_ * num_coords_}); + raw_anchors_buffer_ = absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, num_boxes_ * kNumCoordsPerBox}); + // Parameters + glUseProgram(decode_program_); + glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(), + options_.h_scale()); + + // A shader to score detection boxes. + const std::string score_src = absl::Substitute( + R"( #version 310 es + +layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in; + +#define FLT_MAX 1.0e+37 + +shared float local_scores[$0]; + +layout(std430, binding = 0) writeonly buffer Output { + float data[]; +} scored_boxes; + +layout(std430, binding = 1) readonly buffer Input0 { + float data[]; +} raw_scores; + +uint num_classes = uint($0); +int apply_sigmoid = int($1); +int apply_clipping_thresh = int($2); +float clipping_thresh = float($3); +int ignore_class_0 = int($4); + +float optional_sigmoid(float x) { + if (apply_sigmoid == int(0)) return x; + if (apply_clipping_thresh == int(1)) { + x = clamp(x, -clipping_thresh, clipping_thresh); + } + x = 1.0 / (1.0 + exp(-x)); + return x; +} + +void main() { + uint g_idx = gl_GlobalInvocationID.x; // box idx + uint s_idx = gl_LocalInvocationID.y; // score/class idx + + // load all scores into shared memory + float score = raw_scores.data[g_idx * num_classes + s_idx]; + local_scores[s_idx] = optional_sigmoid(score); + memoryBarrierShared(); + barrier(); + + // find max score in shared memory + if (s_idx == uint(0)) { + float max_score = -FLT_MAX; + float max_class = -1.0; + for (int i=ignore_class_0; i max_score) { + max_score = local_scores[i]; + max_class = float(i); + } + } + scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score; + scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class; + } +})", + num_classes_, options_.sigmoid_score() ? 1 : 0, + options_.has_score_clipping_thresh() ? 1 : 0, + options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() + : 0, + !ignore_classes_.empty() ? 1 : 0); + + // # filter classes supported is hardware dependent. + int max_wg_size; // typically <= 1024 + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, + &max_wg_size); // y-dim + CHECK_LT(num_classes_, max_wg_size) + << "# classes must be < " << max_wg_size; + // TODO support better filtering. + CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + + // Shader program + { + GLuint shader = glCreateShader(GL_COMPUTE_SHADER); + const GLchar* sources[] = {score_src.c_str()}; + glShaderSource(shader, 1, sources, NULL); + glCompileShader(shader); + GLint compiled = GL_FALSE; + glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled); + RET_CHECK(compiled == GL_TRUE); + score_program_ = glCreateProgram(); + glAttachShader(score_program_, shader); + glDeleteShader(shader); + glLinkProgram(score_program_); + } + + // Outputs + scored_boxes_buffer_ = absl::make_unique( + Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); + + return absl::OkStatus(); + })); + +#elif MEDIAPIPE_METAL_ENABLED + id device = gpu_helper_.mtlDevice; + + // A shader to decode detection boxes. + std::string decode_src = absl::Substitute( + R"( +#include + +using namespace metal; + +kernel void decodeKernel( + device float* boxes [[ buffer(0) ]], + device float* raw_boxes [[ buffer(1) ]], + device float* raw_anchors [[ buffer(2) ]], + uint2 gid [[ thread_position_in_grid ]]) { + + uint num_coords = uint($0); + int reverse_output_order = int($1); + int apply_exponential = int($2); + int box_coord_offset = int($3); + int num_keypoints = int($4); + int keypt_coord_offset = int($5); + int num_values_per_keypt = int($6); +)", + options_.num_coords(), // box xywh + options_.reverse_output_order() ? 1 : 0, + options_.apply_exponential_on_box_size() ? 1 : 0, + options_.box_coord_offset(), options_.num_keypoints(), + options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); + decode_src += absl::Substitute( + R"( + float4 scale = float4(($0),($1),($2),($3)); +)", + options_.x_scale(), options_.y_scale(), options_.w_scale(), + options_.h_scale()); + decode_src += R"( + uint g_idx = gid.x; + uint box_offset = g_idx * num_coords + uint(box_coord_offset); + uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox + + float y_center, x_center, h, w; + + if (reverse_output_order == int(0)) { + y_center = raw_boxes[box_offset + uint(0)]; + x_center = raw_boxes[box_offset + uint(1)]; + h = raw_boxes[box_offset + uint(2)]; + w = raw_boxes[box_offset + uint(3)]; + } else { + x_center = raw_boxes[box_offset + uint(0)]; + y_center = raw_boxes[box_offset + uint(1)]; + w = raw_boxes[box_offset + uint(2)]; + h = raw_boxes[box_offset + uint(3)]; + } + + float anchor_yc = raw_anchors[anchor_offset + uint(0)]; + float anchor_xc = raw_anchors[anchor_offset + uint(1)]; + float anchor_h = raw_anchors[anchor_offset + uint(2)]; + float anchor_w = raw_anchors[anchor_offset + uint(3)]; + + x_center = x_center / scale.x * anchor_w + anchor_xc; + y_center = y_center / scale.y * anchor_h + anchor_yc; + + if (apply_exponential == int(1)) { + h = exp(h / scale.w) * anchor_h; + w = exp(w / scale.z) * anchor_w; + } else { + h = (h / scale.w) * anchor_h; + w = (w / scale.z) * anchor_w; + } + + float ymin = y_center - h / 2.0; + float xmin = x_center - w / 2.0; + float ymax = y_center + h / 2.0; + float xmax = x_center + w / 2.0; + + boxes[box_offset + uint(0)] = ymin; + boxes[box_offset + uint(1)] = xmin; + boxes[box_offset + uint(2)] = ymax; + boxes[box_offset + uint(3)] = xmax; + + if (num_keypoints > int(0)){ + for (int k = 0; k < num_keypoints; ++k) { + int kp_offset = + int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; + float kp_y, kp_x; + if (reverse_output_order == int(0)) { + kp_y = raw_boxes[kp_offset + int(0)]; + kp_x = raw_boxes[kp_offset + int(1)]; + } else { + kp_x = raw_boxes[kp_offset + int(0)]; + kp_y = raw_boxes[kp_offset + int(1)]; + } + boxes[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc; + boxes[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc; + } + } +})"; + + { + // Shader program + NSString* library_source = + [NSString stringWithUTF8String:decode_src.c_str()]; + NSError* error = nil; + id library = [device newLibraryWithSource:library_source + options:nullptr + error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"decodeKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + decode_program_ = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(decode_program_ != nil) << "Couldn't create pipeline state " << + [[error localizedDescription] UTF8String]; + // Outputs + decoded_boxes_buffer_ = + absl::make_unique(Tensor::ElementType::kFloat32, + Tensor::Shape{1, num_boxes_ * num_coords_}); + // Inputs + raw_anchors_buffer_ = absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, num_boxes_ * kNumCoordsPerBox}); + } + + // A shader to score detection boxes. + const std::string score_src = absl::Substitute( + R"( +#include + +using namespace metal; + +float optional_sigmoid(float x) { + int apply_sigmoid = int($1); + int apply_clipping_thresh = int($2); + float clipping_thresh = float($3); + if (apply_sigmoid == int(0)) return x; + if (apply_clipping_thresh == int(1)) { + x = clamp(x, -clipping_thresh, clipping_thresh); + } + x = 1.0 / (1.0 + exp(-x)); + return x; +} + +kernel void scoreKernel( + device float* scored_boxes [[ buffer(0) ]], + device float* raw_scores [[ buffer(1) ]], + uint2 tid [[ thread_position_in_threadgroup ]], + uint2 gid [[ thread_position_in_grid ]]) { + + uint num_classes = uint($0); + int apply_sigmoid = int($1); + int apply_clipping_thresh = int($2); + float clipping_thresh = float($3); + int ignore_class_0 = int($4); + + uint g_idx = gid.x; // box idx + uint s_idx = tid.y; // score/class idx + + // load all scores into shared memory + threadgroup float local_scores[$0]; + float score = raw_scores[g_idx * num_classes + s_idx]; + local_scores[s_idx] = optional_sigmoid(score); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // find max score in shared memory + if (s_idx == uint(0)) { + float max_score = -FLT_MAX; + float max_class = -1.0; + for (int i=ignore_class_0; i max_score) { + max_score = local_scores[i]; + max_class = float(i); + } + } + scored_boxes[g_idx * uint(2) + uint(0)] = max_score; + scored_boxes[g_idx * uint(2) + uint(1)] = max_class; + } +})", + num_classes_, options_.sigmoid_score() ? 1 : 0, + options_.has_score_clipping_thresh() ? 1 : 0, + options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() + : 0, + ignore_classes_.size() ? 1 : 0); + + // TODO support better filtering. + CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + + { + // Shader program + NSString* library_source = + [NSString stringWithUTF8String:score_src.c_str()]; + NSError* error = nil; + id library = [device newLibraryWithSource:library_source + options:nullptr + error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"scoreKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + score_program_ = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(score_program_ != nil) << "Couldn't create pipeline state " << + [[error localizedDescription] UTF8String]; + // Outputs + scored_boxes_buffer_ = absl::make_unique( + Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); + // # filter classes supported is hardware dependent. + int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup; + CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; + } + +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto new file mode 100644 index 000000000..24c0a5053 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto @@ -0,0 +1,74 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TensorsToDetectionsCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsToDetectionsCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TensorsToDetectionsCalculatorOptions ext = 335742639; + } + + // [Required] The number of output classes predicted by the detection model. + optional int32 num_classes = 1; + // [Required] The number of output boxes predicted by the detection model. + optional int32 num_boxes = 2; + // [Required] The number of output values per boxes predicted by the detection + // model. The values contain bounding boxes, keypoints, etc. + optional int32 num_coords = 3; + + // The offset of keypoint coordinates in the location tensor. + optional int32 keypoint_coord_offset = 9; + // The number of predicted keypoints. + optional int32 num_keypoints = 10 [default = 0]; + // The dimension of each keypoint, e.g. number of values predicted for each + // keypoint. + optional int32 num_values_per_keypoint = 11 [default = 2]; + // The offset of box coordinates in the location tensor. + optional int32 box_coord_offset = 12 [default = 0]; + + // Parameters for decoding SSD detection model. + optional float x_scale = 4 [default = 0.0]; + optional float y_scale = 5 [default = 0.0]; + optional float w_scale = 6 [default = 0.0]; + optional float h_scale = 7 [default = 0.0]; + + optional bool apply_exponential_on_box_size = 13 [default = false]; + + // Whether to reverse the order of predicted x, y from output. + // If false, the order is [y_center, x_center, h, w], if true the order is + // [x_center, y_center, w, h]. + optional bool reverse_output_order = 14 [default = false]; + // The ids of classes that should be ignored during decoding the score for + // each predicted box. + repeated int32 ignore_classes = 8; + + optional bool sigmoid_score = 15 [default = false]; + optional float score_clipping_thresh = 16; + + // Whether the detection coordinates from the input tensors should be flipped + // vertically (along the y-direction). This is useful, for example, when the + // input tensors represent detections defined with a coordinate system where + // the origin is at the top-left corner, whereas the desired detection + // representation has a bottom-left origin (e.g., in OpenGL). + optional bool flip_vertically = 18 [default = false]; + + // Score threshold for perserving decoded detections. + optional float min_score_thresh = 19; +} diff --git a/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc new file mode 100644 index 000000000..5ec3b4dea --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc @@ -0,0 +1,106 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); } + +} // namespace + +// A calculator for converting Tensors to to a float or a float vector. +// +// Input: +// TENSORS - Vector of Tensors of type kFloat32. Only the first +// tensor will be used. +// Output: +// FLOAT(optional) - Converted single float number. +// FLOATS(optional) - Converted float vector. +// +// Notes: To output FLOAT stream, the input tensor must have size 1, e.g. +// only 1 float number in the tensor. +// +// Usage example: +// node { +// calculator: "TensorsToFloatsCalculator" +// input_stream: "TENSORS:tensors" +// output_stream: "FLOATS:floats" +// } +namespace api2 { +class TensorsToFloatsCalculator : public Node { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Output::Optional kOutFloat{"FLOAT"}; + static constexpr Output>::Optional kOutFloats{"FLOATS"}; + MEDIAPIPE_NODE_INTERFACE(TensorsToFloatsCalculator, kInTensors, kOutFloat, + kOutFloats); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; + + private: + ::mediapipe::TensorsToFloatsCalculatorOptions options_; +}; +MEDIAPIPE_REGISTER_NODE(TensorsToFloatsCalculator); + +absl::Status TensorsToFloatsCalculator::UpdateContract(CalculatorContract* cc) { + // Only exactly a single output allowed. + RET_CHECK(kOutFloat(cc).IsConnected() ^ kOutFloats(cc).IsConnected()); + return absl::OkStatus(); +} + +absl::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options<::mediapipe::TensorsToFloatsCalculatorOptions>(); + return absl::OkStatus(); +} + +absl::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + // TODO: Add option to specify which tensor to take from. + auto view = input_tensors[0].GetCpuReadView(); + auto raw_floats = view.buffer(); + int num_values = input_tensors[0].shape().num_elements(); + auto output_floats = absl::make_unique>( + raw_floats, raw_floats + num_values); + + switch (options_.activation()) { + case TensorsToFloatsCalculatorOptions::SIGMOID: + std::transform(output_floats->begin(), output_floats->end(), + output_floats->begin(), Sigmoid); + break; + case TensorsToFloatsCalculatorOptions::NONE: + break; + } + + if (kOutFloat(cc).IsConnected()) { + // TODO: Could add an index in the option to specifiy returning + // one value of a float array. + RET_CHECK_EQ(num_values, 1); + kOutFloat(cc).Send(output_floats->at(0)); + } else { + kOutFloats(cc).Send(std::move(output_floats)); + } + return absl::OkStatus(); +} +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_floats_calculator.proto b/mediapipe/calculators/tensor/tensors_to_floats_calculator.proto new file mode 100644 index 000000000..694050190 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_floats_calculator.proto @@ -0,0 +1,33 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TensorsToFloatsCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsToFloatsCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TensorsToFloatsCalculatorOptions ext = 343499115; + } + enum Activation { + NONE = 0; + SIGMOID = 1; + } + // Apply activation function to the floats. + optional Activation activation = 1 [default = NONE]; +} diff --git a/mediapipe/calculators/tensor/tensors_to_floats_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_floats_calculator_test.cc new file mode 100644 index 000000000..9a564f564 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_floats_calculator_test.cc @@ -0,0 +1,144 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +using mediapipe::ParseTextProtoOrDie; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +const float kErrorMargin = 1e-2f; + +class TensorsToFloatsCalculatorTest : public ::testing::Test { + protected: + void BuildGraph(mediapipe::CalculatorRunner* runner, + const std::vector& values) { + auto tensors = absl::make_unique>(); + tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, 1, static_cast(values.size()), 1}); + auto view = tensors->back().GetCpuWriteView(); + float* tensor_buffer = view.buffer(); + ASSERT_NE(tensor_buffer, nullptr); + for (int i = 0; i < values.size(); ++i) { + tensor_buffer[i] = values[i]; + } + + int64 stream_timestamp = 0; + auto& input_stream_packets = + runner->MutableInputs()->Tag("TENSORS").packets; + + input_stream_packets.push_back( + mediapipe::Adopt(tensors.release()) + .At(mediapipe::Timestamp(stream_timestamp++))); + } +}; + +TEST_F(TensorsToFloatsCalculatorTest, SingleValue) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToFloatsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "FLOAT:float" + )")); + + const float single_value = 0.5; + BuildGraph(&runner, {single_value}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("FLOAT").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& value = output_packets_[0].Get(); + EXPECT_EQ(single_value, value); +} + +TEST_F(TensorsToFloatsCalculatorTest, SingleValueAsVector) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToFloatsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "FLOATS:floats" + )")); + + const float single_value = 0.5; + BuildGraph(&runner, {single_value}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("FLOATS").packets; + EXPECT_EQ(1, output_packets_.size()); + + const auto& values = output_packets_[0].Get>(); + EXPECT_EQ(1, values.size()); + EXPECT_EQ(single_value, values[0]); +} + +TEST_F(TensorsToFloatsCalculatorTest, FloatVector) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToFloatsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "FLOATS:floats" + )")); + + const std::vector input_values = {0.f, 0.5f, 1.0f}; + BuildGraph(&runner, input_values); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("FLOATS").packets; + EXPECT_EQ(1, output_packets_.size()); + + const auto& values = output_packets_[0].Get>(); + EXPECT_EQ(input_values.size(), values.size()); + for (int i = 0; i < values.size(); ++i) { + EXPECT_NEAR(values[i], input_values[i], kErrorMargin); + } +} + +TEST_F(TensorsToFloatsCalculatorTest, FloatVectorWithSigmoid) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TensorsToFloatsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "FLOATS:floats" + options { + [mediapipe.TensorsToFloatsCalculatorOptions.ext] { activation: SIGMOID } + } + )")); + + const std::vector input_values = {-1.f, 0.f, 1.0f}; + const std::vector expected_output_with_sigmoid = {0.269f, 0.5f, + 0.731f}; + BuildGraph(&runner, input_values); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("FLOATS").packets; + EXPECT_EQ(1, output_packets_.size()); + + const auto& values = output_packets_[0].Get>(); + EXPECT_EQ(expected_output_with_sigmoid.size(), values.size()); + for (int i = 0; i < values.size(); ++i) { + EXPECT_NEAR(values[i], expected_output_with_sigmoid[i], kErrorMargin); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc new file mode 100644 index 000000000..8e4066bee --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc @@ -0,0 +1,219 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); } + +float ApplyActivation( + ::mediapipe::TensorsToLandmarksCalculatorOptions::Activation activation, + float value) { + switch (activation) { + case ::mediapipe::TensorsToLandmarksCalculatorOptions::SIGMOID: + return Sigmoid(value); + break; + default: + return value; + } +} + +} // namespace + +// A calculator for converting Tensors from regression models into landmarks. +// Note that if the landmarks in the tensor has more than 5 dimensions, only the +// first 5 dimensions will be converted to [x,y,z, visibility, presence]. The +// latter two fields may also stay unset if such attributes are not supported in +// the model. +// +// Input: +// TENSORS - Vector of Tensors of type kFloat32. Only the first tensor will be +// used. The size of the values must be (num_dimension x num_landmarks). +// +// FLIP_HORIZONTALLY (optional): Whether to flip landmarks horizontally or +// not. Overrides corresponding side packet and/or field in the calculator +// options. +// +// FLIP_VERTICALLY (optional): Whether to flip landmarks vertically or not. +// Overrides corresponding side packet and/or field in the calculator options. +// +// Input side packet: +// FLIP_HORIZONTALLY (optional): Whether to flip landmarks horizontally or +// not. Overrides the corresponding field in the calculator options. +// +// FLIP_VERTICALLY (optional): Whether to flip landmarks vertically or not. +// Overrides the corresponding field in the calculator options. +// +// Output: +// LANDMARKS(optional) - Result MediaPipe landmarks. +// NORM_LANDMARKS(optional) - Result MediaPipe normalized landmarks. +// +// Notes: +// To output normalized landmarks, user must provide the original input image +// size to the model using calculator option input_image_width and +// input_image_height. +// Usage example: +// node { +// calculator: "TensorsToLandmarksCalculator" +// input_stream: "TENSORS:landmark_tensors" +// output_stream: "LANDMARKS:landmarks" +// output_stream: "NORM_LANDMARKS:landmarks" +// options: { +// [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { +// num_landmarks: 21 +// +// input_image_width: 256 +// input_image_height: 256 +// } +// } +// } +class TensorsToLandmarksCalculator : public Node { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Input::SideFallback::Optional kFlipHorizontally{ + "FLIP_HORIZONTALLY"}; + static constexpr Input::SideFallback::Optional kFlipVertically{ + "FLIP_VERTICALLY"}; + static constexpr Output::Optional kOutLandmarkList{"LANDMARKS"}; + static constexpr Output::Optional + kOutNormalizedLandmarkList{"NORM_LANDMARKS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kFlipHorizontally, kFlipVertically, + kOutLandmarkList, kOutNormalizedLandmarkList); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + absl::Status LoadOptions(CalculatorContext* cc); + int num_landmarks_ = 0; + ::mediapipe::TensorsToLandmarksCalculatorOptions options_; +}; +MEDIAPIPE_REGISTER_NODE(TensorsToLandmarksCalculator); + +absl::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadOptions(cc)); + + if (kOutNormalizedLandmarkList(cc).IsConnected()) { + RET_CHECK(options_.has_input_image_height() && + options_.has_input_image_width()) + << "Must provide input width/height for getting normalized landmarks."; + } + if (kOutLandmarkList(cc).IsConnected() && + (options_.flip_horizontally() || options_.flip_vertically() || + kFlipHorizontally(cc).IsConnected() || + kFlipVertically(cc).IsConnected())) { + RET_CHECK(options_.has_input_image_height() && + options_.has_input_image_width()) + << "Must provide input width/height for using flipping when outputing " + "landmarks in absolute coordinates."; + } + return absl::OkStatus(); +} + +absl::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + bool flip_horizontally = + kFlipHorizontally(cc).GetOr(options_.flip_horizontally()); + bool flip_vertically = kFlipVertically(cc).GetOr(options_.flip_vertically()); + + const auto& input_tensors = *kInTensors(cc); + int num_values = input_tensors[0].shape().num_elements(); + const int num_dimensions = num_values / num_landmarks_; + CHECK_GT(num_dimensions, 0); + + auto view = input_tensors[0].GetCpuReadView(); + auto raw_landmarks = view.buffer(); + + LandmarkList output_landmarks; + + for (int ld = 0; ld < num_landmarks_; ++ld) { + const int offset = ld * num_dimensions; + Landmark* landmark = output_landmarks.add_landmark(); + + if (flip_horizontally) { + landmark->set_x(options_.input_image_width() - raw_landmarks[offset]); + } else { + landmark->set_x(raw_landmarks[offset]); + } + if (num_dimensions > 1) { + if (flip_vertically) { + landmark->set_y(options_.input_image_height() - + raw_landmarks[offset + 1]); + } else { + landmark->set_y(raw_landmarks[offset + 1]); + } + } + if (num_dimensions > 2) { + landmark->set_z(raw_landmarks[offset + 2]); + } + if (num_dimensions > 3) { + landmark->set_visibility(ApplyActivation(options_.visibility_activation(), + raw_landmarks[offset + 3])); + } + if (num_dimensions > 4) { + landmark->set_presence(ApplyActivation(options_.presence_activation(), + raw_landmarks[offset + 4])); + } + } + + // Output normalized landmarks if required. + if (kOutNormalizedLandmarkList(cc).IsConnected()) { + NormalizedLandmarkList output_norm_landmarks; + for (int i = 0; i < output_landmarks.landmark_size(); ++i) { + const Landmark& landmark = output_landmarks.landmark(i); + NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark(); + norm_landmark->set_x(landmark.x() / options_.input_image_width()); + norm_landmark->set_y(landmark.y() / options_.input_image_height()); + // Scale Z coordinate as X + allow additional uniform normalization. + norm_landmark->set_z(landmark.z() / options_.input_image_width() / + options_.normalize_z()); + if (landmark.has_visibility()) { // Set only if supported in the model. + norm_landmark->set_visibility(landmark.visibility()); + } + if (landmark.has_presence()) { // Set only if supported in the model. + norm_landmark->set_presence(landmark.presence()); + } + } + kOutNormalizedLandmarkList(cc).Send(std::move(output_norm_landmarks)); + } + + // Output absolute landmarks. + if (kOutLandmarkList(cc).IsConnected()) { + kOutLandmarkList(cc).Send(std::move(output_landmarks)); + } + + return absl::OkStatus(); +} + +absl::Status TensorsToLandmarksCalculator::LoadOptions(CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = cc->Options<::mediapipe::TensorsToLandmarksCalculatorOptions>(); + RET_CHECK(options_.has_num_landmarks()); + num_landmarks_ = options_.num_landmarks(); + + return absl::OkStatus(); +} +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.proto new file mode 100644 index 000000000..2608a1459 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.proto @@ -0,0 +1,65 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TensorsToLandmarksCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsToLandmarksCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TensorsToLandmarksCalculatorOptions ext = 335742640; + } + + enum Activation { + NONE = 0; + SIGMOID = 1; + } + + // [Required] Number of landmarks from the output of the model. + optional int32 num_landmarks = 1; + + // Size of the input image for the model. These options are used only when + // normalized landmarks are needed. Z coordinate is scaled as X assuming + // a weak perspective projection camera model. + optional int32 input_image_width = 2; + optional int32 input_image_height = 3; + + // Whether the detection coordinates from the input tensors should be flipped + // vertically (along the y-direction). This is useful, for example, when the + // input tensors represent detections defined with a coordinate system where + // the origin is at the top-left corner, whereas the desired detection + // representation has a bottom-left origin (e.g., in OpenGL). + optional bool flip_vertically = 4 [default = false]; + + // Whether the detection coordinates from the input tensors should be flipped + // horizontally (along the x-direction). This is useful, for example, when the + // input image is horizontally flipped in ImageTransformationCalculator + // beforehand. + optional bool flip_horizontally = 6 [default = false]; + + // A value that Z coordinates should be divided by. This option is used only + // when normalized landmarks are needed. It is applied in addition to Z + // coordinate being re-scaled as X. + optional float normalize_z = 5 [default = 1.0]; + + // Apply activation function to the tensor representing landmark visibility. + optional Activation visibility_activation = 7 [default = NONE]; + + // Apply activation function to the tensor representing landmark presence. + optional Activation presence_activation = 8 [default = NONE]; +} diff --git a/mediapipe/calculators/tensor/testdata/add.bin b/mediapipe/calculators/tensor/testdata/add.bin new file mode 100644 index 000000000..b4c02350c Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/add.bin differ diff --git a/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt b/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt new file mode 100644 index 000000000..f1739ce86 --- /dev/null +++ b/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt @@ -0,0 +1,35 @@ +label_id: 0 +score: 0.92843366 +location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.21061149 + ymin: 0.29150677 + width: 0.5657704 + height: 0.5657307 + } + relative_keypoints { + x: 0.37730268 + y: 0.44038114 + } + relative_keypoints { + x: 0.6250565 + y: 0.44425336 + } + relative_keypoints { + x: 0.50687385 + y: 0.5767085 + } + relative_keypoints { + x: 0.50173956 + y: 0.6991459 + } + relative_keypoints { + x: 0.2383742 + y: 0.49879026 + } + relative_keypoints { + x: 0.7404449 + y: 0.50361776 + } +} diff --git a/mediapipe/calculators/tensor/testdata/face_detection_expected.png b/mediapipe/calculators/tensor/testdata/face_detection_expected.png new file mode 100644 index 000000000..df38abf70 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/face_detection_expected.png differ diff --git a/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt new file mode 100644 index 000000000..b0e00346c --- /dev/null +++ b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt @@ -0,0 +1,31 @@ +input_stream: "image" +output_stream: "rendering" +output_stream: "detections" + +# Subgraph that detects faces. +node { + calculator: "FaceDetectionFrontCpu" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RENDER_DATA:render_data" + options: { + [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "IMAGE:image" + input_stream: "render_data" + output_stream: "IMAGE:rendering" +} diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/input.jpg b/mediapipe/calculators/tensor/testdata/image_to_tensor/input.jpg new file mode 100644 index 000000000..37d6c4b20 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/input.jpg differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect.png new file mode 100644 index 000000000..38a13dabe Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png new file mode 100644 index 000000000..1a738a50d Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png new file mode 100644 index 000000000..254dc72ae Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png new file mode 100644 index 000000000..5b096cb4d Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png new file mode 100644 index 000000000..104cb6091 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png new file mode 100644 index 000000000..c5512ec0d Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png new file mode 100644 index 000000000..aba8d2591 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png new file mode 100644 index 000000000..bfb461546 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png new file mode 100644 index 000000000..5ce7c3ec3 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png new file mode 100644 index 000000000..ab14e5954 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png new file mode 100644 index 000000000..ecfb1e537 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png new file mode 100644 index 000000000..d55301146 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png differ diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/noop_except_range.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/noop_except_range.png new file mode 100644 index 000000000..1486d9f15 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/noop_except_range.png differ diff --git a/mediapipe/calculators/tensor/testdata/labelmap.txt b/mediapipe/calculators/tensor/testdata/labelmap.txt new file mode 100644 index 000000000..4291e3c6b --- /dev/null +++ b/mediapipe/calculators/tensor/testdata/labelmap.txt @@ -0,0 +1,3 @@ +classA +classB +classC diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 3daf3827f..4c2b90a59 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -106,6 +106,13 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "vector_string_to_tensor_calculator_options_proto", + srcs = ["vector_string_to_tensor_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + mediapipe_cc_proto_library( name = "graph_tensors_packet_generator_cc_proto", srcs = ["graph_tensors_packet_generator.proto"], @@ -281,6 +288,14 @@ mediapipe_cc_proto_library( deps = [":vector_float_to_tensor_calculator_options_proto"], ) +mediapipe_cc_proto_library( + name = "vector_string_to_tensor_calculator_options_cc_proto", + srcs = ["vector_string_to_tensor_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":vector_string_to_tensor_calculator_options_proto"], +) + cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], @@ -311,7 +326,6 @@ cc_library( "@org_tensorflow//tensorflow/core:framework", ], "//mediapipe:android": [ - "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", ], }), alwayslink = 1, @@ -397,6 +411,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -728,6 +743,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "vector_string_to_tensor_calculator", + srcs = ["vector_string_to_tensor_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":vector_string_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], @@ -842,6 +871,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", @@ -857,11 +887,12 @@ cc_test( ":tensorflow_inference_calculator", ":tensorflow_session", ":tensorflow_session_from_frozen_graph_calculator", - "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", + ":tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", @@ -892,6 +923,7 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", @@ -922,6 +954,7 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", @@ -948,6 +981,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", @@ -1074,6 +1108,21 @@ cc_test( ], ) +cc_test( + name = "vector_string_to_tensor_calculator_test", + srcs = ["vector_string_to_tensor_calculator_test.cc"], + deps = [ + ":vector_string_to_tensor_calculator", + ":vector_string_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + test_suite( name = "ios", tags = ["ios"], @@ -1096,13 +1145,13 @@ cc_test( ":tensorflow_session_from_frozen_graph_generator", ":tensorflow_session_from_frozen_graph_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:integral_types", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:validate_type", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:testlib", diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc index 54126cf1d..310d412bf 100644 --- a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc @@ -33,7 +33,7 @@ namespace tf = ::tensorflow; class GraphTensorsPacketGenerator : public PacketGenerator { public: - static ::mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { RET_CHECK(extendable_options.HasExtension( @@ -45,10 +45,10 @@ class GraphTensorsPacketGenerator : public PacketGenerator { /* "A map of tensor tags and tensors" */); RET_CHECK_EQ(options.tensor_tag_size(), options.tensor_num_nodes_size()); RET_CHECK_GT(options.tensor_tag_size(), 0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - static ::mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& packet_generator_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { const GraphTensorsPacketGeneratorOptions& options = @@ -65,7 +65,7 @@ class GraphTensorsPacketGenerator : public PacketGenerator { (*tensor_map)[tensor_tag].flat().setZero(); } output_side_packets->Index(0) = AdoptAsUniquePtr(tensor_map.release()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(GraphTensorsPacketGenerator); diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc index 77069c658..ef77fb918 100644 --- a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc @@ -72,7 +72,7 @@ TEST_F(GraphTensorsPacketGeneratorTest, VerifyTensorSizeShapeAndValue) { PacketSet inputs({}); PacketSet outputs(1); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "GraphTensorsPacketGenerator", extendable_options_, inputs, &outputs); MP_EXPECT_OK(run_status) << run_status.message(); VerifyTensorMap(&outputs); diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc index fd109a3bd..0db193bcc 100644 --- a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc @@ -78,18 +78,17 @@ std::unique_ptr ImageFrameToNormalizedTensor( // } class ImageFrameToTensorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: ImageFrameToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(ImageFrameToTensorCalculator); -::mediapipe::Status ImageFrameToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ImageFrameToTensorCalculator::GetContract(CalculatorContract* cc) { // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; @@ -101,19 +100,18 @@ REGISTER_CALCULATOR(ImageFrameToTensorCalculator); cc->Outputs().Index(0).Set( // Output TensorFlow Tensor. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageFrameToTensorCalculator::Process( - CalculatorContext* cc) { +absl::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { const Packet& input_item = cc->Inputs().Index(0).Value(); RET_CHECK(!input_item.IsEmpty()) << "Input cannot be empty."; @@ -147,7 +145,7 @@ REGISTER_CALCULATOR(ImageFrameToTensorCalculator); } else if (bytes_per_pixel == 4) { data_type = tf::DT_FLOAT; } else { - return ::mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Unsupported image format (", bytes_per_pixel, " bytes per pixel)")); } @@ -174,7 +172,7 @@ REGISTER_CALCULATOR(ImageFrameToTensorCalculator); } cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc index 36c3da7e7..a07b95ccc 100644 --- a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc @@ -42,16 +42,16 @@ namespace tf = tensorflow; // a flag controls whether a new first dimension is inserted before // concatenation. // -// Currently, the number of tensors output will be buffer_size less than the -// number of input tensors because no padding is implemented and only full -// buffers are output. +// The number of tensors output will be buffer_size less than the +// number of input tensors unless padding is set to a non-zero value in the +// options proto. // // The timestamp of the output batch will match the timestamp of the first // tensor in that batch by default. (e.g. when buffer_size frames are added, the // output tensor will have the timestamp of the first input.). This behavior can // be adjusted by the timestamp_offset option. // -// Example config: +// Example config without padding: // node { // calculator: "LappedTensorBufferCalculator" // input_stream: "input_tensor" @@ -64,37 +64,60 @@ namespace tf = tensorflow; // } // } // } +// +// Example config with padding and timestamp output: +// node { +// calculator: "LappedTensorBufferCalculator" +// input_stream: "input_tensor" +// output_stream: "output_tensor" +// output_stream: "output_timestamp" +// options { +// [mediapipe.LappedTensorBufferCalculatorOptions.ext] { +// buffer_size: 100 +// overlap: 50 +// add_batch_dim_to_tensors: true +// timestamp_offset: 25 +// padding: 25 +// } +// } +// } + class LappedTensorBufferCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - // Adds a batch dimension to the input tensor if specified in the calculator - // options. - ::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor); + // Adds a batch dimension to the input tensor if specified in the + // calculator options. + absl::Status AddBatchDimension(tf::Tensor* input_tensor); + // Sends the current buffer downstream. + absl::Status ProcessBuffer(CalculatorContext* cc); int steps_until_output_; int buffer_size_; int overlap_; int timestamp_offset_; + int initialized_; + std::unique_ptr> timestamp_buffer_; std::unique_ptr> buffer_; LappedTensorBufferCalculatorOptions options_; }; + REGISTER_CALCULATOR(LappedTensorBufferCalculator); -::mediapipe::Status LappedTensorBufferCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LappedTensorBufferCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; cc->Inputs().Index(0).Set( // tensorflow::Tensor stream. ); - RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) - << "Only one output stream is supported."; + RET_CHECK_LE(cc->Outputs().NumEntries(), 2) + << "Only one or two output stream(s) is/are supported."; if (cc->InputSidePackets().HasTag(kBufferSize)) { cc->InputSidePackets().Tag(kBufferSize).Set(); @@ -113,10 +136,14 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator); cc->Outputs().Index(0).Set( // Output tensorflow::Tensor stream with possibly overlapping steps. ); - return ::mediapipe::OkStatus(); + // Output timestamp stream with possibly overlapping steps. + if (cc->Outputs().NumEntries() > 1) { + cc->Outputs().Index(1).Set>(); + } + return absl::OkStatus(); } -::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { +absl::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); if (cc->InputSidePackets().HasTag(kCalculatorOptions)) { options_ = cc->InputSidePackets() @@ -141,44 +168,59 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator); << "Negative timestamp_offset is not allowed."; RET_CHECK_LT(timestamp_offset_, buffer_size_) << "output_frame_num_offset has to be less than buffer_size."; + RET_CHECK_LT(options_.padding(), buffer_size_) + << "padding option must be smaller than buffer size."; timestamp_buffer_ = absl::make_unique>(buffer_size_); buffer_ = absl::make_unique>(buffer_size_); - steps_until_output_ = buffer_size_; - return ::mediapipe::OkStatus(); + steps_until_output_ = buffer_size_ - options_.padding(); + initialized_ = false; + return absl::OkStatus(); } -::mediapipe::Status LappedTensorBufferCalculator::Process( - CalculatorContext* cc) { +absl::Status LappedTensorBufferCalculator::Process(CalculatorContext* cc) { // These are cheap, shallow copies. tensorflow::Tensor input_tensor( cc->Inputs().Index(0).Get()); if (options_.add_batch_dim_to_tensors()) { RET_CHECK_OK(AddBatchDimension(&input_tensor)); } + // Pad frames at the beginning with the first frame. + if (!initialized_) { + for (int i = 0; i < options_.padding(); ++i) { + buffer_->push_back(input_tensor); + timestamp_buffer_->push_back(cc->InputTimestamp()); + } + initialized_ = true; + } buffer_->push_back(input_tensor); timestamp_buffer_->push_back(cc->InputTimestamp()); --steps_until_output_; - if (steps_until_output_ <= 0) { - auto concatenated = ::absl::make_unique(); - - const tf::Status concat_status = tf::tensor::Concat( - std::vector(buffer_->begin(), buffer_->end()), - concatenated.get()); - RET_CHECK(concat_status.ok()) << concat_status.ToString(); - - cc->Outputs().Index(0).Add(concatenated.release(), - timestamp_buffer_->Get(timestamp_offset_)); - - steps_until_output_ = buffer_size_ - overlap_; + MP_RETURN_IF_ERROR(ProcessBuffer(cc)); } - return ::mediapipe::OkStatus(); + + return absl::OkStatus(); +} + +absl::Status LappedTensorBufferCalculator::Close(CalculatorContext* cc) { + if (!initialized_ || options_.padding() == 0) { + return absl::OkStatus(); + } + int last_frame = buffer_size_ - steps_until_output_ - 1; + const auto& pad_frame = buffer_->Get(last_frame); + for (int i = 0; i < steps_until_output_ + options_.padding(); ++i) { + buffer_->push_back(pad_frame); + timestamp_buffer_->push_back(cc->InputTimestamp()); + } + MP_RETURN_IF_ERROR(ProcessBuffer(cc)); + + return absl::OkStatus(); } // Adds a batch dimension to the input tensor if specified in the calculator // options. -::mediapipe::Status LappedTensorBufferCalculator::AddBatchDimension( +absl::Status LappedTensorBufferCalculator::AddBatchDimension( tf::Tensor* input_tensor) { if (options_.add_batch_dim_to_tensors()) { tf::TensorShape new_shape(input_tensor->shape()); @@ -187,7 +229,32 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator); << "Could not add 0th dimension to tensor without changing its shape." << " Current shape: " << input_tensor->shape().DebugString(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); +} + +// Process buffer +absl::Status LappedTensorBufferCalculator::ProcessBuffer( + CalculatorContext* cc) { + auto concatenated = ::absl::make_unique(); + const tf::Status concat_status = tf::tensor::Concat( + std::vector(buffer_->begin(), buffer_->end()), + concatenated.get()); + RET_CHECK(concat_status.ok()) << concat_status.ToString(); + // Output cancatenated tensor. + cc->Outputs().Index(0).Add(concatenated.release(), + timestamp_buffer_->Get(timestamp_offset_)); + if (cc->Outputs().NumEntries() > 1) { + auto output_timestamp = ::absl::make_unique>(); + // Output timestamp vector. + *output_timestamp = std::vector(timestamp_buffer_->begin(), + timestamp_buffer_->end()); + RET_CHECK_EQ(output_timestamp->size(), buffer_size_) + << "Output timestamp size is not correct."; + cc->Outputs().Index(1).Add(output_timestamp.release(), + timestamp_buffer_->Get(timestamp_offset_)); + } + steps_until_output_ = buffer_size_ - overlap_; + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto index 543c65368..bcd14985b 100644 --- a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto @@ -45,4 +45,8 @@ message LappedTensorBufferCalculatorOptions { // This is useful for aligning the timestamp to be centered on the input // range. optional int32 timestamp_offset = 4 [default = 0]; + + // Amount of padding (repeating of first/last value) to add to the beginning + // and end of the input stream. + optional int32 padding = 5; } diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc index 71cc6d1da..e0e3000d2 100644 --- a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc @@ -31,11 +31,15 @@ namespace tf = ::tensorflow; class LappedTensorBufferCalculatorTest : public ::testing::Test { protected: void SetUpCalculator(int buffer_size, int overlap, bool add_dim, - int timestamp_offset) { + int timestamp_offset, int padding, + bool timestamp_output) { CalculatorGraphConfig::Node config; config.set_calculator("LappedTensorBufferCalculator"); config.add_input_stream("input_tensor"); config.add_output_stream("output_tensor"); + if (timestamp_output) { + config.add_output_stream("output_timestamp"); + } auto options = config.mutable_options()->MutableExtension( LappedTensorBufferCalculatorOptions::ext); options->set_buffer_size(buffer_size); @@ -44,13 +48,14 @@ class LappedTensorBufferCalculatorTest : public ::testing::Test { options->set_add_batch_dim_to_tensors(true); } options->set_timestamp_offset(timestamp_offset); + options->set_padding(padding); runner_ = ::absl::make_unique(config); } std::unique_ptr runner_; }; TEST_F(LappedTensorBufferCalculatorTest, OneToOne) { - SetUpCalculator(1, 0, false, 0); + SetUpCalculator(1, 0, false, 0, 0, false); int num_timesteps = 3; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -74,7 +79,7 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToTwo) { int buffer_size = 2; int overlap = 1; bool add_dim = false; - SetUpCalculator(buffer_size, overlap, add_dim, 0); + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); int num_timesteps = 3; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -100,7 +105,7 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThree) { int buffer_size = 3; int overlap = 2; bool add_dim = false; - SetUpCalculator(buffer_size, overlap, add_dim, 0); + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); int num_timesteps = 3; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -126,7 +131,7 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThreeSkip) { int buffer_size = 3; int overlap = 1; bool add_dim = false; - SetUpCalculator(buffer_size, overlap, add_dim, 0); + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); int num_timesteps = 3; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -148,11 +153,40 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThreeSkip) { } } +TEST_F(LappedTensorBufferCalculatorTest, OneToThreeNegativeOverlap) { + int buffer_size = 3; + int overlap = -1; + bool add_dim = false; + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); + int num_timesteps = 7; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(2, output_packets.size()); + // The outputs in packet one should be {0, 1, 2}, and in packet two {4, 5, 6} + for (int i = 0; i < 3; ++i) { + float value_0 = output_packets[0].Get().tensor()(i); + ASSERT_NEAR(value_0, i, 0.0001); + } + for (int i = 0; i < 3; ++i) { + float value_1 = output_packets[1].Get().tensor()(i); + ASSERT_NEAR(value_1, 4 + i, 0.0001); + } +} + TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatch) { int buffer_size = 3; int overlap = 2; bool add_dim = true; - SetUpCalculator(buffer_size, overlap, add_dim, 0); + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); int num_timesteps = 3; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -180,7 +214,7 @@ TEST_F(LappedTensorBufferCalculatorTest, NegativeTimestampOffsetFails) { int overlap = 15; bool add_dim = true; int timestamp_offset = -7; - SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset, 0, false); int num_timesteps = 20; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -197,7 +231,7 @@ TEST_F(LappedTensorBufferCalculatorTest, OutOfRangeTimestampOffsetFails) { int overlap = 15; bool add_dim = true; int timestamp_offset = buffer_size; - SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset, 0, false); int num_timesteps = 20; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -214,7 +248,7 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatchTimestampOffset) { int overlap = 15; bool add_dim = true; int timestamp_offset = 7; - SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset, 0, false); int num_timesteps = 20; for (int i = 0; i < num_timesteps; ++i) { auto input = ::absl::make_unique( @@ -236,5 +270,37 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatchTimestampOffset) { } } +TEST_F(LappedTensorBufferCalculatorTest, + OneToThreeBatchTimestampOffsetPadding) { + int buffer_size = 12; + int overlap = 6; + bool add_dim = true; + int timestamp_offset = 3; + int padding = 0; + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset, padding, + true); + int num_timesteps = 20; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const int output_size = num_timesteps / buffer_size + 1; + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(output_size, output_packets.size()); + for (int i = 0; i < output_size; ++i) { + int64 value = output_packets[i].Timestamp().Value(); + ASSERT_EQ(i * overlap + timestamp_offset, value); + } + const std::vector& output_timestamps = + runner_->Outputs().Index(1).packets; + ASSERT_EQ(output_size, output_timestamps.size()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc index ca704b793..32a0eb70b 100644 --- a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc @@ -26,20 +26,19 @@ namespace mediapipe { namespace { -::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header) { +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { - return ::mediapipe::UnknownError("No header found."); + return absl::UnknownError("No header found."); } if (!header_packet.ValidateAsType().ok()) { - return ::mediapipe::UnknownError( - "Packet does not contain TimeSeriesHeader."); + return absl::UnknownError("Packet does not contain TimeSeriesHeader."); } *header = header_packet.Get(); if (header->has_sample_rate() && header->sample_rate() >= 0 && header->has_num_channels() && header->num_channels() >= 0) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { std::string error_message = "TimeSeriesHeader is missing necessary fields: " @@ -48,7 +47,7 @@ namespace { absl::StrAppend(&error_message, "Got header:\n", header->ShortDebugString()); #endif - return ::mediapipe::InvalidArgumentError(error_message); + return absl::InvalidArgumentError(error_message); } } } // namespace @@ -78,18 +77,17 @@ typedef Eigen::Matrix // } class MatrixToTensorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: MatrixToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(MatrixToTensorCalculator); -::mediapipe::Status MatrixToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status MatrixToTensorCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; cc->Inputs().Index(0).Set( @@ -102,15 +100,15 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator); // TimeSeriesHeader as the input (or no header if the input has no // header). ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { // If the input is part of a time series, then preserve the header so that // downstream consumers can access the sample rate if needed. options_ = cc->Options(); auto input_header = ::absl::make_unique(); - const ::mediapipe::Status header_status = FillTimeSeriesHeaderIfValid( + const absl::Status header_status = FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), input_header.get()); if (header_status.ok()) { cc->Outputs().Index(0).SetHeader(Adopt(input_header.release())); @@ -119,10 +117,10 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator); // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { +absl::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { const Matrix& matrix = cc->Inputs().Index(0).Get(); tf::TensorShape tensor_shape; if (options_.transpose()) { @@ -151,7 +149,7 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator); << " Current shape: " << tensor->shape().DebugString(); } cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc index fa4fd1035..a8abe10d9 100644 --- a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc @@ -93,7 +93,7 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { public: ObjectDetectionTensorsToDetectionsCalculator() = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBoxes).Set(); cc->Inputs().Tag(kScores).Set(); @@ -114,7 +114,7 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { cc->Options(); float mask_threshold = calculator_options.mask_threshold(); if (!(mask_threshold >= 0.0 && mask_threshold <= 1.0)) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "mask_threshold must be in range [0.0, 1.0]"; } } @@ -126,10 +126,10 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { .Tag(kLabelMap) .Set>>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().HasTag(kLabelMap)) { label_map_ = GetFromUniquePtr>( cc->InputSidePackets().Tag(kLabelMap)); @@ -141,10 +141,10 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { tensor_dim_to_squeeze_field.begin(), tensor_dim_to_squeeze_field.end()); std::sort(tensor_dims_to_squeeze_.rbegin(), tensor_dims_to_squeeze_.rend()); cc->SetOffset(0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const auto& options = cc->Options(); @@ -205,15 +205,15 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { .Tag(kDetections) .Add(output_detections.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: std::map* label_map_; std::vector tensor_dims_to_squeeze_; - ::mediapipe::StatusOr MaybeSqueezeDims( - const std::string& tensor_tag, const tf::Tensor& input_tensor) { + absl::StatusOr MaybeSqueezeDims(const std::string& tensor_tag, + const tf::Tensor& input_tensor) { if (tensor_dims_to_squeeze_.empty()) { return input_tensor; } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index cf5635d3a..fdf43dcfd 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" @@ -42,7 +43,7 @@ const char kKeypointsTag[] = "KEYPOINTS"; const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION"; namespace tf = ::tensorflow; -namespace mpms = ::mediapipe::mediasequence; +namespace mpms = mediapipe::mediasequence; // Sink calculator to package streams into tf.SequenceExamples. // @@ -57,7 +58,7 @@ namespace mpms = ::mediapipe::mediasequence; // bounding boxes from vector, and streams with the // "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's // associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints -// from unordered_map>>. "IMAGE_${NAME}", +// from flat_hash_map>>. "IMAGE_${NAME}", // "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of // each stream, which allows for multiple image streams to be included. However, // the default names are suppored by more tools. @@ -93,7 +94,7 @@ uint8 ConvertFloatToByte(const float float_value) { class PackMediaSequenceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); @@ -131,8 +132,8 @@ class PackMediaSequenceCalculator : public CalculatorBase { } cc->Inputs() .Tag(tag) - .Set>>>(); + .Set>>>(); } if (absl::StartsWith(tag, kBBoxTag)) { std::string key = ""; @@ -166,10 +167,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { .Tag(kSequenceExampleTag) .Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { sequence_ = ::absl::make_unique( cc->InputSidePackets() .Tag(kSequenceExampleTag) @@ -184,6 +185,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { features_present_[tag] = false; } + replace_keypoints_ = false; if (cc->Options() .replace_data_instead_of_append()) { for (const auto& tag : cc->Inputs().GetTags()) { @@ -212,6 +214,15 @@ class PackMediaSequenceCalculator : public CalculatorBase { } mpms::ClearBBox(key, sequence_.get()); mpms::ClearBBoxTimestamp(key, sequence_.get()); + mpms::ClearBBoxIsAnnotated(key, sequence_.get()); + mpms::ClearBBoxNumRegions(key, sequence_.get()); + mpms::ClearBBoxLabelString(key, sequence_.get()); + mpms::ClearBBoxLabelIndex(key, sequence_.get()); + mpms::ClearBBoxClassString(key, sequence_.get()); + mpms::ClearBBoxClassIndex(key, sequence_.get()); + mpms::ClearBBoxTrackString(key, sequence_.get()); + mpms::ClearBBoxTrackIndex(key, sequence_.get()); + mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get()); } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / @@ -223,8 +234,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kKeypointsTag)) { std::string key = tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1); - mpms::ClearBBoxPoint(key, sequence_.get()); - mpms::ClearBBoxTimestamp(key, sequence_.get()); + replace_keypoints_ = true; } } if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { @@ -238,10 +248,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { .Tag(kSequenceExampleTag) .SetNextTimestampBound(Timestamp::Max()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status VerifySequence() { + absl::Status VerifySequence() { std::string error_msg = "Missing features - "; bool all_present = true; for (const auto& iter : features_present_) { @@ -251,13 +261,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } if (all_present) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) << error_msg; } } - ::mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { auto& options = cc->Options(); if (options.reconcile_metadata()) { RET_CHECK_OK(mpms::ReconcileMetadata( @@ -266,7 +276,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { } if (options.output_only_if_all_present()) { - ::mediapipe::Status status = VerifySequence(); + absl::Status status = VerifySequence(); if (!status.ok()) { cc->GetCounter(status.ToString())->Increment(); return status; @@ -285,10 +295,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { } sequence_.reset(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { int image_height = -1; int image_width = -1; // Because the tag order may vary, we need to loop through tags to get @@ -339,14 +349,28 @@ class PackMediaSequenceCalculator : public CalculatorBase { const auto& keypoints = cc->Inputs() .Tag(tag) - .Get>>>(); for (const auto& pair : keypoints) { - mpms::AddBBoxTimestamp(mpms::merge_prefix(key, pair.first), - cc->InputTimestamp().Value(), sequence_.get()); - mpms::AddBBoxPoint(mpms::merge_prefix(key, pair.first), pair.second, - sequence_.get()); + std::string prefix = mpms::merge_prefix(key, pair.first); + if (replace_keypoints_) { + mpms::ClearBBoxPoint(prefix, sequence_.get()); + mpms::ClearBBoxTimestamp(prefix, sequence_.get()); + mpms::ClearBBoxIsAnnotated(prefix, sequence_.get()); + mpms::ClearBBoxNumRegions(prefix, sequence_.get()); + mpms::ClearBBoxLabelString(prefix, sequence_.get()); + mpms::ClearBBoxLabelIndex(prefix, sequence_.get()); + mpms::ClearBBoxClassString(prefix, sequence_.get()); + mpms::ClearBBoxClassIndex(prefix, sequence_.get()); + mpms::ClearBBoxTrackString(prefix, sequence_.get()); + mpms::ClearBBoxTrackIndex(prefix, sequence_.get()); + mpms::ClearUnmodifiedBBoxTimestamp(prefix, sequence_.get()); + } + mpms::AddBBoxTimestamp(prefix, cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddBBoxPoint(prefix, pair.second, sequence_.get()); } + replace_keypoints_ = false; } if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { @@ -465,16 +489,17 @@ class PackMediaSequenceCalculator : public CalculatorBase { sequence_.get()); already_has_mask = true; } else { - return ::mediapipe::UnimplementedError( + return absl::UnimplementedError( "Global detections and empty detections are not supported."); } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::unique_ptr sequence_; std::map features_present_; + bool replace_keypoints_; }; REGISTER_CALCULATOR(PackMediaSequenceCalculator); diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 15d92fe92..09c5a0f24 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -14,6 +14,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" @@ -36,7 +37,7 @@ namespace mediapipe { namespace { namespace tf = ::tensorflow; -namespace mpms = ::mediapipe::mediasequence; +namespace mpms = mediapipe::mediasequence; class PackMediaSequenceCalculatorTest : public ::testing::Test { protected: @@ -70,9 +71,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -100,7 +100,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { ASSERT_EQ(num_images, mpms::GetImageEncodedSize(output_sequence)); for (int i = 0; i < num_images; ++i) { ASSERT_EQ(i, mpms::GetImageTimestampAt(output_sequence, i)); - ASSERT_EQ(test_image_string, mpms::GetImageEncodedAt(output_sequence, i)); + ASSERT_EQ(encoded_image.encoded_image(), + mpms::GetImageEncodedAt(output_sequence, i)); } } @@ -113,9 +114,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -144,7 +144,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { ASSERT_EQ(num_images, mpms::GetImageEncodedSize(prefix, output_sequence)); for (int i = 0; i < num_images; ++i) { ASSERT_EQ(i, mpms::GetImageTimestampAt(prefix, output_sequence, i)); - ASSERT_EQ(test_image_string, + ASSERT_EQ(encoded_image.encoded_image(), mpms::GetImageEncodedAt(prefix, output_sequence, i)); } } @@ -238,9 +238,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = ::absl::make_unique(encoded_image); runner_->MutableInputs()->Tag("IMAGE").packets.push_back( @@ -433,7 +432,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) { Adopt(input_sequence.release()); auto status = runner_->Run(); - EXPECT_EQ(::mediapipe::StatusCode::kInvalidArgument, status.code()); + EXPECT_EQ(absl::StatusCode::kInvalidArgument, status.code()); } TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { @@ -479,9 +478,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(width); encoded_image.set_height(height); @@ -537,8 +535,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { std::string test_video_id = "test_video_id"; mpms::SetClipMediaId(test_video_id, input_sequence.get()); - std::unordered_map>> points = - {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; + absl::flat_hash_map>> + points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; runner_->MutableInputs() ->Tag("KEYPOINTS_TEST") .packets.push_back(PointToForeign(&points).At(Timestamp(0))); @@ -693,7 +691,7 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = Adopt(input_sequence.release()); - ::mediapipe::Status status = runner_->Run(); + absl::Status status = runner_->Run(); EXPECT_FALSE(status.ok()); } @@ -793,9 +791,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -839,5 +836,58 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 4), 50); } +TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { + SetUpCalculator({"IMAGE:images", "BBOX:bbox"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(bytes.data(), bytes.size()); + int height = 2; + int width = 2; + encoded_image.set_width(width); + encoded_image.set_height(height); + + int num_images = 5; // Timestamps: 10, 20, 30, 40, 50 + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + Adopt(image_ptr.release()).At(Timestamp(i))); + } + + for (int i = 0; i < num_images; ++i) { + auto detections = ::absl::make_unique<::std::vector>(); + Detection detection; + detection = Detection(); + detection.add_label("relative bbox"); + detection.add_label_id(1); + detection.add_score(0.75); + Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + runner_->MutableInputs()->Tag("BBOX").packets.push_back( + Adopt(detections.release()).At(Timestamp(i))); + } + + for (int i = 0; i < 10; ++i) { + mpms::AddBBoxTimestamp(-1, input_sequence.get()); + mpms::AddBBoxIsAnnotated(-1, input_sequence.get()); + mpms::AddBBoxNumRegions(-1, input_sequence.get()); + mpms::AddBBoxLabelString({"anything"}, input_sequence.get()); + mpms::AddBBoxLabelIndex({-1}, input_sequence.get()); + mpms::AddBBoxClassString({"anything"}, input_sequence.get()); + mpms::AddBBoxClassIndex({-1}, input_sequence.get()); + mpms::AddBBoxTrackString({"anything"}, input_sequence.get()); + mpms::AddBBoxTrackIndex({-1}, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + // If the all the previous values aren't cleared, this assert will fail. + MP_ASSERT_OK(runner_->Run()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc index 6693a0642..da85bed94 100644 --- a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc @@ -44,15 +44,15 @@ constexpr char kSequenceExample[] = "SEQUENCE_EXAMPLE"; class StringToSequenceExampleCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(StringToSequenceExampleCalculator); -::mediapipe::Status StringToSequenceExampleCalculator::GetContract( +absl::Status StringToSequenceExampleCalculator::GetContract( CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kString)) { cc->InputSidePackets().Tag(kString).Set(); @@ -62,38 +62,35 @@ REGISTER_CALCULATOR(StringToSequenceExampleCalculator); cc->InputSidePackets().Tag(kSequenceExample).Set(); cc->OutputSidePackets().Tag(kString).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status StringToSequenceExampleCalculator::Open( - CalculatorContext* cc) { +absl::Status StringToSequenceExampleCalculator::Open(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag(kString)) { auto string_value = cc->InputSidePackets().Tag(kString).Get(); auto example = absl::make_unique(); example->ParseFromString(string_value); cc->OutputSidePackets() .Tag(kSequenceExample) - .Set(::mediapipe::Adopt(example.release())); + .Set(mediapipe::Adopt(example.release())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status StringToSequenceExampleCalculator::Process( - CalculatorContext* cc) { - return ::mediapipe::OkStatus(); +absl::Status StringToSequenceExampleCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); } -::mediapipe::Status StringToSequenceExampleCalculator::Close( - CalculatorContext* cc) { +absl::Status StringToSequenceExampleCalculator::Close(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag(kSequenceExample)) { const auto& example = cc->InputSidePackets().Tag(kSequenceExample).Get(); auto string_value = absl::make_unique(); example.SerializeToString(string_value.get()); cc->OutputSidePackets().Tag(kString).Set( - ::mediapipe::Adopt(string_value.release())); + mediapipe::Adopt(string_value.release())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc index b1e4f05f0..cbf494245 100644 --- a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc @@ -27,7 +27,7 @@ namespace tf = ::tensorflow; // containing identical data (example output dimensions [1024, 5]). class TensorSqueezeDimensionsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Need one input"; cc->Inputs().Index(0).Set( // Input Tensor @@ -36,10 +36,10 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { cc->Outputs().Index(0).Set( // Output Tensor Reduced Dimensions ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { options_ = cc->Options(); RET_CHECK(options_.squeeze_all_single_dims() ^ (options_.dim_size() > 0)) << "Must specify dimensions to remove, or set squeeze_all_single_dims, " @@ -52,10 +52,10 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { remove_dims_initialized_ = true; } cc->SetOffset(0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const tf::Tensor& input_tensor = cc->Inputs().Index(0).Get(); tf::TensorShape tensor_shape = input_tensor.shape(); if (!remove_dims_initialized_) { @@ -78,11 +78,11 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { std::unique_ptr output_tensor(new tf::Tensor); RET_CHECK(output_tensor->CopyFrom(input_tensor, tensor_shape)); cc->Outputs().Index(0).Add(output_tensor.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Close(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Close(CalculatorContext* cc) override { + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc index f6e4354d3..d72c75923 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -45,10 +45,10 @@ constexpr char kTensor[] = "TENSOR"; // Possible extensions: support other input ranges, maybe 4D tensors. class TensorToImageFrameCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: float scale_factor_; @@ -56,8 +56,7 @@ class TensorToImageFrameCalculator : public CalculatorBase { REGISTER_CALCULATOR(TensorToImageFrameCalculator); -::mediapipe::Status TensorToImageFrameCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -70,18 +69,17 @@ REGISTER_CALCULATOR(TensorToImageFrameCalculator); cc->Outputs().Tag(kImage).Set( // Output ImageFrame. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { scale_factor_ = cc->Options().scale_factor(); cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TensorToImageFrameCalculator::Process( - CalculatorContext* cc) { +absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); int32 depth = 1; if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors. @@ -114,11 +112,11 @@ REGISTER_CALCULATOR(TensorToImageFrameCalculator); ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0), input_tensor.dim_size(1), buffer.release()); } else { - return ::mediapipe::InvalidArgumentError("Unrecognized image depth."); + return absl::InvalidArgumentError("Unrecognized image depth."); } cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index 9c7f3458c..d52de7404 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -34,20 +34,19 @@ constexpr char kMatrix[] = "MATRIX"; constexpr char kTensor[] = "TENSOR"; constexpr char kReference[] = "REFERENCE"; -::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header) { +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { - return ::mediapipe::UnknownError("No header found."); + return absl::UnknownError("No header found."); } if (!header_packet.ValidateAsType().ok()) { - return ::mediapipe::UnknownError( - "Packet does not contain TimeSeriesHeader."); + return absl::UnknownError("Packet does not contain TimeSeriesHeader."); } *header = header_packet.Get(); if (header->has_sample_rate() && header->sample_rate() >= 0 && header->has_num_channels() && header->num_channels() >= 0) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { std::string error_message = "TimeSeriesHeader is missing necessary fields: " @@ -56,7 +55,7 @@ constexpr char kReference[] = "REFERENCE"; absl::StrAppend(&error_message, "Got header:\n", header->ShortDebugString()); #endif - return ::mediapipe::InvalidArgumentError(error_message); + return absl::InvalidArgumentError(error_message); } } @@ -110,18 +109,17 @@ constexpr char kReference[] = "REFERENCE"; // } class TensorToMatrixCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Store header information so that we can verify the inputs in process(). TimeSeriesHeader header_; }; REGISTER_CALCULATOR(TensorToMatrixCalculator); -::mediapipe::Status TensorToMatrixCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorToMatrixCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_LE(cc->Inputs().NumEntries(), 2) << "Only one or two input streams are supported."; RET_CHECK_GT(cc->Inputs().NumEntries(), 0) @@ -147,12 +145,12 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator); cc->Outputs().Tag(kMatrix).Set( // Output Matrix. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { auto input_header = absl::make_unique(); - ::mediapipe::Status header_status; + absl::Status header_status; if (cc->Inputs().HasTag(kReference)) { header_status = FillTimeSeriesHeaderIfValid( cc->Inputs().Tag(kReference).Header(), input_header.get()); @@ -184,10 +182,10 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator); cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release())); } cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { +absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { // Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK // failures. These are absolute conditions of the graph for the graph to be // valid, and if it is violated by any input anywhere, the graph will be @@ -221,7 +219,7 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator); *output = Eigen::MatrixXf::Map(input_tensor.flat().data(), length, width); cc->Outputs().Tag(kMatrix).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc index 7b447f4d5..cd807b87b 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc @@ -28,17 +28,17 @@ namespace tf = ::tensorflow; class TensorToVectorFloatCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: TensorToVectorFloatCalculatorOptions options_; }; REGISTER_CALCULATOR(TensorToVectorFloatCalculator); -::mediapipe::Status TensorToVectorFloatCalculator::GetContract( +absl::Status TensorToVectorFloatCalculator::GetContract( CalculatorContract* cc) { // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -58,16 +58,22 @@ REGISTER_CALCULATOR(TensorToVectorFloatCalculator); // Output vector. ); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); - return ::mediapipe::OkStatus(); + + // Inform mediapipe that this calculator produces an output at time t for + // each input received at time t (i.e. this calculator does not buffer + // inputs). This enables mediapipe to propagate time of arrival estimates in + // mediapipe graphs through this calculator. + cc->SetOffset(/*offset=*/0); + + return absl::OkStatus(); } -::mediapipe::Status TensorToVectorFloatCalculator::Process( - CalculatorContext* cc) { +absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) { const tf::Tensor& input_tensor = cc->Inputs().Index(0).Value().Get(); RET_CHECK(tf::DT_FLOAT == input_tensor.dtype()) @@ -103,7 +109,7 @@ REGISTER_CALCULATOR(TensorToVectorFloatCalculator); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 5b9a74a6d..d78a53053 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -98,547 +98,388 @@ class InferenceState { // This calculator performs inference on a trained TensorFlow model. // -// Additional documentation and examples at -// go/mediapipe/tensorflow_in_mediapipe. -// -// TensorFlow Sessions can be created from checkpoint paths, frozen models, or -// the SavedModel system (go/saved-model). See the TensorFlowSessionFrom* -// packet generators for details. Each of these methods defines a mapping -// between MediaPipe streams and TensorFlow tensors. All of this information is -// passed in as an input_side_packet. -// -// The input and output streams are TensorFlow tensors labeled by tags. The tags -// for the streams are matched to feeds and fetchs in a TensorFlow session using -// a named_signature.generic_signature in the ModelManifest. The -// generic_signature is used as key-value pairs between the MediaPipe tag and -// the TensorFlow tensor. The signature_name in the options proto determines -// which named_signature is used. The keys in the generic_signature must be -// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of -// the tensors corresponding to tags in the signature for input_streams are fed -// to the model and for output_streams the tensors are fetched from the model. -// -// Other calculators are used to convert data to and from tensors, this op only -// handles the TensorFlow session and batching. Batching occurs by concatenating -// input tensors along the 0th dimension across timestamps. If the 0th dimension -// is not a batch dimension, this calculator will add a 0th dimension by -// default. Setting add_batch_dim_to_tensors to false disables the dimension -// addition. Once batch_size inputs have been provided, the batch will be run -// and the output tensors sent out on the output streams with timestamps -// corresponding to the input stream packets. Setting the batch_size to 1 -// completely disables batching, but is indepdent of add_batch_dim_to_tensors. -// -// The TensorFlowInferenceCalculator also support feeding states recurrently for -// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the -// recurrent tensors. Initializing the recurrent state can be handled by the -// GraphTensorsPacketGenerator. -// -// The calculator updates two Counters to report timing information: -// ---TotalTimeUsecs = Total time spent running inference (in usecs), -// ---TotalProcessedTimestamps = # of instances processed -// (approximately batches processed * batch_size), -// where is replaced with CalculatorGraphConfig::Node::name() if it -// exists, or with TensorFlowInferenceCalculator if the name is not set. The -// name must be set for timing information to be instance-specific in graphs -// with multiple TensorFlowInferenceCalculators. -// -// Example config: -// packet_generator { -// packet_generator: "TensorFlowSessionFromSavedModelGenerator" -// output_side_packet: "tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: { -// saved_model_path: "/path/to/saved/model" -// signature_name: "mediapipe" -// } -// } -// } -// node { -// calculator: "TensorFlowInferenceCalculator" -// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag" -// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag" -// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag" -// input_side_packet: "SESSION:tensorflow_session" -// } -// -// Where the input and output streams are treated as Packet and -// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and -// "LABELS" and their respective tensors exported to /path/to/bundle. For an -// example of how this model was exported, see -// tensorflow_inference_test_graph_generator.py -// -// It is possible to use a GraphDef proto that was not exported by exporter (i.e -// without MetaGraph with bindings). Such GraphDef could contain all of its -// parameters in-lined (for example, it can be the output of freeze_graph.py). -// To instantiate a TensorFlow model from a GraphDef file, replace the -// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator: -// -// packet_generator { -// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" -// output_side_packet: "SESSION:tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { -// graph_proto_path: "[PATH]" -// tag_to_tensor_names { -// key: "JPG_STRING" -// value: "input:0" -// } -// tag_to_tensor_names { -// key: "SOFTMAX" -// value: "softmax:0" -// } -// } -// } -// } -// -// It is also possible to use a GraphDef proto and checkpoint file that have not -// been frozen. This can be used to load graphs directly as they have been -// written from training. However, it is more brittle and you are encouraged to -// use a one of the more perminent formats described above. To instantiate a -// TensorFlow model from a GraphDef file and checkpoint, replace the -// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator: -// -// packet_generator { -// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator" -// output_side_packet: "SESSION:tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: { -// graph_proto_path: "[PATH]" -// model_options { -// checkpoint_path: "[PATH2]" -// } -// tag_to_tensor_names { -// key: "JPG_STRING" -// value: "input:0" -// } -// tag_to_tensor_names { -// key: "SOFTMAX" -// value: "softmax:0" -// } -// } -// } -// } -class TensorFlowInferenceCalculator : public CalculatorBase { - public: - // Counters for recording timing information. The actual names have the value - // of CalculatorGraphConfig::Node::name() prepended. - static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs"; - static constexpr char kTotalProcessedTimestampsCounterSuffix[] = - "TotalProcessedTimestamps"; - static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] = - "TotalSessionRunsTimeUsecs"; - static constexpr char kTotalNumSessionRunsCounterSuffix[] = - "TotalNumSessionRuns"; +// A mediapipe::TensorFlowSession with a model loaded and ready for use. +// For this calculator it must include a tag_to_tensor_map. +cc->InputSidePackets().Tag("SESSION").Set(); +if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets() + .Tag("RECURRENT_INIT_TENSORS") + .Set>>(); +} +return absl::OkStatus(); +} - TensorFlowInferenceCalculator() : session_(nullptr) { - clock_ = std::unique_ptr( - mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); +std::unique_ptr CreateInferenceState(CalculatorContext* cc) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + std::unique_ptr inference_state = + absl::make_unique(); + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && + !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + std::map* init_tensor_map; + init_tensor_map = GetFromUniquePtr>( + cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + for (const auto& p : *init_tensor_map) { + inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + } + } + return inference_state; +} + +absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + session_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .session.get(); + tag_to_tensor_map_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .tag_to_tensor_map; + + // Validate and store the recurrent tags + RET_CHECK(options_.has_batch_size()); + RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty()) + << "To use recurrent_tag_pairs, batch_size must be 1."; + for (const auto& tag_pair : options_.recurrent_tag_pair()) { + const std::vector tags = absl::StrSplit(tag_pair, ':'); + RET_CHECK_EQ(tags.size(), 2) + << "recurrent_tag_pair must be a colon " + "separated std::string with two components: " + << tag_pair; + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) + << "Can't find tag '" << tags[0] << "' in signature " + << options_.signature_name(); + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) + << "Can't find tag '" << tags[1] << "' in signature " + << options_.signature_name(); + recurrent_feed_tags_.insert(tags[0]); + recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; } - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - const auto& options = cc->Options(); - RET_CHECK(!cc->Inputs().GetTags().empty()); - for (const std::string& tag : cc->Inputs().GetTags()) { - // The tensorflow::Tensor with the tag equal to the graph node. May - // have a TimeSeriesHeader if all present TimeSeriesHeaders match. - if (!options.batched_input()) { - cc->Inputs().Tag(tag).Set(); + // Check that all tags are present in this signature bound to tensors. + for (const std::string& tag : cc->Inputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + for (const std::string& tag : cc->Outputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + + { + absl::WriterMutexLock l(&mutex_); + inference_state_ = std::unique_ptr(); + } + + if (options_.batch_size() == 1 || options_.batched_input()) { + cc->SetOffset(0); + } + + return absl::OkStatus(); +} + +// Adds a batch dimension to the input tensor if specified in the calculator +// options. +absl::Status AddBatchDimension(tf::Tensor* input_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(input_tensor->shape()); + new_shape.InsertDim(0, 1); + RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) + << "Could not add 0th dimension to tensor without changing its shape." + << " Current shape: " << input_tensor->shape().DebugString(); + } + return absl::OkStatus(); +} + +absl::Status AggregateTensorPacket( + const std::string& tag_name, const Packet& packet, + std::map>* + input_tensors_by_tag_by_timestamp, + InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + tf::Tensor input_tensor(packet.Get()); + RET_CHECK_OK(AddBatchDimension(&input_tensor)); + if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { + // If we receive an input on a recurrent tag, override the state. + // It's OK to override the global state because there is just one + // input stream allowed for recurrent tensors. + inference_state_->input_tensor_batches_[tag_name].clear(); + } + (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( + std::make_pair(tag_name, input_tensor)); + return absl::OkStatus(); +} + +// Removes the batch dimension of the output tensor if specified in the +// calculator options. +absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(output_tensor->shape()); + new_shape.RemoveDim(0); + RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) + << "Could not remove 0th dimension from tensor without changing its " + << "shape. Current shape: " << output_tensor->shape().DebugString() + << " (The expected first dimension is 1 for a batch element.)"; + } + return absl::OkStatus(); +} + +absl::Status Process(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process; + { + absl::WriterMutexLock l(&mutex_); + if (inference_state_ == nullptr) { + inference_state_ = CreateInferenceState(cc); + } + std::map> + input_tensors_by_tag_by_timestamp; + for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { + if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { + // Recurrent tensors can be empty. + if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { + if (options_.skip_on_missing_features()) { + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Tag ", tag_as_node_name, + " not present at timestamp: ", cc->InputTimestamp().Value())); + } + } + } else if (options_.batched_input()) { + const auto& tensor_packets = + cc->Inputs().Tag(tag_as_node_name).Get>(); + if (tensor_packets.size() > options_.batch_size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Batch for tag ", tag_as_node_name, + " has more packets than batch capacity. batch_size: ", + options_.batch_size(), " packets: ", tensor_packets.size())); + } + for (const auto& packet : tensor_packets) { + RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet, + &input_tensors_by_tag_by_timestamp, + inference_state_.get())); + } } else { - cc->Inputs().Tag(tag).Set>(); + RET_CHECK_OK(AggregateTensorPacket( + tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), + &input_tensors_by_tag_by_timestamp, inference_state_.get())); } } - RET_CHECK(!cc->Outputs().GetTags().empty()); - for (const std::string& tag : cc->Outputs().GetTags()) { - // The tensorflow::Tensor with tag equal to the graph node to - // output. Any TimeSeriesHeader from the inputs will be forwarded - // with channels set to 0. - cc->Outputs().Tag(tag).Set(); - } - // A mediapipe::TensorFlowSession with a model loaded and ready for use. - // For this calculator it must include a tag_to_tensor_map. - cc->InputSidePackets().Tag("SESSION").Set(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { - cc->InputSidePackets() - .Tag("RECURRENT_INIT_TENSORS") - .Set>>(); - } - return ::mediapipe::OkStatus(); - } - - std::unique_ptr CreateInferenceState(CalculatorContext* cc) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - std::unique_ptr inference_state = - absl::make_unique(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && - !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { - std::map* init_tensor_map; - init_tensor_map = GetFromUniquePtr>( - cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); - for (const auto& p : *init_tensor_map) { - inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + for (const auto& timestamp_and_input_tensors_by_tag : + input_tensors_by_tag_by_timestamp) { + inference_state_->batch_timestamps_.emplace_back( + timestamp_and_input_tensors_by_tag.first); + for (const auto& input_tensor_and_tag : + timestamp_and_input_tensors_by_tag.second) { + inference_state_->input_tensor_batches_[input_tensor_and_tag.first] + .emplace_back(input_tensor_and_tag.second); } } - return inference_state; - } - - ::mediapipe::Status Open(CalculatorContext* cc) override { - options_ = cc->Options(); - - RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); - session_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .session.get(); - tag_to_tensor_map_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .tag_to_tensor_map; - - // Validate and store the recurrent tags - RET_CHECK(options_.has_batch_size()); - RET_CHECK(options_.batch_size() == 1 || - options_.recurrent_tag_pair().empty()) - << "To use recurrent_tag_pairs, batch_size must be 1."; - for (const auto& tag_pair : options_.recurrent_tag_pair()) { - const std::vector tags = absl::StrSplit(tag_pair, ':'); - RET_CHECK_EQ(tags.size(), 2) - << "recurrent_tag_pair must be a colon " - "separated std::string with two components: " - << tag_pair; - RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) - << "Can't find tag '" << tags[0] << "' in signature " - << options_.signature_name(); - RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) - << "Can't find tag '" << tags[1] << "' in signature " - << options_.signature_name(); - recurrent_feed_tags_.insert(tags[0]); - recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; - } - - // Check that all tags are present in this signature bound to tensors. - for (const std::string& tag : cc->Inputs().GetTags()) { - RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - for (const std::string& tag : cc->Outputs().GetTags()) { - RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - - { - absl::WriterMutexLock l(&mutex_); + if (inference_state_->batch_timestamps_.size() == options_.batch_size() || + options_.batched_input()) { + inference_state_to_process = std::move(inference_state_); inference_state_ = std::unique_ptr(); } - - if (options_.batch_size() == 1 || options_.batched_input()) { - cc->SetOffset(0); - } - - return ::mediapipe::OkStatus(); } - // Adds a batch dimension to the input tensor if specified in the calculator - // options. - ::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(input_tensor->shape()); - new_shape.InsertDim(0, 1); - RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) - << "Could not add 0th dimension to tensor without changing its shape." - << " Current shape: " << input_tensor->shape().DebugString(); - } - return ::mediapipe::OkStatus(); + if (inference_state_to_process) { + MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); } - ::mediapipe::Status AggregateTensorPacket( - const std::string& tag_name, const Packet& packet, - std::map>* - input_tensors_by_tag_by_timestamp, - InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - tf::Tensor input_tensor(packet.Get()); - RET_CHECK_OK(AddBatchDimension(&input_tensor)); - if (::mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { - // If we receive an input on a recurrent tag, override the state. - // It's OK to override the global state because there is just one - // input stream allowed for recurrent tensors. - inference_state_->input_tensor_batches_[tag_name].clear(); - } - (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( - std::make_pair(tag_name, input_tensor)); - return ::mediapipe::OkStatus(); - } - - // Removes the batch dimension of the output tensor if specified in the - // calculator options. - ::mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(output_tensor->shape()); - new_shape.RemoveDim(0); - RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) - << "Could not remove 0th dimension from tensor without changing its " - << "shape. Current shape: " << output_tensor->shape().DebugString() - << " (The expected first dimension is 1 for a batch element.)"; - } - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process; - { - absl::WriterMutexLock l(&mutex_); - if (inference_state_ == nullptr) { - inference_state_ = CreateInferenceState(cc); - } - std::map> - input_tensors_by_tag_by_timestamp; - for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { - if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { - // Recurrent tensors can be empty. - if (!::mediapipe::ContainsKey(recurrent_feed_tags_, - tag_as_node_name)) { - if (options_.skip_on_missing_features()) { - return ::mediapipe::OkStatus(); - } else { - return ::mediapipe::InvalidArgumentError(absl::StrCat( - "Tag ", tag_as_node_name, - " not present at timestamp: ", cc->InputTimestamp().Value())); - } - } - } else if (options_.batched_input()) { - const auto& tensor_packets = - cc->Inputs().Tag(tag_as_node_name).Get>(); - if (tensor_packets.size() > options_.batch_size()) { - return ::mediapipe::InvalidArgumentError(absl::StrCat( - "Batch for tag ", tag_as_node_name, - " has more packets than batch capacity. batch_size: ", - options_.batch_size(), " packets: ", tensor_packets.size())); - } - for (const auto& packet : tensor_packets) { - RET_CHECK_OK(AggregateTensorPacket( - tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp, - inference_state_.get())); - } - } else { - RET_CHECK_OK(AggregateTensorPacket( - tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), - &input_tensors_by_tag_by_timestamp, inference_state_.get())); - } - } - for (const auto& timestamp_and_input_tensors_by_tag : - input_tensors_by_tag_by_timestamp) { - inference_state_->batch_timestamps_.emplace_back( - timestamp_and_input_tensors_by_tag.first); - for (const auto& input_tensor_and_tag : - timestamp_and_input_tensors_by_tag.second) { - inference_state_->input_tensor_batches_[input_tensor_and_tag.first] - .emplace_back(input_tensor_and_tag.second); - } - } - if (inference_state_->batch_timestamps_.size() == options_.batch_size() || - options_.batched_input()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - - if (inference_state_to_process) { - MP_RETURN_IF_ERROR( - OutputBatch(cc, std::move(inference_state_to_process))); - } - - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Close(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process = nullptr; - { - absl::WriterMutexLock l(&mutex_); - if (cc->GraphStatus().ok() && inference_state_ != nullptr && - !inference_state_->batch_timestamps_.empty()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - if (inference_state_to_process) { - MP_RETURN_IF_ERROR( - OutputBatch(cc, std::move(inference_state_to_process))); - } - return ::mediapipe::OkStatus(); - } - - // When a batch of input tensors is ready to be run, runs TensorFlow and - // outputs the output tensors. The output tensors have timestamps matching - // the input tensor that formed that batch element. Any requested - // batch_dimension is added and removed. This code takes advantage of the fact - // that copying a tensor shares the same reference-counted, heap allocated - // memory buffer. Therefore, copies are cheap and should not cause the memory - // buffer to fall out of scope. In contrast, concat is only used where - // necessary. - ::mediapipe::Status OutputBatch( - CalculatorContext* cc, std::unique_ptr inference_state) { - const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); - std::vector> input_tensors; - - for (auto& keyed_tensors : inference_state->input_tensor_batches_) { - if (options_.batch_size() == 1) { - // Short circuit to avoid the cost of deep copying tensors in concat. - if (!keyed_tensors.second.empty()) { - input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - keyed_tensors.second[0]); - } else { - // The input buffer can be empty for recurrent tensors. - RET_CHECK(::mediapipe::ContainsKey(recurrent_feed_tags_, - keyed_tensors.first)) - << "A non-recurrent tensor does not have an input: " - << keyed_tensors.first; - } - } else { - // Pad by replicating the first tens or, then ignore the values. - keyed_tensors.second.resize(options_.batch_size()); - std::fill(keyed_tensors.second.begin() + - inference_state->batch_timestamps_.size(), - keyed_tensors.second.end(), keyed_tensors.second[0]); - tf::Tensor concated; - const tf::Status concat_status = - tf::tensor::Concat(keyed_tensors.second, &concated); - CHECK(concat_status.ok()) << concat_status.ToString(); - input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - concated); - } - } - inference_state->input_tensor_batches_.clear(); - std::vector output_tensor_names; - std::vector output_name_in_signature; - for (const std::string& tag : cc->Outputs().GetTags()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); - output_name_in_signature.emplace_back(tag); - } - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - // Ensure that we always fetch the recurrent state tensors. - if (std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), - tag_pair.first) == output_name_in_signature.end()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); - output_name_in_signature.emplace_back(tag_pair.first); - } - } - std::vector outputs; - - SimpleSemaphore* session_run_throttle = nullptr; - if (options_.max_concurrent_session_runs() > 0) { - session_run_throttle = - get_session_run_throttle(options_.max_concurrent_session_runs()); - session_run_throttle->Acquire(1); - } - const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); - tf::Status tf_status; - { -#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) - tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); -#endif - tf_status = session_->Run(input_tensors, output_tensor_names, - {} /* target_node_names */, &outputs); - } - - if (session_run_throttle != nullptr) { - session_run_throttle->Release(1); - } - - // RET_CHECK on the tf::Status object itself in order to print an - // informative error message. - RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); - - const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) - ->IncrementBy(run_end_time - run_start_time); - cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); - - // Feed back the recurrent state. - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - int pos = std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), tag_pair.first) - - output_name_in_signature.begin(); - inference_state->input_tensor_batches_[tag_pair.second].emplace_back( - outputs[pos]); - } + return absl::OkStatus(); +} +absl::Status Close(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process = nullptr; + { absl::WriterMutexLock l(&mutex_); - // Set that we want to split on each index of the 0th dimension. - std::vector split_vector(options_.batch_size(), 1); - for (int i = 0; i < output_tensor_names.size(); ++i) { - if (options_.batch_size() == 1) { - if (cc->Outputs().HasTag(output_name_in_signature[i])) { - tf::Tensor output_tensor(outputs[i]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[0]); - } + if (cc->GraphStatus().ok() && inference_state_ != nullptr && + !inference_state_->batch_timestamps_.empty()) { + inference_state_to_process = std::move(inference_state_); + inference_state_ = std::unique_ptr(); + } + } + if (inference_state_to_process) { + MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); + } + return absl::OkStatus(); +} + +// When a batch of input tensors is ready to be run, runs TensorFlow and +// outputs the output tensors. The output tensors have timestamps matching +// the input tensor that formed that batch element. Any requested +// batch_dimension is added and removed. This code takes advantage of the fact +// that copying a tensor shares the same reference-counted, heap allocated +// memory buffer. Therefore, copies are cheap and should not cause the memory +// buffer to fall out of scope. In contrast, concat is only used where +// necessary. +absl::Status OutputBatch(CalculatorContext* cc, + std::unique_ptr inference_state) { + const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + std::vector> input_tensors; + + for (auto& keyed_tensors : inference_state->input_tensor_batches_) { + if (options_.batch_size() == 1) { + // Short circuit to avoid the cost of deep copying tensors in concat. + if (!keyed_tensors.second.empty()) { + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + keyed_tensors.second[0]); } else { - std::vector split_tensors; - const tf::Status split_status = - tf::tensor::Split(outputs[i], split_vector, &split_tensors); - CHECK(split_status.ok()) << split_status.ToString(); - // Loop over timestamps so that we don't copy the padding. - for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { - tf::Tensor output_tensor(split_tensors[j]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[j]); - } + // The input buffer can be empty for recurrent tensors. + RET_CHECK( + mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first)) + << "A non-recurrent tensor does not have an input: " + << keyed_tensors.first; + } + } else { + // Pad by replicating the first tens or, then ignore the values. + keyed_tensors.second.resize(options_.batch_size()); + std::fill(keyed_tensors.second.begin() + + inference_state->batch_timestamps_.size(), + keyed_tensors.second.end(), keyed_tensors.second[0]); + tf::Tensor concated; + const tf::Status concat_status = + tf::tensor::Concat(keyed_tensors.second, &concated); + CHECK(concat_status.ok()) << concat_status.ToString(); + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + concated); + } + } + inference_state->input_tensor_batches_.clear(); + std::vector output_tensor_names; + std::vector output_name_in_signature; + for (const std::string& tag : cc->Outputs().GetTags()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); + output_name_in_signature.emplace_back(tag); + } + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + // Ensure that we always fetch the recurrent state tensors. + if (std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), + tag_pair.first) == output_name_in_signature.end()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); + output_name_in_signature.emplace_back(tag_pair.first); + } + } + std::vector outputs; + + SimpleSemaphore* session_run_throttle = nullptr; + if (options_.max_concurrent_session_runs() > 0) { + session_run_throttle = + get_session_run_throttle(options_.max_concurrent_session_runs()); + session_run_throttle->Acquire(1); + } + const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + tf::Status tf_status; + { +#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) + tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); +#endif + tf_status = session_->Run(input_tensors, output_tensor_names, + {} /* target_node_names */, &outputs); + } + + if (session_run_throttle != nullptr) { + session_run_throttle->Release(1); + } + + // RET_CHECK on the tf::Status object itself in order to print an + // informative error message. + RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); + + const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) + ->IncrementBy(run_end_time - run_start_time); + cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); + + // Feed back the recurrent state. + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + int pos = std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), tag_pair.first) - + output_name_in_signature.begin(); + inference_state->input_tensor_batches_[tag_pair.second].emplace_back( + outputs[pos]); + } + + absl::WriterMutexLock l(&mutex_); + // Set that we want to split on each index of the 0th dimension. + std::vector split_vector(options_.batch_size(), 1); + for (int i = 0; i < output_tensor_names.size(); ++i) { + if (options_.batch_size() == 1) { + if (cc->Outputs().HasTag(output_name_in_signature[i])) { + tf::Tensor output_tensor(outputs[i]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[0]); + } + } else { + std::vector split_tensors; + const tf::Status split_status = + tf::tensor::Split(outputs[i], split_vector, &split_tensors); + CHECK(split_status.ok()) << split_status.ToString(); + // Loop over timestamps so that we don't copy the padding. + for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { + tf::Tensor output_tensor(split_tensors[j]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[j]); } } - - // Get end time and report. - const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalUsecsCounterSuffix) - ->IncrementBy(end_time - start_time); - cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) - ->IncrementBy(inference_state->batch_timestamps_.size()); - - // Make sure we hold on to the recursive state. - if (!options_.recurrent_tag_pair().empty()) { - inference_state_ = std::move(inference_state); - inference_state_->batch_timestamps_.clear(); - } - - return ::mediapipe::OkStatus(); } - private: - // The Session object is provided by a packet factory and is owned by the - // MediaPipe framework. Individual calls are thread-safe, but session state - // may be shared across threads. - tf::Session* session_; + // Get end time and report. + const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time); + cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) + ->IncrementBy(inference_state->batch_timestamps_.size()); - // A mapping between stream tags and the tensor names they are bound to. - std::map tag_to_tensor_map_; - - absl::Mutex mutex_; - std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); - - // The options for the calculator. - TensorFlowInferenceCalculatorOptions options_; - - // Store the feed and fetch tags for feed/fetch recurrent networks. - std::set recurrent_feed_tags_; - std::map recurrent_fetch_tags_to_feed_tags_; - - // Clock used to measure the computation time in OutputBatch(). - std::unique_ptr clock_; - - // The static singleton semaphore to throttle concurrent session runs. - static SimpleSemaphore* get_session_run_throttle( - int32 max_concurrent_session_runs) { - static SimpleSemaphore* session_run_throttle = - new SimpleSemaphore(max_concurrent_session_runs); - return session_run_throttle; + // Make sure we hold on to the recursive state. + if (!options_.recurrent_tag_pair().empty()) { + inference_state_ = std::move(inference_state); + inference_state_->batch_timestamps_.clear(); } -}; + + return absl::OkStatus(); +} + +private: +// The Session object is provided by a packet factory and is owned by the +// MediaPipe framework. Individual calls are thread-safe, but session state may +// be shared across threads. +tf::Session* session_; + +// A mapping between stream tags and the tensor names they are bound to. +std::map tag_to_tensor_map_; + +absl::Mutex mutex_; +std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); + +// The options for the calculator. +TensorFlowInferenceCalculatorOptions options_; + +// Store the feed and fetch tags for feed/fetch recurrent networks. +std::set recurrent_feed_tags_; +std::map recurrent_fetch_tags_to_feed_tags_; + +// Clock used to measure the computation time in OutputBatch(). +std::unique_ptr clock_; + +// The static singleton semaphore to throttle concurrent session runs. +static SimpleSemaphore* get_session_run_throttle( + int32 max_concurrent_session_runs) { + static SimpleSemaphore* session_run_throttle = + new SimpleSemaphore(max_concurrent_session_runs); + return session_run_throttle; +} +} +; REGISTER_CALCULATOR(TensorFlowInferenceCalculator); constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[]; diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index 2ec6cbe3b..20e80bf33 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -21,6 +21,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" @@ -47,15 +48,15 @@ std::string GetGraphDefPath() { CFURLGetFileSystemRepresentation( bundle_url, true, reinterpret_cast(path), sizeof(path)); CFRelease(bundle_url); - return ::mediapipe::file::JoinPath(path, "testdata/frozen_graph_def.pb"); + return mediapipe::file::JoinPath(path, "testdata/frozen_graph_def.pb"); #elif defined(__ANDROID__) char path[1024]; getcwd(path, sizeof(path)); - return ::mediapipe::file::JoinPath(path, - "mediapipe/calculators/tensorflow/" - "testdata/frozen_graph_def.pb"); + return mediapipe::file::JoinPath(path, + "mediapipe/calculators/tensorflow/" + "testdata/frozen_graph_def.pb"); #else - return ::mediapipe::file::JoinPath( + return mediapipe::file::JoinPath( "./", // This should match the path of the output files // of the genrule() that generates test model files. diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 1c34ee6ed..2c1d169bc 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -59,7 +59,7 @@ void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); bool has_exactly_one_model = @@ -89,10 +89,10 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { // a map from tags to tensor names. ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { auto clock = std::unique_ptr( mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); const uint64 start_time = absl::ToUnixMicros(clock->TimeNow()); @@ -151,11 +151,11 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TensorFlowSessionFromFrozenGraphCalculator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index 5277eb348..8d3d3fdff 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -120,7 +121,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, ProducesPacketUsableByTensorFlowInferenceCalculator) { CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( node { calculator: "TensorFlowInferenceCalculator" @@ -153,7 +154,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index cd46a9a9f..9f5b9e06b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -55,7 +55,7 @@ void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { public: - static ::mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { RET_CHECK(extendable_options.HasExtension( @@ -87,10 +87,10 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { // a map from tags to tensor names. ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - static ::mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& packet_generator_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { auto clock = std::unique_ptr( @@ -151,7 +151,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(TensorFlowSessionFromFrozenGraphGenerator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index e2b968217..c7f06bbc4 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -101,10 +102,10 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindings) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -116,7 +117,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, ProducesPacketUsableByTensorFlowInferenceCalculator) { CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( node { calculator: "TensorFlowInferenceCalculator" @@ -149,7 +150,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( @@ -171,16 +172,16 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindingsFromInputSidePacket) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL:model"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL:model"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); generator_options_->clear_graph_proto_path(); input_side_packets.Tag("STRING_MODEL") = Adopt(new std::string(serialized_graph_contents)); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -191,13 +192,13 @@ TEST_F( TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindingsFromInputSidePacketStringModelFilePath) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->clear_graph_proto_path(); input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -207,15 +208,15 @@ TEST_F( TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CheckFailureForOptionsAndInputsProvideGraphDefProto) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -226,9 +227,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, PacketSet input_side_packets( tool::CreateTagMap( {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) - .ValueOrDie()); + .value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); @@ -237,10 +238,10 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -251,9 +252,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, PacketSet input_side_packets( tool::CreateTagMap( {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) - .ValueOrDie()); + .value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); @@ -263,10 +264,10 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, Adopt(new std::string(GetGraphDefPath())); generator_options_->clear_graph_proto_path(); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -274,11 +275,11 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CheckInitializationOpName) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->add_initialization_op_names("multiplied:0"); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 55709bcd9..6aedb138f 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -35,9 +35,9 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models // in subdirectories, replaces path with the alphabetically last subdirectory. -::mediapipe::Status GetLatestDirectory(std::string* path) { +absl::Status GetLatestDirectory(std::string* path) { #if defined(__ANDROID__) - return ::mediapipe::UnimplementedError( + return absl::UnimplementedError( "GetLatestDirectory is not implemented on Android"); #else std::vector saved_models; @@ -47,7 +47,7 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; << "No exported bundles found in " << path; ::std::sort(saved_models.begin(), saved_models.end()); *path = std::string(file::Dirname(saved_models.back())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); #endif } @@ -75,10 +75,10 @@ const std::string MaybeConvertSignatureToTag( } // namespace // TensorFlowSessionFromSavedModelCalculator is a MediaPipe packet calculator -// that loads a trained TensorFlow model exported via SavedModel's exporter (see -// go/savedmodel) and returns a Packet containing a unique_ptr to a -// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session -// ready for execution and a map between tags and tensor names. +// that loads a trained TensorFlow model exported via SavedModel's exporter and +// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession, +// which in turn contains a TensorFlow Session ready for execution and a map +// between tags and tensor names. // // Example usage: // node { @@ -93,7 +93,7 @@ const std::string MaybeConvertSignatureToTag( // } class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); const bool has_exactly_one_model = @@ -108,10 +108,10 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { } // A TensorFlow model loaded and ready for use along with tensor cc->OutputSidePackets().Tag("SESSION").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options(); std::string path = cc->InputSidePackets().HasTag(kStringSavedModelPath) @@ -140,9 +140,8 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); if (!status.ok()) { - return ::mediapipe::Status( - static_cast<::mediapipe::StatusCode>(status.code()), - status.ToString()); + return absl::Status(static_cast(status.code()), + status.ToString()); } auto session = absl::make_unique(); @@ -161,11 +160,11 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { } cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index a8839ef52..927d3b51f 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -28,7 +28,7 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // SavedModels, include a flag to load the most recent model. // Path to a directory containing a trained TensorFlow model as prepared - // by SavedModel (go/saved-model). + // by SavedModel. optional string saved_model_path = 1; // The name of the generic signature to load into the mapping from tags to // tensor names. diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index d6064d862..912d71600 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -132,7 +133,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, ProducesPacketUsableByTensorFlowInferenceCalculator) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( node { calculator: "TensorFlowInferenceCalculator" @@ -164,7 +165,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 73ffc6497..6489b0267 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -37,9 +37,9 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models // in subdirectories, replaces path with the alphabetically last subdirectory. -::mediapipe::Status GetLatestDirectory(std::string* path) { +absl::Status GetLatestDirectory(std::string* path) { #if defined(__ANDROID__) - return ::mediapipe::UnimplementedError( + return absl::UnimplementedError( "GetLatestDirectory is not implemented on Android"); #else std::vector saved_models; @@ -49,7 +49,7 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; << "No exported bundles found in " << path; ::std::sort(saved_models.begin(), saved_models.end()); *path = std::string(file::Dirname(saved_models.back())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); #endif } @@ -77,13 +77,13 @@ const std::string MaybeConvertSignatureToTag( } // namespace // TensorFlowSessionFromSavedModelGenerator is a MediaPipe packet generator -// that loads a trained TensorFlow model exported via SavedModel's exporter (see -// go/savedmodel) and returns a Packet containing a unique_ptr to a -// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session -// ready for execution and a map between tags and tensor names. +// that loads a trained TensorFlow model exported via SavedModel's exporter and +// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession, +// which in turn contains a TensorFlow Session ready for execution and a map +// between tags and tensor names. class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { public: - static ::mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { const TensorFlowSessionFromSavedModelGeneratorOptions& options = @@ -101,12 +101,12 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { } // A TensorFlow model loaded and ready for use along with tensor output_side_packets->Tag("SESSION").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - static ::mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { const TensorFlowSessionFromSavedModelGeneratorOptions& options = extendable_options.GetExtension( TensorFlowSessionFromSavedModelGeneratorOptions::ext); @@ -135,9 +135,8 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); if (!status.ok()) { - return ::mediapipe::Status( - static_cast<::mediapipe::StatusCode>(status.code()), - status.ToString()); + return absl::Status(static_cast(status.code()), + status.ToString()); } auto session = absl::make_unique(); session->session = std::move(saved_model->session); @@ -155,7 +154,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { } output_side_packets->Tag("SESSION") = Adopt(session.release()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(TensorFlowSessionFromSavedModelGenerator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index 88ce93435..d24a1cd73 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -28,7 +28,7 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // SavedModels, include a flag to load the most recent model. // Path to a directory containing a trained TensorFlow model as prepared - // by SavedModel (go/saved-model). + // by SavedModel. optional string saved_model_path = 1; // The name of the generic signature to load into the mapping from tags to // tensor names. diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index 792c3841b..92d0d5de4 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -66,10 +67,10 @@ class TensorFlowSessionFromSavedModelGeneratorTest : public ::testing::Test { TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, CreatesPacketWithGraphAndBindings) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -105,13 +106,12 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, CreateSessionFromSidePacket) { generator_options_->clear_saved_model_path(); PacketSet input_side_packets( - tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}) - .ValueOrDie()); + tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value()); input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = Adopt(new std::string(GetSavedModelDir())); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -126,7 +126,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, ProducesPacketUsableByTensorFlowInferenceCalculator) { CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( node { calculator: "TensorFlowInferenceCalculator" @@ -159,7 +159,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( @@ -184,10 +184,10 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, std::string(file::SplitPath(GetSavedModelDir()).first)); generator_options_->set_load_latest_model(true); - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -205,10 +205,10 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, generator_options_->mutable_session_config()->mutable_device_count()->insert( {"CPU", 10}); - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); diff --git a/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc b/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc index f3b0b485d..28271f3a7 100644 --- a/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc +++ b/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc @@ -49,14 +49,13 @@ const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; // } class TFRecordReaderCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; -::mediapipe::Status TFRecordReaderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TFRecordReaderCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag(kTFRecordPath).Set(); if (cc->InputSidePackets().HasTag(kRecordIndex)) { cc->InputSidePackets().Tag(kRecordIndex).Set(); @@ -73,10 +72,10 @@ class TFRecordReaderCalculator : public CalculatorBase { .Tag(kSequenceExampleTag) .Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) { +absl::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) { std::unique_ptr file; auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile( cc->InputSidePackets().Tag(kTFRecordPath).Get(), &file); @@ -114,11 +113,11 @@ class TFRecordReaderCalculator : public CalculatorBase { ++current_idx; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) { - return ::mediapipe::OkStatus(); +absl::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); } REGISTER_CALCULATOR(TFRecordReaderCalculator); diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 86a2a4afa..1f4cda359 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -42,7 +42,7 @@ const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE"; const char kAudioDecoderOptions[] = "AUDIO_DECODER_OPTIONS"; namespace tf = ::tensorflow; -namespace mpms = ::mediapipe::mediasequence; +namespace mpms = mediapipe::mediasequence; // Source calculator to unpack side_packets and streams from tf.SequenceExamples // @@ -84,7 +84,7 @@ namespace mpms = ::mediapipe::mediasequence; // node { // calculator: "UnpackMediaSequenceCalculator" // input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet" -// input_side_packet: "ROOT_DIRECTORY:path_to_dataset_root_directory" +// input_side_packet: "DATASET_ROOT:path_to_dataset_root_directory" // output_side_packet: "DATA_PATH:full_path_to_data_element" // output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" // options { @@ -118,7 +118,7 @@ namespace mpms = ::mediapipe::mediasequence; // } class UnpackMediaSequenceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); @@ -183,10 +183,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { cc->Outputs().Tag(tag).Set>(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { // Copy the packet to copy the otherwise inaccessible shared ptr. example_packet_holder_ = cc->InputSidePackets().Tag(kSequenceExampleTag); sequence_ = &example_packet_holder_.Get(); @@ -335,10 +335,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { .Set(MakePacket(mpms::GetImageFrameRate(sequence))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (timestamps_.empty()) { // This occurs when we only have metadata to unpack. LOG(INFO) << "only unpacking metadata because there are no timestamps."; @@ -435,7 +435,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { ++current_timestamp_index_; if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return tool::StatusStop(); } diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index 185e2e186..dcbda224e 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -31,7 +31,7 @@ namespace mediapipe { namespace { namespace tf = ::tensorflow; -namespace mpms = ::mediapipe::mediasequence; +namespace mpms = mediapipe::mediasequence; class UnpackMediaSequenceCalculatorTest : public ::testing::Test { protected: diff --git a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc index daf7f1117..efb3037f8 100644 --- a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc @@ -64,7 +64,7 @@ std::string GetQuantizedFeature( // } class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets() .Tag(kYt8mSequenceExample) .Set(); @@ -84,10 +84,10 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { if (cc->OutputSidePackets().HasTag(kSegmentSize)) { cc->OutputSidePackets().Tag(kSegmentSize).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const tensorflow::SequenceExample& sequence_example = cc->InputSidePackets() .Tag(kYt8mSequenceExample) @@ -108,7 +108,7 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { .feature_size(); if (rgb_feature_list_length != audio_feature_list_length) { - return ::mediapipe::FailedPreconditionError(absl::StrCat( + return absl::FailedPreconditionError(absl::StrCat( "Data corruption: the length of audio features and rgb features are " "not equal. Please check the sequence example that contains yt8m " "id: ", @@ -151,12 +151,12 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { } LOG(INFO) << "Reading the sequence example that contains yt8m id: " << yt8m_id << ". Feature list length: " << feature_list_length_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (current_index_ >= feature_list_length_) { - return ::mediapipe::tool::StatusStop(); + return mediapipe::tool::StatusStop(); } const tensorflow::SequenceExample& sequence_example = cc->InputSidePackets() @@ -179,7 +179,7 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { GetQuantizedFeature(sequence_example, kAudio, current_index_)) .At(timestamp)); ++current_index_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc index f7c041788..96208b3e5 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc @@ -44,17 +44,17 @@ namespace tf = ::tensorflow; // } class VectorFloatToTensorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: VectorFloatToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(VectorFloatToTensorCalculator); -::mediapipe::Status VectorFloatToTensorCalculator::GetContract( +absl::Status VectorFloatToTensorCalculator::GetContract( CalculatorContract* cc) { const auto& options = cc->Options(); // Start with only one input packet. @@ -75,16 +75,16 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator); cc->Outputs().Index(0).Set( // Output stream with data as tf::Tensor and the same TimeSeriesHeader. ); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VectorFloatToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status VectorFloatToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); - return ::mediapipe::OkStatus(); + cc->SetOffset(0); + return absl::OkStatus(); } -::mediapipe::Status VectorFloatToTensorCalculator::Process( - CalculatorContext* cc) { +absl::Status VectorFloatToTensorCalculator::Process(CalculatorContext* cc) { tf::TensorShape tensor_shape; if (options_.input_size() == INPUT_2D) { const std::vector>& input = @@ -127,7 +127,7 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator); } else { LOG(FATAL) << "input size not supported"; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc index 1269e2761..f5bf7661e 100644 --- a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc @@ -62,18 +62,17 @@ void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) { // } class VectorIntToTensorCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: VectorIntToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(VectorIntToTensorCalculator); -::mediapipe::Status VectorIntToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VectorIntToTensorCalculator::GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -92,20 +91,19 @@ REGISTER_CALCULATOR(VectorIntToTensorCalculator); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Only one output stream is supported."; cc->Outputs().Tag(kTensorOut).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 || options_.tensor_data_type() == tf::DT_INT32 || options_.tensor_data_type() == tf::DT_INT64) << "Output tensor data type is not supported."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VectorIntToTensorCalculator::Process( - CalculatorContext* cc) { +absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { tf::TensorShape tensor_shape; if (options_.input_size() == INPUT_2D) { const std::vector>& input = @@ -197,7 +195,7 @@ REGISTER_CALCULATOR(VectorIntToTensorCalculator); } else { LOG(FATAL) << "input size not supported"; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc new file mode 100644 index 000000000..0e579009b --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc @@ -0,0 +1,137 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Converts vector (or vector>) to 1D (or 2D) +// tf::Tensor. + +#include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace { +auto& INPUT_1D = VectorStringToTensorCalculatorOptions::INPUT_1D; +auto& INPUT_2D = VectorStringToTensorCalculatorOptions::INPUT_2D; +} // namespace + +namespace tf = ::tensorflow; + +// The calculator expects one input (a packet containing a vector +// or vector>) and generates one output (a packet containing +// a tf::Tensor containing the same data). The output tensor will be either 1D +// or 2D with dimensions corresponding to the input vector std::string. It will +// hold DT_STRING values. +// +// Example config: +// node { +// calculator: "VectorStringToTensorCalculator" +// input_stream: "vector_string_features" +// output_stream: "tensor_features" +// } +class VectorStringToTensorCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + VectorStringToTensorCalculatorOptions options_; +}; +REGISTER_CALCULATOR(VectorStringToTensorCalculator); + +absl::Status VectorStringToTensorCalculator::GetContract( + CalculatorContract* cc) { + const auto& options = cc->Options(); + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + if (options.input_size() == INPUT_2D) { + cc->Inputs().Index(0).Set>>( + /* "Input vector>." */); + } else if (options.input_size() == INPUT_1D) { + cc->Inputs().Index(0).Set>( + // Input vector. + ); + } else { + LOG(FATAL) << "input size not supported"; + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output stream with data as tf::Tensor and the same TimeSeriesHeader. + ); + return absl::OkStatus(); +} + +absl::Status VectorStringToTensorCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + cc->SetOffset(0); + return absl::OkStatus(); +} + +absl::Status VectorStringToTensorCalculator::Process(CalculatorContext* cc) { + tf::TensorShape tensor_shape; + if (options_.input_size() == INPUT_2D) { + const std::vector>& input = + cc->Inputs() + .Index(0) + .Value() + .Get>>(); + + const int32 rows = input.size(); + RET_CHECK_GE(rows, 1); + const int32 cols = input[0].size(); + RET_CHECK_GE(cols, 1); + for (int i = 1; i < rows; ++i) { + RET_CHECK_EQ(input[i].size(), cols); + } + if (options_.transpose()) { + tensor_shape = tf::TensorShape({cols, rows}); + } else { + tensor_shape = tf::TensorShape({rows, cols}); + } + auto output = ::absl::make_unique(tf::DT_STRING, tensor_shape); + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + if (options_.transpose()) { + output->tensor()(c, r) = input[r][c]; + } else { + output->tensor()(r, c) = input[r][c]; + } + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else if (options_.input_size() == INPUT_1D) { + const std::vector& input = + cc->Inputs().Index(0).Value().Get>(); + RET_CHECK_GE(input.size(), 1); + const int32 length = input.size(); + tensor_shape = tf::TensorShape({length}); + auto output = ::absl::make_unique(tf::DT_STRING, tensor_shape); + for (int i = 0; i < length; ++i) { + output->tensor()(i) = input.at(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + LOG(FATAL) << "input size not supported"; + } + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto new file mode 100644 index 000000000..908d98dff --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message VectorStringToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional VectorStringToTensorCalculatorOptions ext = 357221188; + } + enum InputSize { + UNKNOWN = 0; + INPUT_1D = 1; + INPUT_2D = 2; + } + + // If input_size is INPUT_2D, unpack a vector> to a + // 2d tensor (matrix). If INPUT_1D, + // convert a vector into a 1d tensor (vector). + optional InputSize input_size = 1 [default = INPUT_1D]; + + // If true, the output tensor is transposed. + // Otherwise, the output tensor is not transposed. + // It will be ignored if input_size is INPUT_1D. + optional bool transpose = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc new file mode 100644 index 000000000..5921bd1b0 --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc @@ -0,0 +1,120 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class VectorStringToTensorCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner( + const VectorStringToTensorCalculatorOptions::InputSize input_size, + const bool transpose) { + CalculatorGraphConfig::Node config; + config.set_calculator("VectorStringToTensorCalculator"); + config.add_input_stream("input_string"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + VectorStringToTensorCalculatorOptions::ext); + options->set_input_size(input_size); + options->set_transpose(transpose); + runner_ = ::absl::make_unique(config); + } + + void TestConvertFromVectoVectorString(const bool transpose) { + SetUpRunner(VectorStringToTensorCalculatorOptions::INPUT_2D, transpose); + auto input = ::absl::make_unique>>( + 2, std::vector(2)); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + input->at(i).at(j) = absl::StrCat(i, j); + } + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(2, output_tensor.dims()); + EXPECT_EQ(tf::DT_STRING, output_tensor.dtype()); + const auto matrix = output_tensor.matrix(); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + if (!transpose) { + EXPECT_EQ(absl::StrCat(i, j), matrix(i, j)); + } else { + EXPECT_EQ(absl::StrCat(j, i), matrix(i, j)); + } + } + } + } + + std::unique_ptr runner_; +}; + +TEST_F(VectorStringToTensorCalculatorTest, ConvertsFromVectorString) { + SetUpRunner(VectorStringToTensorCalculatorOptions::INPUT_1D, false); + auto input = ::absl::make_unique>(5); + for (int i = 0; i < 5; ++i) { + input->at(i) = absl::StrCat(i); + } + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(1, output_tensor.dims()); + EXPECT_EQ(tf::DT_STRING, output_tensor.dtype()); + const auto vec = output_tensor.vec(); + + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(absl::StrCat(i), vec(i)); + } +} + +TEST_F(VectorStringToTensorCalculatorTest, ConvertsFromVectorVectorString) { + for (bool transpose : {false, true}) { + TestConvertFromVectoVectorString(transpose); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 4798eef26..3063b4fa2 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -13,131 +13,91 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -proto_library( +mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "ssd_anchors_calculator_cc_proto", - srcs = ["ssd_anchors_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":ssd_anchors_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_custom_op_resolver_calculator_cc_proto", - srcs = ["tflite_custom_op_resolver_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_custom_op_resolver_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_converter_calculator_cc_proto", - srcs = ["tflite_converter_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_converter_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_tensors_to_segmentation_calculator_cc_proto", - srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_tensors_to_segmentation_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_inference_calculator_cc_proto", - srcs = ["tflite_inference_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_inference_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_tensors_to_detections_calculator_cc_proto", - srcs = ["tflite_tensors_to_detections_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_tensors_to_detections_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_tensors_to_classification_calculator_cc_proto", - srcs = ["tflite_tensors_to_classification_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_tensors_to_classification_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tflite_tensors_to_landmarks_calculator_cc_proto", - srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":tflite_tensors_to_landmarks_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( @@ -183,16 +143,15 @@ cc_test( data = [":anchor_golden_files"], deps = [ ":ssd_anchors_calculator", - ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/tool:validate_type", ], ) @@ -232,8 +191,8 @@ cc_library( ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", "//mediapipe/framework:calculator_framework", - "//mediapipe/util:resource_util", "//mediapipe/util/tflite:config", + "//mediapipe/util/tflite:tflite_model_loader", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -453,6 +412,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/formats:classification_cc_proto", @@ -508,7 +468,10 @@ cc_library( # bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local cc_test( name = "tflite_inference_calculator_test", - srcs = ["tflite_inference_calculator_test.cc"], + srcs = [ + "tflite_inference_calculator_test.cc", + "tflite_inference_calculator_test_common.h", + ], data = ["testdata/add.bin"], linkstatic = 1, deps = [ @@ -528,7 +491,9 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", ], ) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc index 90d35573e..f618b2f6a 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -71,12 +71,12 @@ float CalculateScale(float min_scale, float max_scale, int stride_index, // } class SsdAnchorsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets().Index(0).Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const SsdAnchorsCalculatorOptions& options = @@ -85,24 +85,24 @@ class SsdAnchorsCalculator : public CalculatorBase { auto anchors = absl::make_unique>(); MP_RETURN_IF_ERROR(GenerateAnchors(anchors.get(), options)); cc->OutputSidePackets().Index(0).Set(Adopt(anchors.release())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } private: - static ::mediapipe::Status GenerateAnchors( + static absl::Status GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options); }; REGISTER_CALCULATOR(SsdAnchorsCalculator); -::mediapipe::Status SsdAnchorsCalculator::GenerateAnchors( +absl::Status SsdAnchorsCalculator::GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options) { // Verify the options. if (!options.feature_map_height_size() && !options.strides_size()) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both feature map shape and strides are missing. Must provide either " "one."); } @@ -206,7 +206,7 @@ REGISTER_CALCULATOR(SsdAnchorsCalculator); } layer_id = last_same_stride_layer; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc index 7a5b555db..906eeed21 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/calculators/tflite/testdata/README.md b/mediapipe/calculators/tflite/testdata/README.md new file mode 100644 index 000000000..c0efdcf07 --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/README.md @@ -0,0 +1,2 @@ +The model files add.bin, add_quantized.bin +(and corresponding metatada json files) come from tensorflow/lite/testdata/ diff --git a/mediapipe/calculators/tflite/testdata/add_quantized.bin b/mediapipe/calculators/tflite/testdata/add_quantized.bin new file mode 100644 index 000000000..07d48b93e Binary files /dev/null and b/mediapipe/calculators/tflite/testdata/add_quantized.bin differ diff --git a/mediapipe/calculators/tflite/testdata/add_quantized.json b/mediapipe/calculators/tflite/testdata/add_quantized.json new file mode 100644 index 000000000..f70ed8143 --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/add_quantized.json @@ -0,0 +1,123 @@ +{ + version: 3, + operator_codes: [ + { + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "add", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "input", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "output", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + } + ], + inputs: [ + 1 + ], + outputs: [ + 2 + ], + operators: [ + { + inputs: [ + 1, + 1 + ], + outputs: [ + 0 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + }, + { + inputs: [ + 0, + 1 + ], + outputs: [ + 2 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + } + ] + } + ], + buffers: [ + { + data: [ + + ] + } + ] +} diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index 75e4c0b7b..202d9ed84 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -26,9 +26,9 @@ #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" -#endif // MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" @@ -135,22 +135,21 @@ struct GPUData { // class TfLiteConverterCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status InitGpu(CalculatorContext* cc); - ::mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); template - ::mediapipe::Status NormalizeImage(const ImageFrame& image_frame, - bool flip_vertically, float* tensor_ptr); - ::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, - float* tensor_ptr); - ::mediapipe::Status ProcessCPU(CalculatorContext* cc); - ::mediapipe::Status ProcessGPU(CalculatorContext* cc); + absl::Status NormalizeImage(const ImageFrame& image_frame, + bool flip_vertically, float* tensor_ptr); + absl::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); + absl::Status ProcessCPU(CalculatorContext* cc); + absl::Status ProcessGPU(CalculatorContext* cc); std::unique_ptr interpreter_ = nullptr; @@ -184,8 +183,7 @@ bool ShouldUseGpu(CC* cc) { } } // namespace -::mediapipe::Status TfLiteConverterCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TfLiteConverterCalculator::GetContract(CalculatorContract* cc) { // Confirm only one of the input streams is present. RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ cc->Inputs().HasTag(kGpuBufferTag) ^ @@ -201,11 +199,11 @@ bool ShouldUseGpu(CC* cc) { if (cc->Inputs().HasTag(kMatrixTag)) { cc->Inputs().Tag(kMatrixTag).Set(); } -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); } -#endif // MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kTensorsTag)) { cc->Outputs().Tag(kTensorsTag).Set>(); @@ -225,10 +223,10 @@ bool ShouldUseGpu(CC* cc) { // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -253,13 +251,13 @@ bool ShouldUseGpu(CC* cc) { interpreter_->SetInputs({0}); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); @@ -271,24 +269,23 @@ bool ShouldUseGpu(CC* cc) { // Convert to CPU tensors or Matrix type. MP_RETURN_IF_ERROR(ProcessCPU(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { interpreter_.reset(); #if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_out_.reset(); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::ProcessCPU( - CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // CPU ImageFrame to TfLiteTensor conversion. @@ -364,7 +361,7 @@ bool ShouldUseGpu(CC* cc) { MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, tensor_buffer)); } else { - return ::mediapipe::InternalError( + return absl::InternalError( "Only byte-based (8 bit) and float (32 bit) images supported."); } } @@ -377,7 +374,7 @@ bool ShouldUseGpu(CC* cc) { .Add(output_tensors.release(), cc->InputTimestamp()); } else if (cc->Inputs().HasTag(kMatrixTag)) { if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // CPU Matrix to TfLiteTensor conversion. const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get(); @@ -410,17 +407,16 @@ bool ShouldUseGpu(CC* cc) { .Add(output_tensors.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( - CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE // GpuBuffer to tflite::gpu::GlBuffer conversion. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status { + gpu_helper_.RunInGlContext([this, &input]() -> absl::Status { // Convert GL texture into TfLite GlBuffer (SSBO). auto src = gpu_helper_.CreateSourceTexture(input); glActiveTexture(GL_TEXTURE0 + 0); @@ -433,13 +429,13 @@ bool ShouldUseGpu(CC* cc) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindTexture(GL_TEXTURE_2D, 0); src.Release(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); // Copy into outputs. auto output_tensors = absl::make_unique>(); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors]() -> ::mediapipe::Status { + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &output_tensors]() -> absl::Status { output_tensors->resize(1); { // Thuan (2020-04-14: Fix bug output video not stable) @@ -449,7 +445,7 @@ bool ShouldUseGpu(CC* cc) { gpu_data_out_->elements, &tensor)); MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_->buffer, tensor)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); cc->Outputs() .Tag(kTensorsGpuTag) @@ -495,10 +491,10 @@ bool ShouldUseGpu(CC* cc) { RET_CHECK_FAIL() << "GPU processing is not enabled."; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GPU_SUPPORTED // Get input image sizes. const auto& input = @@ -519,7 +515,7 @@ bool ShouldUseGpu(CC* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { + [this, &include_alpha, &input, &single_channel]() -> absl::Status { // Device memory. MP_RETURN_IF_ERROR( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( @@ -565,7 +561,7 @@ bool ShouldUseGpu(CC* cc) { GL_COMPUTE_SHADER, shader_source, &gpu_data_out_->shader)); MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( gpu_data_out_->shader, &gpu_data_out_->program)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -632,11 +628,10 @@ bool ShouldUseGpu(CC* cc) { << [[error localizedDescription] UTF8String]; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. const auto& options = cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>(); @@ -684,11 +679,11 @@ bool ShouldUseGpu(CC* cc) { // Get tensor type, float or quantized. use_quantized_tensors_ = options.use_quantized_tensors(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template -::mediapipe::Status TfLiteConverterCalculator::NormalizeImage( +absl::Status TfLiteConverterCalculator::NormalizeImage( const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) { const int height = image_frame.Height(); const int width = image_frame.Width(); @@ -732,11 +727,11 @@ template } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor( - const Matrix& matrix, float* tensor_ptr) { +absl::Status TfLiteConverterCalculator::CopyMatrixToTensor(const Matrix& matrix, + float* tensor_ptr) { if (row_major_matrix_) { auto matrix_map = Eigen::Map(tensor_ptr, matrix.rows(), matrix.cols()); @@ -747,7 +742,7 @@ template matrix_map = matrix; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc index 69793c199..c99438f4e 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc @@ -86,7 +86,7 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) { // Run the calculator and verify that one output is generated. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( input_stream: "matrix" node { calculator: "TfLiteConverterCalculator" @@ -147,7 +147,7 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) { // Run the calculator and verify that one output is generated. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( input_stream: "matrix" node { calculator: "TfLiteConverterCalculator" @@ -205,7 +205,7 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) { CalculatorGraph graph; // Run the calculator and verify that one output is generated. CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie(R"( + mediapipe::ParseTextProtoOrDie(R"( input_stream: "input_image" node { calculator: "TfLiteConverterCalculator" @@ -228,7 +228,7 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) { MP_ASSERT_OK(graph.Initialize(graph_config)); MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); - cv::Mat mat = ::mediapipe::formats::MatView(input_image.get()); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); @@ -258,7 +258,7 @@ TEST_F(TfLiteConverterCalculatorTest, SetOutputRange) { for (std::pair range : range_values) { CalculatorGraph graph; CalculatorGraphConfig graph_config = - ::mediapipe::ParseTextProtoOrDie( + mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( input_stream: "input_image" node { @@ -284,7 +284,7 @@ TEST_F(TfLiteConverterCalculatorTest, SetOutputRange) { MP_ASSERT_OK(graph.Initialize(graph_config)); MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); - cv::Mat mat = ::mediapipe::formats::MatView(input_image.get()); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc index bce1b6076..11e27dff1 100644 --- a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -39,14 +39,14 @@ namespace mediapipe { // } class TfLiteCustomOpResolverCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets() .Index(0) .Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const TfLiteCustomOpResolverCalculatorOptions& options = @@ -54,17 +54,17 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { std::unique_ptr op_resolver; if (options.use_gpu()) { - op_resolver = absl::make_unique<::mediapipe::OpResolver>(); + op_resolver = absl::make_unique(); } else { - op_resolver = absl::make_unique<::mediapipe::CpuOpResolver>(); + op_resolver = absl::make_unique(); } cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TfLiteCustomOpResolverCalculator); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 31078c053..2bd18d5e6 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -28,7 +28,7 @@ #include "mediapipe/util/cpu_util.h" #endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ -#include "mediapipe/util/resource_util.h" +#include "mediapipe/util/tflite/tflite_model_loader.h" #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -215,37 +215,33 @@ class TfLiteInferenceCalculator : public CalculatorBase { public: using TfLiteDelegatePtr = std::unique_ptr>; - using TfLiteModelPtr = - std::unique_ptr>; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status ReadKernelsFromFile(); - ::mediapipe::Status WriteKernelsToFile(); - ::mediapipe::Status LoadModel(CalculatorContext* cc); - ::mediapipe::StatusOr GetModelAsPacket(const CalculatorContext& cc); - ::mediapipe::Status LoadDelegate(CalculatorContext* cc); - ::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); - ::mediapipe::Status ProcessInputsCpu( - CalculatorContext* cc, std::vector* output_tensors_cpu); - ::mediapipe::Status ProcessOutputsCpu( + absl::Status ReadKernelsFromFile(); + absl::Status WriteKernelsToFile(); + absl::Status LoadModel(CalculatorContext* cc); + absl::StatusOr GetModelAsPacket(const CalculatorContext& cc); + absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status InitTFLiteGPURunner(CalculatorContext* cc); + absl::Status ProcessInputsCpu(CalculatorContext* cc, + std::vector* output_tensors_cpu); + absl::Status ProcessOutputsCpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu); - ::mediapipe::Status ProcessInputsGpu( - CalculatorContext* cc, std::vector* output_tensors_gpu); - ::mediapipe::Status ProcessOutputsGpu( + absl::Status ProcessInputsGpu(CalculatorContext* cc, + std::vector* output_tensors_gpu); + absl::Status ProcessOutputsGpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu, std::unique_ptr> output_tensors_gpu); - ::mediapipe::Status RunInContextIfNeeded( - std::function<::mediapipe::Status(void)> f) { + absl::Status RunInContextIfNeeded(std::function f) { if (gpu_inference_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE return gpu_helper_.RunInGlContext(std::move(f)); @@ -282,6 +278,9 @@ class TfLiteInferenceCalculator : public CalculatorBase { bool use_quantized_tensors_ = false; bool use_advanced_gpu_api_ = false; + bool allow_precision_loss_ = false; + mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::API + tflite_gpu_runner_api_; bool use_kernel_caching_ = false; std::string cached_kernel_filename_; @@ -306,8 +305,7 @@ bool ShouldUseGpu(CC* cc) { } } // namespace -::mediapipe::Status TfLiteInferenceCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^ cc->Inputs().HasTag(kTensorsGpuTag)); RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^ @@ -349,10 +347,10 @@ bool ShouldUseGpu(CC* cc) { // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -366,14 +364,17 @@ bool ShouldUseGpu(CC* cc) { options.has_delegate() && options.delegate().has_gpu() && options.delegate().gpu().use_advanced_gpu_api(); + allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = options.delegate().gpu().api(); - use_kernel_caching_ = - use_advanced_gpu_api_ && options.delegate().gpu().use_kernel_caching(); + use_kernel_caching_ = use_advanced_gpu_api_ && + options.delegate().gpu().has_cached_kernel_path(); if (use_kernel_caching_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) - cached_kernel_filename_ = - "/sdcard/" + mediapipe::File::Basename(options.model_path()) + ".ker"; + cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + mediapipe::File::Basename(options.model_path()) + + ".ker"; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID } @@ -389,29 +390,23 @@ bool ShouldUseGpu(CC* cc) { if (gpu_inference_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : LoadDelegate(cc); - })); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, + &cc]() -> absl::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); + })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); MP_RETURN_IF_ERROR(LoadDelegate(cc)); #endif } else { - // TODO: why only on these platforms? - // It seems that the XNNPACK delegate fails to load on Linux. -#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \ - defined(MEDIAPIPE_IOS) MP_RETURN_IF_ERROR(LoadDelegate(cc)); -#endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { - return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status { +absl::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { + return RunInContextIfNeeded([this, cc]() -> absl::Status { // 0. Declare outputs auto output_tensors_gpu = absl::make_unique>(); auto output_tensors_cpu = absl::make_unique>(); @@ -430,7 +425,20 @@ bool ShouldUseGpu(CC* cc) { } else { RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); } -#else +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE + // Metal delegate supports external command buffer only if all input and + // output buffers are on GPU. + if (gpu_inference_ && gpu_input_ && gpu_output_) { + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceCalculator"; + RET_CHECK( + TFLGpuDelegateSetCommandBuffer(delegate_.get(), command_buffer)); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + [command_buffer commit]; + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } +#else // MEDIAPIPE_TFLITE_GL_INFERENCE RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE @@ -442,11 +450,11 @@ bool ShouldUseGpu(CC* cc) { MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } -::mediapipe::Status TfLiteInferenceCalculator::WriteKernelsToFile() { +absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. @@ -457,13 +465,13 @@ bool ShouldUseGpu(CC* cc) { mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { MP_RETURN_IF_ERROR(WriteKernelsToFile()); - return RunInContextIfNeeded([this]() -> ::mediapipe::Status { + return RunInContextIfNeeded([this]() -> absl::Status { if (delegate_) { interpreter_ = nullptr; delegate_ = nullptr; @@ -481,16 +489,16 @@ bool ShouldUseGpu(CC* cc) { #if defined(MEDIAPIPE_EDGE_TPU) edgetpu_context_.reset(); #endif - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } // Calculator Auxiliary Section -::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu( +absl::Status TfLiteInferenceCalculator::ProcessInputsCpu( CalculatorContext* cc, std::vector* output_tensors_cpu) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Read CPU input into tensors. const auto& input_tensors = @@ -512,13 +520,13 @@ bool ShouldUseGpu(CC* cc) { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu( +absl::Status TfLiteInferenceCalculator::ProcessInputsGpu( CalculatorContext* cc, std::vector* output_tensors_gpu) { if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_advanced_gpu_api_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -584,10 +592,10 @@ bool ShouldUseGpu(CC* cc) { #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu( +absl::Status TfLiteInferenceCalculator::ProcessOutputsCpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu) { // Output result tensors (CPU). @@ -604,10 +612,10 @@ bool ShouldUseGpu(CC* cc) { .Tag(kTensorsTag) .Add(output_tensors_cpu.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu( +absl::Status TfLiteInferenceCalculator::ProcessOutputsGpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu, std::unique_ptr> output_tensors_gpu) { @@ -676,10 +684,10 @@ bool ShouldUseGpu(CC* cc) { #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::ReadKernelsFromFile() { +absl::Status TfLiteInferenceCalculator::ReadKernelsFromFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Load pre-compiled kernel file. @@ -692,10 +700,10 @@ bool ShouldUseGpu(CC* cc) { } } #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( +absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); @@ -709,11 +717,27 @@ bool ShouldUseGpu(CC* cc) { // Create runner tflite::gpu::InferenceOptions options; - options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY; + options.priority1 = allow_precision_loss_ + ? tflite::gpu::InferencePriority::MIN_LATENCY + : tflite::gpu::InferencePriority::MAX_PRECISION; options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO; options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; tflite_gpu_runner_ = std::make_unique(options); + switch (tflite_gpu_runner_api_) { + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: { + tflite_gpu_runner_->ForceOpenGL(); + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENCL: { + tflite_gpu_runner_->ForceOpenCL(); + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } + } MP_RETURN_IF_ERROR( tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); @@ -730,12 +754,11 @@ bool ShouldUseGpu(CC* cc) { quant.type = kTfLiteNoQuantization; quant.params = nullptr; for (int i = 0; i < num_outputs; ++i) { - auto shape = tflite_gpu_runner_->GetOutputShapes()[i]; + auto shape = tflite_gpu_runner_->GetTFLiteOutputShapes()[i]; const int tensor_idx = interpreter_->inputs()[i]; interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "", - {shape.c}, quant); - CHECK(interpreter_->ResizeInputTensor( - tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk); + shape, quant); + CHECK(interpreter_->ResizeInputTensor(tensor_idx, shape) == kTfLiteOk); } CHECK(interpreter_->AllocateTensors() == kTfLiteOk); } @@ -758,14 +781,13 @@ bool ShouldUseGpu(CC* cc) { MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::LoadModel( - CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { if (use_advanced_gpu_api_) { // Use InitTFLiteGPURunner for everything. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); @@ -804,39 +826,30 @@ bool ShouldUseGpu(CC* cc) { if (use_quantized_tensors_) gpu_inference_ = false; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( +absl::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( const CalculatorContext& cc) { const auto& options = cc.Options(); if (!options.model_path().empty()) { - std::string model_path = options.model_path(); - - ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path)); - - auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str()); - RET_CHECK(model) << "Failed to load model from path."; - return MakePacket(TfLiteModelPtr( - model.release(), [](tflite::FlatBufferModel* model) { delete model; })); + return TfLiteModelLoader::LoadFromPath(options.model_path()); } if (cc.InputSidePackets().HasTag("MODEL")) { return cc.InputSidePackets().Tag("MODEL"); } - return ::mediapipe::Status( - ::mediapipe::StatusCode::kNotFound, - "Must specify TFLite model as path or loaded model."); + return absl::Status(absl::StatusCode::kNotFound, + "Must specify TFLite model as path or loaded model."); } -::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( - CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { const auto& calculator_opts = cc->Options(); if (calculator_opts.has_delegate() && calculator_opts.delegate().has_tflite()) { // Default tflite inference requeqsted - no need to modify graph. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (!gpu_inference_) { @@ -855,30 +868,32 @@ bool ShouldUseGpu(CC* cc) { }); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } #endif // MEDIAPIPE_ANDROID #if defined(__EMSCRIPTEN__) - const bool xnnpack_requested = true; + const bool use_xnnpack = true; #else - const bool xnnpack_requested = calculator_opts.has_delegate() && - calculator_opts.delegate().has_xnnpack(); -#endif // __EMSCRIPTEN__ + const bool use_xnnpack = calculator_opts.has_delegate() && + calculator_opts.delegate().has_xnnpack(); +#endif // defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_EDGE_TPU) - if (xnnpack_requested) { + if (use_xnnpack) { TfLiteXNNPackDelegateOptions xnnpack_opts{}; xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); + return absl::OkStatus(); } #endif // !EDGETPU - // Return, no need for GPU delegate below. - return ::mediapipe::OkStatus(); + // Return and use default tflite infernece (on CPU). No need for GPU + // delegate below. + return absl::OkStatus(); } #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -949,7 +964,7 @@ bool ShouldUseGpu(CC* cc) { // Configure and create the delegate. TFLGpuDelegateOptions options; options.allow_precision_loss = true; - options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive; + options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive; if (!delegate_) delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); @@ -1044,7 +1059,7 @@ bool ShouldUseGpu(CC* cc) { gpu_data_out_[i]->shape.c = tensor->dims->data[3]; break; default: - return mediapipe::InternalError("Unsupported tensor shape."); + return absl::InternalError("Unsupported tensor shape."); } } // Create and bind output buffers. @@ -1064,13 +1079,13 @@ bool ShouldUseGpu(CC* cc) { isFloat16:true convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { - return mediapipe::InternalError( + return absl::InternalError( "Error initializating output buffer converter"); } } #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index bd83fea45..862de8b0b 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -49,9 +49,24 @@ message TfLiteInferenceCalculatorOptions { // delegate: { gpu { use_advanced_gpu_api: true } } optional bool use_advanced_gpu_api = 1 [default = false]; + // This option is valid for TFLite GPU delegate API2 only, + // Choose any of available APIs to force running inference using it. + enum API { + ANY = 0; + OPENGL = 1; + OPENCL = 2; + } + optional API api = 4 [default = ANY]; + + // This option is valid for TFLite GPU delegate API2 only, + // Set to true to use 16-bit float precision. If max precision is needed, + // set to false for 32-bit float calculations only. + optional bool allow_precision_loss = 3 [default = true]; + // Load pre-compiled serialized binary cache to accelerate init process. // Only available for OpenCL delegate on Android. - optional bool use_kernel_caching = 2 [default = false]; + // Kernel caching will only be enabled if this path is set. + optional string cached_kernel_path = 2; } // Android only. message Nnapi {} @@ -90,7 +105,10 @@ message TfLiteInferenceCalculatorOptions { optional int32 cpu_num_thread = 4 [default = -1]; // TfLite delegate to run inference. - // NOTE: calculator is free to choose delegate if not specified explicitly. + // If not specified, when any of the input and output is on GPU (i.e, using + // the TENSORS_GPU tag) TFLite GPU delegate is used (as if "gpu {}" is + // specified), or otherwise regular TFLite on CPU is used (as if "tflite {}" + // is specified) except when building with emscripten where xnnpack is used. // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes // precedence over use_* deprecated options.) optional Delegate delegate = 5; diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc index ac53c223d..ec16d1842 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc @@ -12,96 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include - #include "absl/strings/str_replace.h" -#include "absl/strings/string_view.h" -#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" -#include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/calculator_runner.h" -#include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status_matchers.h" // NOLINT -#include "mediapipe/framework/tool/validate_type.h" -#include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" - -#ifdef __APPLE__ -#include -#endif // defined(__APPLE__) +#include "mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h" namespace mediapipe { -using ::tflite::Interpreter; - -void DoSmokeTest(const std::string& graph_proto) { - const int width = 8; - const int height = 8; - const int channels = 3; - - // Prepare input tensor. - std::unique_ptr interpreter(new Interpreter); - ASSERT_NE(interpreter, nullptr); - - interpreter->AddTensors(1); - interpreter->SetInputs({0}); - interpreter->SetOutputs({0}); - interpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - TfLiteQuantization()); - int t = interpreter->inputs()[0]; - TfLiteTensor* tensor = interpreter->tensor(t); - interpreter->ResizeInputTensor(t, {width, height, channels}); - interpreter->AllocateTensors(); - - float* tensor_buffer = tensor->data.f; - ASSERT_NE(tensor_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - tensor_buffer[i] = 1; - } - - auto input_vec = absl::make_unique>(); - input_vec->emplace_back(*tensor); - - // Prepare single calculator graph to and wait for packets. - CalculatorGraphConfig graph_config = - ParseTextProtoOrDie(graph_proto); - std::vector output_packets; - tool::AddVectorSink("tensor_out", &graph_config, &output_packets); - CalculatorGraph graph(graph_config); - MP_ASSERT_OK(graph.StartRun({})); - - // Push the tensor into the graph. - MP_ASSERT_OK(graph.AddPacketToInputStream( - "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); - // Wait until the calculator done processing. - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, output_packets.size()); - - // Get and process results. - const std::vector & result_vec = - output_packets[0].Get>(); - ASSERT_EQ(1, result_vec.size()); - - const TfLiteTensor* result = &(result_vec[0].getTensor()); - float* result_buffer = result->data.f; - ASSERT_NE(result_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - ASSERT_EQ(3, result_buffer[i]); - } - - // Fully close graph at end, otherwise calculator+tensors are destroyed - // after calling WaitUntilDone(). - MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); - MP_ASSERT_OK(graph.WaitUntilDone()); -} - // Tests a simple add model that adds an input tensor to itself. TEST(TfLiteInferenceCalculatorTest, SmokeTest) { std::string graph_proto = R"( @@ -118,13 +33,12 @@ TEST(TfLiteInferenceCalculatorTest, SmokeTest) { } } )"; - DoSmokeTest( - /*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + // Test CPU inference only. + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { tflite {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + DoSmokeTest(absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { xnnpack {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + DoSmokeTest(absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); } @@ -163,11 +77,12 @@ TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) { options { [mediapipe.TfLiteInferenceCalculatorOptions.ext] { use_gpu: false + delegate { tflite {} } } } } )"; - DoSmokeTest(graph_proto); + DoSmokeTest(graph_proto); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h b/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h new file mode 100644 index 000000000..cf995f47b --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h @@ -0,0 +1,128 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ +#define MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ + +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +#ifdef __APPLE__ +#include +#endif // defined(__APPLE__) + +namespace mediapipe { + +using ::tflite::Interpreter; + +template +void DoSmokeTest(const std::string& graph_proto) { + const int width = 8; + const int height = 8; + const int channels = 3; + + static_assert(std::is_same_v || std::is_same_v, + "Only float & uint8 currently supported."); + + // Prepare interpreter and input tensor. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(interpreter, nullptr); + + interpreter->AddTensors(1); + interpreter->SetInputs({0}); + interpreter->SetOutputs({0}); + TfLiteQuantization quant; + if (std::is_integral_v) { + auto* affine_quant = static_cast( + malloc(sizeof(TfLiteAffineQuantization))); + affine_quant->scale = TfLiteFloatArrayCreate(1); + affine_quant->zero_point = TfLiteIntArrayCreate(1); + affine_quant->scale->data[0] = 1.0; + affine_quant->zero_point->data[0] = 0; + quant.type = kTfLiteAffineQuantization; + quant.params = affine_quant; + } + interpreter->SetTensorParametersReadWrite(0, tflite::typeToTfLiteType(), + "", {3}, quant); + + int t = interpreter->inputs()[0]; + TfLiteTensor* input_tensor = interpreter->tensor(t); + interpreter->ResizeInputTensor(t, {width, height, channels}); + interpreter->AllocateTensors(); + + T* input_tensor_buffer = tflite::GetTensorData(input_tensor); + ASSERT_NE(input_tensor_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + input_tensor_buffer[i] = 1; + } + + auto input_vec = absl::make_unique>(); + input_vec->emplace_back(*input_tensor); + + // Prepare single calculator graph to and wait for packets. + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(graph_proto); + std::vector output_packets; + tool::AddVectorSink("tensor_out", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + // Push the tensor into the graph. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& result_vec = + output_packets[0].Get>(); + ASSERT_EQ(1, result_vec.size()); + + const TfLiteTensor* result = &result_vec[0]; + const T* result_buffer = tflite::GetTensorData(result); + ASSERT_NE(result_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + ASSERT_EQ(3, result_buffer[i]); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc b/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc new file mode 100644 index 000000000..eac0d361c --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc @@ -0,0 +1,42 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_replace.h" +#include "mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h" + +namespace mediapipe { + +// Tests a simple add model that adds an input tensor to itself. +TEST(TfLiteInferenceCalculatorTpuTest, SmokeTest) { + std::string graph_proto = R"( + input_stream: "tensor_in" + node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:tensor_in" + output_stream: "TENSORS:tensor_out" + options { + [mediapipe.TfLiteInferenceCalculatorOptions.ext] { + model_path: "mediapipe/calculators/tflite/testdata/add_quantized.bin" + $delegate + } + } + } + )"; + DoSmokeTest( + /*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}})); + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + graph_proto, {{"$delegate", "delegate { tflite {} }"}})); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index d24c55b14..ca28910e5 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -51,13 +51,13 @@ class TfLiteModelCalculator : public CalculatorBase { std::unique_ptr>; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("MODEL_BLOB").Set(); cc->OutputSidePackets().Tag("MODEL").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); const std::string& model_blob = model_packet.Get(); std::unique_ptr model = @@ -74,11 +74,11 @@ class TfLiteModelCalculator : public CalculatorBase { delete model; }))); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TfLiteModelCalculator); diff --git a/mediapipe/calculators/tflite/tflite_model_calculator_test.cc b/mediapipe/calculators/tflite/tflite_model_calculator_test.cc index fed3743a5..a76d322ee 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator_test.cc @@ -58,7 +58,7 @@ TEST(TfLiteModelCalculatorTest, SmokeTest) { MP_ASSERT_OK(graph.WaitUntilIdle()); auto status_or_packet = graph.GetOutputSidePacket("model"); MP_ASSERT_OK(status_or_packet); - auto model_packet = status_or_packet.ValueOrDie(); + auto model_packet = status_or_packet.value(); const auto& model = model_packet.Get< std::unique_ptr>>(); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc index e9c09169b..4d28b91e9 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/container/node_hash_map.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.pb.h" @@ -59,21 +60,21 @@ namespace mediapipe { // } class TfLiteTensorsToClassificationCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: ::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_; int top_k_ = 0; - std::unordered_map label_map_; + absl::node_hash_map label_map_; bool label_map_loaded_ = false; }; REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator); -::mediapipe::Status TfLiteTensorsToClassificationCalculator::GetContract( +absl::Status TfLiteTensorsToClassificationCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -86,10 +87,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator); cc->Outputs().Tag("CLASSIFICATIONS").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToClassificationCalculator::Open( +absl::Status TfLiteTensorsToClassificationCalculator::Open( CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); @@ -113,10 +114,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator); label_map_loaded_ = true; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToClassificationCalculator::Process( +absl::Status TfLiteTensorsToClassificationCalculator::Process( CalculatorContext* cc) { const auto& input_tensors = cc->Inputs().Tag("TENSORS").Get>(); @@ -189,12 +190,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator); .Tag("CLASSIFICATIONS") .Add(classification_list.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToClassificationCalculator::Close( +absl::Status TfLiteTensorsToClassificationCalculator::Close( CalculatorContext* cc) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator_test.cc index a7290b112..ab66d3077 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator_test.cc @@ -27,7 +27,7 @@ namespace mediapipe { -using ::mediapipe::ParseTextProtoOrDie; +using mediapipe::ParseTextProtoOrDie; using ::tflite::Interpreter; using Node = ::mediapipe::CalculatorGraphConfig::Node; diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 7d02e8c00..f29a9524a 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -147,26 +147,27 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, // } class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status ProcessCPU(CalculatorContext* cc, - std::vector* output_detections); - ::mediapipe::Status ProcessGPU(CalculatorContext* cc, - std::vector* output_detections); + absl::Status ProcessCPU(CalculatorContext* cc, + std::vector* output_detections); + absl::Status ProcessGPU(CalculatorContext* cc, + std::vector* output_detections); - ::mediapipe::Status LoadOptions(CalculatorContext* cc); - ::mediapipe::Status GpuInit(CalculatorContext* cc); - ::mediapipe::Status DecodeBoxes(const float* raw_boxes, - const std::vector& anchors, - std::vector* boxes); - ::mediapipe::Status ConvertToDetections( - const float* detection_boxes, const float* detection_scores, - const int* detection_classes, std::vector* output_detections); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status GpuInit(CalculatorContext* cc); + absl::Status DecodeBoxes(const float* raw_boxes, + const std::vector& anchors, + std::vector* boxes); + absl::Status ConvertToDetections(const float* detection_boxes, + const float* detection_scores, + const int* detection_classes, + std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically); @@ -193,7 +194,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract( +absl::Status TfLiteTensorsToDetectionsCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -227,11 +228,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToDetectionsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kTensorsGpuTag)) { @@ -251,14 +251,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); MP_RETURN_IF_ERROR(GpuInit(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Process( +absl::Status TfLiteTensorsToDetectionsCalculator::Process( CalculatorContext* cc) { if ((!gpu_input_ && cc->Inputs().Tag(kTensorsTag).IsEmpty()) || (gpu_input_ && cc->Inputs().Tag(kTensorsGpuTag).IsEmpty())) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_detections = absl::make_unique>(); @@ -276,10 +276,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); .Add(output_detections.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( +absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( CalculatorContext* cc, std::vector* output_detections) { const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get>(); @@ -324,7 +324,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); anchors_ = cc->InputSidePackets().Tag("ANCHORS").Get>(); } else { - return ::mediapipe::UnavailableError("No anchor data available."); + return absl::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -401,9 +401,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); detection_classes.data(), output_detections)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( +absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { #if MEDIAPIPE_TFLITE_GL_INFERENCE const auto& input_tensors = @@ -412,7 +412,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, &output_detections]() - -> ::mediapipe::Status { + -> absl::Status { // Copy inputs. MP_RETURN_IF_ERROR( CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); @@ -469,7 +469,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); ConvertToDetections(boxes.data(), detection_scores.data(), detection_classes.data(), output_detections)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -554,21 +554,20 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); #else LOG(ERROR) << "GPU input on non-Android not supported yet."; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToDetectionsCalculator::Close(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_.reset(); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( +absl::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = @@ -590,10 +589,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); ignore_classes_.insert(options_.ignore_classes(i)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( +absl::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( const float* raw_boxes, const std::vector& anchors, std::vector* boxes) { for (int i = 0; i < num_boxes_; ++i) { @@ -654,10 +653,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( +absl::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { for (int i = 0; i < num_boxes_; ++i) { @@ -670,6 +669,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_scores[i], detection_classes[i], options_.flip_vertically()); + const auto& bbox = detection.location_data().relative_bounding_box(); + if (bbox.width() < 0 || bbox.height() < 0) { + // Decoded detection boxes could have negative values for width/height due + // to model prediction. Filter out those boxes since some downstream + // calculators may assume non-negative values. (b/171391719) + continue; + } + // Add keypoints. if (options_.num_keypoints() > 0) { auto* location_data = detection.mutable_location_data(); @@ -687,7 +694,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } output_detections->emplace_back(detection); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( @@ -710,11 +717,10 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( return detection; } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( +absl::Status TfLiteTensorsToDetectionsCalculator::GpuInit( CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() - -> ::mediapipe::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { gpu_data_ = absl::make_unique(); // A shader to decode detection boxes. @@ -922,7 +928,7 @@ void main() { MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( raw_scores_length, &gpu_data_->raw_scores_buffer)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -1157,7 +1163,7 @@ kernel void scoreKernel( #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc index 72dd60a0b..ef2946c32 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc @@ -38,15 +38,15 @@ namespace mediapipe { // } class TfLiteTensorsToFloatsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); -::mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract( +absl::Status TfLiteTensorsToFloatsCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("TENSORS")); RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); @@ -59,18 +59,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); cc->Outputs().Tag("FLOAT").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToFloatsCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToFloatsCalculator::Process( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); const auto& input_tensors = @@ -98,6 +96,6 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc index bc911efb6..1be83bbe1 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -20,10 +20,30 @@ namespace mediapipe { +namespace { + +inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); } + +float ApplyActivation( + ::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions::Activation + activation, + float value) { + switch (activation) { + case ::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions::SIGMOID: + return Sigmoid(value); + break; + default: + return value; + } +} + +} // namespace + // A calculator for converting TFLite tensors from regression models into -// landmarks. Note that if the landmarks in the tensor has more than 4 -// dimensions, only the first 4 dimensions will be converted to -// [x,y,z, visibility]. +// landmarks. Note that if the landmarks in the tensor has more than 5 +// dimensions, only the first 5 dimensions will be converted to +// [x,y,z, visibility, presence]. The latter two fields may also stay unset if +// such attributes are not supported in the model. // // Input: // TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first @@ -69,13 +89,13 @@ namespace mediapipe { // } class TfLiteTensorsToLandmarksCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - ::mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); int num_landmarks_ = 0; bool flip_vertically_ = false; bool flip_horizontally_ = false; @@ -84,7 +104,7 @@ class TfLiteTensorsToLandmarksCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); -::mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract( +absl::Status TfLiteTensorsToLandmarksCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -117,11 +137,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); cc->Outputs().Tag("NORM_LANDMARKS").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToLandmarksCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -129,7 +148,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); if (cc->Outputs().HasTag("NORM_LANDMARKS")) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for getting normalized landmarks."; + << "Must provide input width/height for getting normalized landmarks."; } if (cc->Outputs().HasTag("LANDMARKS") && (options_.flip_vertically() || options_.flip_horizontally() || @@ -137,7 +156,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); cc->InputSidePackets().HasTag("FLIP_VERTICALLY"))) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for using flip_vertically option " + << "Must provide input width/height for using flip_vertically option " "when outputing landmarks in absolute coordinates."; } @@ -151,10 +170,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); ? cc->InputSidePackets().Tag("FLIP_VERTICALLY").Get() : options_.flip_vertically(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process( +absl::Status TfLiteTensorsToLandmarksCalculator::Process( CalculatorContext* cc) { // Override values if specified so. if (cc->Inputs().HasTag("FLIP_HORIZONTALLY") && @@ -167,7 +186,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); } if (cc->Inputs().Tag("TENSORS").IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_tensors = @@ -207,7 +226,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); landmark->set_z(raw_landmarks[offset + 2]); } if (num_dimensions > 3) { - landmark->set_visibility(raw_landmarks[offset + 3]); + landmark->set_visibility(ApplyActivation(options_.visibility_activation(), + raw_landmarks[offset + 3])); + } + if (num_dimensions > 4) { + landmark->set_presence(ApplyActivation(options_.presence_activation(), + raw_landmarks[offset + 4])); } } @@ -222,7 +246,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); // Scale Z coordinate as X + allow additional uniform normalization. norm_landmark->set_z(landmark.z() / options_.input_image_width() / options_.normalize_z()); - norm_landmark->set_visibility(landmark.visibility()); + if (landmark.has_visibility()) { // Set only if supported in the model. + norm_landmark->set_visibility(landmark.visibility()); + } + if (landmark.has_presence()) { // Set only if supported in the model. + norm_landmark->set_presence(landmark.presence()); + } } cc->Outputs() .Tag("NORM_LANDMARKS") @@ -238,16 +267,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); .At(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToLandmarksCalculator::LoadOptions( +absl::Status TfLiteTensorsToLandmarksCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions>(); num_landmarks_ = options_.num_landmarks(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto index cbf30c181..793639a53 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto @@ -25,6 +25,11 @@ message TfLiteTensorsToLandmarksCalculatorOptions { optional TfLiteTensorsToLandmarksCalculatorOptions ext = 257405002; } + enum Activation { + NONE = 0; + SIGMOID = 1; + } + // Number of landmarks from the output of the model. required int32 num_landmarks = 1; @@ -51,4 +56,10 @@ message TfLiteTensorsToLandmarksCalculatorOptions { // when normalized landmarks are needed. It is applied in addition to Z // coordinate being re-scaled as X. optional float normalize_z = 5 [default = 1.0]; + + // Apply activation function to the tensor representing landmark visibility. + optional Activation visibility_activation = 7 [default = NONE]; + + // Apply activation function to the tensor representing landmark presence. + optional Activation presence_activation = 8 [default = NONE]; } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 3369840e4..ec4945201 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -36,7 +36,7 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. @@ -69,7 +69,7 @@ using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Converts TFLite tensors from a tflite segmentation model to an image mask. // @@ -121,17 +121,17 @@ using ::tflite::gpu::gl::GlShader; // class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status LoadOptions(CalculatorContext* cc); - ::mediapipe::Status InitGpu(CalculatorContext* cc); - ::mediapipe::Status ProcessGpu(CalculatorContext* cc); - ::mediapipe::Status ProcessCpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc); + absl::Status ProcessCpu(CalculatorContext* cc); void GlRender(); ::mediapipe::TfLiteTensorsToSegmentationCalculatorOptions options_; @@ -147,12 +147,12 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { std::unique_ptr mask_program_no_prev_; std::unique_ptr tensor_buffer_; GLuint upsample_program_; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // static -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::GetContract( +absl::Status TfLiteTensorsToSegmentationCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -184,7 +184,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); cc->Inputs().Tag(kSizeImageGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Outputs. if (cc->Outputs().HasTag(kMaskTag)) { @@ -195,17 +195,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); cc->Outputs().Tag(kMaskGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (use_gpu) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open( +absl::Status TfLiteTensorsToSegmentationCalculator::Open( CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); @@ -213,44 +213,42 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); use_gpu_ = true; #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); if (use_gpu_) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - MP_RETURN_IF_ERROR(InitGpu(cc)); - return ::mediapipe::OkStatus(); - })); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(InitGpu(cc)); + return absl::OkStatus(); + })); #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( +absl::Status TfLiteTensorsToSegmentationCalculator::Process( CalculatorContext* cc) { if (use_gpu_) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return ::mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(ProcessGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(ProcessCpu(cc)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( +absl::Status TfLiteTensorsToSegmentationCalculator::Close( CalculatorContext* cc) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) gpu_helper_.RunInGlContext([this] { @@ -260,15 +258,15 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); mask_program_no_prev_.reset(); tensor_buffer_.reset(); }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( +absl::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( CalculatorContext* cc) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Get input streams. @@ -366,17 +364,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); large_mask_mat.copyTo(output_mat); cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Steps: // 1. receive tensor and optional previous mask // 2. process segmentation tensor into small mask // 3. upsample small mask into output mask to be same size as input image -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( +absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( CalculatorContext* cc) { if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) // Get input streams. @@ -458,9 +456,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Cleanup input_mask_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void TfLiteTensorsToSegmentationCalculator::GlRender() { @@ -512,10 +510,10 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( +absl::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = @@ -531,14 +529,13 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() { RET_CHECK_EQ(tensor_channels_, 2) << "Only 2 channel segmentation tensor currently supported"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( +absl::Status TfLiteTensorsToSegmentationCalculator::InitGpu( CalculatorContext* cc) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() - -> ::mediapipe::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { // A shader to process a segmentation tensor into an output mask, // and use an optional previous mask as input. // Currently uses 4 channels for output, @@ -698,11 +695,11 @@ void main() { glUseProgram(upsample_program_); glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 9e327511f..df6d5c6d6 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -exports_files(["LICENSE"]) - cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], @@ -38,186 +36,95 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", ], ) -proto_library( +mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], ) -proto_library( +mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], ) -proto_library( +mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "annotation_overlay_calculator_cc_proto", - srcs = ["annotation_overlay_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":annotation_overlay_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "detection_label_id_to_text_calculator_cc_proto", - srcs = ["detection_label_id_to_text_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [ - ":detection_label_id_to_text_calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "timed_box_list_id_to_label_calculator_cc_proto", - srcs = ["timed_box_list_id_to_label_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [ - ":timed_box_list_id_to_label_calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "latency_cc_proto", - srcs = ["latency.proto"], - visibility = ["//mediapipe:__subpackages__"], - deps = [":latency_proto"], -) - -mediapipe_cc_proto_library( - name = "non_max_suppression_calculator_cc_proto", - srcs = ["non_max_suppression_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":non_max_suppression_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "packet_frequency_cc_proto", - srcs = ["packet_frequency.proto"], - visibility = ["//mediapipe:__subpackages__"], - deps = [":packet_frequency_proto"], -) - -mediapipe_cc_proto_library( - name = "packet_frequency_calculator_cc_proto", - srcs = ["packet_frequency_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe:__subpackages__"], - deps = [ - ":packet_frequency_calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "packet_latency_calculator_cc_proto", - srcs = ["packet_latency_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe:__subpackages__"], - deps = [ - ":packet_latency_calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "collection_has_min_size_calculator_cc_proto", - srcs = ["collection_has_min_size_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe:__subpackages__"], - deps = [":collection_has_min_size_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "association_calculator_cc_proto", - srcs = ["association_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe:__subpackages__"], - deps = [":association_calculator_proto"], -) - cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], @@ -341,9 +248,11 @@ cc_library( "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", "//mediapipe/util:annotation_renderer", @@ -367,6 +276,7 @@ cc_library( deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "@com_google_absl//absl/container:node_hash_map", "//mediapipe/framework/port:status", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -434,16 +344,6 @@ cc_library( alwayslink = 1, ) -mediapipe_cc_proto_library( - name = "thresholding_calculator_cc_proto", - srcs = ["thresholding_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":thresholding_calculator_proto"], -) - cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], @@ -457,14 +357,19 @@ cc_library( alwayslink = 1, ) -mediapipe_cc_proto_library( - name = "landmarks_to_detection_calculator_cc_proto", - srcs = ["landmarks_to_detection_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], +cc_library( + name = "detection_to_landmarks_calculator", + srcs = ["detection_to_landmarks_calculator.cc"], visibility = ["//visibility:public"], - deps = [":landmarks_to_detection_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, ) cc_library( @@ -483,46 +388,6 @@ cc_library( alwayslink = 1, ) -mediapipe_cc_proto_library( - name = "detections_to_rects_calculator_cc_proto", - srcs = ["detections_to_rects_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":detections_to_rects_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "landmark_projection_calculator_cc_proto", - srcs = ["landmark_projection_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":landmark_projection_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "landmarks_to_floats_calculator_cc_proto", - srcs = ["landmarks_to_floats_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":landmarks_to_floats_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "rect_transformation_calculator_cc_proto", - srcs = ["rect_transformation_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":rect_transformation_calculator_proto"], -) - cc_library( name = "detections_to_rects_calculator", srcs = [ @@ -593,162 +458,164 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -proto_library( +mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +cc_library( + name = "landmark_visibility_calculator", + srcs = ["landmark_visibility_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_library( + name = "set_landmark_visibility_calculator", + srcs = ["set_landmark_visibility_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -proto_library( +mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -mediapipe_cc_proto_library( - name = "rect_to_render_data_calculator_cc_proto", - srcs = ["rect_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":rect_to_render_data_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "rect_to_render_scale_calculator_cc_proto", - srcs = ["rect_to_render_scale_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":rect_to_render_scale_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "detections_to_render_data_calculator_cc_proto", - srcs = ["detections_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":detections_to_render_data_calculator_proto"], -) - cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], @@ -768,18 +635,6 @@ cc_library( alwayslink = 1, ) -mediapipe_cc_proto_library( - name = "landmarks_to_render_data_calculator_cc_proto", - srcs = ["landmarks_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":landmarks_to_render_data_calculator_proto"], -) - cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], @@ -800,18 +655,6 @@ cc_library( alwayslink = 1, ) -mediapipe_cc_proto_library( - name = "timed_box_list_to_render_data_calculator_cc_proto", - srcs = ["timed_box_list_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":timed_box_list_to_render_data_calculator_proto"], -) - cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], @@ -916,6 +759,39 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detection_projection_calculator", + srcs = ["detection_projection_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:point", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "detection_projection_calculator_test", + srcs = ["detection_projection_calculator_test.cc"], + deps = [ + ":detection_projection_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_utils", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:point", + ], +) + cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], @@ -945,6 +821,32 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "landmarks_smoothing_calculator_proto", + srcs = ["landmarks_smoothing_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "landmarks_smoothing_calculator", + srcs = ["landmarks_smoothing_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":landmarks_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util/filtering:relative_velocity_filter", + "@com_google_absl//absl/algorithm:container", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], @@ -992,25 +894,16 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "top_k_scores_calculator_cc_proto", - srcs = ["top_k_scores_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":top_k_scores_calculator_proto"], -) - cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], @@ -1056,15 +949,14 @@ cc_test( ], ) -mediapipe_cc_proto_library( - name = "labels_to_render_data_calculator_cc_proto", - srcs = ["labels_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - ], +mediapipe_proto_library( + name = "local_file_contents_calculator_proto", + srcs = ["local_file_contents_calculator.proto"], visibility = ["//visibility:public"], - deps = [":labels_to_render_data_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( @@ -1072,6 +964,7 @@ cc_library( srcs = ["local_file_contents_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1103,6 +996,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/strings", @@ -1118,6 +1012,7 @@ cc_library( deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", @@ -1232,3 +1127,70 @@ cc_library( ], alwayslink = 1, ) + +mediapipe_proto_library( + name = "logic_calculator_proto", + srcs = ["logic_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "logic_calculator", + srcs = ["logic_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":logic_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "to_image_calculator", + srcs = ["to_image_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "from_image_calculator", + srcs = ["from_image_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc b/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc index 49768eae7..edfa4196a 100644 --- a/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc +++ b/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc @@ -40,7 +40,7 @@ namespace {} // namespace // } class AlignmentPointsRectsCalculator : public DetectionsToRectsCalculator { public: - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { RET_CHECK_OK(DetectionsToRectsCalculator::Open(cc)); // Make sure that start and end keypoints are provided. @@ -52,18 +52,18 @@ class AlignmentPointsRectsCalculator : public DetectionsToRectsCalculator { RET_CHECK(options_.has_rotation_vector_end_keypoint_index()) << "End keypoint is required to calculate rect size and rotation"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: - ::mediapipe::Status DetectionToNormalizedRect( + absl::Status DetectionToNormalizedRect( const ::mediapipe::Detection& detection, const DetectionSpec& detection_spec, ::mediapipe::NormalizedRect* rect) override; }; REGISTER_CALCULATOR(AlignmentPointsRectsCalculator); -::mediapipe::Status AlignmentPointsRectsCalculator::DetectionToNormalizedRect( +absl::Status AlignmentPointsRectsCalculator::DetectionToNormalizedRect( const Detection& detection, const DetectionSpec& detection_spec, NormalizedRect* rect) { const auto& location_data = detection.location_data(); @@ -96,7 +96,7 @@ REGISTER_CALCULATOR(AlignmentPointsRectsCalculator); rect->set_width(box_size / image_size->first); rect->set_height(box_size / image_size->second); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 13dcabc7e..7c5aadc55 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -20,33 +20,31 @@ #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" #include "mediapipe/util/annotation_renderer.h" #include "mediapipe/util/color.pb.h" #include "mediapipe/util/render_data.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { namespace { -constexpr char kInputFrameTag[] = "IMAGE"; -constexpr char kOutputFrameTag[] = "IMAGE"; - -constexpr char kInputVectorTag[] = "VECTOR"; - -constexpr char kInputFrameTagGpu[] = "IMAGE_GPU"; -constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU"; +constexpr char kVectorTag[] = "VECTOR"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kImageFrameTag[] = "IMAGE"; enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -57,12 +55,15 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT // merges the annotation overlay with the image frame. As a result, drawing in // this color is not supported and it should be set to something unlikely used. constexpr uchar kAnnotationBackgroundColor = 2; // Grayscale value. + +// Future Image type. +inline bool HasImageTag(mediapipe::CalculatorContext* cc) { return false; } } // namespace // A calculator for rendering data on images. // // Inputs: -// 1. IMAGE or IMAGE_GPU (optional): An ImageFrame (or GpuBuffer) +// 1. IMAGE or IMAGE_GPU (optional): An ImageFrame (or GpuBuffer), // containing the input image. // If output is CPU, and input isn't provided, the renderer creates a // blank canvas with the width, height and color provided in the options. @@ -74,7 +75,8 @@ constexpr uchar kAnnotationBackgroundColor = 2; // Grayscale value. // input vector items. These input streams are tagged with "VECTOR". // // Output: -// 1. IMAGE or IMAGE_GPU: A rendered ImageFrame (or GpuBuffer). +// 1. IMAGE or IMAGE_GPU: A rendered ImageFrame (or GpuBuffer), +// Note: Output types should match their corresponding input stream type. // // For CPU input frames, only SRGBA, SRGB and GRAY8 format are supported. The // output format is the same as input except for GRAY8 where the output is in @@ -122,26 +124,29 @@ class AnnotationOverlayCalculator : public CalculatorBase { AnnotationOverlayCalculator() = default; ~AnnotationOverlayCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status CreateRenderTargetCpu(CalculatorContext* cc, - std::unique_ptr& image_mat, - ImageFormat::Format* target_format); - ::mediapipe::Status CreateRenderTargetGpu( - CalculatorContext* cc, std::unique_ptr& image_mat); - ::mediapipe::Status RenderToGpu(CalculatorContext* cc, uchar* overlay_image); - ::mediapipe::Status RenderToCpu(CalculatorContext* cc, - const ImageFormat::Format& target_format, - uchar* data_image); + absl::Status CreateRenderTargetCpu(CalculatorContext* cc, + std::unique_ptr& image_mat, + ImageFormat::Format* target_format); + template + absl::Status CreateRenderTargetGpu(CalculatorContext* cc, + std::unique_ptr& image_mat); + template + absl::Status RenderToGpu(CalculatorContext* cc, uchar* overlay_image); + absl::Status RenderToCpu(CalculatorContext* cc, + const ImageFormat::Format& target_format, + uchar* data_image); - ::mediapipe::Status GlRender(CalculatorContext* cc); - ::mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlRender(CalculatorContext* cc); + template + absl::Status GlSetup(CalculatorContext* cc); // Options for the calculator. AnnotationOverlayCalculatorOptions options_; @@ -154,7 +159,7 @@ class AnnotationOverlayCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. @@ -162,34 +167,35 @@ class AnnotationOverlayCalculator : public CalculatorBase { int height_ = 0; int width_canvas_ = 0; // Size of overlay drawing texture canvas. int height_canvas_ = 0; -#endif // MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(AnnotationOverlayCalculator); -::mediapipe::Status AnnotationOverlayCalculator::GetContract( - CalculatorContract* cc) { +absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; - if (cc->Inputs().HasTag(kInputFrameTag) && - cc->Inputs().HasTag(kInputFrameTagGpu)) { - return ::mediapipe::InternalError("Cannot have multiple input images."); + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("Cannot have multiple input images."); } - if (cc->Inputs().HasTag(kInputFrameTagGpu) != - cc->Outputs().HasTag(kOutputFrameTagGpu)) { - return ::mediapipe::InternalError("GPU output must have GPU input."); + if (cc->Inputs().HasTag(kGpuBufferTag) != + cc->Outputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("GPU output must have GPU input."); } - // Input image to render onto copy of. -#if !defined(MEDIAPIPE_DISABLE_GPU) - if (cc->Inputs().HasTag(kInputFrameTagGpu)) { - cc->Inputs().Tag(kInputFrameTagGpu).Set(); - use_gpu |= true; + // Input image to render onto copy of. Should be same type as output. +#if !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag(kGpuBufferTag)) { + cc->Inputs().Tag(kGpuBufferTag).Set(); + CHECK(cc->Outputs().HasTag(kGpuBufferTag)); + use_gpu = true; } -#endif // !MEDIAPIPE_DISABLE_GPU - if (cc->Inputs().HasTag(kInputFrameTag)) { - cc->Inputs().Tag(kInputFrameTag).Set(); +#endif // !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag(kImageFrameTag)) { + cc->Inputs().Tag(kImageFrameTag).Set(); + CHECK(cc->Outputs().HasTag(kImageFrameTag)); } // Data streams to render. @@ -197,7 +203,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ++id) { auto tag_and_index = cc->Inputs().TagAndIndexFromId(id); std::string tag = tag_and_index.first; - if (tag == kInputVectorTag) { + if (tag == kVectorTag) { cc->Inputs().Get(id).Set>(); } else if (tag.empty()) { // Empty tag defaults to accepting a single object of RenderData type. @@ -205,44 +211,39 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } } - // Rendered image. -#if !defined(MEDIAPIPE_DISABLE_GPU) - if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { - cc->Outputs().Tag(kOutputFrameTagGpu).Set(); - use_gpu |= true; + // Rendered image. Should be same type as input. +#if !MEDIAPIPE_DISABLE_GPU + if (cc->Outputs().HasTag(kGpuBufferTag)) { + cc->Outputs().Tag(kGpuBufferTag).Set(); } -#endif // !MEDIAPIPE_DISABLE_GPU - if (cc->Outputs().HasTag(kOutputFrameTag)) { - cc->Outputs().Tag(kOutputFrameTag).Set(); +#endif // !MEDIAPIPE_DISABLE_GPU + if (cc->Outputs().HasTag(kImageFrameTag)) { + cc->Outputs().Tag(kImageFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { +absl::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - if (cc->Inputs().HasTag(kInputFrameTagGpu) && - cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Inputs().HasTag(kGpuBufferTag) || HasImageTag(cc)) { +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; -#else - RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - if (cc->Inputs().HasTag(kInputFrameTagGpu) || - cc->Inputs().HasTag(kInputFrameTag)) { + if (cc->Inputs().HasTag(kGpuBufferTag) || + cc->Inputs().HasTag(kImageFrameTag) || HasImageTag(cc)) { image_frame_available_ = true; } else { - image_frame_available_ = false; RET_CHECK(options_.has_canvas_width_px()); RET_CHECK(options_.has_canvas_height_px()); } @@ -253,44 +254,46 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); if (use_gpu_) renderer_->SetScaleFactor(options_.gpu_scale_factor()); // Set the output header based on the input header (if present). - const char* input_tag = use_gpu_ ? kInputFrameTagGpu : kInputFrameTag; - const char* output_tag = use_gpu_ ? kOutputFrameTagGpu : kOutputFrameTag; - if (image_frame_available_ && - !cc->Inputs().Tag(input_tag).Header().IsEmpty()) { + const char* tag = use_gpu_ ? kGpuBufferTag : kImageFrameTag; + if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) { const auto& input_header = - cc->Inputs().Tag(input_tag).Header().Get(); + cc->Inputs().Tag(tag).Header().Get(); auto* output_video_header = new VideoHeader(input_header); - cc->Outputs().Tag(output_tag).SetHeader(Adopt(output_video_header)); + cc->Outputs().Tag(tag).SetHeader(Adopt(output_video_header)); } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::Process( - CalculatorContext* cc) { +absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { // Initialize render target, drawn with OpenCV. std::unique_ptr image_mat; ImageFormat::Format target_format; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (!gpu_initialized_) { MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { - MP_RETURN_IF_ERROR(GlSetup(cc)); - return ::mediapipe::OkStatus(); + gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + return GlSetup(cc); })); gpu_initialized_ = true; } -#endif // !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat)); + if (cc->Inputs().HasTag(kGpuBufferTag)) { + MP_RETURN_IF_ERROR( + (CreateRenderTargetGpu( + cc, image_mat))); + } +#endif // !MEDIAPIPE_DISABLE_GPU } else { - MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); + if (cc->Outputs().HasTag(kImageFrameTag)) { + MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); + } } // Reset the renderer with the image_mat. No copy here. @@ -301,7 +304,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ++id) { auto tag_and_index = cc->Inputs().TagAndIndexFromId(id); std::string tag = tag_and_index.first; - if (!tag.empty() && tag != kInputVectorTag) { + if (!tag.empty() && tag != kVectorTag) { continue; } if (cc->Inputs().Get(id).IsEmpty()) { @@ -312,7 +315,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); const RenderData& render_data = cc->Inputs().Get(id).Get(); renderer_->RenderDataOnImage(render_data); } else { - RET_CHECK_EQ(kInputVectorTag, tag); + RET_CHECK_EQ(kVectorTag, tag); const std::vector& render_data_vec = cc->Inputs().Get(id).Get>(); for (const RenderData& render_data : render_data_vec) { @@ -322,44 +325,44 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Overlay rendered image in OpenGL, onto a copy of input. uchar* image_mat_ptr = image_mat->data; - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, cc, image_mat_ptr]() -> ::mediapipe::Status { - MP_RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr)); - return ::mediapipe::OkStatus(); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc, image_mat_ptr]() -> absl::Status { + return RenderToGpu( + cc, image_mat_ptr); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } else { // Copy the rendered image to output. uchar* image_mat_ptr = image_mat->data; MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); image_mat_tex_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( +absl::Status AnnotationOverlayCalculator::RenderToCpu( CalculatorContext* cc, const ImageFormat::Format& target_format, uchar* data_image) { auto output_frame = absl::make_unique( target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kGlDefaultAlignmentBoundary); @@ -367,21 +370,23 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kDefaultAlignmentBoundary); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - cc->Outputs() - .Tag(kOutputFrameTag) - .Add(output_frame.release(), cc->InputTimestamp()); + if (cc->Outputs().HasTag(kImageFrameTag)) { + cc->Outputs() + .Tag(kImageFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( - CalculatorContext* cc, uchar* overlay_image) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +template +absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc, + uchar* overlay_image) { +#if !MEDIAPIPE_DISABLE_GPU // Source and destination textures. - const auto& input_frame = - cc->Inputs().Tag(kInputFrameTagGpu).Get(); + const auto& input_frame = cc->Inputs().Tag(Tag).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); auto output_texture = gpu_helper_.CreateDestinationTexture( @@ -414,25 +419,23 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } // Send out blended image as GPU packet. - auto output_frame = output_texture.GetFrame(); - cc->Outputs() - .Tag(kOutputFrameTagGpu) - .Add(output_frame.release(), cc->InputTimestamp()); + auto output_frame = output_texture.GetFrame(); + cc->Outputs().Tag(Tag).Add(output_frame.release(), cc->InputTimestamp()); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( +absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( CalculatorContext* cc, std::unique_ptr& image_mat, ImageFormat::Format* target_format) { if (image_frame_available_) { const auto& input_frame = - cc->Inputs().Tag(kInputFrameTag).Get(); + cc->Inputs().Tag(kImageFrameTag).Get(); int target_mat_type; switch (input_frame.Format()) { @@ -449,45 +452,38 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); target_mat_type = CV_8UC3; break; default: - return ::mediapipe::UnknownError("Unexpected image frame format."); + return absl::UnknownError("Unexpected image frame format."); break; } image_mat = absl::make_unique( input_frame.Height(), input_frame.Width(), target_mat_type); + + auto input_mat = formats::MatView(&input_frame); if (input_frame.Format() == ImageFormat::GRAY8) { - const int target_num_channels = - ImageFrame::NumberOfChannelsForFormat(*target_format); - for (int i = 0; i < input_frame.PixelDataSize(); i++) { - const auto& pix = input_frame.PixelData()[i]; - for (int c = 0; c < target_num_channels; c++) { - image_mat->data[i * target_num_channels + c] = pix; - } - } + cv::Mat rgb_mat; + cv::cvtColor(input_mat, rgb_mat, CV_GRAY2RGB); + rgb_mat.copyTo(*image_mat); } else { - // Make of a copy since the input frame may be consumed by other nodes. - const int buffer_size = - input_frame.Height() * input_frame.Width() * - ImageFrame::NumberOfChannelsForFormat(*target_format); - input_frame.CopyToBuffer(image_mat->data, buffer_size); + input_mat.copyTo(*image_mat); } } else { image_mat = absl::make_unique( options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); + *target_format = ImageFormat::SRGB; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( +template +absl::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( CalculatorContext* cc, std::unique_ptr& image_mat) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (image_frame_available_) { - const auto& input_frame = - cc->Inputs().Tag(kInputFrameTagGpu).Get(); - + const auto& input_frame = cc->Inputs().Tag(Tag).Get(); const mediapipe::ImageFormat::Format format = mediapipe::ImageFormatForGpuBufferFormat(input_frame.format()); if (format != mediapipe::ImageFormat::SRGBA && @@ -503,14 +499,13 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::GlRender( - CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::GlRender(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -558,14 +553,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AnnotationOverlayCalculator::GlSetup( - CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +template +absl::Status AnnotationOverlayCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -638,8 +633,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); const float alignment = ImageFrame::kGlDefaultAlignmentBoundary; const float scale_factor = options_.gpu_scale_factor(); if (image_frame_available_) { - const auto& input_frame = - cc->Inputs().Tag(kInputFrameTagGpu).Get(); + const auto& input_frame = cc->Inputs().Tag(Tag).Get(); width_ = RoundUp(input_frame.width(), alignment); height_ = RoundUp(input_frame.height(), alignment); } else { @@ -663,9 +657,9 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glBindTexture(GL_TEXTURE_2D, 0); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/association_calculator.h b/mediapipe/calculators/util/association_calculator.h index a16de4977..6e5b480ce 100644 --- a/mediapipe/calculators/util/association_calculator.h +++ b/mediapipe/calculators/util/association_calculator.h @@ -56,7 +56,7 @@ inline float OverlapSimilarity(const Rectangle_f& rect1, template class AssociationCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { // Atmost one input stream can be tagged with "PREV". RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1); @@ -71,10 +71,10 @@ class AssociationCalculator : public CalculatorBase { cc->Outputs().Index(0).Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); has_prev_input_stream_ = cc->Inputs().HasTag("PREV"); @@ -84,15 +84,15 @@ class AssociationCalculator : public CalculatorBase { options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>(); CHECK_GE(options_.min_similarity_threshold(), 0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto get_non_overlapping_elements = GetNonOverlappingElements(cc); if (!get_non_overlapping_elements.ok()) { return get_non_overlapping_elements.status(); } - std::list result = get_non_overlapping_elements.ValueOrDie(); + std::list result = get_non_overlapping_elements.value(); if (has_prev_input_stream_ && !cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) { @@ -114,7 +114,7 @@ class AssociationCalculator : public CalculatorBase { } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } protected: @@ -123,8 +123,8 @@ class AssociationCalculator : public CalculatorBase { bool has_prev_input_stream_; CollectionItemId prev_input_stream_id_; - virtual ::mediapipe::StatusOr GetRectangle(const T& input) { - return ::mediapipe::OkStatus(); + virtual absl::StatusOr GetRectangle(const T& input) { + return absl::OkStatus(); } virtual std::pair GetId(const T& input) { return {false, -1}; } @@ -134,7 +134,7 @@ class AssociationCalculator : public CalculatorBase { private: // Get a list of non-overlapping elements from all input streams, with // increasing order of priority based on input stream index. - mediapipe::StatusOr> GetNonOverlappingElements( + absl::StatusOr> GetNonOverlappingElements( CalculatorContext* cc) { std::list result; @@ -176,7 +176,7 @@ class AssociationCalculator : public CalculatorBase { return result; } - ::mediapipe::Status AddElementToList(T element, std::list* current) { + absl::Status AddElementToList(T element, std::list* current) { // Compare this element with elements of the input collection. If this // element has high overlap with elements of the collection, remove // those elements from the collection and add this element. @@ -207,20 +207,20 @@ class AssociationCalculator : public CalculatorBase { } current->push_back(element); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Compare elements of the current list with elements in from the collection // of elements from the previous input stream, and propagate IDs from the // previous input stream as appropriate. - ::mediapipe::Status PropagateIdsFromPreviousToCurrent( + absl::Status PropagateIdsFromPreviousToCurrent( const std::vector& prev_input_vec, std::list* current) { for (auto vit = current->begin(); vit != current->end(); ++vit) { auto get_cur_rectangle = GetRectangle(*vit); if (!get_cur_rectangle.ok()) { return get_cur_rectangle.status(); } - const Rectangle_f& cur_rect = get_cur_rectangle.ValueOrDie(); + const Rectangle_f& cur_rect = get_cur_rectangle.value(); bool change_id = false; int id_for_vi = -1; @@ -230,7 +230,7 @@ class AssociationCalculator : public CalculatorBase { if (!get_prev_rectangle.ok()) { return get_prev_rectangle.status(); } - const Rectangle_f& prev_rect = get_prev_rectangle.ValueOrDie(); + const Rectangle_f& prev_rect = get_prev_rectangle.value(); if (OverlapSimilarity(cur_rect, prev_rect) > options_.min_similarity_threshold()) { @@ -250,7 +250,7 @@ class AssociationCalculator : public CalculatorBase { *vit = element; } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/util/association_detection_calculator.cc b/mediapipe/calculators/util/association_detection_calculator.cc index 125e8c4ff..35112aee7 100644 --- a/mediapipe/calculators/util/association_detection_calculator.cc +++ b/mediapipe/calculators/util/association_detection_calculator.cc @@ -37,27 +37,27 @@ namespace mediapipe { class AssociationDetectionCalculator : public AssociationCalculator<::mediapipe::Detection> { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return AssociationCalculator<::mediapipe::Detection>::GetContract(cc); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Open(cc); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Process(cc); } - ::mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Close(cc); } protected: - ::mediapipe::StatusOr GetRectangle( + absl::StatusOr GetRectangle( const ::mediapipe::Detection& input) override { if (!input.has_location_data()) { - return ::mediapipe::InternalError("Missing location_data in Detection"); + return absl::InternalError("Missing location_data in Detection"); } const Location location(input.location_data()); return location.GetRelativeBBox(); diff --git a/mediapipe/calculators/util/association_norm_rect_calculator.cc b/mediapipe/calculators/util/association_norm_rect_calculator.cc index 4069eda60..a9194604a 100644 --- a/mediapipe/calculators/util/association_norm_rect_calculator.cc +++ b/mediapipe/calculators/util/association_norm_rect_calculator.cc @@ -36,29 +36,28 @@ namespace mediapipe { class AssociationNormRectCalculator : public AssociationCalculator<::mediapipe::NormalizedRect> { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return AssociationCalculator<::mediapipe::NormalizedRect>::GetContract(cc); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Open(cc); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Process(cc); } - ::mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Close(cc); } protected: - ::mediapipe::StatusOr GetRectangle( + absl::StatusOr GetRectangle( const ::mediapipe::NormalizedRect& input) override { if (!input.has_x_center() || !input.has_y_center() || !input.has_width() || !input.has_height()) { - return ::mediapipe::InternalError( - "Missing dimensions in NormalizedRect."); + return absl::InternalError("Missing dimensions in NormalizedRect."); } const float xmin = input.x_center() - input.width() / 2.0; const float ymin = input.y_center() - input.height() / 2.0; diff --git a/mediapipe/calculators/util/clock_latency_calculator.cc b/mediapipe/calculators/util/clock_latency_calculator.cc index 768abb2a4..5c5711731 100644 --- a/mediapipe/calculators/util/clock_latency_calculator.cc +++ b/mediapipe/calculators/util/clock_latency_calculator.cc @@ -60,18 +60,17 @@ class ClockLatencyCalculator : public CalculatorBase { public: ClockLatencyCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int64 num_packet_streams_ = -1; }; REGISTER_CALCULATOR(ClockLatencyCalculator); -::mediapipe::Status ClockLatencyCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); int64 num_packet_streams = cc->Inputs().NumEntries() - 1; @@ -83,17 +82,17 @@ REGISTER_CALCULATOR(ClockLatencyCalculator); } cc->Inputs().Tag(kReferenceTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ClockLatencyCalculator::Open(CalculatorContext* cc) { +absl::Status ClockLatencyCalculator::Open(CalculatorContext* cc) { // Direct passthrough, as far as timestamp and bounds are concerned. cc->SetOffset(TimestampDiff(0)); num_packet_streams_ = cc->Inputs().NumEntries() - 1; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { +absl::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { // Get reference time. RET_CHECK(!cc->Inputs().Tag(kReferenceTag).IsEmpty()); const absl::Time& reference_time = @@ -110,7 +109,7 @@ REGISTER_CALCULATOR(ClockLatencyCalculator); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/clock_timestamp_calculator.cc b/mediapipe/calculators/util/clock_timestamp_calculator.cc index ea715f8ae..4ba56cfd0 100644 --- a/mediapipe/calculators/util/clock_timestamp_calculator.cc +++ b/mediapipe/calculators/util/clock_timestamp_calculator.cc @@ -52,10 +52,10 @@ class ClockTimestampCalculator : public CalculatorBase { public: ClockTimestampCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Clock object. @@ -63,8 +63,7 @@ class ClockTimestampCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(ClockTimestampCalculator); -::mediapipe::Status ClockTimestampCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ClockTimestampCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); @@ -78,10 +77,10 @@ REGISTER_CALCULATOR(ClockTimestampCalculator); .Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { +absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { // Direct passthrough, as far as timestamp and bounds are concerned. cc->SetOffset(TimestampDiff(0)); @@ -95,14 +94,14 @@ REGISTER_CALCULATOR(ClockTimestampCalculator); ::mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ClockTimestampCalculator::Process(CalculatorContext* cc) { +absl::Status ClockTimestampCalculator::Process(CalculatorContext* cc) { // Push the Time packet to output. auto timestamp_packet = MakePacket(clock_->TimeNow()); cc->Outputs().Index(0).AddPacket(timestamp_packet.At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator.cc b/mediapipe/calculators/util/collection_has_min_size_calculator.cc index 22bfb9c4c..956818c87 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator.cc +++ b/mediapipe/calculators/util/collection_has_min_size_calculator.cc @@ -17,18 +17,24 @@ #include +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" namespace mediapipe { -typedef CollectionHasMinSizeCalculator> +typedef CollectionHasMinSizeCalculator> NormalizedRectVectorHasMinSizeCalculator; REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator); typedef CollectionHasMinSizeCalculator< - std::vector<::mediapipe::NormalizedLandmarkList>> + std::vector> NormalizedLandmarkListVectorHasMinSizeCalculator; REGISTER_CALCULATOR(NormalizedLandmarkListVectorHasMinSizeCalculator); +typedef CollectionHasMinSizeCalculator< + std::vector> + ClassificationListVectorHasMinSizeCalculator; +REGISTER_CALCULATOR(ClassificationListVectorHasMinSizeCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator.h b/mediapipe/calculators/util/collection_has_min_size_calculator.h index 80b4556c5..4d4b6a678 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator.h +++ b/mediapipe/calculators/util/collection_has_min_size_calculator.h @@ -42,7 +42,7 @@ namespace mediapipe { template class CollectionHasMinSizeCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("ITERABLE")); RET_CHECK_EQ(1, cc->Inputs().NumEntries()); @@ -60,10 +60,10 @@ class CollectionHasMinSizeCalculator : public CalculatorBase { if (cc->InputSidePackets().NumEntries() > 0) { cc->InputSidePackets().Index(0).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); min_size_ = cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>() @@ -73,17 +73,17 @@ class CollectionHasMinSizeCalculator : public CalculatorBase { !cc->InputSidePackets().Index(0).IsEmpty()) { min_size_ = cc->InputSidePackets().Index(0).Get(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const IterableT& input = cc->Inputs().Tag("ITERABLE").Get(); bool has_min_size = input.size() >= min_size_; cc->Outputs().Index(0).AddPacket( MakePacket(has_min_size).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 7d9cd5740..0de1e53b2 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/container/node_hash_map.h" #include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" @@ -46,26 +47,25 @@ namespace mediapipe { // } class DetectionLabelIdToTextCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - std::unordered_map label_map_; + absl::node_hash_map label_map_; }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); -::mediapipe::Status DetectionLabelIdToTextCalculator::GetContract( +absl::Status DetectionLabelIdToTextCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set>(); cc->Outputs().Index(0).Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionLabelIdToTextCalculator::Open( - CalculatorContext* cc) { +absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -89,11 +89,10 @@ REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); label_map_[i] = options.label(i); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionLabelIdToTextCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { std::vector output_detections; for (const auto& input_detection : cc->Inputs().Index(0).Get>()) { @@ -114,7 +113,7 @@ REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); cc->Outputs().Index(0).AddPacket( MakePacket>(output_detections) .At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc index cf3761010..8f8025576 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -70,7 +70,7 @@ constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; // } class DetectionLetterboxRemovalCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionsTag) && cc->Inputs().HasTag(kLetterboxPaddingTag)) << "Missing one or more input streams."; @@ -80,19 +80,19 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs().Tag(kDetectionsTag).Set>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Only process if there's input detections. if (cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_detections = @@ -146,7 +146,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs() .Tag("DETECTIONS") .Add(output_detections.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DetectionLetterboxRemovalCalculator); diff --git a/mediapipe/calculators/util/detection_projection_calculator.cc b/mediapipe/calculators/util/detection_projection_calculator.cc new file mode 100644 index 000000000..211fd204c --- /dev/null +++ b/mediapipe/calculators/util/detection_projection_calculator.cc @@ -0,0 +1,178 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/point2.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// Projects detections to a different coordinate system using a provided +// projection matrix. +// +// Input: +// DETECTIONS - std::vector +// Detections to project using the provided projection matrix. +// PROJECTION_MATRIX - std::array +// A 4x4 row-major-order matrix that maps data from one coordinate system to +// another. +// +// Output: +// DETECTIONS - std::vector +// Projected detections. +// +// Example: +// node { +// calculator: "DetectionProjectionCalculator" +// input_stream: "DETECTIONS:detections" +// input_stream: "PROJECTION_MATRIX:matrix" +// output_stream: "DETECTIONS:projected_detections" +// } +class DetectionProjectionCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(DetectionProjectionCalculator); + +namespace { + +constexpr char kDetections[] = "DETECTIONS"; +constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX"; + +absl::Status ProjectDetection( + const std::function& project_fn, + Detection* detection) { + auto* location_data = detection->mutable_location_data(); + RET_CHECK_EQ(location_data->format(), LocationData::RELATIVE_BOUNDING_BOX); + + // Project keypoints. + for (int i = 0; i < location_data->relative_keypoints_size(); ++i) { + auto* kp = location_data->mutable_relative_keypoints(i); + const auto point = project_fn({kp->x(), kp->y()}); + kp->set_x(point.x()); + kp->set_y(point.y()); + } + + // Project bounding box. + auto* box = location_data->mutable_relative_bounding_box(); + + const float xmin = box->xmin(); + const float ymin = box->ymin(); + const float width = box->width(); + const float height = box->height(); + // a) Define and project box points. + std::array box_coordinates = { + Point2_f{xmin, ymin}, Point2_f{xmin + width, ymin}, + Point2_f{xmin + width, ymin + height}, Point2_f{xmin, ymin + height}}; + std::transform(box_coordinates.begin(), box_coordinates.end(), + box_coordinates.begin(), project_fn); + // b) Find new left top and right bottom points for a box which encompases + // non-projected (rotated) box. + constexpr float kFloatMax = std::numeric_limits::max(); + constexpr float kFloatMin = std::numeric_limits::lowest(); + Point2_f left_top = {kFloatMax, kFloatMax}; + Point2_f right_bottom = {kFloatMin, kFloatMin}; + std::for_each(box_coordinates.begin(), box_coordinates.end(), + [&left_top, &right_bottom](const Point2_f& p) { + left_top.set_x(std::min(left_top.x(), p.x())); + left_top.set_y(std::min(left_top.y(), p.y())); + right_bottom.set_x(std::max(right_bottom.x(), p.x())); + right_bottom.set_y(std::max(right_bottom.y(), p.y())); + }); + box->set_xmin(left_top.x()); + box->set_ymin(left_top.y()); + box->set_width(right_bottom.x() - left_top.x()); + box->set_height(right_bottom.y() - left_top.y()); + + return absl::OkStatus(); +} + +} // namespace + +absl::Status DetectionProjectionCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetections) && + cc->Inputs().HasTag(kProjectionMatrix)) + << "Missing one or more input streams."; + + RET_CHECK_EQ(cc->Inputs().NumEntries(kDetections), + cc->Outputs().NumEntries(kDetections)) + << "Same number of DETECTIONS input and output is required."; + + for (CollectionItemId id = cc->Inputs().BeginId(kDetections); + id != cc->Inputs().EndId(kDetections); ++id) { + cc->Inputs().Get(id).Set>(); + } + cc->Inputs().Tag(kProjectionMatrix).Set>(); + + for (CollectionItemId id = cc->Outputs().BeginId(kDetections); + id != cc->Outputs().EndId(kDetections); ++id) { + cc->Outputs().Get(id).Set>(); + } + + return absl::OkStatus(); +} + +absl::Status DetectionProjectionCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + return absl::OkStatus(); +} + +absl::Status DetectionProjectionCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) { + return absl::OkStatus(); + } + const auto& project_mat = + cc->Inputs().Tag(kProjectionMatrix).Get>(); + auto project_fn = [project_mat](const Point2_f& p) -> Point2_f { + return {p.x() * project_mat[0] + p.y() * project_mat[1] + project_mat[3], + p.x() * project_mat[4] + p.y() * project_mat[5] + project_mat[7]}; + }; + + CollectionItemId input_id = cc->Inputs().BeginId(kDetections); + CollectionItemId output_id = cc->Outputs().BeginId(kDetections); + // Number of inputs and outpus is the same according to the contract. + for (; input_id != cc->Inputs().EndId(kDetections); ++input_id, ++output_id) { + const auto& input_packet = cc->Inputs().Get(input_id); + if (input_packet.IsEmpty()) { + continue; + } + + std::vector output_detections; + for (const auto& detection : input_packet.Get>()) { + Detection output_detection = detection; + MP_RETURN_IF_ERROR(ProjectDetection(project_fn, &output_detection)); + output_detections.push_back(std::move(output_detection)); + } + + cc->Outputs().Get(output_id).AddPacket( + MakePacket>(std::move(output_detections)) + .At(cc->InputTimestamp())); + } + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_projection_calculator_test.cc b/mediapipe/calculators/util/detection_projection_calculator_test.cc new file mode 100644 index 000000000..0437e6f96 --- /dev/null +++ b/mediapipe/calculators/util/detection_projection_calculator_test.cc @@ -0,0 +1,309 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/point2.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; +using ::testing::FloatNear; + +constexpr float kMaxError = 1e-4; + +MATCHER_P2(PointEq, x, y, "") { + bool result = testing::Value(arg.x(), FloatNear(x, kMaxError)) && + testing::Value(arg.y(), FloatNear(y, kMaxError)); + if (!result) { + *result_listener << "actual: {" << arg.x() << ", " << arg.y() + << "}, expected: {" << x << ", " << y << "}"; + } + return result; +} + +MATCHER_P4(BoundingBoxEq, xmin, ymin, width, height, "") { + return testing::Value(arg.xmin(), FloatNear(xmin, kMaxError)) && + testing::Value(arg.ymin(), FloatNear(ymin, kMaxError)) && + testing::Value(arg.width(), FloatNear(width, kMaxError)) && + testing::Value(arg.height(), FloatNear(height, kMaxError)); +} + +std::vector GetPoints(const Detection& detection) { + std::vector points; + const auto& location_data = detection.location_data(); + for (int i = 0; i < location_data.relative_keypoints_size(); ++i) { + const auto& kp = location_data.relative_keypoints(i); + points.push_back({kp.x(), kp.y()}); + } + return points; +} + +// Test helper function to run "DetectionProjectionCalculator". +absl::StatusOr RunProjectionCalculator( + Detection detection, std::array project_mat) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionProjectionCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "PROJECTION_MATRIX:matrix" + output_stream: "DETECTIONS:projected_detections" + )")); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back(MakePacket>( + std::vector({std::move(detection)})) + .At(Timestamp::PostStream())); + runner.MutableInputs() + ->Tag("PROJECTION_MATRIX") + .packets.push_back( + MakePacket>(std::move(project_mat)) + .At(Timestamp::PostStream())); + + MP_RETURN_IF_ERROR(runner.Run()); + const std::vector& output = + runner.Outputs().Tag("DETECTIONS").packets; + RET_CHECK_EQ(output.size(), 1); + const auto& output_detections = output[0].Get>(); + + RET_CHECK_EQ(output_detections.size(), 1); + return output_detections[0]; +} + +TEST(DetectionProjectionCalculatorTest, ProjectionFullRoiNoOp) { + Detection detection; + auto* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(0.0f); + location_data->mutable_relative_bounding_box()->set_ymin(0.0f); + location_data->mutable_relative_bounding_box()->set_width(0.5f); + location_data->mutable_relative_bounding_box()->set_height(0.5f); + + auto* kp = location_data->add_relative_keypoints(); + kp->set_x(0.25f); + kp->set_y(0.25f); + + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0.0f); + + constexpr int kImageWidth = 100; + constexpr int kImageHeight = 100; + + RotatedRect rect; + rect.center_x = roi.x_center() * kImageWidth; + rect.center_y = roi.y_center() * kImageHeight; + rect.width = roi.width() * kImageWidth; + rect.height = roi.height() * kImageHeight; + rect.rotation = roi.rotation(); + + std::array projection_matrix; + GetRotatedSubRectToRectTransformMatrix(rect, kImageWidth, kImageHeight, + /*flip_horizontaly=*/false, + &projection_matrix); + + auto status_or_result = RunProjectionCalculator(std::move(detection), + std::move(projection_matrix)); + MP_ASSERT_OK(status_or_result); + const auto& result = status_or_result.value(); + ASSERT_EQ(result.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(result.location_data().relative_bounding_box(), + BoundingBoxEq(0.0f, 0.0f, 0.5f, 0.5f)); + EXPECT_THAT(GetPoints(result), testing::ElementsAre(PointEq(0.25f, 0.25f))); +} + +TEST(DetectionProjectionCalculatorTest, ProjectionFullRoi90Rotation) { + Detection detection; + auto* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(0.0f); + location_data->mutable_relative_bounding_box()->set_ymin(0.0f); + location_data->mutable_relative_bounding_box()->set_width(0.5f); + location_data->mutable_relative_bounding_box()->set_height(0.5f); + + auto* kp = location_data->add_relative_keypoints(); + kp->set_x(0.25f); + kp->set_y(0.25f); + + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(90 * M_PI / 180.0f); + + constexpr int kImageWidth = 100; + constexpr int kImageHeight = 100; + + RotatedRect rect; + rect.center_x = roi.x_center() * kImageWidth; + rect.center_y = roi.y_center() * kImageHeight; + rect.width = roi.width() * kImageWidth; + rect.height = roi.height() * kImageHeight; + rect.rotation = roi.rotation(); + + std::array projection_matrix; + GetRotatedSubRectToRectTransformMatrix(rect, kImageWidth, kImageHeight, + /*flip_horizontaly=*/false, + &projection_matrix); + + auto status_or_result = RunProjectionCalculator(std::move(detection), + std::move(projection_matrix)); + MP_ASSERT_OK(status_or_result); + const auto& result = status_or_result.value(); + ASSERT_EQ(result.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(result.location_data().relative_bounding_box(), + BoundingBoxEq(0.5f, 0.0f, 0.5f, 0.5f)); + EXPECT_THAT(GetPoints(result), ElementsAre(PointEq(0.75f, 0.25f))); +} + +TEST(DetectionProjectionCalculatorTest, ProjectionSmallerRoi) { + Detection detection; + auto* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(0.5f); + location_data->mutable_relative_bounding_box()->set_ymin(0.0f); + location_data->mutable_relative_bounding_box()->set_width(0.5f); + location_data->mutable_relative_bounding_box()->set_height(0.5f); + + auto* kp = location_data->add_relative_keypoints(); + kp->set_x(0.5f); + kp->set_y(0.5f); + + mediapipe::NormalizedRect roi; + roi.set_x_center(0.75f); + roi.set_y_center(0.75f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0.0f); + + constexpr int kImageWidth = 100; + constexpr int kImageHeight = 100; + + RotatedRect rect; + rect.center_x = roi.x_center() * kImageWidth; + rect.center_y = roi.y_center() * kImageHeight; + rect.width = roi.width() * kImageWidth; + rect.height = roi.height() * kImageHeight; + rect.rotation = roi.rotation(); + + std::array projection_matrix; + GetRotatedSubRectToRectTransformMatrix(rect, kImageWidth, kImageHeight, + /*flip_horizontaly=*/false, + &projection_matrix); + + auto status_or_result = RunProjectionCalculator(std::move(detection), + std::move(projection_matrix)); + MP_ASSERT_OK(status_or_result); + const auto& result = status_or_result.value(); + ASSERT_EQ(result.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(result.location_data().relative_bounding_box(), + BoundingBoxEq(0.75f, 0.5f, 0.25f, 0.25f)); + EXPECT_THAT(GetPoints(result), ElementsAre(PointEq(0.75f, 0.75f))); +} + +TEST(DetectionProjectionCalculatorTest, ProjectionSmallerRoi30Rotation) { + constexpr float kImageWidth = 80; + constexpr float kImageHeight = 120; + constexpr float kRectWidth = 50; + constexpr float kRectHeight = 30; + constexpr float kRectXCenter = 65; + constexpr float kRectYCenter = 85; + constexpr float kRectRotation = 30 * M_PI / 180.0f; + + Detection detection; + auto* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(0.0f); + location_data->mutable_relative_bounding_box()->set_ymin(0.0f); + location_data->mutable_relative_bounding_box()->set_width(1.0f); + location_data->mutable_relative_bounding_box()->set_height(1.0f); + // Expected box values were calculated manually from image. + constexpr float kExpectedBoxXMin = 35.849f / kImageWidth; + constexpr float kExpectedBoxYMin = 59.510f / kImageHeight; + constexpr float kExpectedBoxWidth = 58.301f / kImageWidth; + constexpr float kExpectedBoxHeight = 50.981f / kImageHeight; + + auto* kp1 = location_data->add_relative_keypoints(); + kp1->set_x(0.0f); + kp1->set_y(0.0f); + auto* kp2 = location_data->add_relative_keypoints(); + kp2->set_x(0.5f); + kp2->set_y(0.5f); + auto* kp3 = location_data->add_relative_keypoints(); + kp3->set_x(1.0f); + kp3->set_y(0.0f); + // Expected key points were calculated manually from image. + constexpr float kExpectedPoint1X = 50.85f / kImageWidth; + constexpr float kExpectedPoint1Y = 59.52f / kImageHeight; + constexpr float kExpectedPoint2X = kRectXCenter / kImageWidth; + constexpr float kExpectedPoint2Y = kRectYCenter / kImageHeight; + constexpr float kExpectedPoint3X = 94.15f / kImageWidth; + constexpr float kExpectedPoint3Y = 84.51f / kImageHeight; + + mediapipe::NormalizedRect roi; + roi.set_x_center(kRectXCenter / kImageWidth); + roi.set_y_center(kRectYCenter / kImageHeight); + roi.set_width(kRectWidth / kImageWidth); + roi.set_height(kRectHeight / kImageHeight); + roi.set_rotation(kRectRotation); + + RotatedRect rect; + rect.center_x = roi.x_center() * kImageWidth; + rect.center_y = roi.y_center() * kImageHeight; + rect.width = roi.width() * kImageWidth; + rect.height = roi.height() * kImageHeight; + rect.rotation = roi.rotation(); + + std::array projection_matrix; + GetRotatedSubRectToRectTransformMatrix(rect, kImageWidth, kImageHeight, + /*flip_horizontaly=*/false, + &projection_matrix); + + auto status_or_result = RunProjectionCalculator(std::move(detection), + std::move(projection_matrix)); + MP_ASSERT_OK(status_or_result); + const auto& result = status_or_result.value(); + ASSERT_EQ(result.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(result.location_data().relative_bounding_box(), + BoundingBoxEq(kExpectedBoxXMin, kExpectedBoxYMin, + kExpectedBoxWidth, kExpectedBoxHeight)); + EXPECT_THAT(GetPoints(result), + ElementsAre(PointEq(kExpectedPoint1X, kExpectedPoint1Y), + PointEq(kExpectedPoint2X, kExpectedPoint2Y), + PointEq(kExpectedPoint3X, kExpectedPoint3Y))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_to_landmarks_calculator.cc b/mediapipe/calculators/util/detection_to_landmarks_calculator.cc new file mode 100644 index 000000000..549298bad --- /dev/null +++ b/mediapipe/calculators/util/detection_to_landmarks_calculator.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" + +namespace mediapipe { + +namespace { + +constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kLandmarksTag[] = "LANDMARKS"; + +absl::Status ConvertDetectionToLandmarks(const Detection& detection, + NormalizedLandmarkList* landmarks) { + const auto& location_data = detection.location_data(); + for (int i = 0; i < location_data.relative_keypoints_size(); ++i) { + const auto& keypoint = location_data.relative_keypoints(i); + + auto* landmark = landmarks->add_landmark(); + landmark->set_x(keypoint.x()); + landmark->set_y(keypoint.y()); + } + + return absl::OkStatus(); +} + +} // namespace + +// Converts a detection into a normalized landmark list by extracting the +// location data relative keypoints as landmarks. +// +// Input: +// DETECTION - `Detection` +// A detection to be converted. +// +// Output: +// LANDMARKS - `NormalizedLandmarkList` +// A converted normalized landmark list. +// +// Example: +// +// node { +// calculator: "DetectionToLandmarksCalculator" +// input_stream: "DETECTION:detection" +// output_stream: "LANDMARKS:landmarks" +// } +// +class DetectionToLandmarksCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetectionTag)); + RET_CHECK(cc->Outputs().HasTag(kLandmarksTag)); + + cc->Inputs().Tag(kDetectionTag).Set(); + cc->Outputs().Tag(kLandmarksTag).Set(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + const auto& detection = cc->Inputs().Tag(kDetectionTag).Get(); + + auto landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(ConvertDetectionToLandmarks(detection, landmarks.get())); + + cc->Outputs() + .Tag(kLandmarksTag) + .Add(landmarks.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } +}; + +REGISTER_CALCULATOR(DetectionToLandmarksCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_unique_id_calculator.cc b/mediapipe/calculators/util/detection_unique_id_calculator.cc index 2069f1677..ac8889ffb 100644 --- a/mediapipe/calculators/util/detection_unique_id_calculator.cc +++ b/mediapipe/calculators/util/detection_unique_id_calculator.cc @@ -44,7 +44,7 @@ inline int GetNextDetectionId() { return ++detection_id; } // } class DetectionUniqueIdCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag)) << "None of the input streams are provided."; @@ -60,25 +60,24 @@ class DetectionUniqueIdCalculator : public CalculatorBase { cc->Outputs().Tag(kDetectionsTag).Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(::mediapipe::TimestampDiff(0)); - return ::mediapipe::OkStatus(); + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(mediapipe::TimestampDiff(0)); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(DetectionUniqueIdCalculator); -::mediapipe::Status DetectionUniqueIdCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs().Tag(kDetectionListTag).IsEmpty()) { auto result = cc->Inputs().Tag(kDetectionListTag).Value().Consume(); if (result.ok()) { - auto detection_list = std::move(result).ValueOrDie(); + auto detection_list = std::move(result).value(); for (Detection& detection : *detection_list->mutable_detection()) { detection.set_detection_id(GetNextDetectionId()); } @@ -95,7 +94,7 @@ REGISTER_CALCULATOR(DetectionUniqueIdCalculator); .Value() .Consume>(); if (result.ok()) { - auto detections = std::move(result).ValueOrDie(); + auto detections = std::move(result).value(); for (Detection& detection : *detections) { detection.set_detection_id(GetNextDetectionId()); } @@ -104,7 +103,7 @@ REGISTER_CALCULATOR(DetectionUniqueIdCalculator); .Add(detections.release(), cc->InputTimestamp()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 52ba9dd7a..29836cb59 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -14,6 +14,7 @@ #include "mediapipe/calculators/util/detections_to_rects_calculator.h" #include +#include #include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -36,40 +37,97 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr float kMinFloat = std::numeric_limits::lowest(); +constexpr float kMaxFloat = std::numeric_limits::max(); + +absl::Status NormRectFromKeyPoints(const LocationData& location_data, + NormalizedRect* rect) { + RET_CHECK_GT(location_data.relative_keypoints_size(), 1) + << "2 or more key points required to calculate a rect."; + float xmin = kMaxFloat; + float ymin = kMaxFloat; + float xmax = kMinFloat; + float ymax = kMinFloat; + for (int i = 0; i < location_data.relative_keypoints_size(); ++i) { + const auto& kp = location_data.relative_keypoints(i); + xmin = std::min(xmin, kp.x()); + ymin = std::min(ymin, kp.y()); + xmax = std::max(xmax, kp.x()); + ymax = std::max(ymax, kp.y()); + } + rect->set_x_center((xmin + xmax) / 2); + rect->set_y_center((ymin + ymax) / 2); + rect->set_width(xmax - xmin); + rect->set_height(ymax - ymin); + return absl::OkStatus(); +} + +template +void RectFromBox(B box, R* rect) { + rect->set_x_center(box.xmin() + box.width() / 2); + rect->set_y_center(box.ymin() + box.height() / 2); + rect->set_width(box.width()); + rect->set_height(box.height()); +} + } // namespace -::mediapipe::Status DetectionsToRectsCalculator::DetectionToRect( +absl::Status DetectionsToRectsCalculator::DetectionToRect( const Detection& detection, const DetectionSpec& detection_spec, Rect* rect) { const LocationData location_data = detection.location_data(); - RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX) - << "Only Detection with formats of BOUNDING_BOX can be converted to Rect"; - const LocationData::BoundingBox bounding_box = location_data.bounding_box(); - rect->set_x_center(bounding_box.xmin() + bounding_box.width() / 2); - rect->set_y_center(bounding_box.ymin() + bounding_box.height() / 2); - rect->set_width(bounding_box.width()); - rect->set_height(bounding_box.height()); - return ::mediapipe::OkStatus(); + switch (options_.conversion_mode()) { + case mediapipe::DetectionsToRectsCalculatorOptions_ConversionMode_DEFAULT: + case mediapipe:: + DetectionsToRectsCalculatorOptions_ConversionMode_USE_BOUNDING_BOX: { + RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX) + << "Only Detection with formats of BOUNDING_BOX can be converted to " + "Rect"; + RectFromBox(location_data.bounding_box(), rect); + break; + } + case mediapipe:: + DetectionsToRectsCalculatorOptions_ConversionMode_USE_KEYPOINTS: { + RET_CHECK(detection_spec.image_size.has_value()) + << "Rect with absolute coordinates calculation requires image size."; + const int width = detection_spec.image_size->first; + const int height = detection_spec.image_size->second; + NormalizedRect norm_rect; + MP_RETURN_IF_ERROR(NormRectFromKeyPoints(location_data, &norm_rect)); + rect->set_x_center(std::round(norm_rect.x_center() * width)); + rect->set_y_center(std::round(norm_rect.y_center() * height)); + rect->set_width(std::round(norm_rect.width() * width)); + rect->set_height(std::round(norm_rect.height() * height)); + break; + } + } + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect( +absl::Status DetectionsToRectsCalculator::DetectionToNormalizedRect( const Detection& detection, const DetectionSpec& detection_spec, NormalizedRect* rect) { const LocationData location_data = detection.location_data(); - RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX) - << "Only Detection with formats of RELATIVE_BOUNDING_BOX can be " - "converted to NormalizedRect"; - const LocationData::RelativeBoundingBox bounding_box = - location_data.relative_bounding_box(); - rect->set_x_center(bounding_box.xmin() + bounding_box.width() / 2); - rect->set_y_center(bounding_box.ymin() + bounding_box.height() / 2); - rect->set_width(bounding_box.width()); - rect->set_height(bounding_box.height()); - return ::mediapipe::OkStatus(); + switch (options_.conversion_mode()) { + case mediapipe::DetectionsToRectsCalculatorOptions_ConversionMode_DEFAULT: + case mediapipe:: + DetectionsToRectsCalculatorOptions_ConversionMode_USE_BOUNDING_BOX: { + RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX) + << "Only Detection with formats of RELATIVE_BOUNDING_BOX can be " + "converted to NormalizedRect"; + RectFromBox(location_data.relative_bounding_box(), rect); + break; + } + case mediapipe:: + DetectionsToRectsCalculatorOptions_ConversionMode_USE_KEYPOINTS: { + MP_RETURN_IF_ERROR(NormRectFromKeyPoints(location_data, rect)); + break; + } + } + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRectsCalculator::GetContract( - CalculatorContract* cc) { +absl::Status DetectionsToRectsCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^ cc->Inputs().HasTag(kDetectionsTag)) << "Exactly one of DETECTION or DETECTIONS input stream should be " @@ -105,10 +163,10 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; cc->Outputs().Tag(kNormRectsTag).Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { +absl::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -133,18 +191,17 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; output_zero_rect_for_empty_detections_ = options_.output_zero_rect_for_empty_detections(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRectsCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kDetectionTag) && cc->Inputs().Tag(kDetectionTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kDetectionsTag) && cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector detections; @@ -172,7 +229,7 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; .Add(rect_vector.release(), cc->InputTimestamp()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -239,10 +296,10 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; .Add(output_rects.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRectsCalculator::ComputeRotation( +absl::Status DetectionsToRectsCalculator::ComputeRotation( const Detection& detection, const DetectionSpec& detection_spec, float* rotation) { const auto& location_data = detection.location_data(); @@ -260,7 +317,7 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; *rotation = NormalizeRadians(target_angle_ - std::atan2(-(y1 - y0), x1 - x0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } DetectionSpec DetectionsToRectsCalculator::GetDetectionSpec( diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.h b/mediapipe/calculators/util/detections_to_rects_calculator.h index 7fb26895e..e91441bc6 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.h +++ b/mediapipe/calculators/util/detections_to_rects_calculator.h @@ -83,21 +83,21 @@ struct DetectionSpec { // } class DetectionsToRectsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: - virtual ::mediapipe::Status DetectionToRect( - const ::mediapipe::Detection& detection, - const DetectionSpec& detection_spec, ::mediapipe::Rect* rect); - virtual ::mediapipe::Status DetectionToNormalizedRect( + virtual absl::Status DetectionToRect(const ::mediapipe::Detection& detection, + const DetectionSpec& detection_spec, + ::mediapipe::Rect* rect); + virtual absl::Status DetectionToNormalizedRect( const ::mediapipe::Detection& detection, const DetectionSpec& detection_spec, ::mediapipe::NormalizedRect* rect); - virtual ::mediapipe::Status ComputeRotation( - const ::mediapipe::Detection& detection, - const DetectionSpec& detection_spec, float* rotation); + virtual absl::Status ComputeRotation(const ::mediapipe::Detection& detection, + const DetectionSpec& detection_spec, + float* rotation); virtual DetectionSpec GetDetectionSpec(const CalculatorContext* cc); static inline float NormalizeRadians(float angle) { diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.proto b/mediapipe/calculators/util/detections_to_rects_calculator.proto index 8d1a49a1e..d49eb6c52 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.proto +++ b/mediapipe/calculators/util/detections_to_rects_calculator.proto @@ -35,4 +35,12 @@ message DetectionsToRectsCalculatorOptions { // Whether to output a zero-rect (with origin and size both zero) when the // input detection vector is empty. optional bool output_zero_rect_for_empty_detections = 5; + + enum ConversionMode { + DEFAULT = 0; + USE_BOUNDING_BOX = 1; + USE_KEYPOINTS = 2; + } + + optional ConversionMode conversion_mode = 6; } diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 7281847ca..85c2bd72f 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include + #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -26,6 +30,21 @@ #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { +namespace { + +MATCHER_P4(RectEq, x_center, y_center, width, height, "") { + return testing::Value(arg.x_center(), testing::Eq(x_center)) && + testing::Value(arg.y_center(), testing::Eq(y_center)) && + testing::Value(arg.width(), testing::Eq(width)) && + testing::Value(arg.height(), testing::Eq(height)); +} + +MATCHER_P4(NormRectEq, x_center, y_center, width, height, "") { + return testing::Value(arg.x_center(), testing::FloatEq(x_center)) && + testing::Value(arg.y_center(), testing::FloatEq(y_center)) && + testing::Value(arg.width(), testing::FloatEq(width)) && + testing::Value(arg.height(), testing::FloatEq(height)); +} Detection DetectionWithLocationData(int32 xmin, int32 ymin, int32 width, int32 height) { @@ -39,6 +58,19 @@ Detection DetectionWithLocationData(int32 xmin, int32 ymin, int32 width, return detection; } +Detection DetectionWithKeyPoints( + const std::vector>& key_points) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + std::for_each(key_points.begin(), key_points.end(), + [location_data](std::pair kp) { + auto* new_kp = location_data->add_relative_keypoints(); + new_kp->set_x(kp.first); + new_kp->set_y(kp.second); + }); + return detection; +} + Detection DetectionWithRelativeLocationData(double xmin, double ymin, double width, double height) { Detection detection; @@ -70,10 +102,61 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) { const std::vector& output = runner.Outputs().Tag("RECT").packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); - EXPECT_EQ(rect.width(), 300); - EXPECT_EQ(rect.height(), 400); - EXPECT_EQ(rect.x_center(), 250); - EXPECT_EQ(rect.y_center(), 400); + EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); +} + +absl::StatusOr RunDetectionKeyPointsToRectCalculation( + Detection detection, std::pair image_size) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "RECT:rect" + options: { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + conversion_mode: USE_KEYPOINTS + } + } + )")); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back(MakePacket(std::move(detection)) + .At(Timestamp::PostStream())); + runner.MutableInputs() + ->Tag("IMAGE_SIZE") + .packets.push_back(MakePacket>(image_size) + .At(Timestamp::PostStream())); + + MP_RETURN_IF_ERROR(runner.Run()); + const std::vector& output = runner.Outputs().Tag("RECT").packets; + RET_CHECK_EQ(output.size(), 1); + return output[0].Get(); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionKeyPointsToRect) { + auto status_or_value = RunDetectionKeyPointsToRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {1.0f, 1.0f}}), + /*image_size=*/{640, 480}); + EXPECT_THAT(status_or_value.value(), RectEq(320, 240, 640, 480)); + + status_or_value = RunDetectionKeyPointsToRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.25f, 0.25f}, {0.75f, 0.75f}}), + /*image_size=*/{640, 480}); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(320, 240, 320, 240)); + + status_or_value = RunDetectionKeyPointsToRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {0.5f, 0.5f}}), + /*image_size=*/{640, 480}); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(160, 120, 320, 240)); + + status_or_value = RunDetectionKeyPointsToRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.5f, 0.5f}, {1.0f, 1.0f}}), + /*image_size=*/{640, 480}); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(480, 360, 320, 240)); } TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { @@ -95,10 +178,56 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); - EXPECT_FLOAT_EQ(rect.width(), 0.3); - EXPECT_FLOAT_EQ(rect.height(), 0.4); - EXPECT_FLOAT_EQ(rect.x_center(), 0.25); - EXPECT_FLOAT_EQ(rect.y_center(), 0.4); + EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); +} + +absl::StatusOr RunDetectionKeyPointsToNormRectCalculation( + Detection detection) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + output_stream: "NORM_RECT:rect" + options: { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + conversion_mode: USE_KEYPOINTS + } + } + )")); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back(MakePacket(std::move(detection)) + .At(Timestamp::PostStream())); + + MP_RETURN_IF_ERROR(runner.Run()); + const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + RET_CHECK_EQ(output.size(), 1); + return output[0].Get(); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionKeyPointsToNormalizedRect) { + NormalizedRect rect; + + auto status_or_value = RunDetectionKeyPointsToNormRectCalculation( + /*detection=*/DetectionWithKeyPoints( + {{0.0f, 0.0f}, {0.5f, 0.5f}, {1.0f, 1.0f}})); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(0.5f, 0.5f, 1.0f, 1.0f)); + + status_or_value = RunDetectionKeyPointsToNormRectCalculation( + /*detection=*/DetectionWithKeyPoints( + {{0.25f, 0.25f}, {0.75f, 0.25f}, {0.75f, 0.75f}})); + EXPECT_THAT(status_or_value.value(), RectEq(0.5f, 0.5f, 0.5f, 0.5f)); + + status_or_value = RunDetectionKeyPointsToNormRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {0.5f, 0.5f}})); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(0.25f, 0.25f, 0.5f, 0.5f)); + + status_or_value = RunDetectionKeyPointsToNormRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.5f, 0.5f}, {1.0f, 1.0f}})); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(0.75f, 0.75f, 0.5f, 0.5f)); } TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { @@ -121,10 +250,7 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { const std::vector& output = runner.Outputs().Tag("RECT").packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); - EXPECT_EQ(rect.width(), 300); - EXPECT_EQ(rect.height(), 400); - EXPECT_EQ(rect.x_center(), 250); - EXPECT_EQ(rect.y_center(), 400); + EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); } TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) { @@ -147,10 +273,7 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) { const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); - EXPECT_FLOAT_EQ(rect.width(), 0.3); - EXPECT_FLOAT_EQ(rect.height(), 0.4); - EXPECT_FLOAT_EQ(rect.x_center(), 0.25); - EXPECT_FLOAT_EQ(rect.y_center(), 0.4); + EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); } TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) { @@ -173,15 +296,9 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) { const std::vector& output = runner.Outputs().Tag("RECTS").packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); - EXPECT_EQ(rects.size(), 2); - EXPECT_EQ(rects[0].width(), 300); - EXPECT_EQ(rects[0].height(), 400); - EXPECT_EQ(rects[0].x_center(), 250); - EXPECT_EQ(rects[0].y_center(), 400); - EXPECT_EQ(rects[1].width(), 400); - EXPECT_EQ(rects[1].height(), 500); - EXPECT_EQ(rects[1].x_center(), 400); - EXPECT_EQ(rects[1].y_center(), 550); + ASSERT_EQ(rects.size(), 2); + EXPECT_THAT(rects[0], RectEq(250, 400, 300, 400)); + EXPECT_THAT(rects[1], RectEq(400, 550, 400, 500)); } TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) { @@ -205,15 +322,9 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) { runner.Outputs().Tag("NORM_RECTS").packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); - EXPECT_EQ(rects.size(), 2); - EXPECT_FLOAT_EQ(rects[0].width(), 0.3); - EXPECT_FLOAT_EQ(rects[0].height(), 0.4); - EXPECT_FLOAT_EQ(rects[0].x_center(), 0.25); - EXPECT_FLOAT_EQ(rects[0].y_center(), 0.4); - EXPECT_FLOAT_EQ(rects[1].width(), 0.4); - EXPECT_FLOAT_EQ(rects[1].height(), 0.5); - EXPECT_FLOAT_EQ(rects[1].x_center(), 0.4); - EXPECT_FLOAT_EQ(rects[1].y_center(), 0.55); + ASSERT_EQ(rects.size(), 2); + EXPECT_THAT(rects[0], NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); + EXPECT_THAT(rects[1], NormRectEq(0.4f, 0.55f, 0.4f, 0.5f)); } TEST(DetectionsToRectsCalculatorTest, DetectionToRects) { @@ -236,10 +347,7 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) { ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); EXPECT_EQ(rects.size(), 1); - EXPECT_EQ(rects[0].width(), 300); - EXPECT_EQ(rects[0].height(), 400); - EXPECT_EQ(rects[0].x_center(), 250); - EXPECT_EQ(rects[0].y_center(), 400); + EXPECT_THAT(rects[0], RectEq(250, 400, 300, 400)); } TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) { @@ -262,11 +370,8 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) { runner.Outputs().Tag("NORM_RECTS").packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); - EXPECT_EQ(rects.size(), 1); - EXPECT_FLOAT_EQ(rects[0].width(), 0.3); - EXPECT_FLOAT_EQ(rects[0].height(), 0.4); - EXPECT_FLOAT_EQ(rects[0].x_center(), 0.25); - EXPECT_FLOAT_EQ(rects[0].y_center(), 0.4); + ASSERT_EQ(rects.size(), 1); + EXPECT_THAT(rects[0], NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); } TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) { @@ -309,4 +414,5 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) { "Only Detection with formats of RELATIVE_BOUNDING_BOX")); } +} // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc index 5082cd363..25d74ba68 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -82,11 +82,11 @@ class DetectionsToRenderDataCalculator : public CalculatorBase { DetectionsToRenderDataCalculator& operator=( const DetectionsToRenderDataCalculator&) = delete; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // These utility methods are supposed to be used only by this class. No @@ -122,7 +122,7 @@ class DetectionsToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); -::mediapipe::Status DetectionsToRenderDataCalculator::GetContract( +absl::Status DetectionsToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag) || @@ -139,18 +139,16 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); cc->Inputs().Tag(kDetectionsTag).Set>(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status DetectionsToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DetectionsToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionsToRenderDataCalculator::Process(CalculatorContext* cc) { const auto& options = cc->Options(); const bool has_detection_from_list = cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs() @@ -165,7 +163,7 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); !cc->Inputs().Tag(kDetectionTag).IsEmpty(); if (!options.produce_empty_packet() && !has_detection_from_list && !has_detection_from_vector && !has_single_detection) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // TODO: Add score threshold to @@ -191,7 +189,7 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void DetectionsToRenderDataCalculator::SetRenderAnnotationColorThickness( diff --git a/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc b/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc index b0a177e58..4b4742b18 100644 --- a/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc +++ b/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc @@ -42,7 +42,7 @@ constexpr char kBoxesTag[] = "BOXES"; // } class DetectionsToTimedBoxListCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag)) << "None of the input streams are provided."; @@ -53,14 +53,14 @@ class DetectionsToTimedBoxListCalculator : public CalculatorBase { cc->Inputs().Tag(kDetectionsTag).Set>(); } cc->Outputs().Tag(kBoxesTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: void ConvertDetectionToTimedBox(const Detection& detection, @@ -68,7 +68,7 @@ class DetectionsToTimedBoxListCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(DetectionsToTimedBoxListCalculator); -::mediapipe::Status DetectionsToTimedBoxListCalculator::Process( +absl::Status DetectionsToTimedBoxListCalculator::Process( CalculatorContext* cc) { auto output_timed_box_list = absl::make_unique(); @@ -91,7 +91,7 @@ REGISTER_CALCULATOR(DetectionsToTimedBoxListCalculator); cc->Outputs().Tag(kBoxesTag).Add(output_timed_box_list.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void DetectionsToTimedBoxListCalculator::ConvertDetectionToTimedBox( diff --git a/mediapipe/calculators/util/filter_collection_calculator.cc b/mediapipe/calculators/util/filter_collection_calculator.cc index 356b03dd6..690ca2a93 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.cc +++ b/mediapipe/calculators/util/filter_collection_calculator.cc @@ -20,9 +20,14 @@ #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/integral_types.h" namespace mediapipe { +typedef FilterCollectionCalculator> + FilterUInt64CollectionCalculator; +REGISTER_CALCULATOR(FilterUInt64CollectionCalculator); + typedef FilterCollectionCalculator> FilterNormalizedRectCollectionCalculator; REGISTER_CALCULATOR(FilterNormalizedRectCollectionCalculator); diff --git a/mediapipe/calculators/util/filter_collection_calculator.h b/mediapipe/calculators/util/filter_collection_calculator.h index 5f08dd982..60a6255c9 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.h +++ b/mediapipe/calculators/util/filter_collection_calculator.h @@ -42,7 +42,7 @@ namespace mediapipe { template class FilterCollectionCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("ITERABLE")); RET_CHECK(cc->Inputs().HasTag("CONDITION")); RET_CHECK(cc->Outputs().HasTag("ITERABLE")); @@ -52,20 +52,20 @@ class FilterCollectionCalculator : public CalculatorBase { cc->Outputs().Tag("ITERABLE").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag("ITERABLE").IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().Tag("CONDITION").IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const std::vector& filter_by = @@ -77,11 +77,11 @@ class FilterCollectionCalculator : public CalculatorBase { } template - ::mediapipe::Status FilterCollection(std::true_type, CalculatorContext* cc, - const std::vector& filter_by) { + absl::Status FilterCollection(std::true_type, CalculatorContext* cc, + const std::vector& filter_by) { const IterableU& input = cc->Inputs().Tag("ITERABLE").Get(); if (input.size() != filter_by.size()) { - return ::mediapipe::InternalError(absl::StrCat( + return absl::InternalError(absl::StrCat( "Input vector size: ", input.size(), " doesn't mach condition vector size: ", filter_by.size())); } @@ -93,14 +93,13 @@ class FilterCollectionCalculator : public CalculatorBase { } } cc->Outputs().Tag("ITERABLE").Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } template - ::mediapipe::Status FilterCollection(std::false_type, CalculatorContext* cc, - const std::vector& filter_by) { - return ::mediapipe::InternalError( - "Cannot copy input collection to filter it."); + absl::Status FilterCollection(std::false_type, CalculatorContext* cc, + const std::vector& filter_by) { + return absl::InternalError("Cannot copy input collection to filter it."); } }; diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc new file mode 100644 index 000000000..7484d9257 --- /dev/null +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -0,0 +1,164 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { +constexpr char kImageFrameTag[] = "IMAGE_CPU"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kImageTag[] = "IMAGE"; +} // namespace + +// A calculator for converting the unified image container into +// legacy MediaPipe datatypes. +// +// Inputs: +// IMAGE: An Image containing input image. +// +// Output: +// One of the following two tags: +// IMAGE_CPU: An ImageFrame containing output image. +// IMAGE_GPU: A GpuBuffer containing output image. +// +// Note: +// Data is automatically transferred to/from the CPU or GPU +// depending on output type. +// +class FromImageCalculator : public CalculatorBase { + public: + FromImageCalculator() = default; + ~FromImageCalculator() override = default; + + static absl::Status GetContract(CalculatorContract* cc); + + // From Calculator. + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + + bool gpu_output_ = false; + bool gpu_initialized_ = false; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(FromImageCalculator); + +absl::Status FromImageCalculator::GetContract(CalculatorContract* cc) { + cc->Inputs().Tag(kImageTag).Set(); + + bool gpu_output = false; + + if (cc->Outputs().HasTag(kImageFrameTag) && + cc->Outputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("Cannot have multiple outputs."); + } + + if (cc->Outputs().HasTag(kGpuBufferTag)) { +#if !MEDIAPIPE_DISABLE_GPU + cc->Outputs().Tag(kGpuBufferTag).Set(); + gpu_output = true; +#else + RET_CHECK_FAIL() << "GPU is disabled. Cannot use IMAGE_GPU stream."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + if (cc->Outputs().HasTag(kImageFrameTag)) { + cc->Outputs().Tag(kImageFrameTag).Set(); + } + + if (gpu_output) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Outputs().HasTag(kGpuBufferTag)) { + gpu_output_ = true; + } + + if (gpu_output_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Process(CalculatorContext* cc) { + if (gpu_output_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { + auto& input = cc->Inputs().Tag(kImageTag).Get(); + // Unwrap texture pointer; shallow copy. + auto output = + std::make_unique(input.GetGpuBuffer()); + cc->Outputs() + .Tag(kGpuBufferTag) + .Add(output.release(), cc->InputTimestamp()); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + // The input Image. + auto& input = cc->Inputs().Tag(kImageTag).Get(); + // Make a copy of the input packet to co-own the input Image. + Packet* packet_copy_ptr = new Packet(cc->Inputs().Tag(kImageTag).Value()); + // Create an output ImageFrame that points to the same pixel data as the + // input Image and also owns the packet copy. As a result, the output + // ImageFrame indirectly co-owns the input Image. This ensures a correct + // life span of the shared pixel data. + std::unique_ptr output = + std::make_unique( + input.image_format(), input.width(), input.height(), input.step(), + const_cast(input.GetImageFrameSharedPtr()->PixelData()), + [packet_copy_ptr](uint8*) { delete packet_copy_ptr; }); + cc->Outputs() + .Tag(kImageFrameTag) + .Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index fafedba5b..cf448cff1 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -59,9 +59,9 @@ constexpr float kFontHeightScale = 1.25f; // } class LabelsToRenderDataCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: LabelsToRenderDataCalculatorOptions options_; @@ -73,8 +73,7 @@ class LabelsToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(LabelsToRenderDataCalculator); -::mediapipe::Status LabelsToRenderDataCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("CLASSIFICATIONS")) { cc->Inputs().Tag("CLASSIFICATIONS").Set(); } else { @@ -89,26 +88,25 @@ REGISTER_CALCULATOR(LabelsToRenderDataCalculator); cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); } cc->Outputs().Tag("RENDER_DATA").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); num_colors_ = options_.color_size(); label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LabelsToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag("VIDEO_PRESTREAM") && cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); video_width_ = video_header.width; video_height_ = video_header.height; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT) << "Only TOP_LEFT is supported without VIDEO_PRESTREAM."; @@ -180,6 +178,6 @@ REGISTER_CALCULATOR(LabelsToRenderDataCalculator); .Tag("RENDER_DATA") .AddPacket(MakePacket(render_data).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc index 925272230..d3c7a6453 100644 --- a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc @@ -64,7 +64,7 @@ constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; // } class LandmarkLetterboxRemovalCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().HasTag(kLetterboxPaddingTag)) << "Missing one or more input streams."; @@ -84,18 +84,18 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs().Get(id).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag(kLetterboxPaddingTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& letterbox_padding = cc->Inputs().Tag(kLetterboxPaddingTag).Get>(); @@ -124,20 +124,17 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase { const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom); const float new_z = landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X. - + *new_landmark = landmark; new_landmark->set_x(new_x); new_landmark->set_y(new_y); - // Keep z-coord as is. new_landmark->set_z(new_z); - // Keep visibility as is. - new_landmark->set_visibility(landmark.visibility()); } cc->Outputs().Get(output_id).AddPacket( MakePacket(output_landmarks) .At(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(LandmarkLetterboxRemovalCalculator); diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index 0309c530a..59b7c020c 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -60,7 +60,7 @@ constexpr char kRectTag[] = "NORM_RECT"; // } class LandmarkProjectionCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().HasTag(kRectTag)) << "Missing one or more input streams."; @@ -80,18 +80,18 @@ class LandmarkProjectionCalculator : public CalculatorBase { cc->Outputs().Get(id).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag(kRectTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_rect = cc->Inputs().Tag(kRectTag).Get(); @@ -126,18 +126,17 @@ class LandmarkProjectionCalculator : public CalculatorBase { const float new_z = landmark.z() * input_rect.width(); // Scale Z coordinate as X. + *new_landmark = landmark; new_landmark->set_x(new_x); new_landmark->set_y(new_y); new_landmark->set_z(new_z); - // Keep visibility as is. - new_landmark->set_visibility(landmark.visibility()); } cc->Outputs().Get(output_id).AddPacket( MakePacket(output_landmarks) .At(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(LandmarkProjectionCalculator); diff --git a/mediapipe/calculators/util/landmark_visibility_calculator.cc b/mediapipe/calculators/util/landmark_visibility_calculator.cc new file mode 100644 index 000000000..f22d2ac57 --- /dev/null +++ b/mediapipe/calculators/util/landmark_visibility_calculator.cc @@ -0,0 +1,85 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kVisibilityTag[] = "VISIBILITY"; + +} // namespace + +// A calculator to extract visibility from the landmark. +// +// Inputs: +// NORM_LANDMARKS: A NormalizedLandmarkList with only a single landmark to +// take visibility from. It's a list and not single landmark as +// split/concatenate calculators work with lists. +// +// Outputs: +// VISIBILITY: Float visibility of the given landmark. +// +// Example config: +// node { +// calculator: "LandmarkVisibilityCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// output_stream: "VISIBILITY:visibility" +// } +// +class LandmarkVisibilityCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(LandmarkVisibilityCalculator); + +absl::Status LandmarkVisibilityCalculator::GetContract(CalculatorContract* cc) { + cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); + cc->Outputs().Tag(kVisibilityTag).Set(); + + return absl::OkStatus(); +} + +absl::Status LandmarkVisibilityCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + return absl::OkStatus(); +} + +absl::Status LandmarkVisibilityCalculator::Process(CalculatorContext* cc) { + // Check that landmark is not empty. + // Don't emit an empty packet for this timestamp. + if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& landmarks = + cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); + RET_CHECK_EQ(landmarks.landmark_size(), 1); + float visibility = landmarks.landmark(0).visibility(); + + cc->Outputs() + .Tag(kVisibilityTag) + .AddPacket(MakePacket(visibility).At(cc->InputTimestamp())); + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc similarity index 75% rename from mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.cc rename to mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 4b9b0f87b..4f1d4a608 100644 --- a/mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -13,12 +13,12 @@ // limitations under the License. #include "absl/algorithm/container.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.pb.h" -#include "mediapipe/graphs/pose_tracking/calculators/relative_velocity_filter.h" +#include "mediapipe/util/filtering/relative_velocity_filter.h" namespace mediapipe { @@ -28,7 +28,7 @@ constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; -using ::mediapipe::RelativeVelocityFilter; +using mediapipe::RelativeVelocityFilter; // Estimate object scale to use its inverse value as velocity scale for // RelativeVelocityFilter. If value will be too small (less than @@ -38,17 +38,17 @@ using ::mediapipe::RelativeVelocityFilter; // with sides parallel to axis. float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width, int image_height) { - const auto& [lm_min_x, lm_max_x] = absl::c_minmax_element( + const auto& lm_minmax_x = absl::c_minmax_element( landmarks.landmark(), [](const auto& a, const auto& b) { return a.x() < b.x(); }); - const float x_min = lm_min_x->x(); - const float x_max = lm_max_x->x(); + const float x_min = lm_minmax_x.first->x(); + const float x_max = lm_minmax_x.second->x(); - const auto& [lm_min_y, lm_max_y] = absl::c_minmax_element( + const auto& lm_minmax_y = absl::c_minmax_element( landmarks.landmark(), [](const auto& a, const auto& b) { return a.y() < b.y(); }); - const float y_min = lm_min_y->y(); - const float y_max = lm_max_y->y(); + const float y_min = lm_minmax_y.first->y(); + const float y_max = lm_minmax_y.second->y(); const float object_width = (x_max - x_min) * image_width; const float object_height = (y_max - y_min) * image_height; @@ -61,23 +61,23 @@ class LandmarksFilter { public: virtual ~LandmarksFilter() = default; - virtual ::mediapipe::Status Reset() { return ::mediapipe::OkStatus(); } + virtual absl::Status Reset() { return absl::OkStatus(); } - virtual ::mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) = 0; + virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) = 0; }; // Returns landmarks as is without smoothing. class NoFilter : public LandmarksFilter { public: - ::mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { *out_landmarks = in_landmarks; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -90,17 +90,17 @@ class VelocityFilter : public LandmarksFilter { velocity_scale_(velocity_scale), min_allowed_object_scale_(min_allowed_object_scale) {} - ::mediapipe::Status Reset() override { + absl::Status Reset() override { x_filters_.clear(); y_filters_.clear(); z_filters_.clear(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { // Get image size. int image_width; int image_height; @@ -113,7 +113,7 @@ class VelocityFilter : public LandmarksFilter { GetObjectScale(in_landmarks, image_width, image_height); if (object_scale < min_allowed_object_scale_) { *out_landmarks = in_landmarks; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const float value_scale = 1.0f / object_scale; @@ -125,6 +125,7 @@ class VelocityFilter : public LandmarksFilter { const NormalizedLandmark& in_landmark = in_landmarks.landmark(i); NormalizedLandmark* out_landmark = out_landmarks->add_landmark(); + *out_landmark = in_landmark; out_landmark->set_x(x_filters_[i].Apply(timestamp, value_scale, in_landmark.x() * image_width) / image_width); @@ -135,22 +136,20 @@ class VelocityFilter : public LandmarksFilter { out_landmark->set_z(z_filters_[i].Apply(timestamp, value_scale, in_landmark.z() * image_width) / image_width); - // Keep visibility as is. - out_landmark->set_visibility(in_landmark.visibility()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: // Initializes filters for the first time or after Reset. If initialized then // check the size. - ::mediapipe::Status InitializeFiltersIfEmpty(const int n_landmarks) { + absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { if (!x_filters_.empty()) { RET_CHECK_EQ(x_filters_.size(), n_landmarks); RET_CHECK_EQ(y_filters_.size(), n_landmarks); RET_CHECK_EQ(z_filters_.size(), n_landmarks); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } x_filters_.resize(n_landmarks, @@ -160,7 +159,7 @@ class VelocityFilter : public LandmarksFilter { z_filters_.resize(n_landmarks, RelativeVelocityFilter(window_size_, velocity_scale_)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } int window_size_; @@ -191,8 +190,8 @@ class VelocityFilter : public LandmarksFilter { // input_stream: "NORM_LANDMARKS:pose_landmarks" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered" -// node_options: { -// [type.googleapis.com/mediapipe.LandmarksSmoothingCalculatorOptions] { +// options: { +// [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { // velocity_filter: { // window_size: 5 // velocity_scale: 10.0 @@ -203,27 +202,26 @@ class VelocityFilter : public LandmarksFilter { // class LandmarksSmoothingCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: LandmarksFilter* landmarks_filter_; }; REGISTER_CALCULATOR(LandmarksSmoothingCalculator); -::mediapipe::Status LandmarksSmoothingCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs() .Tag(kNormalizedFilteredLandmarksTag) .Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { +absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); // Pick landmarks filter. @@ -240,16 +238,15 @@ REGISTER_CALCULATOR(LandmarksSmoothingCalculator); << "Landmarks filter is either not specified or not supported"; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksSmoothingCalculator::Process( - CalculatorContext* cc) { +absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { // Check that landmarks are not empty and reset the filter if so. // Don't emit an empty packet for this timestamp. if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& in_landmarks = @@ -267,7 +264,7 @@ REGISTER_CALCULATOR(LandmarksSmoothingCalculator); .Tag(kNormalizedFilteredLandmarksTag) .Add(out_landmarks.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.proto b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto similarity index 96% rename from mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.proto rename to mediapipe/calculators/util/landmarks_smoothing_calculator.proto index 9c7dd502b..aca539cab 100644 --- a/mediapipe/graphs/pose_tracking/calculators/landmarks_smoothing_calculator.proto +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto @@ -16,7 +16,7 @@ syntax = "proto2"; package mediapipe; -import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; message LandmarksSmoothingCalculatorOptions { extend CalculatorOptions { diff --git a/mediapipe/calculators/util/landmarks_to_detection_calculator.cc b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc index 64a7a8cc6..ffa359877 100644 --- a/mediapipe/calculators/util/landmarks_to_detection_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc @@ -80,17 +80,17 @@ Detection ConvertLandmarksToDetection(const NormalizedLandmarkList& landmarks) { // } class LandmarksToDetectionCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: ::mediapipe::LandmarksToDetectionCalculatorOptions options_; }; REGISTER_CALCULATOR(LandmarksToDetectionCalculator); -::mediapipe::Status LandmarksToDetectionCalculator::GetContract( +absl::Status LandmarksToDetectionCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag)); RET_CHECK(cc->Outputs().HasTag(kDetectionTag)); @@ -98,19 +98,17 @@ REGISTER_CALCULATOR(LandmarksToDetectionCalculator); cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Outputs().Tag(kDetectionTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksToDetectionCalculator::Open( - CalculatorContext* cc) { +absl::Status LandmarksToDetectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options<::mediapipe::LandmarksToDetectionCalculatorOptions>(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksToDetectionCalculator::Process( - CalculatorContext* cc) { +absl::Status LandmarksToDetectionCalculator::Process(CalculatorContext* cc) { const auto& landmarks = cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); RET_CHECK_GT(landmarks.landmark_size(), 0) @@ -134,7 +132,7 @@ REGISTER_CALCULATOR(LandmarksToDetectionCalculator); .Tag(kDetectionTag) .Add(detection.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_floats_calculator.cc b/mediapipe/calculators/util/landmarks_to_floats_calculator.cc index b86542dd5..fe8dd3ab1 100644 --- a/mediapipe/calculators/util/landmarks_to_floats_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_floats_calculator.cc @@ -62,7 +62,7 @@ constexpr char kMatrixTag[] = "MATRIX"; // } class LandmarksToFloatsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kLandmarksTag).Set(); RET_CHECK(cc->Outputs().HasTag(kFloatsTag) || cc->Outputs().HasTag(kMatrixTag)); @@ -73,10 +73,10 @@ class LandmarksToFloatsCalculator : public CalculatorBase { cc->Outputs().Tag(kMatrixTag).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>(); @@ -84,13 +84,13 @@ class LandmarksToFloatsCalculator : public CalculatorBase { // Currently number of dimensions must be within [1, 3]. RET_CHECK_GE(num_dimensions_, 1); RET_CHECK_LE(num_dimensions_, 3); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Only process if there's input landmarks. if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_landmarks = @@ -128,7 +128,7 @@ class LandmarksToFloatsCalculator : public CalculatorBase { .Tag(kMatrixTag) .Add(output_matrix.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index 6d8ee3fed..7818ad8cd 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -33,7 +33,6 @@ constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kLandmarkLabel[] = "KEYPOINT"; -constexpr int kMaxLandmarkThickness = 18; inline void SetColor(RenderAnnotation* annotation, const Color& color) { annotation->mutable_color()->set_r(color.r()); @@ -59,15 +58,18 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min, } void SetColorSizeValueFromZ(float z, float z_min, float z_max, - RenderAnnotation* render_annotation) { + RenderAnnotation* render_annotation, + float min_depth_circle_thickness, + float max_depth_circle_thickness) { const int color_value = 255 - static_cast(Remap(z, z_min, z_max, 255)); ::mediapipe::Color color; color.set_r(color_value); color.set_g(color_value); color.set_b(color_value); SetColor(render_annotation, color); - const int thickness = static_cast((1.f - Remap(z, z_min, z_max, 1)) * - kMaxLandmarkThickness); + const float scale = max_depth_circle_thickness - min_depth_circle_thickness; + const int thickness = static_cast( + min_depth_circle_thickness + (1.f - Remap(z, z_min, z_max, 1)) * scale); render_annotation->set_thickness(thickness); } @@ -97,14 +99,21 @@ template void AddConnectionsWithDepth(const LandmarkListType& landmarks, const std::vector& landmark_connections, bool utilize_visibility, - float visibility_threshold, float thickness, + float visibility_threshold, bool utilize_presence, + float presence_threshold, float thickness, bool normalized, float min_z, float max_z, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (visibility_threshold && (ld0.visibility() < visibility_threshold || - ld1.visibility() < visibility_threshold)) { + if (utilize_visibility && + ((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || + (ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { + continue; + } + if (utilize_presence && + ((ld0.has_presence() && ld0.presence() < presence_threshold) || + (ld1.has_presence() && ld1.presence() < presence_threshold))) { continue; } const int gray_val1 = @@ -136,13 +145,20 @@ template void AddConnections(const LandmarkListType& landmarks, const std::vector& landmark_connections, bool utilize_visibility, float visibility_threshold, + bool utilize_presence, float presence_threshold, const Color& connection_color, float thickness, bool normalized, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (visibility_threshold && (ld0.visibility() < visibility_threshold || - ld1.visibility() < visibility_threshold)) { + if (utilize_visibility && + ((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || + (ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { + continue; + } + if (utilize_presence && + ((ld0.has_presence() && ld0.presence() < presence_threshold) || + (ld1.has_presence() && ld1.presence() < presence_threshold))) { continue; } AddConnectionToRenderData(ld0, ld1, connection_color, @@ -161,7 +177,7 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, } // namespace -::mediapipe::Status LandmarksToRenderDataCalculator::GetContract( +absl::Status LandmarksToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) || cc->Inputs().HasTag(kNormLandmarksTag)) @@ -181,11 +197,10 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, cc->Inputs().Tag(kRenderScaleTag).Set(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status LandmarksToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -197,20 +212,19 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, landmark_connections_.push_back(options_.landmark_connections(i)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LandmarksToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { // Check that landmarks are not empty and skip rendering if so. // Don't emit an empty packet for this timestamp. if (cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kNormLandmarksTag) && cc->Inputs().Tag(kNormLandmarksTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } auto render_data = absl::make_unique(); @@ -238,27 +252,35 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, if (visualize_depth) { AddConnectionsWithDepth( landmarks, landmark_connections_, options_.utilize_visibility(), - options_.visibility_threshold(), thickness, /*normalized=*/false, - z_min, z_max, render_data.get()); + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold(), thickness, /*normalized=*/false, z_min, + z_max, render_data.get()); } else { AddConnections( landmarks, landmark_connections_, options_.utilize_visibility(), - options_.visibility_threshold(), options_.connection_color(), - thickness, /*normalized=*/false, render_data.get()); + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold(), options_.connection_color(), thickness, + /*normalized=*/false, render_data.get()); } for (int i = 0; i < landmarks.landmark_size(); ++i) { const Landmark& landmark = landmarks.landmark(i); - if (options_.utilize_visibility() && + if (options_.utilize_visibility() && landmark.has_visibility() && landmark.visibility() < options_.visibility_threshold()) { continue; } + if (options_.utilize_presence() && landmark.has_presence() && + landmark.presence() < options_.presence_threshold()) { + continue; + } + auto* landmark_data_render = AddPointRenderData( options_.landmark_color(), thickness, render_data.get()); if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, - landmark_data_render); + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); } auto* landmark_data = landmark_data_render->mutable_point(); landmark_data->set_normalized(false); @@ -279,27 +301,34 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, if (visualize_depth) { AddConnectionsWithDepth( landmarks, landmark_connections_, options_.utilize_visibility(), - options_.visibility_threshold(), thickness, /*normalized=*/true, - z_min, z_max, render_data.get()); + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold(), thickness, /*normalized=*/true, z_min, + z_max, render_data.get()); } else { AddConnections( landmarks, landmark_connections_, options_.utilize_visibility(), - options_.visibility_threshold(), options_.connection_color(), - thickness, /*normalized=*/true, render_data.get()); + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold(), options_.connection_color(), thickness, + /*normalized=*/true, render_data.get()); } for (int i = 0; i < landmarks.landmark_size(); ++i) { const NormalizedLandmark& landmark = landmarks.landmark(i); - if (options_.utilize_visibility() && + if (options_.utilize_visibility() && landmark.has_visibility() && landmark.visibility() < options_.visibility_threshold()) { continue; } + if (options_.utilize_presence() && landmark.has_presence() && + landmark.presence() < options_.presence_threshold()) { + continue; + } auto* landmark_data_render = AddPointRenderData( options_.landmark_color(), thickness, render_data.get()); if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, - landmark_data_render); + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); } auto* landmark_data = landmark_data_render->mutable_point(); landmark_data->set_normalized(true); @@ -311,7 +340,7 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(LandmarksToRenderDataCalculator); diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.h b/mediapipe/calculators/util/landmarks_to_render_data_calculator.h index 8f45955f4..0fbe9700c 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.h +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.h @@ -54,11 +54,11 @@ class LandmarksToRenderDataCalculator : public CalculatorBase { LandmarksToRenderDataCalculator& operator=( const LandmarksToRenderDataCalculator&) = delete; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: ::mediapipe::LandmarksToRenderDataCalculatorOptions options_; diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto index ff1e8fdc5..34f073f26 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto @@ -49,4 +49,19 @@ message LandmarksToRenderDataCalculatorOptions { // Threshold to determine visibility of the landmark. Landmark with visibility // greater or equal than threshold is considered visible. optional double visibility_threshold = 7 [default = 0.0]; + + // Use landmarks presence while rendering landmarks and connections. If + // landmark is not present, neither it nor adjacent connections will be + // rendered. + optional bool utilize_presence = 8 [default = false]; + + // Threshold to determine presence of the landmark. Landmark with presence + // greater or equal than threshold is considered present. + optional double presence_threshold = 9 [default = 0.0]; + + // Min thickness of the drawing for landmark circle. + optional double min_depth_circle_thickness = 10 [default = 0.0]; + + // Max thickness of the drawing for landmark circle. + optional double max_depth_circle_thickness = 11 [default = 18.0]; } diff --git a/mediapipe/calculators/util/local_file_contents_calculator.cc b/mediapipe/calculators/util/local_file_contents_calculator.cc index 254a552c9..4ad066f69 100644 --- a/mediapipe/calculators/util/local_file_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_contents_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "mediapipe/calculators/util/local_file_contents_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/util/resource_util.h" @@ -52,7 +53,7 @@ constexpr char kContentsTag[] = "CONTENTS"; // } class LocalFileContentsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->InputSidePackets().HasTag(kFilePathTag)) << "Missing PATH input side packet(s)"; RET_CHECK(cc->OutputSidePackets().HasTag(kContentsTag)) @@ -72,12 +73,14 @@ class LocalFileContentsCalculator : public CalculatorBase { cc->OutputSidePackets().Get(id).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { CollectionItemId input_id = cc->InputSidePackets().BeginId(kFilePathTag); CollectionItemId output_id = cc->OutputSidePackets().BeginId(kContentsTag); + auto options = cc->Options(); + // Number of inputs and outpus is the same according to the contract. for (; input_id != cc->InputSidePackets().EndId(kFilePathTag); ++input_id, ++output_id) { @@ -86,15 +89,16 @@ class LocalFileContentsCalculator : public CalculatorBase { ASSIGN_OR_RETURN(file_path, PathToResourceAsFile(file_path)); std::string contents; - MP_RETURN_IF_ERROR(GetResourceContents(file_path, &contents)); + MP_RETURN_IF_ERROR(GetResourceContents( + file_path, &contents, /*read_as_binary=*/!options.text_mode())); cc->OutputSidePackets().Get(output_id).Set( MakePacket(std::move(contents))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { - return ::mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/graphs/object_detection_3d/calculators/lift_2d_frame_annotation_to_3d_calculator.proto b/mediapipe/calculators/util/local_file_contents_calculator.proto similarity index 68% rename from mediapipe/graphs/object_detection_3d/calculators/lift_2d_frame_annotation_to_3d_calculator.proto rename to mediapipe/calculators/util/local_file_contents_calculator.proto index ccbdf2ee4..17876c89f 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/lift_2d_frame_annotation_to_3d_calculator.proto +++ b/mediapipe/calculators/util/local_file_contents_calculator.proto @@ -12,19 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -// The option proto for the Lift2DFrameAnnotationTo3DCalculatorOptions. - syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; -import "mediapipe/graphs/object_detection_3d/calculators/belief_decoder_config.proto"; -message Lift2DFrameAnnotationTo3DCalculatorOptions { +message LocalFileContentsCalculatorOptions { extend CalculatorOptions { - optional Lift2DFrameAnnotationTo3DCalculatorOptions ext = 290166284; + optional LocalFileContentsCalculatorOptions ext = 346849340; } - optional BeliefDecoderConfig decoder_config = 1; + // By default, set the file open mode to 'rb'. Otherwise, set the mode to 'r'. + optional bool text_mode = 1; } diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index 04fe3ac1c..fcba83a49 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -34,22 +34,22 @@ namespace mediapipe { // } class LocalFilePatternContentsCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("FILE_DIRECTORY").Set(); cc->InputSidePackets().Tag("FILE_SUFFIX").Set(); cc->Outputs().Tag("CONTENTS").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { - MP_RETURN_IF_ERROR(::mediapipe::file::MatchFileTypeInDirectory( + absl::Status Open(CalculatorContext* cc) override { + MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory( cc->InputSidePackets().Tag("FILE_DIRECTORY").Get(), cc->InputSidePackets().Tag("FILE_SUFFIX").Get(), &filenames_)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (current_output_ < filenames_.size()) { auto contents = absl::make_unique(); LOG(INFO) << filenames_[current_output_]; @@ -62,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { } else { return tool::StatusStop(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/logic_calculator.cc b/mediapipe/calculators/util/logic_calculator.cc new file mode 100644 index 000000000..d9bb9281a --- /dev/null +++ b/mediapipe/calculators/util/logic_calculator.cc @@ -0,0 +1,105 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/util/logic_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +using mediapipe::LogicCalculatorOptions; + +// A calculator to compute logical functions of bool inputs. +// With just one input, the output equals the input as expected. +// +// Inputs: One or more bool inputs, which may be input-stream-packets, +// input-side-packets, or options input-values. +// +// Outputs: One bool stream. +// +// Example config: +// node { +// calculator: "LogicCalculator" +// input_stream: "has_data" +// input_side_packet: "enable" +// input_stream: "is_valid" +// output_stream: "process_data" +// options { +// [mediapipe.LogicCalculatorOptions.ext] { +// op: AND +// input_value: true +// } +// } +// } +class LogicCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + for (int k = 0; k < cc->Inputs().NumEntries(""); ++k) { + cc->Inputs().Index(k).Set(); + } + for (int k = 0; k < cc->InputSidePackets().NumEntries(""); ++k) { + cc->InputSidePackets().Index(k).Set(); + } + RET_CHECK_GE(cc->Inputs().NumEntries("") + + cc->InputSidePackets().NumEntries("") + + cc->Options().input_value_size(), + 1); + RET_CHECK_EQ(cc->Outputs().NumEntries(""), 1); + cc->Outputs().Index(0).Set(); + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + cc->SetOffset(TimestampDiff(0)); + return absl::OkStatus(); + } + + bool LogicalOp(bool b1, bool b2) { + switch (options_.op()) { + case LogicCalculatorOptions::AND: + return b1 && b2; + case LogicCalculatorOptions::OR: + return b1 || b2; + case LogicCalculatorOptions::XOR: + return b1 ^ b2; + } + return false; + } + + absl::Status Process(CalculatorContext* cc) override { + bool result = options_.op() == LogicCalculatorOptions::AND ? true : false; + for (int k = 0; k < options_.input_value_size(); ++k) { + result = LogicalOp(result, options_.input_value(k)); + } + for (int k = 0; k < cc->Inputs().NumEntries(""); ++k) { + result = LogicalOp(result, cc->Inputs().Index(k).Value().Get()); + } + for (int k = 0; k < cc->InputSidePackets().NumEntries(""); ++k) { + result = LogicalOp(result, cc->InputSidePackets().Index(k).Get()); + } + if (options_.negate()) { + result = !result; + } + cc->Outputs().Index(0).Add(new bool(result), cc->InputTimestamp()); + return absl::OkStatus(); + } + + private: + LogicCalculatorOptions options_; +}; +REGISTER_CALCULATOR(LogicCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/logic_calculator.proto b/mediapipe/calculators/util/logic_calculator.proto new file mode 100644 index 000000000..fe00a2d9b --- /dev/null +++ b/mediapipe/calculators/util/logic_calculator.proto @@ -0,0 +1,38 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LogicCalculatorOptions { + extend CalculatorOptions { + optional LogicCalculatorOptions ext = 338731246; + } + // The logical operation to apply. + enum Operation { + AND = 0; + OR = 1; + XOR = 2; + } + optional Operation op = 1; + + // Whether to negate the result. + optional bool negate = 2; + + // Optional bool input values. + repeated bool input_value = 3; +} diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.cc b/mediapipe/calculators/util/non_max_suppression_calculator.cc index 1ea1b3d6b..535e2a719 100644 --- a/mediapipe/calculators/util/non_max_suppression_calculator.cc +++ b/mediapipe/calculators/util/non_max_suppression_calculator.cc @@ -52,6 +52,7 @@ bool RetainMaxScoringLabelOnly(Detection* detection) { << "Number of scores must be equal to number of detections."; std::vector> indexed_scores; + indexed_scores.reserve(detection->score_size()); for (int k = 0; k < detection->score_size(); ++k) { indexed_scores.push_back(std::make_pair(k, detection->score(k))); } @@ -154,7 +155,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { NonMaxSuppressionCalculator() = default; ~NonMaxSuppressionCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); if (cc->Inputs().HasTag(kImageTag)) { cc->Inputs().Tag(kImageTag).Set(); @@ -163,10 +164,10 @@ class NonMaxSuppressionCalculator : public CalculatorBase { cc->Inputs().Index(k).Set(); } cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -176,10 +177,10 @@ class NonMaxSuppressionCalculator : public CalculatorBase { << "max_num_detections=0 is not a valid value. Please choose a " << "positive number of you want to limit the number of output " << "detections, or set -1 if you do not want any limit."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Add all input detections to the same vector. Detections input_detections; for (int i = 0; i < options_.num_detection_streams(); ++i) { @@ -199,7 +200,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { if (options_.return_empty_detections()) { cc->Outputs().Index(0).Add(new Detections(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Remove all but the maximum scoring label from each input detection. This @@ -244,7 +245,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/packet_frequency_calculator.cc b/mediapipe/calculators/util/packet_frequency_calculator.cc index f63c72fdc..19ffae70e 100644 --- a/mediapipe/calculators/util/packet_frequency_calculator.cc +++ b/mediapipe/calculators/util/packet_frequency_calculator.cc @@ -70,27 +70,25 @@ class PacketFrequencyCalculator : public CalculatorBase { public: PacketFrequencyCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Outputs the given framerate on the specified output stream as a // PacketFrequency proto. - ::mediapipe::Status OutputPacketFrequency(CalculatorContext* cc, - int stream_id, double framerate_hz, - const std::string& label, - const Timestamp& input_timestamp); + absl::Status OutputPacketFrequency(CalculatorContext* cc, int stream_id, + double framerate_hz, + const std::string& label, + const Timestamp& input_timestamp); // Adds the input timestamp in the particular stream's timestamp buffer. - ::mediapipe::Status AddPacketTimestampForStream(int stream_id, - int64 timestamp); + absl::Status AddPacketTimestampForStream(int stream_id, int64 timestamp); // For the specified input stream, clears timestamps from buffer that are // older than the configured time_window_sec. - ::mediapipe::Status ClearOldpacketTimestamps(int stream_id, - int64 current_timestamp); + absl::Status ClearOldpacketTimestamps(int stream_id, int64 current_timestamp); // Options for the calculator. PacketFrequencyCalculatorOptions options_; @@ -106,17 +104,16 @@ class PacketFrequencyCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(PacketFrequencyCalculator); -::mediapipe::Status PacketFrequencyCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketFrequencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Outputs().NumEntries(), cc->Inputs().NumEntries()); for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { +absl::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_EQ(options_.label_size(), cc->Inputs().NumEntries()); RET_CHECK_GT(options_.time_window_sec(), 0); @@ -128,10 +125,10 @@ REGISTER_CALCULATOR(PacketFrequencyCalculator); previous_timestamps_for_stream_id_[i] = {}; first_timestamp_for_stream_id_usec_[i] = -1; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { +absl::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { if (cc->Inputs().Index(i).IsEmpty()) { continue; @@ -165,26 +162,26 @@ REGISTER_CALCULATOR(PacketFrequencyCalculator); options_.label(i), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketFrequencyCalculator::AddPacketTimestampForStream( +absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( int stream_id, int64 timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { - return ::mediapipe::InvalidArgumentError("Input stream id is invalid"); + return absl::InvalidArgumentError("Input stream id is invalid"); } previous_timestamps_for_stream_id_[stream_id].push_back(timestamp_usec); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( +absl::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( int stream_id, int64 current_timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { - return ::mediapipe::InvalidArgumentError("Input stream id is invalid"); + return absl::InvalidArgumentError("Input stream id is invalid"); } auto& timestamps_buffer = previous_timestamps_for_stream_id_[stream_id]; @@ -199,10 +196,10 @@ REGISTER_CALCULATOR(PacketFrequencyCalculator); }), timestamps_buffer.end()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketFrequencyCalculator::OutputPacketFrequency( +absl::Status PacketFrequencyCalculator::OutputPacketFrequency( CalculatorContext* cc, int stream_id, double framerate_hz, const std::string& label, const Timestamp& input_timestamp) { auto packet_frequency = absl::make_unique(); @@ -212,7 +209,7 @@ REGISTER_CALCULATOR(PacketFrequencyCalculator); cc->Outputs().Index(stream_id).Add(packet_frequency.release(), input_timestamp); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 162cc9356..35e415505 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -101,10 +101,10 @@ class PacketLatencyCalculator : public CalculatorBase { public: PacketLatencyCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Resets the histogram and running average variables by initializing them to @@ -139,8 +139,7 @@ class PacketLatencyCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(PacketLatencyCalculator); -::mediapipe::Status PacketLatencyCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); // Input and output streams. @@ -161,7 +160,7 @@ REGISTER_CALCULATOR(PacketLatencyCalculator); .Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void PacketLatencyCalculator::ResetStatistics() { @@ -178,7 +177,7 @@ void PacketLatencyCalculator::ResetStatistics() { } } -::mediapipe::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { +absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); num_packet_streams_ = cc->Inputs().NumEntries() - 1; @@ -225,10 +224,10 @@ void PacketLatencyCalculator::ResetStatistics() { ::mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { +absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { // Record first process timestamp if this is the first call. if (first_process_time_usec_ < 0 && !cc->Inputs().Tag(kReferenceSignalTag).IsEmpty()) { @@ -239,7 +238,7 @@ void PacketLatencyCalculator::ResetStatistics() { if (first_process_time_usec_ < 0) { LOG(WARNING) << "No reference packet received."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (options_.reset_duration_usec() > 0) { @@ -293,7 +292,7 @@ void PacketLatencyCalculator::ResetStatistics() { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index 0b0ec9468..dcc6e7391 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -47,29 +47,28 @@ constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; // class RectProjectionCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(RectProjectionCalculator); -::mediapipe::Status RectProjectionCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectProjectionCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kNormReferenceRectTag).Set(); cc->Outputs().Tag(kNormRectTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectProjectionCalculator::Open(CalculatorContext* cc) { +absl::Status RectProjectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectProjectionCalculator::Process(CalculatorContext* cc) { +absl::Status RectProjectionCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kNormRectTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& rect = cc->Inputs().Tag(kNormRectTag).Get(); @@ -101,7 +100,7 @@ REGISTER_CALCULATOR(RectProjectionCalculator); cc->Outputs().Tag(kNormRectTag).Add(new_rect.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 365d364dc..3b395818f 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -37,7 +37,12 @@ RenderAnnotation::Rectangle* NewRect( annotation->mutable_color()->set_b(options.color().b()); annotation->set_thickness(options.thickness()); - return options.filled() + return options.oval() ? options.filled() + ? annotation->mutable_filled_oval() + ->mutable_oval() + ->mutable_rectangle() + : annotation->mutable_oval()->mutable_rectangle() + : options.filled() ? annotation->mutable_filled_rectangle()->mutable_rectangle() : annotation->mutable_rectangle(); } @@ -89,19 +94,18 @@ void SetRect(bool normalized, double xmin, double ymin, double width, // } class RectToRenderDataCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectToRenderDataCalculatorOptions options_; }; REGISTER_CALCULATOR(RectToRenderDataCalculator); -::mediapipe::Status RectToRenderDataCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectToRenderDataCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) + @@ -125,18 +129,18 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator); } cc->Outputs().Tag(kRenderDataTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { +absl::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { auto render_data = absl::make_unique(); if (cc->Inputs().HasTag(kNormRectTag) && @@ -180,7 +184,7 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator); .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.proto b/mediapipe/calculators/util/rect_to_render_data_calculator.proto index badc8df44..9b6d5e6ee 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.proto @@ -32,4 +32,7 @@ message RectToRenderDataCalculatorOptions { // Thickness of the line (applicable when the rectangle is not filled). optional double thickness = 3 [default = 1.0]; + + // Whether the rendered rectangle should be an oval. + optional bool oval = 4 [default = false]; } diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d55063aa4..79a740315 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -51,39 +51,37 @@ constexpr char kRenderScaleTag[] = "RENDER_SCALE"; // } class RectToRenderScaleCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectToRenderScaleCalculatorOptions options_; }; REGISTER_CALCULATOR(RectToRenderScaleCalculator); -::mediapipe::Status RectToRenderScaleCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectToRenderScaleCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kRenderScaleTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectToRenderScaleCalculator::Open(CalculatorContext* cc) { +absl::Status RectToRenderScaleCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectToRenderScaleCalculator::Process( - CalculatorContext* cc) { +absl::Status RectToRenderScaleCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kNormRectTag).IsEmpty()) { cc->Outputs() .Tag(kRenderScaleTag) .AddPacket( MakePacket(options_.multiplier()).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Get image size. @@ -105,7 +103,7 @@ REGISTER_CALCULATOR(RectToRenderScaleCalculator); .Tag(kRenderScaleTag) .AddPacket(MakePacket(render_scale).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 5b42a3499..7c71dd5a1 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -57,10 +57,10 @@ inline float NormalizeRadians(float angle) { // } class RectTransformationCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectTransformationCalculatorOptions options_; @@ -72,8 +72,7 @@ class RectTransformationCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(RectTransformationCalculator); -::mediapipe::Status RectTransformationCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectTransformationCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) + (cc->Inputs().HasTag(kRectTag) ? 1 : 0) + @@ -100,21 +99,20 @@ REGISTER_CALCULATOR(RectTransformationCalculator); cc->Outputs().Index(0).Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectTransformationCalculator::Open(CalculatorContext* cc) { +absl::Status RectTransformationCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); RET_CHECK(!(options_.has_rotation() && options_.has_rotation_degrees())); RET_CHECK(!(options_.has_square_long() && options_.has_square_short())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RectTransformationCalculator::Process( - CalculatorContext* cc) { +absl::Status RectTransformationCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { auto rect = cc->Inputs().Tag(kRectTag).Get(); TransformRect(&rect); @@ -157,7 +155,7 @@ REGISTER_CALCULATOR(RectTransformationCalculator); cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } float RectTransformationCalculator::ComputeNewRotation(float rotation) { diff --git a/mediapipe/calculators/util/set_landmark_visibility_calculator.cc b/mediapipe/calculators/util/set_landmark_visibility_calculator.cc new file mode 100644 index 000000000..233c3a0cb --- /dev/null +++ b/mediapipe/calculators/util/set_landmark_visibility_calculator.cc @@ -0,0 +1,102 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kVisibilityTag[] = "VISIBILITY"; + +} // namespace + +// A calculator to set landmark visibility. +// +// Inputs: +// NORM_LANDMARKS: A NormalizedLandmarkList with only a single landmark to set +// visibility to. It's a list and not single landmark as split/concatenate +// calculators work with lists. +// +// VISIBILITY: Float visibility of the given landmark. +// +// Outputs: +// NORM_LANDMARKS: A NormalizedLandmarkList with only single landmark with +// updated visibility. +// +// Example config: +// node { +// calculator: "SetLandmarkVisibility" +// input_stream: "NORM_LANDMARKS:landmarks" +// input_stream: "VISIBILITY:visibility" +// output_stream: "NORM_LANDMARKS:landmarks_with_visibility" +// } +// +class SetLandmarkVisibilityCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(SetLandmarkVisibilityCalculator); + +absl::Status SetLandmarkVisibilityCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); + cc->Inputs().Tag(kVisibilityTag).Set(); + cc->Outputs().Tag(kNormalizedLandmarksTag).Set(); + + return absl::OkStatus(); +} + +absl::Status SetLandmarkVisibilityCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + return absl::OkStatus(); +} + +absl::Status SetLandmarkVisibilityCalculator::Process(CalculatorContext* cc) { + // Check that landmark and visibility are not empty. + // Don't emit an empty packet for this timestamp. + if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty() || + cc->Inputs().Tag(kVisibilityTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& in_landmarks = + cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); + RET_CHECK_EQ(in_landmarks.landmark_size(), 1); + const NormalizedLandmark& in_landmark = in_landmarks.landmark(0); + + const auto& visibility = cc->Inputs().Tag(kVisibilityTag).Get(); + + auto out_landmarks = absl::make_unique(); + NormalizedLandmark* out_landmark = out_landmarks->add_landmark(); + *out_landmark = in_landmark; + // Update visibility. + out_landmark->set_visibility(visibility); + + cc->Outputs() + .Tag(kNormalizedLandmarksTag) + .Add(out_landmarks.release(), cc->InputTimestamp()); + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc index 1d7b5476b..65876c075 100644 --- a/mediapipe/calculators/util/thresholding_calculator.cc +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -50,18 +50,17 @@ namespace mediapipe { // } class ThresholdingCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: double threshold_{}; }; REGISTER_CALCULATOR(ThresholdingCalculator); -::mediapipe::Status ThresholdingCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("FLOAT")); cc->Inputs().Tag("FLOAT").Set(); @@ -84,10 +83,10 @@ REGISTER_CALCULATOR(ThresholdingCalculator); "supported."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ThresholdingCalculator::Open(CalculatorContext* cc) { +absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -104,10 +103,10 @@ REGISTER_CALCULATOR(ThresholdingCalculator); if (cc->InputSidePackets().HasTag("THRESHOLD")) { threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ThresholdingCalculator::Process(CalculatorContext* cc) { +absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag("THRESHOLD") && !cc->Inputs().Tag("THRESHOLD").IsEmpty()) { threshold_ = cc->Inputs().Tag("THRESHOLD").Get(); @@ -132,6 +131,6 @@ REGISTER_CALCULATOR(ThresholdingCalculator); MakePacket(false).At(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc index 5d81a7af3..790b426de 100644 --- a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc +++ b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc @@ -48,26 +48,25 @@ using mediapipe::TimedBoxProtoList; // } class TimedBoxListIdToLabelCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: absl::node_hash_map label_map_; }; REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); -::mediapipe::Status TimedBoxListIdToLabelCalculator::GetContract( +absl::Status TimedBoxListIdToLabelCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimedBoxListIdToLabelCalculator::Open( - CalculatorContext* cc) { +absl::Status TimedBoxListIdToLabelCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -84,11 +83,10 @@ REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); while (std::getline(stream, line)) { label_map_[i++] = line; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimedBoxListIdToLabelCalculator::Process( - CalculatorContext* cc) { +absl::Status TimedBoxListIdToLabelCalculator::Process(CalculatorContext* cc) { const auto& input_list = cc->Inputs().Index(0).Get(); auto output_list = absl::make_unique(); for (const auto& input_box : input_list.box()) { @@ -100,7 +98,7 @@ REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); } } cc->Outputs().Index(0).Add(output_list.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc b/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc index 3979d14fc..53c2ffa2f 100644 --- a/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc @@ -120,35 +120,34 @@ class TimedBoxListToRenderDataCalculator : public CalculatorBase { TimedBoxListToRenderDataCalculator& operator=( const TimedBoxListToRenderDataCalculator&) = delete; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: TimedBoxListToRenderDataCalculatorOptions options_; }; REGISTER_CALCULATOR(TimedBoxListToRenderDataCalculator); -::mediapipe::Status TimedBoxListToRenderDataCalculator::GetContract( +absl::Status TimedBoxListToRenderDataCalculator::GetContract( CalculatorContract* cc) { if (cc->Inputs().HasTag(kTimedBoxListTag)) { cc->Inputs().Tag(kTimedBoxListTag).Set(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimedBoxListToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status TimedBoxListToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TimedBoxListToRenderDataCalculator::Process( +absl::Status TimedBoxListToRenderDataCalculator::Process( CalculatorContext* cc) { auto render_data = absl::make_unique(); @@ -164,7 +163,7 @@ REGISTER_CALCULATOR(TimedBoxListToRenderDataCalculator); cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/to_image_calculator.cc b/mediapipe/calculators/util/to_image_calculator.cc new file mode 100644 index 000000000..5e119fca7 --- /dev/null +++ b/mediapipe/calculators/util/to_image_calculator.cc @@ -0,0 +1,160 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { +constexpr char kImageFrameTag[] = "IMAGE_CPU"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kImageTag[] = "IMAGE"; +} // namespace + +// A calculator for converting from legacy MediaPipe datatypes into a +// unified image container. +// +// Inputs: +// One of the following two tags: +// IMAGE_CPU: An ImageFrame containing input image. +// IMAGE_GPU: A GpuBuffer containing input image. +// +// Output: +// IMAGE: An Image containing output image. +// +// Note: +// No CPU/GPU conversion is done. +// +class ToImageCalculator : public CalculatorBase { + public: + ToImageCalculator() = default; + ~ToImageCalculator() override = default; + + static absl::Status GetContract(CalculatorContract* cc); + + // From Calculator. + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + + bool gpu_input_ = false; + bool gpu_initialized_ = false; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(ToImageCalculator); + +absl::Status ToImageCalculator::GetContract(CalculatorContract* cc) { + cc->Outputs().Tag(kImageTag).Set(); + + bool gpu_input = false; + + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("Cannot have multiple inputs."); + } + + if (cc->Inputs().HasTag(kGpuBufferTag)) { +#if !MEDIAPIPE_DISABLE_GPU + cc->Inputs().Tag(kGpuBufferTag).Set(); + gpu_input = true; +#else + RET_CHECK_FAIL() << "GPU is disabled. Cannot use IMAGE_GPU stream."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + if (cc->Inputs().HasTag(kImageFrameTag)) { + cc->Inputs().Tag(kImageFrameTag).Set(); + } + + if (gpu_input) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Inputs().HasTag(kGpuBufferTag)) { + gpu_input_ = true; + } + + if (gpu_input_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Process(CalculatorContext* cc) { + if (gpu_input_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { + auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); + // Wrap texture pointer; shallow copy. + auto output = std::make_unique(input); + cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp()); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + // The input ImageFrame. + auto& input = cc->Inputs().Tag(kImageFrameTag).Get(); + // Make a copy of the input packet to co-own the input ImageFrame. + Packet* packet_copy_ptr = + new Packet(cc->Inputs().Tag(kImageFrameTag).Value()); + // Create an output Image that (co-)owns a new ImageFrame that points to + // the same pixel data as the input ImageFrame and also owns the packet + // copy. As a result, the output Image indirectly co-owns the input + // ImageFrame. This ensures a correct life span of the shared pixel data. + std::unique_ptr output = + std::make_unique( + std::make_shared( + input.Format(), input.Width(), input.Height(), + input.WidthStep(), const_cast(input.PixelData()), + [packet_copy_ptr](uint8*) { delete packet_copy_ptr; })); + cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index 1d5a8fede..37d1b2ab2 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -62,14 +62,14 @@ namespace mediapipe { // } class TopKScoresCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - ::mediapipe::Status LoadLabelmap(std::string label_map_path); + absl::Status LoadLabelmap(std::string label_map_path); int top_k_ = -1; float threshold_ = 0.0; @@ -78,7 +78,7 @@ class TopKScoresCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TopKScoresCalculator); -::mediapipe::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { +absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("SCORES")); cc->Inputs().Tag("SCORES").Set>(); if (cc->Outputs().HasTag("TOP_K_INDEXES")) { @@ -96,10 +96,10 @@ REGISTER_CALCULATOR(TopKScoresCalculator); if (cc->Outputs().HasTag("SUMMARY")) { cc->Outputs().Tag("SUMMARY").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TopKScoresCalculator::Open(CalculatorContext* cc) { +absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::TopKScoresCalculatorOptions>(); RET_CHECK(options.has_top_k() || options.has_threshold()) << "Must specify at least one of the top_k and threshold fields in " @@ -117,10 +117,10 @@ REGISTER_CALCULATOR(TopKScoresCalculator); if (cc->Outputs().HasTag("TOP_K_LABELS")) { RET_CHECK(!label_map_.empty()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TopKScoresCalculator::Process(CalculatorContext* cc) { +absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { const std::vector& input_vector = cc->Inputs().Tag("SCORES").Get>(); std::vector top_k_indexes; @@ -213,11 +213,10 @@ REGISTER_CALCULATOR(TopKScoresCalculator); } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TopKScoresCalculator::LoadLabelmap( - std::string label_map_path) { +absl::Status TopKScoresCalculator::LoadLabelmap(std::string label_map_path) { std::string string_path; ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(label_map_path)); std::string label_map_string; @@ -230,7 +229,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator); label_map_[i++] = line; } label_map_loaded_ = true; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 2930c488a..af526044a 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -317,6 +317,7 @@ cc_library( "//mediapipe/util/tracking:box_tracker", "//mediapipe/util/tracking:tracking_visualization_utilities", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index d9afdd333..b7b91d253 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -92,11 +92,11 @@ class BoxDetectorCalculator : public CalculatorBase { public: ~BoxDetectorCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: BoxDetectorCalculatorOptions options_; @@ -109,7 +109,7 @@ class BoxDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(BoxDetectorCalculator); -::mediapipe::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { +absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("TRACKING")) { cc->Inputs().Tag("TRACKING").Set(); } @@ -172,10 +172,10 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); box_detector_ = BoxDetectorInterface::Create(options_.detector_options()); @@ -210,10 +210,10 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const Timestamp timestamp = cc->InputTimestamp(); const int64 timestamp_msec = timestamp.Value() / 1000; @@ -246,7 +246,7 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); } if (!detector_switch_) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* track_stream = cc->Inputs().HasTag("TRACKING") @@ -274,7 +274,7 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); if (track_stream != nullptr) { // Detect from tracking data if (track_stream->IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const TrackingData& tracking_data = track_stream->Get(); @@ -289,7 +289,7 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); } else if (video_stream != nullptr) { // Detect from input frame if (video_stream->IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } TimedBoxProtoList tracked_boxes; @@ -305,7 +305,7 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); detected_boxes.get()); } else { if (feature_stream->IsEmpty() || descriptor_stream->IsEmpty()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& image_size = @@ -377,17 +377,17 @@ REGISTER_CALCULATOR(BoxDetectorCalculator); cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { if (write_index_) { BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex(); MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents( cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get(), index.SerializeAsString())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index a56392ee3..7d04d9765 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -19,6 +19,7 @@ #include #include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" #include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/video/box_tracker_calculator.pb.h" @@ -122,10 +123,10 @@ class BoxTrackerCalculator : public CalculatorBase { public: ~BoxTrackerCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: void RenderStates(const std::vector& states, cv::Mat* mat); @@ -166,7 +167,7 @@ class BoxTrackerCalculator : public CalculatorBase { }; // MotionBoxPath per unique id that we are tracking. - typedef std::unordered_map MotionBoxMap; + typedef absl::node_hash_map MotionBoxMap; // Performs tracking of all MotionBoxes in box_map by one frame forward or // backward to or from data_frame_num using passed TrackingData. @@ -207,7 +208,7 @@ class BoxTrackerCalculator : public CalculatorBase { // Boxes that are tracked in streaming mode. MotionBoxMap streaming_motion_boxes_; - std::unordered_map> last_tracked_boxes_; + absl::node_hash_map> last_tracked_boxes_; int frame_num_since_reset_ = 0; // Cache used during streaming mode for fast forward tracking. @@ -372,7 +373,7 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, } // namespace. -::mediapipe::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { +absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("TRACKING")) { cc->Inputs().Tag("TRACKING").Set(); } @@ -451,10 +452,10 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, cc->InputSidePackets().Tag(kOptionsTag).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { +absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -514,10 +515,10 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, << "Streaming mode not compatible with cache dir."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { +absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { // Batch mode, issue tracking requests. if (box_tracker_ && !tracking_issued_) { for (const auto& pos : initial_pos_.box()) { @@ -529,7 +530,7 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, const Timestamp& timestamp = cc->InputTimestamp(); if (timestamp == Timestamp::PreStream()) { // Indicator packet. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* track_stream = cc->Inputs().HasTag("TRACKING") @@ -891,7 +892,7 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void BoxTrackerCalculator::AddSmoothTransitionToOutputBox( diff --git a/mediapipe/calculators/video/flow_packager_calculator.cc b/mediapipe/calculators/video/flow_packager_calculator.cc index f871433eb..a57105928 100644 --- a/mediapipe/calculators/video/flow_packager_calculator.cc +++ b/mediapipe/calculators/video/flow_packager_calculator.cc @@ -59,11 +59,11 @@ class FlowPackagerCalculator : public CalculatorBase { public: ~FlowPackagerCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; // Writes passed chunk to disk. void WriteChunk(const TrackingDataChunk& chunk) const; @@ -90,8 +90,7 @@ class FlowPackagerCalculator : public CalculatorBase { REGISTER_CALCULATOR(FlowPackagerCalculator); -::mediapipe::Status FlowPackagerCalculator::GetContract( - CalculatorContract* cc) { +absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().HasTag("FLOW")) { return tool::StatusFail("No input flow was specified."); } @@ -115,10 +114,10 @@ REGISTER_CALCULATOR(FlowPackagerCalculator); cc->InputSidePackets().Tag("CACHE_DIR").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); flow_packager_.reset(new FlowPackager(options_.flow_packager_options())); @@ -129,10 +128,10 @@ REGISTER_CALCULATOR(FlowPackagerCalculator); cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { InputStream* flow_stream = &(cc->Inputs().Tag("FLOW")); const RegionFlowFeatureList& flow = flow_stream->Get(); @@ -194,10 +193,10 @@ REGISTER_CALCULATOR(FlowPackagerCalculator); prev_timestamp_ = timestamp; ++frame_idx_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { if (frame_idx_ > 0) { tracking_chunk_.set_last_chunk(true); if (cc->Outputs().HasTag("TRACKING_CHUNK")) { @@ -216,7 +215,7 @@ REGISTER_CALCULATOR(FlowPackagerCalculator); cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void FlowPackagerCalculator::WriteChunk(const TrackingDataChunk& chunk) const { diff --git a/mediapipe/calculators/video/flow_to_image_calculator.cc b/mediapipe/calculators/video/flow_to_image_calculator.cc index d32319c6f..6a078ee72 100644 --- a/mediapipe/calculators/video/flow_to_image_calculator.cc +++ b/mediapipe/calculators/video/flow_to_image_calculator.cc @@ -56,27 +56,27 @@ class FlowToImageCalculator : public CalculatorBase { public: FlowToImageCalculator() {} ~FlowToImageCalculator() override {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: FlowQuantizerModel model_; }; -::mediapipe::Status FlowToImageCalculator::GetContract(CalculatorContract* cc) { +absl::Status FlowToImageCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); // Model sanity check const auto& options = cc->Options(); if (options.min_value() >= options.max_value()) { - return ::mediapipe::InvalidArgumentError("Invalid quantizer model."); + return absl::InvalidArgumentError("Invalid quantizer model."); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FlowToImageCalculator::Open(CalculatorContext* cc) { +absl::Status FlowToImageCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options(); // Fill the the model_data, ideally we want to train the model, but we omit // the step for now, and takes the (min, max) range from protobuf. @@ -86,10 +86,10 @@ class FlowToImageCalculator : public CalculatorBase { options.min_value(), options.min_value(), options.max_value(), options.max_value())); model_.LoadFromProto(model_data); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FlowToImageCalculator::Process(CalculatorContext* cc) { +absl::Status FlowToImageCalculator::Process(CalculatorContext* cc) { const auto& input = cc->Inputs().Index(0).Get(); // Input flow is 2-channel with x-dim flow and y-dim flow. // Convert it to a ImageFrame in SRGB space, the 3rd channel is not used (0). @@ -106,7 +106,7 @@ class FlowToImageCalculator : public CalculatorBase { } } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(FlowToImageCalculator); diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index 6746bc3a9..59673108c 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -95,11 +95,11 @@ class MotionAnalysisCalculator : public CalculatorBase { public: ~MotionAnalysisCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: // Outputs results to Outputs() if MotionAnalysis buffered sufficient results. @@ -107,8 +107,8 @@ class MotionAnalysisCalculator : public CalculatorBase { void OutputMotionAnalyzedFrames(bool flush, CalculatorContext* cc); // Lazy init function to be called on Process. - ::mediapipe::Status InitOnProcess(InputStream* video_stream, - InputStream* selection_stream); + absl::Status InitOnProcess(InputStream* video_stream, + InputStream* selection_stream); // Parses CSV file contents to homographies. bool ParseModelCSV(const std::string& contents, @@ -189,8 +189,7 @@ class MotionAnalysisCalculator : public CalculatorBase { REGISTER_CALCULATOR(MotionAnalysisCalculator); -::mediapipe::Status MotionAnalysisCalculator::GetContract( - CalculatorContract* cc) { +absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("VIDEO")) { cc->Inputs().Tag("VIDEO").Set(); } @@ -246,10 +245,10 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); cc->InputSidePackets().Tag(kOptionsTag).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -364,7 +363,7 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); // If no video header is provided, just return and initialize on the first // Process() call. if (video_header == nullptr) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE /////////////// @@ -397,12 +396,12 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); .SetHeader(Adopt(new VideoHeader(*video_header))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { if (options_.bypass_mode()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* video_stream = @@ -441,7 +440,7 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); } ++frame_idx_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (motion_analysis_ == nullptr) { @@ -491,7 +490,7 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_frame) { @@ -520,7 +519,7 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); selected_motions_.push_back(frame_selection_result->camera_motion()); switch (options_.selection_analysis()) { case MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION: - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Should not reach this point!"; case MotionAnalysisCalculatorOptions::ANALYSIS_FROM_FEATURES: @@ -574,10 +573,10 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); OutputMotionAnalyzedFrames(false, cc); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { // Guard against empty videos. if (motion_analysis_) { OutputMotionAnalyzedFrames(true, cc); @@ -588,7 +587,7 @@ REGISTER_CALCULATOR(MotionAnalysisCalculator); << meta_motions_.size(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( @@ -688,7 +687,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( } } -::mediapipe::Status MotionAnalysisCalculator::InitOnProcess( +absl::Status MotionAnalysisCalculator::InitOnProcess( InputStream* video_stream, InputStream* selection_stream) { if (video_stream) { frame_width_ = video_stream->Get().Width(); @@ -761,7 +760,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( motion_options->set_filter_initialized_irls_weights(true); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool MotionAnalysisCalculator::ParseModelCSV( diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index c75e58620..bf7ed3e8a 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -86,7 +86,7 @@ ImageFormat::Format GetImageFormat(int num_channels) { // class OpenCvVideoDecoderCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); cc->Outputs().Tag("VIDEO").Set(); if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { @@ -95,15 +95,15 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const std::string& input_file_path = cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); cap_ = absl::make_unique(input_file_path); if (!cap_->isOpened()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to open video file at " << input_file_path; } width_ = static_cast(cap_->get(cv::CAP_PROP_FRAME_WIDTH)); @@ -116,19 +116,19 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { cv::Mat frame; cap_->read(frame); if (frame.empty()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to read any frames from the video file at " << input_file_path; } format_ = GetImageFormat(frame.channels()); if (format_ == ImageFormat::UNKNOWN) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported video format of the video file at " << input_file_path; } if (fps <= 0 || frame_count_ <= 0 || width_ <= 0 || height_ <= 0) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to make video header due to the incorrect metadata from " "the video file at " << input_file_path; @@ -170,17 +170,17 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { .Set(MakePacket(std::string())); } #else - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "OpenCVVideoDecoderCalculator can't save the audio file " "because FFmpeg is not installed. Please remove " "output_side_packet: \"SAVED_AUDIO_PATH\" from the node " "config."; #endif } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto image_frame = absl::make_unique(format_, width_, height_, /*alignment_boundary=*/1); // Use microsecond as the unit of time. @@ -213,10 +213,10 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { decoded_frames_++; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { if (cap_ && cap_->isOpened()) { cap_->release(); } @@ -225,7 +225,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { << frame_count_ << " vs decoded frames: " << decoded_frames_ << ")."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index 43d71f059..9a74fb710 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -76,21 +76,20 @@ namespace mediapipe { // class OpenCvVideoEncoderCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - ::mediapipe::Status SetUpVideoWriter(float frame_rate, int width, int height); + absl::Status SetUpVideoWriter(float frame_rate, int width, int height); std::string output_file_path_; int four_cc_; std::unique_ptr writer_; }; -::mediapipe::Status OpenCvVideoEncoderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("VIDEO")); cc->Inputs().Tag("VIDEO").Set(); if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { @@ -101,10 +100,10 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { OpenCvVideoEncoderCalculatorOptions options = cc->Options(); RET_CHECK(options.has_codec() && options.codec().length() == 4) @@ -128,13 +127,12 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { // from the video header directly. The calculator will receive the video // header packet at timestamp prestream. if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } return SetUpVideoWriter(options.fps(), options.width(), options.height()); } -::mediapipe::Status OpenCvVideoEncoderCalculator::Process( - CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); @@ -149,7 +147,7 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { if (format == ImageFormat::GRAY8) { frame = formats::MatView(&image_frame); if (frame.empty()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " << cc->Inputs().Tag("VIDEO").Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; @@ -157,7 +155,7 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { } else { cv::Mat tmp_frame = formats::MatView(&image_frame); if (tmp_frame.empty()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " << cc->Inputs().Tag("VIDEO").Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; @@ -167,15 +165,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { } else if (format == ImageFormat::SRGBA) { cv::cvtColor(tmp_frame, frame, cv::COLOR_RGBA2BGR); } else { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format: " << format; } } writer_->write(frame); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { if (writer_ && writer_->isOpened()) { writer_->release(); } @@ -199,28 +197,29 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { } #else - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "OpenCVVideoEncoderCalculator can't attach the audio tracks to " "the video because FFmpeg is not installed. Please remove " "input_side_packet: \"AUDIO_FILE_PATH\" from the node " "config."; #endif } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter( - float frame_rate, int width, int height) { +absl::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter(float frame_rate, + int width, + int height) { RET_CHECK(frame_rate > 0 && width > 0 && height > 0) << "Invalid video metadata: frame_rate=" << frame_rate << ", width=" << width << ", height=" << height; writer_ = absl::make_unique( output_file_path_, four_cc_, frame_rate, cv::Size(width, height)); if (!writer_->isOpened()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to open file at " << output_file_path_; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvVideoEncoderCalculator); diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc index faf693c60..1a1530331 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc @@ -70,7 +70,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; @@ -129,7 +129,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; @@ -190,7 +190,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; diff --git a/mediapipe/calculators/video/tool/BUILD b/mediapipe/calculators/video/tool/BUILD index 3d3ed2f86..408461d2f 100644 --- a/mediapipe/calculators/video/tool/BUILD +++ b/mediapipe/calculators/video/tool/BUILD @@ -19,8 +19,6 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"]) -exports_files(["LICENSE"]) - proto_library( name = "flow_quantizer_model_proto", srcs = ["flow_quantizer_model.proto"], diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index 7e6ba6749..c416fa9b0 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -106,7 +106,15 @@ Detection GetAxisAlignedDetectionFromTrackedDetection( } else { detection.set_detection_id(tracked_detection.unique_id()); } + + // Sort the labels by descending scores. + std::vector> labels_and_scores; for (const auto& label_and_score : tracked_detection.label_to_score_map()) { + labels_and_scores.push_back(label_and_score); + } + std::sort(labels_and_scores.begin(), labels_and_scores.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + for (const auto& label_and_score : labels_and_scores) { detection.add_label(label_and_score.first); detection.add_score(label_and_score.second); } @@ -139,10 +147,10 @@ Detection GetAxisAlignedDetectionFromTrackedDetection( // } class TrackedDetectionManagerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Adds new list of detections to |waiting_for_update_detections_|. @@ -161,7 +169,7 @@ class TrackedDetectionManagerCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TrackedDetectionManagerCalculator); -::mediapipe::Status TrackedDetectionManagerCalculator::GetContract( +absl::Status TrackedDetectionManagerCalculator::GetContract( CalculatorContract* cc) { if (cc->Inputs().HasTag(kDetectionsTag)) { cc->Inputs().Tag(kDetectionsTag).Set>(); @@ -183,20 +191,18 @@ REGISTER_CALCULATOR(TrackedDetectionManagerCalculator); cc->Outputs().Tag(kDetectionBoxesTag).Set>(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TrackedDetectionManagerCalculator::Open( - CalculatorContext* cc) { +absl::Status TrackedDetectionManagerCalculator::Open(CalculatorContext* cc) { mediapipe::TrackedDetectionManagerCalculatorOptions options = cc->Options(); tracked_detection_manager_.SetConfig( options.tracked_detection_manager_options()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status TrackedDetectionManagerCalculator::Process( - CalculatorContext* cc) { +absl::Status TrackedDetectionManagerCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kTrackingBoxesTag) && !cc->Inputs().Tag(kTrackingBoxesTag).IsEmpty()) { const TimedBoxProtoList& tracked_boxes = @@ -296,7 +302,7 @@ REGISTER_CALCULATOR(TrackedDetectionManagerCalculator); AddDetectionList(detection_list, cc); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void TrackedDetectionManagerCalculator::AddDetectionList( diff --git a/mediapipe/calculators/video/tracking_graph_test.cc b/mediapipe/calculators/video/tracking_graph_test.cc index fc04ee6e8..e446e155c 100644 --- a/mediapipe/calculators/video/tracking_graph_test.cc +++ b/mediapipe/calculators/video/tracking_graph_test.cc @@ -52,14 +52,14 @@ std::string GetTestDir() { CFURLGetFileSystemRepresentation( bundle_url, true, reinterpret_cast(path), sizeof(path)); CFRelease(bundle_url); - return ::mediapipe::file::JoinPath(path, "testdata"); + return mediapipe::file::JoinPath(path, "testdata"); #elif defined(__ANDROID__) char path[1024]; getcwd(path, sizeof(path)); - return ::mediapipe::file::JoinPath(path, - "mediapipe/calculators/video/testdata"); + return mediapipe::file::JoinPath(path, + "mediapipe/calculators/video/testdata"); #else - return ::mediapipe::file::JoinPath( + return mediapipe::file::JoinPath( "./", // This should match the path of the output files // of the genrule() that generates test model files. diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc index c774cfeb1..cf00da1f7 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -74,14 +74,14 @@ cv::Mat ConvertToGrayscale(const cv::Mat& image) { // num_threads: 10 class Tvl1OpticalFlowCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - ::mediapipe::Status CalculateOpticalFlow(const ImageFrame& current_frame, - const ImageFrame& next_frame, - OpticalFlowField* flow); + absl::Status CalculateOpticalFlow(const ImageFrame& current_frame, + const ImageFrame& next_frame, + OpticalFlowField* flow); bool forward_requested_ = false; bool backward_requested_ = false; // Stores the idle DenseOpticalFlow objects. @@ -93,11 +93,10 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { absl::Mutex mutex_; }; -::mediapipe::Status Tvl1OpticalFlowCalculator::GetContract( - CalculatorContract* cc) { +absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().HasTag("FIRST_FRAME") || !cc->Inputs().HasTag("SECOND_FRAME")) { - return ::mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " "must be specified."); } @@ -109,10 +108,10 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { if (cc->Outputs().HasTag("BACKWARD_FLOW")) { cc->Outputs().Tag("BACKWARD_FLOW").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { +absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { { absl::MutexLock lock(&mutex_); tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); @@ -124,10 +123,10 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { backward_requested_ = true; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { +absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { const ImageFrame& first_frame = cc->Inputs().Tag("FIRST_FRAME").Value().Get(); const ImageFrame& second_frame = @@ -148,10 +147,10 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { .Tag("BACKWARD_FLOW") .Add(backward_optical_flow_field.release(), cc->InputTimestamp()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( +absl::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( const ImageFrame& current_frame, const ImageFrame& next_frame, OpticalFlowField* flow) { CHECK(flow); @@ -184,7 +183,7 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { absl::MutexLock lock(&mutex_); tvl1_computers_.push_back(tvl1_computer); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(Tvl1OpticalFlowCalculator); diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc index b226dfd87..c9d30b73d 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc @@ -78,11 +78,11 @@ void RunTest(int num_input_packets, int max_in_flight) { StatusOrPoller status_or_poller1 = graph.AddOutputStreamPoller("forward_flow"); ASSERT_TRUE(status_or_poller1.ok()); - OutputStreamPoller poller1 = std::move(status_or_poller1.ValueOrDie()); + OutputStreamPoller poller1 = std::move(status_or_poller1.value()); StatusOrPoller status_or_poller2 = graph.AddOutputStreamPoller("backward_flow"); ASSERT_TRUE(status_or_poller2.ok()); - OutputStreamPoller poller2 = std::move(status_or_poller2.ValueOrDie()); + OutputStreamPoller poller2 = std::move(status_or_poller2.value()); MP_ASSERT_OK(graph.StartRun({})); AddInputPackets(num_input_packets, &graph); diff --git a/mediapipe/calculators/video/video_pre_stream_calculator.cc b/mediapipe/calculators/video/video_pre_stream_calculator.cc index 69c76ec36..ab9cd22a4 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator.cc @@ -45,13 +45,13 @@ namespace mediapipe { // } class VideoPreStreamCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - ::mediapipe::Status ProcessWithFrameRateInPreStream(CalculatorContext* cc); - ::mediapipe::Status ProcessWithFrameRateInOptions(CalculatorContext* cc); + absl::Status ProcessWithFrameRateInPreStream(CalculatorContext* cc); + absl::Status ProcessWithFrameRateInOptions(CalculatorContext* cc); std::unique_ptr header_; bool frame_rate_in_prestream_ = false; @@ -60,8 +60,7 @@ class VideoPreStreamCalculator : public CalculatorBase { REGISTER_CALCULATOR(VideoPreStreamCalculator); -::mediapipe::Status VideoPreStreamCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().UsesTags()) { cc->Inputs().Index(0).Set(); } else { @@ -69,17 +68,17 @@ REGISTER_CALCULATOR(VideoPreStreamCalculator); cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); } cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { +absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { frame_rate_in_prestream_ = cc->Inputs().UsesTags() && cc->Inputs().HasTag("FRAME") && cc->Inputs().HasTag("VIDEO_PRESTREAM"); header_ = absl::make_unique(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( +absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment(); if (cc->InputTimestamp() == Timestamp::PreStream()) { @@ -99,13 +98,13 @@ REGISTER_CALCULATOR(VideoPreStreamCalculator); cc->Outputs().Index(0).Add(header_.release(), Timestamp::PreStream()); emitted_ = true; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VideoPreStreamCalculator::Process(CalculatorContext* cc) { +absl::Status VideoPreStreamCalculator::Process(CalculatorContext* cc) { cc->GetCounter("Process")->Increment(); if (emitted_) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (frame_rate_in_prestream_) { return ProcessWithFrameRateInPreStream(cc); @@ -114,7 +113,7 @@ REGISTER_CALCULATOR(VideoPreStreamCalculator); } } -::mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInOptions( +absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInOptions( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInOptions")->Increment(); RET_CHECK_NE(cc->InputTimestamp(), Timestamp::PreStream()); @@ -136,7 +135,7 @@ REGISTER_CALCULATOR(VideoPreStreamCalculator); RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero"; cc->Outputs().Index(0).Add(header_.release(), Timestamp::PreStream()); emitted_ = true; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/video_pre_stream_calculator_test.cc b/mediapipe/calculators/video/video_pre_stream_calculator_test.cc index 1f0ccb7f2..38f132e9e 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator_test.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator_test.cc @@ -39,7 +39,7 @@ TEST(VideoPreStreamCalculatorTest, ProcessesWithFrameRateInOptions) { MP_ASSERT_OK(graph.Initialize(config)); auto poller_status = graph.AddOutputStreamPoller("output"); MP_ASSERT_OK(poller_status.status()); - OutputStreamPoller& poller = poller_status.ValueOrDie(); + OutputStreamPoller& poller = poller_status.value(); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( "input", @@ -79,7 +79,7 @@ TEST(VideoPreStreamCalculatorTest, ProcessesWithFrameRateInPreStream) { MP_ASSERT_OK(graph.Initialize(config)); auto poller_status = graph.AddOutputStreamPoller("output_header"); MP_ASSERT_OK(poller_status.status()); - OutputStreamPoller& poller = poller_status.ValueOrDie(); + OutputStreamPoller& poller = poller_status.value(); MP_ASSERT_OK(graph.StartRun({})); auto input_header = absl::make_unique(); input_header->frame_rate = 3.0; @@ -118,7 +118,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInOptions) { "frame", Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); - ::mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.ToString(), testing::HasSubstr("frame rate should be non-zero")); @@ -144,7 +144,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInPreStream1) { Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); MP_ASSERT_OK(graph.CloseInputStream("input_header")); - ::mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.ToString(), testing::HasSubstr("frame rate should be non-zero")); @@ -177,7 +177,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInPreStream2) { "frame", Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); - ::mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); } } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml index 99288624c..f7218c97c 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml @@ -32,5 +32,6 @@ + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/BUILD index d0ff4e8cb..ae4652dba 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/BUILD @@ -74,6 +74,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/MainActivity.java index 8a4924756..952132cdf 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/MainActivity.java @@ -46,6 +46,14 @@ public class MainActivity extends AppCompatActivity { // NOTE: use "flipFramesVertically" in manifest metadata to override this behavior. private static final boolean FLIP_FRAMES_VERTICALLY = true; + // Number of output frames allocated in ExternalTextureConverter. + // NOTE: use "converterNumBuffers" in manifest metadata to override number of buffers. For + // example, when there is a FlowLimiterCalculator in the graph, number of buffers should be at + // least `max_in_flight + max_in_queue + 1` (where max_in_flight and max_in_queue are used in + // FlowLimiterCalculator options). That's because we need buffers for all the frames that are in + // flight/queue plus one for the next frame from the camera. + private static final int NUM_BUFFERS = 2; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -103,7 +111,6 @@ public class MainActivity extends AppCompatActivity { applicationInfo.metaData.getString("binaryGraphName"), applicationInfo.metaData.getString("inputVideoStreamName"), applicationInfo.metaData.getString("outputVideoStreamName")); - processor .getVideoSurfaceOutput() .setFlipY( @@ -121,7 +128,10 @@ public class MainActivity extends AppCompatActivity { @Override protected void onResume() { super.onResume(); - converter = new ExternalTextureConverter(eglManager.getContext()); + converter = + new ExternalTextureConverter( + eglManager.getContext(), + applicationInfo.metaData.getInt("converterNumBuffers", NUM_BUFFERS)); converter.setFlipY( applicationInfo.metaData.getBoolean("flipFramesVertically", FLIP_FRAMES_VERTICALLY)); converter.setConsumer(processor); @@ -134,6 +144,9 @@ public class MainActivity extends AppCompatActivity { protected void onPause() { super.onPause(); converter.close(); + + // Hide preview display until we re-open the camera again. + previewDisplayView.setVisibility(View.GONE); } @Override @@ -165,7 +178,7 @@ public class MainActivity extends AppCompatActivity { ? CameraHelper.CameraFacing.FRONT : CameraHelper.CameraFacing.BACK; cameraHelper.startCamera( - this, cameraFacing, /*surfaceTexture=*/ null, cameraTargetResolution()); + this, cameraFacing, /*unusedSurfaceTexture=*/ null, cameraTargetResolution()); } protected Size computeViewSize(int width, int height) { diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD index 7536be08b..279d29b74 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD @@ -36,9 +36,8 @@ android_binary( name = "facedetectioncpu", srcs = glob(["*.java"]), assets = [ - "//mediapipe/graphs/face_detection:mobile_cpu.binarypb", - "//mediapipe/models:face_detection_front.tflite", - "//mediapipe/models:face_detection_front_labelmap.txt", + "//mediapipe/graphs/face_detection:face_detection_mobile_cpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", @@ -47,10 +46,11 @@ android_binary( "appName": "Face Detection (CPU)", "mainActivity": "com.google.mediapipe.apps.basic.MainActivity", "cameraFacingFront": "True", - "binaryGraphName": "mobile_cpu.binarypb", + "binaryGraphName": "face_detection_mobile_cpu.binarypb", "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD index 46a758ab6..11351fc56 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD @@ -36,9 +36,8 @@ android_binary( name = "facedetectiongpu", srcs = glob(["*.java"]), assets = [ - "//mediapipe/graphs/face_detection:mobile_gpu.binarypb", - "//mediapipe/models:face_detection_front.tflite", - "//mediapipe/models:face_detection_front_labelmap.txt", + "//mediapipe/graphs/face_detection:face_detection_mobile_gpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", @@ -47,10 +46,11 @@ android_binary( "appName": "Face Detection", "mainActivity": "com.google.mediapipe.apps.basic.MainActivity", "cameraFacingFront": "True", - "binaryGraphName": "mobile_gpu.binarypb", + "binaryGraphName": "face_detection_mobile_gpu.binarypb", "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD new file mode 100644 index 000000000..8bf6c0a54 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD @@ -0,0 +1,71 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/face_effect:face_effect_gpu_deps", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +android_binary( + name = "faceeffect", + srcs = glob(["*.java"]), + assets = [ + "//mediapipe/graphs/face_effect/data:axis.binarypb", + "//mediapipe/graphs/face_effect/data:axis.pngblob", + "//mediapipe/graphs/face_effect/data:facepaint.pngblob", + "//mediapipe/graphs/face_effect/data:glasses.binarypb", + "//mediapipe/graphs/face_effect/data:glasses.pngblob", + "//mediapipe/graphs/face_effect:face_effect_gpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_detection.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_landmarks.binarypb", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + ], + assets_dir = "", + manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.apps.faceeffect", + "appName": "Face Effect", + "mainActivity": ".MainActivity", + "cameraFacingFront": "True", + "binaryGraphName": "face_effect_gpu.binarypb", + "inputVideoStreamName": "input_video", + "outputVideoStreamName": "output_video", + "flipFramesVertically": "True", + "converterNumBuffers": "2", + }, + multidex = "native", + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", + "//mediapipe/framework/formats:matrix_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/modules/face_geometry/protos:face_geometry_java_proto_lite", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java new file mode 100644 index 000000000..78c220aae --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java @@ -0,0 +1,216 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.faceeffect; + +import android.graphics.Color; +import android.os.Bundle; +import android.util.Log; +import android.view.GestureDetector; +import android.view.Gravity; +import android.view.MotionEvent; +import android.view.View; +import android.view.ViewGroup; +import android.view.ViewGroup.LayoutParams; +import android.widget.RelativeLayout; +import android.widget.TextView; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.modules.facegeometry.FaceGeometryProto.FaceGeometry; +import com.google.mediapipe.formats.proto.MatrixDataProto.MatrixData; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Main activity of MediaPipe face mesh app. */ +public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { + private static final String TAG = "MainActivity"; + + // Side packet / stream names. + private static final String USE_FACE_DETECTION_INPUT_SOURCE_INPUT_SIDE_PACKET_NAME = + "use_face_detection_input_source"; + private static final String SELECTED_EFFECT_ID_INPUT_STREAM_NAME = "selected_effect_id"; + private static final String OUTPUT_FACE_GEOMETRY_STREAM_NAME = "multi_face_geometry"; + + private static final String EFFECT_SWITCHING_HINT_TEXT = "Tap to switch between effects!"; + + private static final boolean USE_FACE_DETECTION_INPUT_SOURCE = false; + private static final int MATRIX_TRANSLATION_Z_INDEX = 14; + + private static final int SELECTED_EFFECT_ID_AXIS = 0; + private static final int SELECTED_EFFECT_ID_FACEPAINT = 1; + private static final int SELECTED_EFFECT_ID_GLASSES = 2; + + private final Object effectSelectionLock = new Object(); + private int selectedEffectId; + + private View effectSwitchingHintView; + private GestureDetector tapGestureDetector; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + // Add an effect switching hint view to the preview layout. + effectSwitchingHintView = createEffectSwitchingHintView(); + effectSwitchingHintView.setVisibility(View.INVISIBLE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(effectSwitchingHintView); + + // By default, render the axis effect for the face detection input source and the glasses effect + // for the face landmark input source. + if (USE_FACE_DETECTION_INPUT_SOURCE) { + selectedEffectId = SELECTED_EFFECT_ID_AXIS; + } else { + selectedEffectId = SELECTED_EFFECT_ID_GLASSES; + } + + // Pass the USE_FACE_DETECTION_INPUT_SOURCE flag value as an input side packet into the graph. + Map inputSidePackets = new HashMap<>(); + inputSidePackets.put( + USE_FACE_DETECTION_INPUT_SOURCE_INPUT_SIDE_PACKET_NAME, + processor.getPacketCreator().createBool(USE_FACE_DETECTION_INPUT_SOURCE)); + processor.setInputSidePackets(inputSidePackets); + + // This callback demonstrates how the output face geometry packet can be obtained and used + // in an Android app. As an example, the Z-translation component of the face pose transform + // matrix is logged for each face being equal to the approximate distance away from the camera + // in centimeters. + processor.addPacketCallback( + OUTPUT_FACE_GEOMETRY_STREAM_NAME, + (packet) -> { + effectSwitchingHintView.post( + () -> + effectSwitchingHintView.setVisibility( + USE_FACE_DETECTION_INPUT_SOURCE ? View.INVISIBLE : View.VISIBLE)); + + Log.d(TAG, "Received a multi face geometry packet."); + List multiFaceGeometry = + PacketGetter.getProtoVector(packet, FaceGeometry.parser()); + + StringBuilder approxDistanceAwayFromCameraLogMessage = new StringBuilder(); + for (FaceGeometry faceGeometry : multiFaceGeometry) { + if (approxDistanceAwayFromCameraLogMessage.length() > 0) { + approxDistanceAwayFromCameraLogMessage.append(' '); + } + MatrixData poseTransformMatrix = faceGeometry.getPoseTransformMatrix(); + approxDistanceAwayFromCameraLogMessage.append( + -poseTransformMatrix.getPackedData(MATRIX_TRANSLATION_Z_INDEX)); + } + + Log.d( + TAG, + "[TS:" + + packet.getTimestamp() + + "] size = " + + multiFaceGeometry.size() + + "; approx. distance away from camera in cm for faces = [" + + approxDistanceAwayFromCameraLogMessage + + "]"); + }); + + // Alongside the input camera frame, we also send the `selected_effect_id` int32 packet to + // indicate which effect should be rendered on this frame. + processor.setOnWillAddFrameListener( + (timestamp) -> { + Packet selectedEffectIdPacket = null; + try { + synchronized (effectSelectionLock) { + selectedEffectIdPacket = processor.getPacketCreator().createInt32(selectedEffectId); + } + + processor + .getGraph() + .addPacketToInputStream( + SELECTED_EFFECT_ID_INPUT_STREAM_NAME, selectedEffectIdPacket, timestamp); + } catch (RuntimeException e) { + Log.e( + TAG, "Exception while adding packet to input stream while switching effects: " + e); + } finally { + if (selectedEffectIdPacket != null) { + selectedEffectIdPacket.release(); + } + } + }); + + // We use the tap gesture detector to switch between face effects. This allows users to try + // multiple pre-bundled face effects without a need to recompile the app. + tapGestureDetector = + new GestureDetector( + this, + new GestureDetector.SimpleOnGestureListener() { + @Override + public void onLongPress(MotionEvent event) { + switchEffect(); + } + + @Override + public boolean onSingleTapUp(MotionEvent event) { + switchEffect(); + return true; + } + + private void switchEffect() { + // Avoid switching the Axis effect for the face detection input source. + if (USE_FACE_DETECTION_INPUT_SOURCE) { + return; + } + + // Looped effect order: glasses -> facepaint -> axis -> glasses -> ... + synchronized (effectSelectionLock) { + switch (selectedEffectId) { + case SELECTED_EFFECT_ID_AXIS: + { + selectedEffectId = SELECTED_EFFECT_ID_GLASSES; + break; + } + + case SELECTED_EFFECT_ID_FACEPAINT: + { + selectedEffectId = SELECTED_EFFECT_ID_AXIS; + break; + } + + case SELECTED_EFFECT_ID_GLASSES: + { + selectedEffectId = SELECTED_EFFECT_ID_FACEPAINT; + break; + } + + default: + break; + } + } + } + }); + } + + @Override + public boolean onTouchEvent(MotionEvent event) { + return tapGestureDetector.onTouchEvent(event); + } + + private View createEffectSwitchingHintView() { + TextView effectSwitchingHintView = new TextView(getApplicationContext()); + effectSwitchingHintView.setLayoutParams( + new RelativeLayout.LayoutParams(LayoutParams.FILL_PARENT, LayoutParams.FILL_PARENT)); + effectSwitchingHintView.setText(EFFECT_SWITCHING_HINT_TEXT); + effectSwitchingHintView.setGravity(Gravity.CENTER_HORIZONTAL | Gravity.BOTTOM); + effectSwitchingHintView.setPadding(0, 0, 0, 480); + effectSwitchingHintView.setTextColor(Color.parseColor("#ffffff")); + effectSwitchingHintView.setTextSize((float) 24); + + return effectSwitchingHintView; + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD index 2de32b36f..26406e77b 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD @@ -51,6 +51,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD index 284dcd9a0..df58f2713 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD @@ -50,6 +50,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu/BUILD index d7841b6fa..2d9813301 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu/BUILD @@ -37,8 +37,7 @@ android_binary( srcs = glob(["*.java"]), assets = [ "//mediapipe/graphs/hand_tracking:hand_detection_mobile_gpu.binarypb", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/modules/palm_detection:palm_detection.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", @@ -51,6 +50,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD index 546ce9aa0..7b3bfe847 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/BUILD @@ -37,10 +37,9 @@ android_binary( srcs = glob(["*.java"]), assets = [ "//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb", - "//mediapipe/models:handedness.txt", - "//mediapipe/models:hand_landmark.tflite", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/modules/hand_landmark:handedness.txt", + "//mediapipe/modules/hand_landmark:hand_landmark.tflite", + "//mediapipe/modules/palm_detection:palm_detection.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", @@ -53,6 +52,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java index e45510c1c..445431bc4 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java @@ -18,76 +18,75 @@ import android.os.Bundle; import android.util.Log; import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.framework.AndroidPacketCreator; +import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; -import com.google.protobuf.InvalidProtocolBufferException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** Main activity of MediaPipe hand tracking app. */ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { private static final String TAG = "MainActivity"; - private static final String OUTPUT_HAND_PRESENCE_STREAM_NAME = "hand_presence"; + private static final String INPUT_NUM_HANDS_SIDE_PACKET_NAME = "num_hands"; private static final String OUTPUT_LANDMARKS_STREAM_NAME = "hand_landmarks"; + // Max number of hands to detect/process. + private static final int NUM_HANDS = 2; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); - processor.addPacketCallback( - OUTPUT_HAND_PRESENCE_STREAM_NAME, - (packet) -> { - Boolean handPresence = PacketGetter.getBool(packet); - if (!handPresence) { - Log.d( - TAG, - "[TS:" + packet.getTimestamp() + "] Hand presence is false, no hands detected."); - } - }); + AndroidPacketCreator packetCreator = processor.getPacketCreator(); + Map inputSidePackets = new HashMap<>(); + inputSidePackets.put(INPUT_NUM_HANDS_SIDE_PACKET_NAME, packetCreator.createInt32(NUM_HANDS)); + processor.setInputSidePackets(inputSidePackets); // To show verbose logging, run: // adb shell setprop log.tag.MainActivity VERBOSE if (Log.isLoggable(TAG, Log.VERBOSE)) { processor.addPacketCallback( - OUTPUT_LANDMARKS_STREAM_NAME, - (packet) -> { - byte[] landmarksRaw = PacketGetter.getProtoBytes(packet); - try { - NormalizedLandmarkList landmarks = NormalizedLandmarkList.parseFrom(landmarksRaw); - if (landmarks == null) { - Log.v(TAG, "[TS:" + packet.getTimestamp() + "] No hand landmarks."); - return; - } - // Note: If hand_presence is false, these landmarks are useless. + OUTPUT_LANDMARKS_STREAM_NAME, + (packet) -> { + Log.v(TAG, "Received multi-hand landmarks packet."); + List multiHandLandmarks = + PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); Log.v( TAG, "[TS:" + packet.getTimestamp() - + "] #Landmarks for hand: " - + landmarks.getLandmarkCount()); - Log.v(TAG, getLandmarksDebugString(landmarks)); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Couldn't Exception received - " + e); - return; - } - }); + + "] " + + getMultiHandLandmarksDebugString(multiHandLandmarks)); + }); } } - private static String getLandmarksDebugString(NormalizedLandmarkList landmarks) { - int landmarkIndex = 0; - String landmarksString = ""; - for (NormalizedLandmark landmark : landmarks.getLandmarkList()) { - landmarksString += - "\t\tLandmark[" - + landmarkIndex - + "]: (" - + landmark.getX() - + ", " - + landmark.getY() - + ", " - + landmark.getZ() - + ")\n"; - ++landmarkIndex; + private String getMultiHandLandmarksDebugString(List multiHandLandmarks) { + if (multiHandLandmarks.isEmpty()) { + return "No hand landmarks"; } - return landmarksString; + String multiHandLandmarksStr = "Number of hands detected: " + multiHandLandmarks.size() + "\n"; + int handIndex = 0; + for (NormalizedLandmarkList landmarks : multiHandLandmarks) { + multiHandLandmarksStr += + "\t#Hand landmarks for hand[" + handIndex + "]: " + landmarks.getLandmarkCount() + "\n"; + int landmarkIndex = 0; + for (NormalizedLandmark landmark : landmarks.getLandmarkList()) { + multiHandLandmarksStr += + "\t\tLandmark [" + + landmarkIndex + + "]: (" + + landmark.getX() + + ", " + + landmark.getY() + + ", " + + landmark.getZ() + + ")\n"; + ++landmarkIndex; + } + ++handIndex; + } + return multiHandLandmarksStr; } } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD new file mode 100644 index 000000000..44a6d6428 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/holistictrackinggpu/BUILD @@ -0,0 +1,69 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu_deps", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +android_binary( + name = "holistictrackinggpu", + srcs = glob(["*.java"]), + assets = [ + "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + "//mediapipe/modules/hand_landmark:hand_landmark.tflite", + "//mediapipe/modules/hand_landmark:handedness.txt", + "//mediapipe/modules/holistic_landmark:hand_recrop.tflite", + "//mediapipe/modules/pose_detection:pose_detection.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + ], + assets_dir = "", + manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.apps.holistictrackinggpu", + "appName": "Holistic Tracking", + "mainActivity": "com.google.mediapipe.apps.basic.MainActivity", + "cameraFacingFront": "False", + "binaryGraphName": "holistic_tracking_gpu.binarypb", + "inputVideoStreamName": "input_video", + "outputVideoStreamName": "output_video", + "flipFramesVertically": "True", + "converterNumBuffers": "3", + }, + multidex = "native", + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD index 784221084..3dea64053 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD @@ -89,6 +89,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD index 473404fdd..f629951df 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD @@ -52,6 +52,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ @@ -59,5 +60,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/MainActivity.java index fc4c67755..8079daa75 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/MainActivity.java @@ -32,16 +32,22 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { private static final String FOCAL_LENGTH_STREAM_NAME = "focal_length_pixel"; private static final String OUTPUT_LANDMARKS_STREAM_NAME = "face_landmarks_with_iris"; + private boolean haveAddedSidePackets = false; + @Override protected void onCameraStarted(SurfaceTexture surfaceTexture) { super.onCameraStarted(surfaceTexture); - float focalLength = cameraHelper.getFocalLengthPixels(); - if (focalLength != Float.MIN_VALUE) { - Packet focalLengthSidePacket = processor.getPacketCreator().createFloat32(focalLength); - Map inputSidePackets = new HashMap<>(); - inputSidePackets.put(FOCAL_LENGTH_STREAM_NAME, focalLengthSidePacket); - processor.setInputSidePackets(inputSidePackets); + // onCameraStarted gets called each time the activity resumes, but we only want to do this once. + if (!haveAddedSidePackets) { + float focalLength = cameraHelper.getFocalLengthPixels(); + if (focalLength != Float.MIN_VALUE) { + Packet focalLengthSidePacket = processor.getPacketCreator().createFloat32(focalLength); + Map inputSidePackets = new HashMap<>(); + inputSidePackets.put(FOCAL_LENGTH_STREAM_NAME, focalLengthSidePacket); + processor.setInputSidePackets(inputSidePackets); + } + haveAddedSidePackets = true; } } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java deleted file mode 100644 index 0d4dfde7f..000000000 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.apps.multihandtrackinggpu; - -import android.os.Bundle; -import android.util.Log; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; -import com.google.mediapipe.framework.PacketGetter; -import java.util.List; - -/** Main activity of MediaPipe multi-hand tracking app. */ -public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { - private static final String TAG = "MainActivity"; - - private static final String OUTPUT_LANDMARKS_STREAM_NAME = "multi_hand_landmarks"; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - // To show verbose logging, run: - // adb shell setprop log.tag.MainActivity VERBOSE - if (Log.isLoggable(TAG, Log.VERBOSE)) { - processor.addPacketCallback( - OUTPUT_LANDMARKS_STREAM_NAME, - (packet) -> { - Log.v(TAG, "Received multi-hand landmarks packet."); - List multiHandLandmarks = - PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); - Log.v( - TAG, - "[TS:" - + packet.getTimestamp() - + "] " - + getMultiHandLandmarksDebugString(multiHandLandmarks)); - }); - } - } - - private String getMultiHandLandmarksDebugString(List multiHandLandmarks) { - if (multiHandLandmarks.isEmpty()) { - return "No hand landmarks"; - } - String multiHandLandmarksStr = "Number of hands detected: " + multiHandLandmarks.size() + "\n"; - int handIndex = 0; - for (NormalizedLandmarkList landmarks : multiHandLandmarks) { - multiHandLandmarksStr += - "\t#Hand landmarks for hand[" + handIndex + "]: " + landmarks.getLandmarkCount() + "\n"; - int landmarkIndex = 0; - for (NormalizedLandmark landmark : landmarks.getLandmarkList()) { - multiHandLandmarksStr += - "\t\tLandmark [" - + landmarkIndex - + "]: (" - + landmark.getX() - + ", " - + landmark.getY() - + ", " - + landmark.getZ() - + ")\n"; - ++landmarkIndex; - } - ++handIndex; - } - return multiHandLandmarksStr; - } -} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD index f07bc8ebc..783ae200e 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@bazel_skylib//lib:selects.bzl", "selects") +load(":build_defs.bzl", "generate_manifest_values") + licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +config_setting( + name = "use_chair", + define_values = { + "chair": "true", + }, +) + +config_setting( + name = "use_cup", + define_values = { + "cup": "true", + }, +) + +config_setting( + name = "use_camera", + define_values = { + "camera": "true", + }, +) + +config_setting( + name = "use_shoe_1stage", + define_values = { + "shoe_1stage": "true", + }, +) + +config_setting( + name = "use_chair_1stage", + define_values = { + "chair_1stage": "true", + }, +) + +selects.config_setting_group( + name = "1stage", + match_any = [ + ":use_shoe_1stage", + ":use_chair_1stage", + ], +) + cc_binary( name = "libmediapipe_jni.so", linkshared = 1, linkstatic = 1, - deps = [ - "//mediapipe/graphs/object_detection_3d:mobile_calculators", + deps = select({ + "//conditions:default": ["//mediapipe/graphs/object_detection_3d:mobile_calculators"], + ":1stage": ["//mediapipe/graphs/object_detection_3d:mobile_calculators_1stage"], + }) + [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -32,67 +80,108 @@ cc_library( alwayslink = 1, ) -# To use the "chair" model instead of the default "shoes" model, -# add "--define chair=true" to the bazel build command. -config_setting( - name = "use_chair_model", - define_values = { - "chair": "true", - }, -) - genrule( name = "binary_graph", srcs = select({ - "//conditions:default": ["//mediapipe/graphs/object_detection_3d:mobile_gpu_binary_graph_shoe"], - ":use_chair_model": ["//mediapipe/graphs/object_detection_3d:mobile_gpu_binary_graph_chair"], + "//conditions:default": ["//mediapipe/graphs/object_detection_3d:mobile_gpu_binary_graph"], + ":1stage": ["//mediapipe/graphs/object_detection_3d:mobile_gpu_1stage_binary_graph"], }), outs = ["object_detection_3d.binarypb"], cmd = "cp $< $@", ) +MODELS_DIR = "//mediapipe/modules/objectron" + genrule( name = "model", srcs = select({ - "//conditions:default": ["//mediapipe/models:object_detection_3d_sneakers.tflite"], - ":use_chair_model": ["//mediapipe/models:object_detection_3d_chair.tflite"], + "//conditions:default": [MODELS_DIR + ":object_detection_3d_sneakers.tflite"], + ":use_chair": [MODELS_DIR + ":object_detection_3d_chair.tflite"], + ":use_cup": [MODELS_DIR + ":object_detection_3d_cup.tflite"], + ":use_camera": [MODELS_DIR + ":object_detection_3d_camera.tflite"], + ":use_shoe_1stage": [MODELS_DIR + ":object_detection_3d_sneakers_1stage.tflite"], + ":use_chair_1stage": [MODELS_DIR + ":object_detection_3d_chair_1stage.tflite"], }), outs = ["object_detection_3d.tflite"], cmd = "cp $< $@", ) +MANIFESTS_DIR = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests" + +android_library( + name = "manifest_lib", + exports_manifest = 1, + manifest = select({ + "//conditions:default": MANIFESTS_DIR + ":AndroidManifestSneaker.xml", + ":use_chair": MANIFESTS_DIR + ":AndroidManifestChair.xml", + ":use_cup": MANIFESTS_DIR + ":AndroidManifestCup.xml", + ":use_camera": MANIFESTS_DIR + ":AndroidManifestCamera.xml", + ":use_shoe_1stage": MANIFESTS_DIR + ":AndroidManifestSneaker.xml", + ":use_chair_1stage": MANIFESTS_DIR + ":AndroidManifestChair.xml", + }), + deps = [ + "//third_party:opencv", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:com_google_guava_guava", + ], +) + +ASSETS_DIR = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets" + +genrule( + name = "mesh", + srcs = select({ + "//conditions:default": [ASSETS_DIR + "/sneaker:model.obj.uuu"], + ":use_chair": [ASSETS_DIR + "/chair:model.obj.uuu"], + ":use_cup": [ASSETS_DIR + "/cup:model.obj.uuu"], + ":use_camera": [ASSETS_DIR + "/camera:model.obj.uuu"], + ":use_shoe_1stage": [ASSETS_DIR + "/sneaker:model.obj.uuu"], + ":use_chair_1stage": [ASSETS_DIR + "/chair:model.obj.uuu"], + }), + outs = ["model.obj.uuu"], + cmd = "cp $< $@", +) + +genrule( + name = "texture", + srcs = select({ + "//conditions:default": [ASSETS_DIR + "/sneaker:texture.jpg"], + ":use_chair": [ASSETS_DIR + "/chair:texture.jpg"], + ":use_cup": [ASSETS_DIR + "/cup:texture.jpg"], + ":use_camera": [ASSETS_DIR + "/camera:texture.jpg"], + ":use_shoe_1stage": [ASSETS_DIR + "/sneaker:texture.jpg"], + ":use_chair_1stage": [ASSETS_DIR + "/chair:texture.jpg"], + }), + outs = ["texture.jpg"], + cmd = "cp $< $@", +) + android_binary( name = "objectdetection3d", srcs = glob(["*.java"]), assets = [ ":binary_graph", ":model", - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets:box.obj.uuu", - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets:classic_colors.png", - ] + select({ - "//conditions:default": [ - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:model.obj.uuu", - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:texture.jpg", - ], - ":use_chair_model": [ - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:model.obj.uuu", - "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:texture.jpg", - ], - }), + ":mesh", + ":texture", + MODELS_DIR + ":object_detection_ssd_mobilenetv2_oidv4_fp16.tflite", + MODELS_DIR + ":object_detection_oidv4_labelmap.txt", + ASSETS_DIR + ":box.obj.uuu", + ASSETS_DIR + ":classic_colors.png", + ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", - manifest_values = { - "applicationId": "com.google.mediapipe.apps.objectdetection3d", - "appName": "Objectron", - "mainActivity": ".MainActivity", - "cameraFacingFront": "False", - "binaryGraphName": "object_detection_3d.binarypb", - "inputVideoStreamName": "input_video", - "outputVideoStreamName": "output_video", - "flipFramesVertically": "True", - }, + manifest_values = select({ + "//conditions:default": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_shoe", "Shoe Objectron"), + ":use_chair": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_chair", "Chair Objectron"), + ":use_cup": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_cup", "Cup Objectron"), + ":use_camera": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_camera", "Camera Objectron"), + ":use_shoe_1stage": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_shoe_1stage", "Single Stage Shoe Objectron"), + ":use_chair_1stage": generate_manifest_values("com.google.mediapipe.apps.objectdetection3d_chair_1stage", "Single Stage Chair Objectron"), + }), multidex = "native", deps = [ + ":manifest_lib", ":mediapipe_jni_lib", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java index 92f9f55bb..b3a6dfeea 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ package com.google.mediapipe.apps.objectdetection3d; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageManager; +import android.content.pm.PackageManager.NameNotFoundException; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; @@ -40,10 +43,27 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { private Bitmap objTexture = null; private Bitmap boxTexture = null; + // ApplicationInfo for retrieving metadata defined in the manifest. + private ApplicationInfo applicationInfo; + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); + try { + applicationInfo = + getPackageManager().getApplicationInfo(getPackageName(), PackageManager.GET_META_DATA); + } catch (NameNotFoundException e) { + Log.e(TAG, "Cannot find application info: " + e); + } + // Get allowed object category. + String categoryName = applicationInfo.metaData.getString("categoryName"); + // Get maximum allowed number of objects. + int maxNumObjects = applicationInfo.metaData.getInt("maxNumObjects"); + float[] modelScale = parseFloatArrayFromString( + applicationInfo.metaData.getString("modelScale")); + float[] modelTransform = parseFloatArrayFromString( + applicationInfo.metaData.getString("modelTransformation")); prepareDemoAssets(); AndroidPacketCreator packetCreator = processor.getPacketCreator(); Map inputSidePackets = new HashMap<>(); @@ -51,6 +71,10 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { inputSidePackets.put("box_asset_name", packetCreator.createString(BOX_FILE)); inputSidePackets.put("obj_texture", packetCreator.createRgbaImageFrame(objTexture)); inputSidePackets.put("box_texture", packetCreator.createRgbaImageFrame(boxTexture)); + inputSidePackets.put("allowed_labels", packetCreator.createString(categoryName)); + inputSidePackets.put("max_num_objects", packetCreator.createInt32(maxNumObjects)); + inputSidePackets.put("model_scale", packetCreator.createFloat32Array(modelScale)); + inputSidePackets.put("model_transformation", packetCreator.createFloat32Array(modelTransform)); processor.setInputSidePackets(inputSidePackets); } @@ -97,8 +121,8 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { } catch (RuntimeException e) { Log.e( TAG, - "MediaPipeException encountered adding packets to width and height" - + " input streams."); + "MediaPipeException encountered adding packets to input_width and input_height" + + " input streams.", e); } widthPacket.release(); heightPacket.release(); @@ -134,4 +158,13 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { throw new RuntimeException(e); } } + + private static float[] parseFloatArrayFromString(String string) { + String[] elements = string.split(",", -1); + float[] array = new float[elements.length]; + for (int i = 0; i < elements.length; ++i) { + array[i] = Float.parseFloat(elements[i]); + } + return array; + } } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/BUILD index 46d164040..a8bb9124c 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/examples/python/__init__.py b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/BUILD similarity index 83% rename from mediapipe/examples/python/__init__.py rename to mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/BUILD index 5d9133833..a8bb9124c 100644 --- a/mediapipe/examples/python/__init__.py +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/BUILD @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""MediaPipe Python Examples.""" -from mediapipe.examples.python.upper_body_pose_tracker import UpperBodyPoseTracker +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +exports_files( + srcs = glob(["**"]), +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/model.obj.uuu b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/model.obj.uuu new file mode 100644 index 000000000..0280d5dd0 Binary files /dev/null and b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/model.obj.uuu differ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/texture.jpg b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/texture.jpg new file mode 100644 index 000000000..4a19534dd Binary files /dev/null and b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/camera/texture.jpg differ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair/BUILD index 46d164040..a8bb9124c 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/BUILD new file mode 100644 index 000000000..a8bb9124c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/BUILD @@ -0,0 +1,21 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +exports_files( + srcs = glob(["**"]), +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/model.obj.uuu b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/model.obj.uuu new file mode 100644 index 000000000..167e134eb Binary files /dev/null and b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/model.obj.uuu differ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/texture.jpg b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/texture.jpg new file mode 100644 index 000000000..f3aea3568 Binary files /dev/null and b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/cup/texture.jpg differ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker/BUILD index 46d164040..a8bb9124c 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/build_defs.bzl b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/build_defs.bzl new file mode 100644 index 000000000..9c30dd58c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/build_defs.bzl @@ -0,0 +1,15 @@ +"""Build defs for Objectron.""" + +def generate_manifest_values(application_id, app_name): + manifest_values = { + "applicationId": application_id, + "appName": app_name, + "mainActivity": "com.google.mediapipe.apps.objectdetection3d.MainActivity", + "cameraFacingFront": "False", + "binaryGraphName": "object_detection_3d.binarypb", + "inputVideoStreamName": "input_video", + "outputVideoStreamName": "output_video", + "flipFramesVertically": "True", + "converterNumBuffers": "2", + } + return manifest_values diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml new file mode 100644 index 000000000..10f8492ef --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestChair.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestChair.xml new file mode 100644 index 000000000..cd6502fa2 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestChair.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCup.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCup.xml new file mode 100644 index 000000000..a06694563 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCup.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestSneaker.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestSneaker.xml new file mode 100644 index 000000000..a1fd1143b --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestSneaker.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/BUILD new file mode 100644 index 000000000..a8bb9124c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/BUILD @@ -0,0 +1,21 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +exports_files( + srcs = glob(["**"]), +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD index 080fe4ced..9bb054936 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD @@ -51,6 +51,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD index 56e70c2b6..81f2ed3e6 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD @@ -51,6 +51,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objecttrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objecttrackinggpu/BUILD index 220d48067..50ea70f89 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objecttrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objecttrackinggpu/BUILD @@ -51,6 +51,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD similarity index 71% rename from mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/BUILD rename to mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD index 7d4d7418c..5eff6a833 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/hand_tracking:multi_hand_mobile_calculators", + "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -33,26 +33,25 @@ cc_library( ) android_binary( - name = "multihandtrackinggpu", + name = "posetrackinggpu", srcs = glob(["*.java"]), assets = [ - "//mediapipe/graphs/hand_tracking:multi_hand_tracking_mobile_gpu.binarypb", - "//mediapipe/models:handedness.txt", - "//mediapipe/models:hand_landmark.tflite", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", + "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + "//mediapipe/modules/pose_detection:pose_detection.tflite", ], assets_dir = "", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", manifest_values = { - "applicationId": "com.google.mediapipe.apps.multihandtrackinggpu", - "appName": "Multi-hand Tracking", + "applicationId": "com.google.mediapipe.apps.posetrackinggpu", + "appName": "Pose Tracking", "mainActivity": ".MainActivity", - "cameraFacingFront": "True", - "binaryGraphName": "multi_hand_tracking_mobile_gpu.binarypb", + "cameraFacingFront": "False", + "binaryGraphName": "pose_tracking_gpu.binarypb", "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ @@ -60,5 +59,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/MainActivity.java new file mode 100644 index 000000000..730aa6e1f --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/MainActivity.java @@ -0,0 +1,75 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.posetrackinggpu; + +import android.os.Bundle; +import android.util.Log; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.framework.PacketGetter; +import com.google.protobuf.InvalidProtocolBufferException; + +/** Main activity of MediaPipe pose tracking app. */ +public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { + private static final String TAG = "MainActivity"; + + private static final String OUTPUT_LANDMARKS_STREAM_NAME = "pose_landmarks"; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + // To show verbose logging, run: + // adb shell setprop log.tag.MainActivity VERBOSE + if (Log.isLoggable(TAG, Log.VERBOSE)) { + processor.addPacketCallback( + OUTPUT_LANDMARKS_STREAM_NAME, + (packet) -> { + Log.v(TAG, "Received pose landmarks packet."); + try { + NormalizedLandmarkList poseLandmarks = + PacketGetter.getProto(packet, NormalizedLandmarkList.class); + Log.v( + TAG, + "[TS:" + + packet.getTimestamp() + + "] " + + getPoseLandmarksDebugString(poseLandmarks)); + } catch (InvalidProtocolBufferException exception) { + Log.e(TAG, "Failed to get proto.", exception); + } + }); + } + } + + private static String getPoseLandmarksDebugString(NormalizedLandmarkList poseLandmarks) { + String poseLandmarkStr = "Pose landmarks: " + poseLandmarks.getLandmarkCount() + "\n"; + int landmarkIndex = 0; + for (NormalizedLandmark landmark : poseLandmarks.getLandmarkList()) { + poseLandmarkStr += + "\tLandmark [" + + landmarkIndex + + "]: (" + + landmark.getX() + + ", " + + landmark.getY() + + ", " + + landmark.getZ() + + ")\n"; + ++landmarkIndex; + } + return poseLandmarkStr; + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/templatematchingcpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/templatematchingcpu/BUILD index 0ceeeee1b..ed3a63a70 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/templatematchingcpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/templatematchingcpu/BUILD @@ -52,6 +52,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD index fe2da982c..50f9d643a 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD @@ -51,6 +51,7 @@ android_binary( "inputVideoStreamName": "input_video", "outputVideoStreamName": "output_video", "flipFramesVertically": "True", + "converterNumBuffers": "2", }, multidex = "native", deps = [ @@ -58,5 +59,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/coral/BUILD b/mediapipe/examples/coral/BUILD index 03d4027e7..ec747573b 100644 --- a/mediapipe/examples/coral/BUILD +++ b/mediapipe/examples/coral/BUILD @@ -51,6 +51,6 @@ cc_binary( name = "face_detection_tpu", deps = [ "//mediapipe/examples/coral:demo_run_graph_main", - "//mediapipe/graphs/face_detection:desktop_tflite_calculators", + "//mediapipe/graphs/face_detection:desktop_live_calculators", ], ) diff --git a/mediapipe/examples/coral/Dockerfile b/mediapipe/examples/coral/Dockerfile index de5d3a909..bc655c580 100644 --- a/mediapipe/examples/coral/Dockerfile +++ b/mediapipe/examples/coral/Dockerfile @@ -62,7 +62,7 @@ COPY . /mediapipe/ # Install bazel # Please match the current MediaPipe Bazel requirements according to docs. -ARG BAZEL_VERSION=2.0.0 +ARG BAZEL_VERSION=3.4.1 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ diff --git a/mediapipe/examples/coral/demo_run_graph_main.cc b/mediapipe/examples/coral/demo_run_graph_main.cc index 0755ecb60..698955472 100644 --- a/mediapipe/examples/coral/demo_run_graph_main.cc +++ b/mediapipe/examples/coral/demo_run_graph_main.cc @@ -40,10 +40,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -56,22 +57,22 @@ DEFINE_string(output_video_path, "", LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (save_video) { LOG(INFO) << "Prepare video writer."; cv::Mat test_frame; capture.read(test_frame); // Consume first frame. capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning. - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), test_frame.size()); RET_CHECK(writer.isOpened()); @@ -143,7 +144,7 @@ DEFINE_string(output_video_path, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 688084062..99b9d6fff 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -368,6 +368,7 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgcodecs", diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index 0dc672208..caaa368a7 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -68,9 +68,9 @@ class BorderDetectionCalculator : public CalculatorBase { BorderDetectionCalculator& operator=(const BorderDetectionCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Given a color and image direction, check to see if a border of that color @@ -83,7 +83,7 @@ class BorderDetectionCalculator : public CalculatorBase { double ColorCount(const Color& mask_color, const cv::Mat& image) const; // Set member vars (image size) and confirm no changes frame-to-frame. - mediapipe::Status SetAndCheckInputs(const cv::Mat& frame); + absl::Status SetAndCheckInputs(const cv::Mat& frame); // Find the dominant color for a input image. double FindDominantColor(const cv::Mat& image, Color* dominant_color); @@ -97,15 +97,14 @@ class BorderDetectionCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(BorderDetectionCalculator); -::mediapipe::Status BorderDetectionCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status BorderDetectionCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_LT(options_.vertical_search_distance(), 0.5) << "Search distance must be less than half the full image."; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BorderDetectionCalculator::SetAndCheckInputs( +absl::Status BorderDetectionCalculator::SetAndCheckInputs( const cv::Mat& frame) { if (frame_width_ < 0) { frame_width_ = frame.cols; @@ -118,14 +117,14 @@ mediapipe::Status BorderDetectionCalculator::SetAndCheckInputs( RET_CHECK_EQ(frame.rows, frame_height_) << "Input frame dimensions must remain constant throughout the video."; RET_CHECK_EQ(frame.channels(), 3) << "Input video type must be 3-channel"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BorderDetectionCalculator::Process( +absl::Status BorderDetectionCalculator::Process( mediapipe::CalculatorContext* cc) { if (!cc->Inputs().HasTag(kVideoInputTag) || cc->Inputs().Tag(kVideoInputTag).Value().IsEmpty()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Input tag VIDEO not set or empty at timestamp: " << cc->InputTimestamp().Value(); } @@ -173,7 +172,7 @@ mediapipe::Status BorderDetectionCalculator::Process( .Tag(kDetectedBorders) .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Find the dominant color within an image. @@ -291,11 +290,11 @@ void BorderDetectionCalculator::DetectBorder( } } -::mediapipe::Status BorderDetectionCalculator::GetContract( +absl::Status BorderDetectionCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag(kVideoInputTag).Set(); cc->Outputs().Tag(kDetectedBorders).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index cba751057..c2ee6b0ff 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -35,7 +35,7 @@ constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kCropRect[] = "CROP_RECT"; // Field-of-view (degrees) of the camera's x-axis (width). // TODO: Parameterize FOV based on camera specs. -constexpr float kWidthFieldOfView = 60; +constexpr float kFieldOfView = 60; namespace mediapipe { namespace autoflip { @@ -55,23 +55,25 @@ class ContentZoomingCalculator : public CalculatorBase { ContentZoomingCalculator(const ContentZoomingCalculator&) = delete; ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete; - static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - ::mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - ::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Converts bounds to tilt offset, pan offset and height. - ::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, - float ymax, int* tilt_offset, - int* pan_offset, int* height); + absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, + float ymax, int* tilt_offset, + int* pan_offset, int* height); + // Sets max_frame_value_ and target_aspect_ + absl::Status UpdateAspectAndMax(); ContentZoomingCalculatorOptions options_; // Detection frame width/height. int frame_height_; int frame_width_; // Path solver used to smooth top/bottom border crop values. - std::unique_ptr path_solver_height_; - std::unique_ptr path_solver_width_; - std::unique_ptr path_solver_offset_; + std::unique_ptr path_solver_zoom_; + std::unique_ptr path_solver_pan_; + std::unique_ptr path_solver_tilt_; // Are parameters initialized. bool initialized_; // Stores the time of the last "only_required" input. @@ -89,7 +91,7 @@ class ContentZoomingCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(ContentZoomingCalculator); -::mediapipe::Status ContentZoomingCalculator::GetContract( +absl::Status ContentZoomingCalculator::GetContract( mediapipe::CalculatorContract* cc) { RET_CHECK( !(cc->Inputs().HasTag(kVideoFrame) && cc->Inputs().HasTag(kVideoSize))) @@ -99,7 +101,7 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); } else if (cc->Inputs().HasTag(kVideoSize)) { cc->Inputs().Tag(kVideoSize).Set>(); } else { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input VIDEO or VIDEO_SIZE must be provided."; } if (cc->Inputs().HasTag(kSalientRegions)) { @@ -114,27 +116,26 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); if (cc->Outputs().HasTag(kCropRect)) { cc->Outputs().Tag(kCropRect).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ContentZoomingCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status ContentZoomingCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); if (options_.has_kinematic_options()) { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Deprecated kinematic_options was set, please set " "kinematic_options_zoom and kinematic_options_tilt."; } if (options_.has_min_motion_to_reframe()) { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Deprecated min_motion_to_reframe was set, please set " "in kinematic_options_zoom and kinematic_options_tilt " "directly."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom( +absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom( float xmin, float xmax, float ymin, float ymax, int* tilt_offset, int* pan_offset, int* height) { // Find center of the y-axis offset (for tilt control). @@ -142,10 +143,11 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); // Find center of the x-axis offset (for pan control). float x_center = xmin + (xmax - xmin) / 2; // Find size and apply scale factor to y-axis. - float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); + float fit_size_raw = + fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); // Apply max frame for cases where the target size is different than input // frame size. - fit_size = fmin(max_frame_value_, fit_size); + float fit_size = fmin(max_frame_value_, fit_size_raw); // Prevent box from extending beyond the image. if (y_center - fit_size / 2 < 0) { y_center = fit_size / 2; @@ -160,8 +162,8 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); // Scale to pixel coordinates. *tilt_offset = frame_height_ * y_center; *pan_offset = frame_width_ * x_center; - *height = frame_height_ * fit_size; - return ::mediapipe::OkStatus(); + *height = frame_height_ * fit_size_raw; + return absl::OkStatus(); } namespace { @@ -185,12 +187,12 @@ mediapipe::autoflip::RectF ShiftDetection( relative_bounding_box.width() * x_offset_percent); return shifted_bb; } -::mediapipe::Status UpdateRanges(const SalientRegion& region, - const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { +absl::Status UpdateRanges(const SalientRegion& region, + const float shift_vertical, + const float shift_horizontal, float* xmin, + float* xmax, float* ymin, float* ymax) { if (!region.has_location_normalized()) { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "SalientRegion did not have location normalized set."; } auto location = ShiftDetection(region.location_normalized(), shift_vertical, @@ -200,12 +202,12 @@ mediapipe::autoflip::RectF ShiftDetection( *ymin = fmin(*ymin, location.y()); *ymax = fmax(*ymax, location.y() + location.height()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection, - const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { +absl::Status UpdateRanges(const mediapipe::Detection& detection, + const float shift_vertical, + const float shift_horizontal, float* xmin, + float* xmax, float* ymin, float* ymax) { RET_CHECK(detection.location_data().format() == mediapipe::LocationData::RELATIVE_BOUNDING_BOX) << "Face detection input is lacking required relative_bounding_box()"; @@ -217,7 +219,7 @@ mediapipe::autoflip::RectF ShiftDetection( *ymin = fmin(*ymin, location.ymin()); *ymax = fmax(*ymax, location.ymin() + location.height()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void MakeStaticFeatures(const int top_border, const int bottom_border, const int frame_width, const int frame_height, @@ -236,54 +238,97 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, border_bottom->mutable_border_position()->set_width(frame_width); border_bottom->mutable_border_position()->set_height(bottom_border); } -} // namespace - -::mediapipe::Status ContentZoomingCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status GetVideoResolution(mediapipe::CalculatorContext* cc, + int* frame_width, int* frame_height) { if (cc->Inputs().HasTag(kVideoFrame)) { - frame_width_ = cc->Inputs().Tag(kVideoFrame).Get().Width(); - frame_height_ = cc->Inputs().Tag(kVideoFrame).Get().Height(); + *frame_width = cc->Inputs().Tag(kVideoFrame).Get().Width(); + *frame_height = cc->Inputs().Tag(kVideoFrame).Get().Height(); } else if (cc->Inputs().HasTag(kVideoSize)) { - frame_width_ = + *frame_width = cc->Inputs().Tag(kVideoSize).Get>().first; - frame_height_ = + *frame_height = cc->Inputs().Tag(kVideoSize).Get>().second; } else { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input VIDEO or VIDEO_SIZE must be provided."; } + return absl::OkStatus(); +} +} // namespace +absl::Status ContentZoomingCalculator::UpdateAspectAndMax() { + max_frame_value_ = 1.0; + target_aspect_ = frame_width_ / static_cast(frame_height_); + // If target size is set and wider than input aspect, make sure to always + // crop the min required amount. + if (options_.has_target_size()) { + RET_CHECK_GT(options_.target_size().width(), 0) + << "Provided target width not valid."; + RET_CHECK_GT(options_.target_size().height(), 0) + << "Provided target height not valid."; + float input_aspect = frame_width_ / static_cast(frame_height_); + target_aspect_ = options_.target_size().width() / + static_cast(options_.target_size().height()); + max_frame_value_ = + std::min(input_aspect / target_aspect_, target_aspect_ / input_aspect); + } + return absl::OkStatus(); +} + +absl::Status ContentZoomingCalculator::Process( + mediapipe::CalculatorContext* cc) { + // For async subgraph support, return on empty video size packets. + if (cc->Inputs().HasTag(kVideoSize) && + cc->Inputs().Tag(kVideoSize).IsEmpty()) { + return absl::OkStatus(); + } + int frame_width, frame_height; + MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height)); + + // Init on first call. if (!initialized_) { - path_solver_height_ = std::make_unique( - options_.kinematic_options_zoom(), 0, frame_height_, - static_cast(frame_width_) / kWidthFieldOfView); - path_solver_width_ = std::make_unique( + frame_width_ = frame_width; + frame_height_ = frame_height; + path_solver_pan_ = std::make_unique( options_.kinematic_options_pan(), 0, frame_width_, - static_cast(frame_width_) / kWidthFieldOfView); - path_solver_offset_ = std::make_unique( + static_cast(frame_width_) / kFieldOfView); + path_solver_tilt_ = std::make_unique( options_.kinematic_options_tilt(), 0, frame_height_, - static_cast(frame_width_) / kWidthFieldOfView); - max_frame_value_ = 1.0; - target_aspect_ = frame_width_ / static_cast(frame_height_); - // If target size is set and wider than input aspect, make sure to always - // crop the min required amount. - if (options_.has_target_size()) { - RET_CHECK_GT(options_.target_size().width(), 0) - << "Provided target width not valid."; - RET_CHECK_GT(options_.target_size().height(), 0) - << "Provided target height not valid."; - float input_aspect = frame_width_ / static_cast(frame_height_); - target_aspect_ = options_.target_size().width() / - static_cast(options_.target_size().height()); - max_frame_value_ = std::min(input_aspect / target_aspect_, - target_aspect_ / input_aspect); - } + static_cast(frame_height_) / kFieldOfView); + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / + static_cast(kFieldOfView)); + path_solver_zoom_ = std::make_unique( + options_.kinematic_options_zoom(), min_zoom_size, + max_frame_value_ * frame_height_, + static_cast(frame_height_) / kFieldOfView); last_measured_height_ = max_frame_value_ * frame_height_; last_measured_x_offset_ = target_aspect_ * frame_width_; last_measured_y_offset_ = frame_width_ / 2; initialized_ = true; } + // Update state for change in input resolution. + if (frame_width_ != frame_width || frame_height_ != frame_height) { + double width_scale = frame_width / static_cast(frame_width_); + double height_scale = frame_height / static_cast(frame_height_); + last_measured_height_ = last_measured_height_ * height_scale; + last_measured_y_offset_ = last_measured_y_offset_ * height_scale; + last_measured_x_offset_ = last_measured_x_offset_ * width_scale; + frame_width_ = frame_width; + frame_height_ = frame_height; + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + MP_RETURN_IF_ERROR(path_solver_pan_->UpdateMinMaxLocation(0, frame_width_)); + MP_RETURN_IF_ERROR( + path_solver_tilt_->UpdateMinMaxLocation(0, frame_height_)); + int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / + static_cast(kFieldOfView)); + MP_RETURN_IF_ERROR(path_solver_zoom_->UpdateMinMaxLocation( + min_zoom_size, max_frame_value_ * frame_height_)); + MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree( + static_cast(frame_height_) / kFieldOfView)); + } + bool only_required_found = false; // Compute the box that contains all "is_required" detections. @@ -302,6 +347,16 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, } if (cc->Inputs().HasTag(kDetections)) { + if (cc->Inputs().Tag(kDetections).IsEmpty()) { + auto default_rect = absl::make_unique(); + default_rect->set_x_center(frame_width_ / 2); + default_rect->set_y_center(frame_height_ / 2); + default_rect->set_width(frame_width_); + default_rect->set_height(frame_height_); + cc->Outputs().Tag(kCropRect).Add(default_rect.release(), + Timestamp(cc->InputTimestamp())); + return absl::OkStatus(); + } auto raw_detections = cc->Inputs().Tag(kDetections).Get>(); for (const auto& detection : raw_detections) { @@ -339,19 +394,46 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, offset_y = last_measured_y_offset_; } - // Compute smoothed camera paths. - MP_RETURN_IF_ERROR(path_solver_height_->AddObservation( + // Check if the camera is changing in pan, tilt or zoom. If the camera is in + // motion disable temporal filtering. + bool pan_state, tilt_state, zoom_state; + MP_RETURN_IF_ERROR(path_solver_pan_->PredictMotionState( + offset_x, cc->InputTimestamp().Microseconds(), &pan_state)); + MP_RETURN_IF_ERROR(path_solver_tilt_->PredictMotionState( + offset_y, cc->InputTimestamp().Microseconds(), &tilt_state)); + MP_RETURN_IF_ERROR(path_solver_zoom_->PredictMotionState( + height, cc->InputTimestamp().Microseconds(), &zoom_state)); + if (pan_state || tilt_state || zoom_state) { + path_solver_pan_->ClearHistory(); + path_solver_tilt_->ClearHistory(); + path_solver_zoom_->ClearHistory(); + } + + // Compute smoothed zoom camera path. + MP_RETURN_IF_ERROR(path_solver_zoom_->AddObservation( height, cc->InputTimestamp().Microseconds())); - MP_RETURN_IF_ERROR(path_solver_width_->AddObservation( - offset_x, cc->InputTimestamp().Microseconds())); - MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation( - offset_y, cc->InputTimestamp().Microseconds())); int path_height; - MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height)); + MP_RETURN_IF_ERROR(path_solver_zoom_->GetState(&path_height)); + int path_width = path_height * target_aspect_; + + // Update pixel-per-degree value for pan/tilt. + int target_height; + MP_RETURN_IF_ERROR(path_solver_zoom_->GetTargetPosition(&target_height)); + int target_width = target_height * target_aspect_; + MP_RETURN_IF_ERROR(path_solver_pan_->UpdatePixelsPerDegree( + static_cast(target_width) / kFieldOfView)); + MP_RETURN_IF_ERROR(path_solver_tilt_->UpdatePixelsPerDegree( + static_cast(target_height) / kFieldOfView)); + + // Compute smoothed pan/tilt paths. + MP_RETURN_IF_ERROR(path_solver_pan_->AddObservation( + offset_x, cc->InputTimestamp().Microseconds())); + MP_RETURN_IF_ERROR(path_solver_tilt_->AddObservation( + offset_y, cc->InputTimestamp().Microseconds())); int path_offset_x; - MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x)); + MP_RETURN_IF_ERROR(path_solver_pan_->GetState(&path_offset_x)); int path_offset_y; - MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y)); + MP_RETURN_IF_ERROR(path_solver_tilt_->GetState(&path_offset_y)); // Prevent box from extending beyond the image after camera smoothing. if (path_offset_y - ceil(path_height / 2.0) < 0) { @@ -359,7 +441,7 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, } else if (path_offset_y + ceil(path_height / 2.0) > frame_height_) { path_offset_y = frame_height_ - ceil(path_height / 2.0); } - int path_width = path_height * target_aspect_; + if (path_offset_x - ceil(path_width / 2.0) < 0) { path_offset_x = ceil(path_width / 2.0); } else if (path_offset_x + ceil(path_width / 2.0) > frame_width_) { @@ -392,7 +474,7 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, Timestamp(cc->InputTimestamp())); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 2634a4afe..c0d4dd78b 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 13 +// NextTag: 14 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -52,6 +52,9 @@ message ContentZoomingCalculatorOptions { optional float detection_shift_vertical = 11 [default = 0.0]; optional float detection_shift_horizontal = 12 [default = 0.0]; + // Defines the smallest value in degrees the camera is permitted to zoom. + optional float max_zoom_value_deg = 13 [default = 35]; + // Deprecated parameters optional KinematicOptions kinematic_options = 2 [deprecated = true]; optional int64 min_motion_to_reframe = 4 [deprecated = true]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 818e6b4a1..0db252fec 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -42,6 +42,20 @@ const char kConfigA[] = R"( input_stream: "VIDEO:camera_frames" input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; const char kConfigB[] = R"( @@ -55,6 +69,16 @@ const char kConfigB[] = R"( width: 1000 height: 500 } + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } } } )"; @@ -64,6 +88,20 @@ const char kConfigC[] = R"( input_stream: "VIDEO_SIZE:size" input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; const char kConfigD[] = R"( @@ -71,6 +109,20 @@ const char kConfigD[] = R"( input_stream: "VIDEO_SIZE:size" input_stream: "DETECTIONS:detections" output_stream: "CROP_RECT:rect" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; void CheckBorder(const StaticFeatures& static_features, int width, int height, @@ -91,8 +143,9 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height, EXPECT_EQ(Border::BOTTOM, part.relative_position()); } -void AddDetection(const cv::Rect_& position, const int64 time, - CalculatorRunner* runner) { +void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, + const int width, const int height, + CalculatorRunner* runner) { auto detections = std::make_unique>(); mediapipe::Detection detection; detection.mutable_location_data()->set_format( @@ -111,12 +164,17 @@ void AddDetection(const cv::Rect_& position, const int64 time, ->Tag("DETECTIONS") .packets.push_back(Adopt(detections.release()).At(Timestamp(time))); - auto input_size = ::absl::make_unique>(1000, 1000); + auto input_size = ::absl::make_unique>(width, height); runner->MutableInputs() ->Tag("VIDEO_SIZE") .packets.push_back(Adopt(input_size.release()).At(Timestamp(time))); } +void AddDetection(const cv::Rect_& position, const int64 time, + CalculatorRunner* runner) { + AddDetectionFrameSize(position, time, 1000, 1000, runner); +} + void CheckCropRect(const int x_center, const int y_center, const int width, const int height, const int frame_number, const std::vector& output_packets) { @@ -174,15 +232,15 @@ TEST(ContentZoomingCalculatorTest, PanConfig) { ContentZoomingCalculatorOptions::ext); options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_pan()->set_update_rate_seconds(2); - options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); - options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(50.0); auto runner = ::absl::make_unique(config); AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); MP_ASSERT_OK(runner->Run()); CheckCropRect(450, 550, 111, 111, 0, runner->Outputs().Tag("CROP_RECT").packets); - CheckCropRect(488, 550, 111, 111, 1, + CheckCropRect(483, 550, 111, 111, 1, runner->Outputs().Tag("CROP_RECT").packets); } @@ -190,17 +248,17 @@ TEST(ContentZoomingCalculatorTest, TiltConfig) { auto config = ParseTextProtoOrDie(kConfigD); auto* options = config.mutable_options()->MutableExtension( ContentZoomingCalculatorOptions::ext); - options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(50.0); options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_tilt()->set_update_rate_seconds(2); - options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(50.0); auto runner = ::absl::make_unique(config); AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); MP_ASSERT_OK(runner->Run()); CheckCropRect(450, 550, 111, 111, 0, runner->Outputs().Tag("CROP_RECT").packets); - CheckCropRect(450, 588, 111, 111, 1, + CheckCropRect(450, 583, 111, 111, 1, runner->Outputs().Tag("CROP_RECT").packets); } @@ -208,8 +266,8 @@ TEST(ContentZoomingCalculatorTest, ZoomConfig) { auto config = ParseTextProtoOrDie(kConfigD); auto* options = config.mutable_options()->MutableExtension( ContentZoomingCalculatorOptions::ext); - options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); - options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(50.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0); options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2); auto runner = ::absl::make_unique(config); @@ -418,6 +476,71 @@ TEST(ContentZoomingCalculatorTest, ShiftOutsideBounds) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, EmptySize) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + MP_ASSERT_OK(runner->Run()); + ASSERT_EQ(runner->Outputs().Tag("CROP_RECT").packets.size(), 0); +} + +TEST(ContentZoomingCalculatorTest, EmptyDetections) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + auto input_size = ::absl::make_unique>(1000, 1000); + runner->MutableInputs() + ->Tag("VIDEO_SIZE") + .packets.push_back(Adopt(input_size.release()).At(Timestamp(0))); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ResolutionChangeStationary) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 222, 222, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500 * 0.5, 500 * 0.5, 222 * 0.5, 222 * 0.5, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.1, .1, .8, .8), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 2000000, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 888, 888, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 588, 588, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500 * 0.5, 500 * 0.5, 288 * 0.5, 288 * 0.5, 2, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, MaxZoomValue) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_max_zoom_value_deg(55); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + // 55/60 * 1000 = 916 + CheckCropRect(500, 500, 916, 916, 0, + runner->Outputs().Tag("CROP_RECT").packets); +} + } // namespace } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto new file mode 100644 index 000000000..b00755d36 --- /dev/null +++ b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe.autoflip; + +import "mediapipe/framework/calculator.proto"; + +message FaceBoxAdjusterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FaceBoxAdjusterCalculatorOptions ext = 347462240; + } + + // When faces are detected in a given frame, we check these number of frames + // in the past. We include only those faces in auto framing that have been + // seen in this past history. This helps reduce False Positives and also + // handles some of the edge cases. Setting the value to 0 disables the + // feature. + optional int32 num_frame_history = 1 [default = 0]; + + // IOU threshold for matching detected faces with the faces in the frame + // history buffer. + optional float iou_threshold = 2 [default = 0.2]; + + // If true, the face boxes are adjusted based on their face pose. This is done + // to correct for extreme poses that can cause the detected face boxes to be + // either too big or too small. + optional bool adjust_for_pose = 3 [default = true]; + + // There are DEPRECATED fields. Do not use. + optional float box_area_change_per_up_tilt_degree = 4 [deprecated = true]; + optional float box_area_change_per_down_tilt_degree = 5 [deprecated = true]; + + // The ratios of the face-pose corrected IPD to the face bounding box's width + // and height respectively. + optional float ipd_face_box_width_ratio = 6 [default = 0.5566]; + optional float ipd_face_box_height_ratio = 7 [default = 0.3131]; +} diff --git a/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc index 86f03cf7a..3c9aeb4c8 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc @@ -55,9 +55,9 @@ class FaceToRegionCalculator : public CalculatorBase { FaceToRegionCalculator(const FaceToRegionCalculator&) = delete; FaceToRegionCalculator& operator=(const FaceToRegionCalculator&) = delete; - static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - ::mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - ::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: double NormalizeX(const int pixel); @@ -78,18 +78,17 @@ REGISTER_CALCULATOR(FaceToRegionCalculator); FaceToRegionCalculator::FaceToRegionCalculator() {} -::mediapipe::Status FaceToRegionCalculator::GetContract( +absl::Status FaceToRegionCalculator::GetContract( mediapipe::CalculatorContract* cc) { if (cc->Inputs().HasTag("VIDEO")) { cc->Inputs().Tag("VIDEO").Set(); } cc->Inputs().Tag("FACES").Set>(); cc->Outputs().Tag("REGIONS").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FaceToRegionCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status FaceToRegionCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); if (!cc->Inputs().HasTag("VIDEO")) { RET_CHECK(!options_.use_visual_scorer()) @@ -105,7 +104,7 @@ FaceToRegionCalculator::FaceToRegionCalculator() {} scorer_ = absl::make_unique(options_.scorer_options()); frame_width_ = -1; frame_height_ = -1; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } inline double FaceToRegionCalculator::NormalizeX(const int pixel) { @@ -146,11 +145,10 @@ void FaceToRegionCalculator::ExtendSalientRegionWithPoint( } } -::mediapipe::Status FaceToRegionCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status FaceToRegionCalculator::Process(mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag("VIDEO") && cc->Inputs().Tag("VIDEO").Value().IsEmpty()) { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "No VIDEO input at time " << cc->InputTimestamp().Seconds(); } @@ -280,7 +278,7 @@ void FaceToRegionCalculator::ExtendSalientRegionWithPoint( } cc->Outputs().Tag("REGIONS").Add(region_set.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc index 572d80998..80f0f4552 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc @@ -38,9 +38,9 @@ class LocalizationToRegionCalculator : public mediapipe::CalculatorBase { LocalizationToRegionCalculator& operator=( const LocalizationToRegionCalculator&) = delete; - static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - ::mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - ::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Calculator options. @@ -84,21 +84,21 @@ void FillSalientRegion(const mediapipe::Detection& detection, } // namespace -::mediapipe::Status LocalizationToRegionCalculator::GetContract( +absl::Status LocalizationToRegionCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag("DETECTIONS").Set>(); cc->Outputs().Tag("REGIONS").Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LocalizationToRegionCalculator::Open( +absl::Status LocalizationToRegionCalculator::Open( mediapipe::CalculatorContext* cc) { options_ = cc->Options(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status LocalizationToRegionCalculator::Process( +absl::Status LocalizationToRegionCalculator::Process( mediapipe::CalculatorContext* cc) { const auto& annotations = cc->Inputs().Tag("DETECTIONS").Get>(); @@ -119,7 +119,7 @@ void FillSalientRegion(const mediapipe::Detection& detection, } cc->Outputs().Tag("REGIONS").Add(regions.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index a8ba3eeb9..885753d63 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -68,8 +68,8 @@ constexpr char kOutputSummary[] = "CROPPING_SUMMARY"; constexpr char kExternalRenderingPerFrame[] = "EXTERNAL_RENDERING_PER_FRAME"; constexpr char kExternalRenderingFullVid[] = "EXTERNAL_RENDERING_FULL_VID"; -::mediapipe::Status SceneCroppingCalculator::GetContract( - ::mediapipe::CalculatorContract* cc) { +absl::Status SceneCroppingCalculator::GetContract( + mediapipe::CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kInputExternalSettings)) { cc->InputSidePackets().Tag(kInputExternalSettings).Set(); } @@ -136,10 +136,10 @@ constexpr char kExternalRenderingFullVid[] = "EXTERNAL_RENDERING_FULL_VID"; cc->Outputs().HasTag(kExternalRenderingFullVid) || cc->Outputs().HasTag(kOutputCroppedFrames)) << "At leaset one output stream must be specified"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCroppingCalculator::Open(CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_GT(options_.max_scene_size(), 0) << "Maximum scene size is non-positive."; @@ -175,17 +175,17 @@ constexpr char kExternalRenderingFullVid[] = "EXTERNAL_RENDERING_FULL_VID"; should_perform_frame_cropping_ = cc->Outputs().HasTag(kOutputCroppedFrames); scene_camera_motion_analyzer_ = absl::make_unique( options_.scene_camera_motion_analyzer_options()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } namespace { -::mediapipe::Status ParseAspectRatioString( - const std::string& aspect_ratio_string, double* aspect_ratio) { +absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, + double* aspect_ratio) { std::string error_msg = "Aspect ratio std::string must be in the format of 'width:height', e.g. " "'1:1' or '5:4', your input was " + aspect_ratio_string; - auto pos = aspect_ratio_string.find(":"); + auto pos = aspect_ratio_string.find(':'); RET_CHECK(pos != std::string::npos) << error_msg; double width_ratio; RET_CHECK(absl::SimpleAtod(aspect_ratio_string.substr(0, pos), &width_ratio)) @@ -196,7 +196,7 @@ namespace { &height_ratio)) << error_msg; *aspect_ratio = width_ratio / height_ratio; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, @@ -235,8 +235,8 @@ int RoundToEven(float value) { } // namespace -::mediapipe::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( - ::mediapipe::CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( + mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag(kInputVideoFrames)) { const auto& frame = cc->Inputs().Tag(kInputVideoFrames).Get(); frame_width_ = frame.Width(); @@ -248,7 +248,7 @@ int RoundToEven(float value) { frame_height_ = cc->Inputs().Tag(kInputVideoSize).Get>().second; } else { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input VIDEO or VIDEO_SIZE must be provided."; } RET_CHECK_GT(frame_height_, 0) << "Input frame height is non-positive."; @@ -302,8 +302,7 @@ int RoundToEven(float value) { target_height_ = frame_height_; break; case SceneCroppingCalculatorOptions::UNKNOWN: - return mediapipe::InvalidArgumentError( - "target_size_type not set properly."); + return absl::InvalidArgumentError("target_size_type not set properly."); } target_aspect_ratio_ = GetRatio(target_width_, target_height_); @@ -337,18 +336,18 @@ int RoundToEven(float value) { scene_cropper_ = absl::make_unique( options_.camera_motion_options(), frame_width_, frame_height_); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -bool HasFrameSignal(::mediapipe::CalculatorContext* cc) { +bool HasFrameSignal(mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag(kInputVideoFrames)) { return !cc->Inputs().Tag(kInputVideoFrames).Value().IsEmpty(); } return !cc->Inputs().Tag(kInputVideoSize).Value().IsEmpty(); } -::mediapipe::Status SceneCroppingCalculator::Process( - ::mediapipe::CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::Process( + mediapipe::CalculatorContext* cc) { // Sets frame dimension and initializes scenecroppingcalculator on first video // frame. if (frame_width_ < 0) { @@ -417,11 +416,10 @@ bool HasFrameSignal(::mediapipe::CalculatorContext* cc) { continue_last_scene_ = true; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCroppingCalculator::Close( - ::mediapipe::CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::Close(mediapipe::CalculatorContext* cc) { if (!scene_frame_timestamps_.empty()) { MP_RETURN_IF_ERROR(ProcessScene(/* is_end_of_scene = */ true, cc)); } @@ -435,12 +433,12 @@ bool HasFrameSignal(::mediapipe::CalculatorContext* cc) { .Tag(kExternalRenderingFullVid) .Add(external_render_list_.release(), Timestamp::PostStream()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // TODO: split this function into two, one for calculating the border // sizes, the other for the actual removal of borders from the frames. -::mediapipe::Status SceneCroppingCalculator::RemoveStaticBorders( +absl::Status SceneCroppingCalculator::RemoveStaticBorders( CalculatorContext* cc, int* top_border_size, int* bottom_border_size) { *top_border_size = 0; *bottom_border_size = 0; @@ -492,11 +490,10 @@ bool HasFrameSignal(::mediapipe::CalculatorContext* cc) { *key_frame_infos_[i].mutable_detections() = adjusted_detections; } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status -SceneCroppingCalculator::InitializeFrameCropRegionComputer() { +absl::Status SceneCroppingCalculator::InitializeFrameCropRegionComputer() { key_frame_crop_options_ = options_.key_frame_crop_options(); MP_RETURN_IF_ERROR( SetKeyFrameCropTarget(frame_width_, effective_frame_height_, @@ -505,7 +502,7 @@ SceneCroppingCalculator::InitializeFrameCropRegionComputer() { VLOG(1) << "Target height " << key_frame_crop_options_.target_height(); frame_crop_region_computer_ = absl::make_unique(key_frame_crop_options_); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void SceneCroppingCalculator::FilterKeyFrameInfo() { @@ -531,8 +528,8 @@ void SceneCroppingCalculator::FilterKeyFrameInfo() { } } -::mediapipe::Status SceneCroppingCalculator::ProcessScene( - const bool is_end_of_scene, CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, + CalculatorContext* cc) { // Removes detections under special circumstances. FilterKeyFrameInfo(); @@ -654,10 +651,10 @@ void SceneCroppingCalculator::FilterKeyFrameInfo() { is_key_frames_.clear(); static_features_.clear(); static_features_timestamps_.clear(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( +absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( const int crop_width, const int crop_height, const int num_frames, std::vector* render_to_locations, bool* apply_padding, std::vector* padding_colors, float* vertical_fill_percent, @@ -730,7 +727,7 @@ void SceneCroppingCalculator::FilterKeyFrameInfo() { padding_colors->push_back(padding_color_to_add); } if (!cropped_frames_ptr) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Resizes cropped frames, pads frames, and output frames. @@ -773,10 +770,10 @@ void SceneCroppingCalculator::FilterKeyFrameInfo() { .Add(scaled_frame.release(), timestamp); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::OutputVizFrames( +absl::Status SceneCroppingCalculator::OutputVizFrames( const std::vector& key_frame_crop_results, const std::vector& focus_point_frames, const std::vector& crop_from_locations, @@ -816,7 +813,7 @@ mediapipe::Status SceneCroppingCalculator::OutputVizFrames( .Add(viz_frames[i].release(), Timestamp(scene_frame_timestamps_[i])); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(SceneCroppingCalculator); diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h index 1c00e6210..61b7b53d6 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h @@ -125,35 +125,34 @@ namespace autoflip { // fields are optional with default settings. class SceneCroppingCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Validates calculator options and initializes SceneCameraMotionAnalyzer and // SceneCropper. - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Buffers each scene frame and its timestamp. Packs and stores KeyFrameInfo // for key frames (a.k.a. frames with detection features). When a shot // boundary is encountered or when the buffer is full, calls ProcessScene() // to process the scene at once, and clears buffers. - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Calls ProcessScene() on remaining buffered frames. Optionally outputs a // VideoCroppingSummary if the output stream CROPPING_SUMMARY is present. - ::mediapipe::Status Close(::mediapipe::CalculatorContext* cc) override; + absl::Status Close(mediapipe::CalculatorContext* cc) override; private: // Removes any static borders from the scene frames before cropping. The // arguments |top_border_size| and |bottom_border_size| report the size of the // removed borders. - ::mediapipe::Status RemoveStaticBorders(CalculatorContext* cc, - int* top_border_size, - int* bottom_border_size); + absl::Status RemoveStaticBorders(CalculatorContext* cc, int* top_border_size, + int* bottom_border_size); // Sets up autoflip after first frame is received and input size is known. - ::mediapipe::Status InitializeSceneCroppingCalculator( - ::mediapipe::CalculatorContext* cc); + absl::Status InitializeSceneCroppingCalculator( + mediapipe::CalculatorContext* cc); // Initializes a FrameCropRegionComputer given input and target frame sizes. - ::mediapipe::Status InitializeFrameCropRegionComputer(); + absl::Status InitializeFrameCropRegionComputer(); // Processes a scene using buffered scene frames and KeyFrameInfos: // 1. Computes key frame crop regions using a FrameCropRegionComputer. @@ -165,8 +164,7 @@ class SceneCroppingCalculator : public CalculatorBase { // to force flush). // 6. Optionally outputs visualization frames. // 7. Optionally updates cropping summary. - ::mediapipe::Status ProcessScene(const bool is_end_of_scene, - CalculatorContext* cc); + absl::Status ProcessScene(const bool is_end_of_scene, CalculatorContext* cc); // Formats and outputs the cropped frames passed in through // |cropped_frames_ptr|. Scales them to be at least as big as the target @@ -177,14 +175,14 @@ class SceneCroppingCalculator : public CalculatorBase { // cropped frames. This is useful when the calculator is only used for // computing the cropping metadata rather than doing the actual cropping // operation. - ::mediapipe::Status FormatAndOutputCroppedFrames( + absl::Status FormatAndOutputCroppedFrames( const int crop_width, const int crop_height, const int num_frames, std::vector* render_to_locations, bool* apply_padding, std::vector* padding_colors, float* vertical_fill_percent, const std::vector* cropped_frames_ptr, CalculatorContext* cc); // Draws and outputs visualization frames if those streams are present. - ::mediapipe::Status OutputVizFrames( + absl::Status OutputVizFrames( const std::vector& key_frame_crop_results, const std::vector& focus_point_frames, const std::vector& crop_from_locations, diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 6cc9217e3..27867d31b 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -803,6 +803,7 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPath) { SceneCroppingCalculatorOptions::ext); auto* kinematic_options = options->mutable_camera_motion_options()->mutable_kinematic_options(); + kinematic_options->set_min_motion_to_reframe(1.2); kinematic_options->set_max_velocity(200); auto runner = absl::make_unique(config); @@ -875,6 +876,7 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPathNoVideo) { SceneCroppingCalculatorOptions::ext); auto* kinematic_options = options->mutable_camera_motion_options()->mutable_kinematic_options(); + kinematic_options->set_min_motion_to_reframe(1.2); kinematic_options->set_max_velocity(2.0); auto runner = absl::make_unique(config); diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc index 8d8e2570a..299f60b10 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc @@ -60,9 +60,9 @@ class ShotBoundaryCalculator : public mediapipe::CalculatorBase { ShotBoundaryCalculator(const ShotBoundaryCalculator&) = delete; ShotBoundaryCalculator& operator=(const ShotBoundaryCalculator&) = delete; - static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Computes the histogram of an image. @@ -98,12 +98,11 @@ void ShotBoundaryCalculator::ComputeHistogram(const cv::Mat& image, kHistogramBinNum, kHistogramRange, true, false); } -mediapipe::Status ShotBoundaryCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status ShotBoundaryCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); last_shot_timestamp_ = Timestamp(0); init_ = false; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, @@ -127,8 +126,7 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, } } -::mediapipe::Status ShotBoundaryCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status ShotBoundaryCalculator::Process(mediapipe::CalculatorContext* cc) { // Connect to input frame and make a mutable copy. cv::Mat frame_org = mediapipe::formats::MatView( &cc->Inputs().Tag(kVideoInputTag).Get()); @@ -142,7 +140,7 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, last_histogram_ = current_histogram; init_ = true; Transmit(cc, false); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } double current_motion_estimate = @@ -152,7 +150,7 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, if (motion_history_.size() != options_.window_size()) { Transmit(cc, false); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Shot detection algorithm is a mixture of adaptive (controlled with @@ -176,14 +174,14 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, // Store histogram for next frame. last_histogram_ = current_histogram; motion_history_.pop_back(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ShotBoundaryCalculator::GetContract( +absl::Status ShotBoundaryCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag(kVideoInputTag).Set(); cc->Outputs().Tag(kShotChangeTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc index 06e5e768b..e2b4f659d 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_core_inc.h" diff --git a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc index 703932938..37643b5d1 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc @@ -105,13 +105,13 @@ class SignalFusingCalculator : public mediapipe::CalculatorBase { SignalFusingCalculator(const SignalFusingCalculator&) = delete; SignalFusingCalculator& operator=(const SignalFusingCalculator&) = delete; - static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Close(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; + absl::Status Close(mediapipe::CalculatorContext* cc) override; private: - mediapipe::Status ProcessScene(mediapipe::CalculatorContext* cc); + absl::Status ProcessScene(mediapipe::CalculatorContext* cc); std::vector GetSignalPackets(mediapipe::CalculatorContext* cc); SignalFusingCalculatorOptions options_; std::map settings_by_type_; @@ -154,8 +154,7 @@ void SetupOrderedInput(mediapipe::CalculatorContract* cc) { } } // namespace -mediapipe::Status SignalFusingCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); for (const auto& setting : options_.signal_settings()) { settings_by_type_[CreateSettingsKey(setting.type())] = setting; @@ -166,19 +165,18 @@ mediapipe::Status SignalFusingCalculator::Open( process_by_scene_ = false; } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SignalFusingCalculator::Close( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Close(mediapipe::CalculatorContext* cc) { if (!scene_frames_.empty()) { MP_RETURN_IF_ERROR(ProcessScene(cc)); scene_frames_.clear(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SignalFusingCalculator::ProcessScene( +absl::Status SignalFusingCalculator::ProcessScene( mediapipe::CalculatorContext* cc) { std::map detection_count; std::map multiframe_score; @@ -240,7 +238,7 @@ mediapipe::Status SignalFusingCalculator::ProcessScene( } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector SignalFusingCalculator::GetSignalPackets( @@ -260,8 +258,7 @@ std::vector SignalFusingCalculator::GetSignalPackets( return signal_packets; } -mediapipe::Status SignalFusingCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Process(mediapipe::CalculatorContext* cc) { bool is_boundary = false; if (process_by_scene_) { const auto& shot_tag = (tag_input_interface_) @@ -302,17 +299,17 @@ mediapipe::Status SignalFusingCalculator::Process( scene_frames_.clear(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SignalFusingCalculator::GetContract( +absl::Status SignalFusingCalculator::GetContract( mediapipe::CalculatorContract* cc) { if (cc->Inputs().NumEntries(kSignalInputsTag) > 0) { SetupTagInput(cc); } else { SetupOrderedInput(cc); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc index b8af09ce8..8d67eb8f0 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc @@ -57,20 +57,19 @@ class VideoFilteringCalculator : public CalculatorBase { VideoFilteringCalculator() = default; ~VideoFilteringCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(VideoFilteringCalculator); -::mediapipe::Status VideoFilteringCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VideoFilteringCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kInputFrameTag).Set(); cc->Outputs().Tag(kOutputFrameTag).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { +absl::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { const auto& options = cc->Options(); const Packet& input_packet = cc->Inputs().Tag(kInputFrameTag).Value(); @@ -84,7 +83,7 @@ REGISTER_CALCULATOR(VideoFilteringCalculator); if (filter_type == VideoFilteringCalculatorOptions::AspectRatioFilter::NO_FILTERING) { cc->Outputs().Tag(kOutputFrameTag).AddPacket(input_packet); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const int target_width = options.aspect_ratio_filter().target_width(); const int target_height = options.aspect_ratio_filter().target_height(); @@ -92,7 +91,7 @@ REGISTER_CALCULATOR(VideoFilteringCalculator); RET_CHECK_GT(target_height, 0); bool should_pass = false; - cv::Mat frame_mat = ::mediapipe::formats::MatView(&frame); + cv::Mat frame_mat = mediapipe::formats::MatView(&frame); const double ratio = static_cast(frame_mat.cols) / frame_mat.rows; const double target_ratio = static_cast(target_width) / target_height; if (filter_type == VideoFilteringCalculatorOptions::AspectRatioFilter:: @@ -106,16 +105,16 @@ REGISTER_CALCULATOR(VideoFilteringCalculator); } if (should_pass) { cc->Outputs().Tag(kOutputFrameTag).AddPacket(input_packet); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (options.fail_if_any()) { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << absl::Substitute( + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << absl::Substitute( "Failing due to aspect ratio. Target aspect ratio: $0. Frame " "width: $1, height: $2.", target_ratio, frame.Width(), frame.Height()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc index 4f907d001..758193832 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc @@ -166,8 +166,8 @@ TEST(VerticalFrameRemovalCalculatorTest, OutputError) { runner->MutableInputs() ->Tag("INPUT_FRAMES") .packets.push_back(Adopt(input_frame.release()).At(Timestamp(1000))); - ::mediapipe::Status status = runner->Run(); - EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown); + absl::Status status = runner->Run(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnknown); EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Failing due to aspect ratio")); } diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index a6e79c3a3..4a5ac3b7a 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -249,6 +249,7 @@ cc_test( ":scene_camera_motion_analyzer", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", diff --git a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc index 0b57cd0da..5916d1829 100644 --- a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc @@ -22,7 +22,7 @@ namespace mediapipe { namespace autoflip { -::mediapipe::Status FrameCropRegionComputer::ExpandSegmentUnderConstraint( +absl::Status FrameCropRegionComputer::ExpandSegmentUnderConstraint( const Segment& segment_to_add, const Segment& base_segment, const int max_length, Segment* combined_segment, CoverType* cover_type) const { @@ -75,10 +75,10 @@ namespace autoflip { *combined_segment = std::make_pair(combined_segment_left, combined_segment_right); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FrameCropRegionComputer::ExpandRectUnderConstraints( +absl::Status FrameCropRegionComputer::ExpandRectUnderConstraints( const Rect& rect_to_add, const int max_width, const int max_height, Rect* base_rect, CoverType* cover_type) const { RET_CHECK(base_rect != nullptr) << "Base rect is null."; @@ -129,7 +129,7 @@ namespace autoflip { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void FrameCropRegionComputer::UpdateCropRegionScore( @@ -167,7 +167,7 @@ void FrameCropRegionComputer::UpdateCropRegionScore( } } -::mediapipe::Status FrameCropRegionComputer::ComputeFrameCropRegion( +absl::Status FrameCropRegionComputer::ComputeFrameCropRegion( const KeyFrameInfo& frame_info, KeyFrameCropResult* crop_result) const { RET_CHECK(crop_result != nullptr) << "KeyFrameCropResult is null."; @@ -254,7 +254,7 @@ void FrameCropRegionComputer::UpdateCropRegionScore( crop_result->set_region_is_empty(crop_region_is_empty); crop_result->set_region_score(crop_region_score); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h index d7adbc41c..b2be9e28c 100644 --- a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h +++ b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h @@ -43,8 +43,8 @@ class FrameCropRegionComputer { // consider static features, and simply tries to fit the detected features // within the target frame size. The score of the crop region is aggregated // from individual feature scores given the score aggregation type. - ::mediapipe::Status ComputeFrameCropRegion( - const KeyFrameInfo& frame_info, KeyFrameCropResult* crop_result) const; + absl::Status ComputeFrameCropRegion(const KeyFrameInfo& frame_info, + KeyFrameCropResult* crop_result) const; protected: // A segment is a 1-d object defined by its left and right point. @@ -75,10 +75,11 @@ class FrameCropRegionComputer { // fraction of the new segment exceeds the maximum length. // In this case the combined segment is the base segment, and cover // type is NOT_COVERED. - ::mediapipe::Status ExpandSegmentUnderConstraint( - const Segment& segment_to_add, const Segment& base_segment, - const int max_length, Segment* combined_segment, - CoverType* cover_type) const; + absl::Status ExpandSegmentUnderConstraint(const Segment& segment_to_add, + const Segment& base_segment, + const int max_length, + Segment* combined_segment, + CoverType* cover_type) const; // Expands a base rectangle to cover a new rectangle to be added under width // and height constraints. The operation is best-effort. It considers @@ -87,11 +88,10 @@ class FrameCropRegionComputer { // FULLY_COVERED if the new rectangle is fully covered in both directions, // PARTIALLY_COVERED if it is at least partially covered in both directions, // and NOT_COVERED if it is not covered in either direction. - ::mediapipe::Status ExpandRectUnderConstraints(const Rect& rect_to_add, - const int max_width, - const int max_height, - Rect* base_rect, - CoverType* cover_type) const; + absl::Status ExpandRectUnderConstraints(const Rect& rect_to_add, + const int max_width, + const int max_height, Rect* base_rect, + CoverType* cover_type) const; // Updates crop region score given current feature score, whether the feature // is required, and the score aggregation type. Ignores negative scores. diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc index 573c990d7..899724921 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc @@ -2,11 +2,83 @@ namespace mediapipe { namespace autoflip { +namespace { +int Median(const std::deque>& positions_raw) { + std::deque positions; + for (const auto& position : positions_raw) { + positions.push_back(position.second); + } -::mediapipe::Status KinematicPathSolver::AddObservation(int position, - const uint64 time_us) { + size_t n = positions.size() / 2; + nth_element(positions.begin(), positions.begin() + n, positions.end()); + return positions[n]; +} +} // namespace +bool KinematicPathSolver::IsMotionTooSmall(double delta_degs) { + if (options_.has_min_motion_to_reframe()) { + return abs(delta_degs) < options_.min_motion_to_reframe(); + } else if (delta_degs > 0) { + return delta_degs < options_.min_motion_to_reframe_upper(); + } else { + return abs(delta_degs) < options_.min_motion_to_reframe_lower(); + } +} +void KinematicPathSolver::ClearHistory() { raw_positions_at_time_.clear(); } +absl::Status KinematicPathSolver::PredictMotionState(int position, + const uint64 time_us, + bool* state) { if (!initialized_) { - current_position_px_ = position; + *state = false; + return absl::OkStatus(); + } + + auto raw_positions_at_time_copy = raw_positions_at_time_; + + raw_positions_at_time_copy.push_front( + std::pair(time_us, position)); + while (raw_positions_at_time_copy.size() > 1) { + if (static_cast(raw_positions_at_time_copy.back().first) < + static_cast(time_us) - options_.filtering_time_window_us()) { + raw_positions_at_time_copy.pop_back(); + } else { + break; + } + } + + int filtered_position = Median(raw_positions_at_time_copy); + double delta_degs = + (filtered_position - current_position_px_) / pixels_per_degree_; + + // If the motion is smaller than the min_motion_to_reframe and camera is + // stationary, don't use the update. + if (IsMotionTooSmall(delta_degs) && !motion_state_) { + *state = false; + } else if (abs(delta_degs) < options_.reframe_window() && motion_state_) { + // If the motion is smaller than the reframe_window and camera is moving, + // don't use the update. + *state = false; + } else { + // Apply new position, plus the reframe window size. + *state = true; + } + + return absl::OkStatus(); +} +absl::Status KinematicPathSolver::AddObservation(int position, + const uint64 time_us) { + if (!initialized_) { + if (position < min_location_) { + current_position_px_ = min_location_; + } else if (position > max_location_) { + current_position_px_ = max_location_; + } else { + current_position_px_ = position; + } + target_position_px_ = position; + motion_state_ = false; + mean_delta_t_ = -1; + raw_positions_at_time_.push_front( + std::pair(time_us, position)); current_time_ = time_us; initialized_ = true; current_velocity_deg_per_s_ = 0; @@ -14,28 +86,70 @@ namespace autoflip { << "pixels_per_degree must be larger than 0."; RET_CHECK_GE(options_.update_rate_seconds(), 0) << "update_rate_seconds must be greater than 0."; - RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) - << "Reframe window cannot exceed min_motion_to_reframe."; - return ::mediapipe::OkStatus(); + RET_CHECK_GE(options_.filtering_time_window_us(), 0) + << "update_rate_seconds must be greater than 0."; + RET_CHECK_GE(options_.mean_period_update_rate(), 0) + << "mean_period_update_rate must be greater than 0."; + RET_CHECK(options_.has_min_motion_to_reframe() ^ + (options_.has_min_motion_to_reframe_upper() && + options_.has_min_motion_to_reframe_lower())) + << "Must set min_motion_to_reframe or min_motion_to_reframe_upper and " + "min_motion_to_reframe_lower."; + if (options_.has_min_motion_to_reframe()) { + RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + } else { + RET_CHECK_GE(options_.min_motion_to_reframe_upper(), + options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + RET_CHECK_GE(options_.min_motion_to_reframe_lower(), + options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + } + return absl::OkStatus(); } RET_CHECK(current_time_ < time_us) << "Observation added before a prior observations."; - double delta_degs = (position - current_position_px_) / pixels_per_degree_; + raw_positions_at_time_.push_front(std::pair(time_us, position)); + while (raw_positions_at_time_.size() > 1) { + if (static_cast(raw_positions_at_time_.back().first) < + static_cast(time_us) - options_.filtering_time_window_us()) { + raw_positions_at_time_.pop_back(); + } else { + break; + } + } - // If the motion is smaller than the min, don't use the update. - if (abs(delta_degs) < options_.min_motion_to_reframe()) { - position = current_position_px_; + int filtered_position = Median(raw_positions_at_time_); + double delta_degs = + (filtered_position - current_position_px_) / pixels_per_degree_; + + // If the motion is smaller than the min_motion_to_reframe and camera is + // stationary, don't use the update. + if (IsMotionTooSmall(delta_degs) && !motion_state_) { delta_degs = 0; + motion_state_ = false; + } else if (abs(delta_degs) < options_.reframe_window() && motion_state_) { + // If the motion is smaller than the reframe_window and camera is moving, + // don't use the update. + delta_degs = 0; + motion_state_ = false; } else if (delta_degs > 0) { // Apply new position, less the reframe window size. - position = position - pixels_per_degree_ * options_.reframe_window(); - delta_degs = (position - current_position_px_) / pixels_per_degree_; + target_position_px_ = + filtered_position - pixels_per_degree_ * options_.reframe_window(); + delta_degs = + (target_position_px_ - current_position_px_) / pixels_per_degree_; + motion_state_ = true; } else { // Apply new position, plus the reframe window size. - position = position + pixels_per_degree_ * options_.reframe_window(); - delta_degs = (position - current_position_px_) / pixels_per_degree_; + target_position_px_ = + filtered_position + pixels_per_degree_ * options_.reframe_window(); + delta_degs = + (target_position_px_ - current_position_px_) / pixels_per_degree_; + motion_state_ = true; } // Time and position updates. @@ -56,35 +170,80 @@ namespace autoflip { return UpdatePrediction(time_us); } -::mediapipe::Status KinematicPathSolver::UpdatePrediction(const int64 time_us) { +absl::Status KinematicPathSolver::UpdatePrediction(const int64 time_us) { RET_CHECK(current_time_ < time_us) << "Prediction time added before a prior observation or prediction."; - // Time since last state/prediction update. + + // Time since last state/prediction update, smoothed by + // mean_period_update_rate. double delta_t = (time_us - current_time_) / 1000000.0; + if (mean_delta_t_ < 0) { + mean_delta_t_ = delta_t; + } else { + mean_delta_t_ = mean_delta_t_ * (1 - options_.mean_period_update_rate()) + + delta_t * options_.mean_period_update_rate(); + } // Position update limited by min/max. - - const double update_position_px = + double update_position_px = current_position_px_ + - current_velocity_deg_per_s_ * delta_t * pixels_per_degree_; + current_velocity_deg_per_s_ * mean_delta_t_ * pixels_per_degree_; + if (update_position_px < min_location_) { current_position_px_ = min_location_; current_velocity_deg_per_s_ = 0; + motion_state_ = false; } else if (update_position_px > max_location_) { current_position_px_ = max_location_; current_velocity_deg_per_s_ = 0; + motion_state_ = false; } else { current_position_px_ = update_position_px; } current_time_ = time_us; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status KinematicPathSolver::GetState(int* position) { +absl::Status KinematicPathSolver::GetState(int* position) { RET_CHECK(initialized_) << "GetState called before first observation added."; *position = round(current_position_px_); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); +} + +absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) { + RET_CHECK(initialized_) + << "GetTargetPosition called before first observation added."; + *target_position = round(target_position_px_); + return absl::OkStatus(); +} + +absl::Status KinematicPathSolver::UpdatePixelsPerDegree( + const float pixels_per_degree) { + RET_CHECK_GT(pixels_per_degree_, 0) + << "pixels_per_degree must be larger than 0."; + pixels_per_degree_ = pixels_per_degree; + return absl::OkStatus(); +} + +absl::Status KinematicPathSolver::UpdateMinMaxLocation(const int min_location, + const int max_location) { + RET_CHECK(initialized_) + << "UpdateMinMaxLocation called before first observation added."; + double prior_distance = max_location_ - min_location_; + double updated_distance = max_location - min_location; + double scale_change = updated_distance / prior_distance; + current_position_px_ = current_position_px_ * scale_change; + target_position_px_ = target_position_px_ * scale_change; + max_location_ = max_location; + min_location_ = min_location; + auto original_positions_at_time = raw_positions_at_time_; + raw_positions_at_time_.clear(); + for (auto position_at_time : original_positions_at_time) { + position_at_time.second = position_at_time.second * scale_change; + raw_positions_at_time_.push_front(position_at_time); + } + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h index 2dcd9e520..4f4b896e2 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h @@ -15,6 +15,8 @@ #ifndef MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_QUALITY_UNIFORM_ACCELERATION_PATH_SOLVER_H_ #define MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_QUALITY_UNIFORM_ACCELERATION_PATH_SOLVER_H_ +#include + #include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.pb.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" @@ -41,24 +43,49 @@ class KinematicPathSolver { initialized_(false), pixels_per_degree_(pixels_per_degree) {} // Add an observation (detection) at a position and time. - ::mediapipe::Status AddObservation(int position, const uint64 time_us); + absl::Status AddObservation(int position, const uint64 time_us); // Get the predicted position at a time. - ::mediapipe::Status UpdatePrediction(const int64 time_us); + absl::Status UpdatePrediction(const int64 time_us); // Get the state at a time. - ::mediapipe::Status GetState(int* position); + absl::Status GetState(int* position); + // Update PixelPerDegree value. + absl::Status UpdatePixelsPerDegree(const float pixels_per_degree); + // Provide the current target position of the reframe action. + absl::Status GetTargetPosition(int* target_position); + // Change min/max location and update state based on new scaling. + absl::Status UpdateMinMaxLocation(const int min_location, + const int max_location); + // Check if motion is within the reframe window, return false if not. + bool IsMotionTooSmall(double delta_degs); + // Check if a position measurement will cause the camera to be in motion + // without updating the internal state. + absl::Status PredictMotionState(int position, const uint64 time_us, + bool* state); + // Clear any history buffer of positions that are used when + // filtering_time_window_us is set to a non-zero value. + void ClearHistory(); private: // Tuning options. KinematicOptions options_; // Min and max value the state can be. - const int min_location_; - const int max_location_; + int min_location_; + int max_location_; bool initialized_; float pixels_per_degree_; // Current state values. double current_position_px_; double current_velocity_deg_per_s_; uint64 current_time_; + // History of observations (second) and their time (first). + std::deque> raw_positions_at_time_; + // Current target position. + double target_position_px_; + // Defines if the camera is moving to a target (true) or reached a target + // within a tolerance (false). + bool motion_state_; + // Average period of incoming frames. + double mean_delta_t_; }; } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto index ac2595328..9f481db6d 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto @@ -8,8 +8,14 @@ message KinematicOptions { optional double update_rate = 1 [default = 0.5, deprecated = true]; // Max velocity (degrees per second) that the camera can move. optional double max_velocity = 2 [default = 18]; - // Min motion (in degrees) to react in pixels. - optional float min_motion_to_reframe = 3 [default = 1.8]; + // Min motion (in degrees) to react for both upper and lower directions. Must + // not be set if using min_motion_to_reframe_lower and + // min_motion_to_reframe_upper. + optional float min_motion_to_reframe = 3; + // Min motion (in degrees) for upper and lower direction to react. Both must + // be set and min_motion_to_reframe cannot be set if these are specified. + optional float min_motion_to_reframe_lower = 9; + optional float min_motion_to_reframe_upper = 10; // When motion exceeds min_motion_to_reframe, move within this distance of the // camera from the starting direction. Setting this value non-zero reduces // total reframe distance on average. Value cannot exceed @@ -20,4 +26,8 @@ message KinematicOptions { // where delta_time_s is the time since the last frame. optional double update_rate_seconds = 5 [default = 0.20]; optional double max_update_rate = 6 [default = 0.8]; + // History time window of observations to be median filtered. + optional int64 filtering_time_window_us = 7 [default = 0]; + // Weighted update of average period, used for motion updates. + optional float mean_period_update_rate = 8 [default = 0.25]; } diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc index 0bdfb50d2..d6f14cce4 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc @@ -81,6 +81,46 @@ TEST(KinematicPathSolverTest, PassNotEnoughMotionSmallImg) { EXPECT_EQ(state, 400); } +TEST(KinematicPathSolverTest, PassEnoughMotionFiltered) { + KinematicOptions options; + // Set min motion to 2deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate(1); + options.set_max_velocity(1000); + options.set_filtering_time_window_us(3000000); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 2)); + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 3)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to not move. + EXPECT_EQ(state, 500); +} + +TEST(KinematicPathSolverTest, PassEnoughMotionNotFiltered) { + KinematicOptions options; + // Set min motion to 2deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate(1); + options.set_max_velocity(1000); + options.set_filtering_time_window_us(0); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 2)); + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 3)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to not move. + EXPECT_EQ(state, 506); +} + TEST(KinematicPathSolverTest, PassEnoughMotionLargeImg) { KinematicOptions options; // Set min motion to 1deg @@ -147,7 +187,51 @@ TEST(KinematicPathSolverTest, PassReframeWindow) { MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); MP_ASSERT_OK(solver.GetState(&state)); // Expect cam to move 1.2-.75 deg, * 16.6 = 7.47px + 500 = - EXPECT_EQ(state, 507); + EXPECT_EQ(state, 508); +} + +TEST(KinematicPathSolverTest, PassReframeWindowLowerUpper) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe_upper(1.3); + options.set_min_motion_to_reframe_lower(1.0); + options.set_update_rate_seconds(.0000001); + options.set_max_update_rate(1.0); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(0.75); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to not move + EXPECT_EQ(state, 500); + MP_ASSERT_OK(solver.AddObservation(480, kMicroSecInSec * 2)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to move + EXPECT_EQ(state, 493); +} + +TEST(KinematicPathSolverTest, PassCheckState) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate_seconds(.0000001); + options.set_max_update_rate(1.0); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(0.75); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + bool motion_state; + MP_ASSERT_OK( + solver.PredictMotionState(520, kMicroSecInSec * 1, &motion_state)); + EXPECT_TRUE(motion_state); } TEST(KinematicPathSolverTest, PassUpdateRate30FPS) { @@ -187,13 +271,37 @@ TEST(KinematicPathSolverTest, PassUpdateRate) { options.set_max_update_rate(1.0); options.set_max_velocity(18); KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); - int state; + int state, target_position; MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 500); MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 520); MP_ASSERT_OK(solver.GetState(&state)); EXPECT_EQ(state, 505); } +TEST(KinematicPathSolverTest, PassUpdateRateResolutionChange) { + KinematicOptions options; + options.set_min_motion_to_reframe(1.0); + options.set_update_rate_seconds(4); + options.set_max_update_rate(1.0); + options.set_max_velocity(18); + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state, target_position; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 500); + MP_ASSERT_OK(solver.UpdateMinMaxLocation(0, 500)); + MP_ASSERT_OK(solver.UpdatePixelsPerDegree(500.0 / kWidthFieldOfView)); + MP_ASSERT_OK(solver.AddObservation(520 * 0.5, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 520 * 0.5); + MP_ASSERT_OK(solver.GetState(&state)); + EXPECT_EQ(state, 253); +} + TEST(KinematicPathSolverTest, PassMaxVelocity) { KinematicOptions options; options.set_min_motion_to_reframe(1.0); @@ -207,6 +315,28 @@ TEST(KinematicPathSolverTest, PassMaxVelocity) { EXPECT_EQ(state, 600); } +TEST(KinematicPathSolverTest, PassDegPerPxChange) { + KinematicOptions options; + // Set min motion to 2deg + options.set_min_motion_to_reframe(2.0); + options.set_update_rate(1); + options.set_max_velocity(1000); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to not move. + EXPECT_EQ(state, 500); + MP_ASSERT_OK(solver.UpdatePixelsPerDegree(500.0 / kWidthFieldOfView)); + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 2)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to move. + EXPECT_EQ(state, 516); +} + } // namespace } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc index 3da821f08..3d489c395 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc @@ -45,7 +45,7 @@ PaddingEffectGenerator::PaddingEffectGenerator(const int input_width, } } -::mediapipe::Status PaddingEffectGenerator::Process( +absl::Status PaddingEffectGenerator::Process( const ImageFrame& input_frame, const float background_contrast, const int blur_cv_size, const float overlay_opacity, ImageFrame* output_frame, const cv::Scalar* background_color_in_rgb) { @@ -170,7 +170,7 @@ PaddingEffectGenerator::PaddingEffectGenerator(const int input_width, output_frame->CopyPixelData(input_frame.Format(), canvas.cols, canvas.rows, canvas.data, ImageFrame::kDefaultAlignmentBoundary); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } cv::Rect PaddingEffectGenerator::ComputeOutputLocation() { diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h index 679f01a68..2d33593b6 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h @@ -49,11 +49,10 @@ class PaddingEffectGenerator { // the opacity of the black layer. // - background_color_in_rgb: If not null, uses this solid color as background // instead of blurring the image, and does not adjust contrast or opacity. - ::mediapipe::Status Process( - const ImageFrame& input_frame, const float background_contrast, - const int blur_cv_size, const float overlay_opacity, - ImageFrame* output_frame, - const cv::Scalar* background_color_in_rgb = nullptr); + absl::Status Process(const ImageFrame& input_frame, + const float background_contrast, const int blur_cv_size, + const float overlay_opacity, ImageFrame* output_frame, + const cv::Scalar* background_color_in_rgb = nullptr); // Compute the "render location" on the output frame where the "crop from" // location is to be placed. For use with external rendering soutions. diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index 0bf5c0960..fcdcf4b09 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -48,12 +48,14 @@ const cv::Scalar kRed = cv::Scalar(255, 0, 0); void TestWithAspectRatio(const double aspect_ratio, const cv::Scalar* background_color_in_rgb = nullptr) { std::string test_image; - const bool process_arbitrary_image = !FLAGS_input_image.empty(); + const bool process_arbitrary_image = + !absl::GetFlag(FLAGS_input_image).empty(); if (!process_arbitrary_image) { std::string test_image_path = mediapipe::file::JoinPath("./", kTestImage); MP_ASSERT_OK(mediapipe::file::GetContents(test_image_path, &test_image)); } else { - MP_ASSERT_OK(mediapipe::file::GetContents(FLAGS_input_image, &test_image)); + MP_ASSERT_OK(mediapipe::file::GetContents(absl::GetFlag(FLAGS_input_image), + &test_image)); } const std::vector contents_vector(test_image.begin(), test_image.end()); @@ -72,11 +74,11 @@ void TestWithAspectRatio(const double aspect_ratio, cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGB); break; case 4: - MP_ASSERT_OK(::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + MP_ASSERT_OK(mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) << "4-channel image isn't supported yet"); break; default: - MP_ASSERT_OK(::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + MP_ASSERT_OK(mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Unsupported number of channels: " << decoded_mat.channels()); } @@ -101,11 +103,11 @@ void TestWithAspectRatio(const double aspect_ratio, cv::cvtColor(original_mat, input_mat, cv::COLOR_RGB2BGR); break; case 4: - MP_ASSERT_OK(::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + MP_ASSERT_OK(mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) << "4-channel image isn't supported yet"); break; default: - MP_ASSERT_OK(::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + MP_ASSERT_OK(mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Unsupported number of channels: " << original_mat.channels()); } @@ -120,7 +122,7 @@ void TestWithAspectRatio(const double aspect_ratio, // Check its JpegEncoder::write() in "imgcodecs/src/grfmt_jpeg.cpp" for more // info. if (!cv::imencode(".jpg", input_mat, encode_buffer, parameters)) { - MP_ASSERT_OK(::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) + MP_ASSERT_OK(mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) << "Fail to encode the image to be jpeg format."); } @@ -138,7 +140,7 @@ void TestWithAspectRatio(const double aspect_ratio, EXPECT_EQ(result_image, output_string); } else { std::string output_string_path = mediapipe::file::JoinPath( - FLAGS_output_folder, + absl::GetFlag(FLAGS_output_folder), absl::StrCat("result_", aspect_ratio, background_color_in_rgb ? "_solid_background" : "", ".jpg")); diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc index b038b0f3c..dd30566c2 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc @@ -91,7 +91,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem( problem->AddResidualBlock(cost_function, new CauchyLoss(0.5), a, b, c, d, k); } -::mediapipe::Status PolynomialRegressionPathSolver::ComputeCameraPath( +absl::Status PolynomialRegressionPathSolver::ComputeCameraPath( const std::vector& focus_point_frames, const std::vector& prior_focus_point_frames, const int original_width, const int original_height, const int output_width, @@ -163,7 +163,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem( } all_transforms->push_back(transform); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h index 514f8760d..a510169db 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h @@ -42,7 +42,7 @@ class PolynomialRegressionPathSolver { // y-axis, such that focus points can be preserved as much as possible. The // returned |all_transforms| hold the camera location at each timestamp // corresponding to each input frame. - ::mediapipe::Status ComputeCameraPath( + absl::Status ComputeCameraPath( const std::vector& focus_point_frames, const std::vector& prior_focus_point_frames, const int original_width, const int original_height, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc index 34f1a4ee6..0bfe72548 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc @@ -30,8 +30,7 @@ namespace mediapipe { namespace autoflip { -::mediapipe::Status -SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( +absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -67,7 +66,7 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( scene_frame_timestamps, focus_point_frames); } -::mediapipe::Status SceneCameraMotionAnalyzer::ToUseSteadyMotion( +absl::Status SceneCameraMotionAnalyzer::ToUseSteadyMotion( const float look_at_center_x, const float look_at_center_y, const int crop_window_width, const int crop_window_height, SceneKeyFrameCropSummary* scene_summary, @@ -77,10 +76,10 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( auto* steady_motion = scene_camera_motion->mutable_steady_motion(); steady_motion->set_steady_look_at_center_x(look_at_center_x); steady_motion->set_steady_look_at_center_y(look_at_center_y); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( +absl::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( const float start_x, const float start_y, const float end_x, const float end_y, const int crop_window_width, const int crop_window_height, const double time_duration_in_sec, @@ -99,10 +98,10 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( scene_summary->frame_success_rate(), start_x, start_y, end_x, end_y, time_duration_in_sec); VLOG(1) << sweeping_log; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( +absl::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, const double scene_span_sec, const int64 end_time_us, SceneKeyFrameCropSummary* scene_summary, @@ -131,7 +130,7 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( no_salient_position_x, no_salient_position_y, scene_summary->crop_window_width(), scene_summary->crop_window_height(), scene_summary, scene_camera_motion)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Sweep across the scene when 1) success rate is too low, AND 2) the current @@ -164,7 +163,7 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( start_x, start_y, end_x, end_y, key_frame_crop_options.target_width(), key_frame_crop_options.target_height(), scene_span_sec, scene_summary, scene_camera_motion)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // If scene motion is small, then look at a steady point in the scene. @@ -179,14 +178,14 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( // Otherwise, tracks the focus regions. scene_camera_motion->mutable_tracking_motion(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // If there is no required focus region, looks at the middle of the center // range, and snaps to the scene center if close. Otherwise, look at the center // of the union of the required focus regions, and ensures the crop region // covers this union. -::mediapipe::Status SceneCameraMotionAnalyzer::DecideSteadyLookAtRegion( +absl::Status SceneCameraMotionAnalyzer::DecideSteadyLookAtRegion( const KeyFrameCropOptions& key_frame_crop_options, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const { @@ -252,11 +251,10 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( MP_RETURN_IF_ERROR(ToUseSteadyMotion(center_x, center_y, crop_width, crop_height, scene_summary, scene_camera_motion)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status -SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( +absl::Status SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( const float center_x, const float center_y, const int frame_width, const int frame_height, const FocusPointFrameType type, const float weight, const float bound, FocusPointFrame* focus_point_frame) const { @@ -294,10 +292,10 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( } else { RET_CHECK_FAIL() << absl::StrCat("Invalid FocusPointFrameType ", type); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( +absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, const std::vector& scene_frame_timestamps, @@ -340,7 +338,7 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( options_.salient_point_bound(), &focus_point_frame)); focus_point_frames->push_back(focus_point_frame); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else if (scene_camera_motion.has_sweeping_motion()) { // Camera sweeps across the frame. const auto& sweeping_motion = scene_camera_motion.sweeping_motion(); @@ -361,7 +359,7 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( options_.salient_point_bound(), &focus_point_frame)); focus_point_frames->push_back(focus_point_frame); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else if (scene_camera_motion.has_tracking_motion()) { // Camera tracks crop regions. RET_CHECK_GT(scene_summary.num_key_frames(), 0) << "No key frames."; @@ -369,8 +367,7 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( scene_summary, focus_point_frame_type, scene_frame_timestamps, focus_point_frames); } else { - return ::mediapipe::Status(StatusCode::kInvalidArgument, - "Unknown motion type."); + return absl::Status(StatusCode::kInvalidArgument, "Unknown motion type."); } } @@ -380,8 +377,7 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( // The weight for the focus point is proportional to the interpolated score // and scaled so that the maximum weight is equal to // maximum_focus_point_weight in the SceneCameraMotionAnalyzerOptions. -::mediapipe::Status -SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( +absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, const std::vector& scene_frame_timestamps, @@ -440,7 +436,7 @@ SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( focus_point->set_weight(scale * focus_point->weight()); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h index 4aca2108c..8688e16ed 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h @@ -62,7 +62,7 @@ class SceneCameraMotionAnalyzer { // Aggregates information from KeyFrameInfos and KeyFrameCropResults into // SceneKeyFrameCropSummary, and populates FocusPointFrames given scene // frame timestamps. Optionally returns SceneCameraMotion. - ::mediapipe::Status AnalyzeSceneAndPopulateFocusPointFrames( + absl::Status AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -75,7 +75,7 @@ class SceneCameraMotionAnalyzer { protected: // Decides SceneCameraMotion based on SceneKeyFrameCropSummary. Updates the // crop window in SceneKeyFrameCropSummary in the case of steady motion. - ::mediapipe::Status DecideCameraMotionType( + absl::Status DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, const double scene_span_sec, const int64 end_time_us, SceneKeyFrameCropSummary* scene_summary, @@ -83,7 +83,7 @@ class SceneCameraMotionAnalyzer { // Populates the FocusPointFrames for each scene frame based on // SceneKeyFrameCropSummary, SceneCameraMotion, and scene frame timestamps. - ::mediapipe::Status PopulateFocusPointFrames( + absl::Status PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, const std::vector& scene_frame_timestamps, @@ -91,7 +91,7 @@ class SceneCameraMotionAnalyzer { private: // Decides the look-at region when camera is steady. - ::mediapipe::Status DecideSteadyLookAtRegion( + absl::Status DecideSteadyLookAtRegion( const KeyFrameCropOptions& key_frame_crop_options, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const; @@ -105,7 +105,7 @@ class SceneCameraMotionAnalyzer { // Adds FocusPoint(s) to given FocusPointFrame given center location, // frame size, FocusPointFrameType, weight, and bound. - ::mediapipe::Status AddFocusPointsFromCenterTypeAndWeight( + absl::Status AddFocusPointsFromCenterTypeAndWeight( const float center_x, const float center_y, const int frame_width, const int frame_height, const FocusPointFrameType type, const float weight, const float bound, @@ -114,21 +114,22 @@ class SceneCameraMotionAnalyzer { // Populates the FocusPointFrames for each scene frame based on // SceneKeyFrameCropSummary and scene frame timestamps in the case where // camera is tracking the crop regions. - ::mediapipe::Status PopulateFocusPointFramesForTracking( + absl::Status PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const; // Decide to use steady motion. - ::mediapipe::Status ToUseSteadyMotion( - const float look_at_center_x, const float look_at_center_y, - const int crop_window_width, const int crop_window_height, - SceneKeyFrameCropSummary* scene_summary, - SceneCameraMotion* scene_camera_motion) const; + absl::Status ToUseSteadyMotion(const float look_at_center_x, + const float look_at_center_y, + const int crop_window_width, + const int crop_window_height, + SceneKeyFrameCropSummary* scene_summary, + SceneCameraMotion* scene_camera_motion) const; // Decide to use sweeping motion. - ::mediapipe::Status ToUseSweepingMotion( + absl::Status ToUseSweepingMotion( const float start_x, const float start_y, const float end_x, const float end_y, const int crop_window_width, const int crop_window_height, const double time_duration_in_sec, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc index f24a2f22d..1e8805b09 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc @@ -24,6 +24,7 @@ #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc index 420cb8146..a3c6f17c6 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc @@ -29,7 +29,7 @@ constexpr float kWidthFieldOfView = 60; namespace mediapipe { namespace autoflip { -::mediapipe::Status SceneCropper::ProcessKinematicPathSolver( +absl::Status SceneCropper::ProcessKinematicPathSolver( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -77,10 +77,10 @@ namespace autoflip { -(x_path - scene_summary.crop_window_width() / 2); all_xforms->push_back(transform); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SceneCropper::CropFrames( +absl::Status SceneCropper::CropFrames( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -151,7 +151,7 @@ namespace autoflip { // If no cropped_frames is passed in, return directly. if (!cropped_frames) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } RET_CHECK(!scene_frames_or_empty.empty()) << "If |cropped_frames| != nullptr, scene_frames_or_empty must not be " diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h index c99ae59e7..0e5c332db 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h @@ -60,7 +60,7 @@ class SceneCropper { // on the transform matrix if |cropped_frames| is not nullptr and // |scene_frames_or_empty| isn't empty. // TODO: split this function into two separate functions. - ::mediapipe::Status CropFrames( + absl::Status CropFrames( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -71,7 +71,7 @@ class SceneCropper { const bool continue_last_scene, std::vector* crop_from_location, std::vector* cropped_frames); - ::mediapipe::Status ProcessKinematicPathSolver( + absl::Status ProcessKinematicPathSolver( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc index e2be36c08..d99292fa3 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc @@ -46,7 +46,7 @@ const cv::Scalar kOrange = cv::Scalar(255.0, 165.0, 0.0); // ica object detector const cv::Scalar kWhite = cv::Scalar(255.0, 255.0, 255.0); // others -::mediapipe::Status DrawDetectionsAndCropRegions( +absl::Status DrawDetectionsAndCropRegions( const std::vector& scene_frames, const std::vector& is_key_frames, const std::vector& key_frame_infos, @@ -130,7 +130,7 @@ const cv::Scalar kWhite = cv::Scalar(255.0, 255.0, 255.0); // others } viz_frames->push_back(std::move(viz_frame)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } namespace { @@ -147,7 +147,7 @@ cv::Rect LimitBounds(const cv::Rect& rect, const int max_width, } } // namespace -::mediapipe::Status DrawDetectionAndFramingWindow( +absl::Status DrawDetectionAndFramingWindow( const std::vector& org_scene_frames, const std::vector& crop_from_locations, const ImageFormat::Format image_format, const float overlay_opacity, @@ -166,10 +166,10 @@ cv::Rect LimitBounds(const cv::Rect& rect, const int max_width, scene_frame(crop_from_bounded).copyTo(darkened(crop_from_bounded)); viz_frames->push_back(std::move(viz_frame)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status DrawFocusPointAndCropWindow( +absl::Status DrawFocusPointAndCropWindow( const std::vector& scene_frames, const std::vector& focus_point_frames, const float overlay_opacity, const int crop_window_width, @@ -215,7 +215,7 @@ cv::Rect LimitBounds(const cv::Rect& rect, const int max_width, } viz_frames->push_back(std::move(viz_frame)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h index e951f2df7..01f8c5de5 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h @@ -36,7 +36,7 @@ namespace autoflip { // magenta, logos are red, ocrs are yellow (foreground) and light yellow // (background), brain objects are cyan, ica objects are orange, and the rest // are white. -::mediapipe::Status DrawDetectionsAndCropRegions( +absl::Status DrawDetectionsAndCropRegions( const std::vector& scene_frames, const std::vector& is_key_frames, const std::vector& key_frame_infos, @@ -47,7 +47,7 @@ namespace autoflip { // Draws the focus point from the given FocusPointFrame and the crop window // centered around it on the scene frame in red. This helps visualize the input // to the retargeter. -::mediapipe::Status DrawFocusPointAndCropWindow( +absl::Status DrawFocusPointAndCropWindow( const std::vector& scene_frames, const std::vector& focus_point_frames, const float overlay_opacity, const int crop_window_width, @@ -57,7 +57,7 @@ namespace autoflip { // Draws the final smoothed path of the camera retargeter by darkening the // removed areas. -::mediapipe::Status DrawDetectionAndFramingWindow( +absl::Status DrawDetectionAndFramingWindow( const std::vector& org_scene_frames, const std::vector& crop_from_locations, const ImageFormat::Format image_format, const float overlay_opacity, diff --git a/mediapipe/examples/desktop/autoflip/quality/utils.cc b/mediapipe/examples/desktop/autoflip/quality/utils.cc index 68db4aa11..7b25930fc 100644 --- a/mediapipe/examples/desktop/autoflip/quality/utils.cc +++ b/mediapipe/examples/desktop/autoflip/quality/utils.cc @@ -53,13 +53,12 @@ void NormalizedRectToRect(const RectF& normalized_location, const int width, ScaleRect(normalized_location, width, height, location); } -::mediapipe::Status ClampRect(const int width, const int height, - Rect* location) { +absl::Status ClampRect(const int width, const int height, Rect* location) { return ClampRect(0, 0, width, height, location); } -::mediapipe::Status ClampRect(const int x0, const int y0, const int x1, - const int y1, Rect* location) { +absl::Status ClampRect(const int x0, const int y0, const int x1, const int y1, + Rect* location) { RET_CHECK(!(location->x() >= x1 || location->x() + location->width() <= x0 || location->y() >= y1 || location->y() + location->height() <= y0)); @@ -74,7 +73,7 @@ void NormalizedRectToRect(const RectF& normalized_location, const int width, location->set_y(clamped_top); location->set_width(std::max(0, clamped_right - clamped_left)); location->set_height(std::max(0, clamped_bottom - clamped_top)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void RectUnion(const Rect& rect_to_add, Rect* rect) { @@ -90,13 +89,13 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { rect->set_height(y2 - y1); } -::mediapipe::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, - const DetectionSet& detections, - const int original_frame_width, - const int original_frame_height, - const int feature_frame_width, - const int feature_frame_height, - KeyFrameInfo* key_frame_info) { +absl::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, + const DetectionSet& detections, + const int original_frame_width, + const int original_frame_height, + const int feature_frame_width, + const int feature_frame_height, + KeyFrameInfo* key_frame_info) { RET_CHECK(key_frame_info != nullptr) << "KeyFrameInfo is null"; RET_CHECK(original_frame_width > 0 && original_frame_height > 0 && feature_frame_width > 0 && feature_frame_height > 0) @@ -136,13 +135,12 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SortDetections( - const DetectionSet& detections, - std::vector* required_regions, - std::vector* non_required_regions) { +absl::Status SortDetections(const DetectionSet& detections, + std::vector* required_regions, + std::vector* non_required_regions) { required_regions->clear(); non_required_regions->clear(); @@ -175,13 +173,13 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { non_required_regions->push_back(detections.detections(original_idx)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status SetKeyFrameCropTarget(const int frame_width, - const int frame_height, - const double target_aspect_ratio, - KeyFrameCropOptions* crop_options) { +absl::Status SetKeyFrameCropTarget(const int frame_width, + const int frame_height, + const double target_aspect_ratio, + KeyFrameCropOptions* crop_options) { RET_CHECK_NE(crop_options, nullptr) << "KeyFrameCropOptions is null."; RET_CHECK_GT(frame_width, 0) << "Frame width is non-positive."; RET_CHECK_GT(frame_height, 0) << "Frame height is non-positive."; @@ -199,10 +197,10 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { : std::round(frame_width / target_aspect_ratio); crop_options->set_target_width(crop_target_width); crop_options->set_target_height(crop_target_height); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AggregateKeyFrameResults( +absl::Status AggregateKeyFrameResults( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -232,7 +230,7 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { // Handles the corner case of no key frames. if (num_key_frames == 0) { scene_summary->set_has_salient_region(false); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } scene_summary->set_num_key_frames(num_key_frames); @@ -328,10 +326,10 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { scene_summary->key_frame_center_min_y()) / scene_frame_height; scene_summary->set_vertical_motion_amount(motion_y); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ComputeSceneStaticBordersSize( +absl::Status ComputeSceneStaticBordersSize( const std::vector& static_features, int* top_border_size, int* bottom_border_size) { RET_CHECK(top_border_size) << "Output top border size is null."; @@ -375,10 +373,10 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { *top_border_size = std::max(0, *top_border_size); *bottom_border_size = std::max(0, *bottom_border_size); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FindSolidBackgroundColor( +absl::Status FindSolidBackgroundColor( const std::vector& static_features, const std::vector& static_features_timestamps, const double min_fraction_solid_background_color, @@ -423,13 +421,13 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { min_fraction_solid_background_color) { *has_solid_background = true; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status AffineRetarget( - const cv::Size& output_size, const std::vector& frames, - const std::vector& affine_projection, - std::vector* cropped_frames) { +absl::Status AffineRetarget(const cv::Size& output_size, + const std::vector& frames, + const std::vector& affine_projection, + std::vector* cropped_frames) { RET_CHECK(frames.size() == affine_projection.size()) << "number of frames and retarget offsets must be the same."; RET_CHECK(cropped_frames->size() == frames.size()) @@ -443,7 +441,7 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { RET_CHECK(affine.rows == 2) << "Affine matrix must be 2x3"; cv::warpAffine(frames[i], (*cropped_frames)[i], affine, output_size); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/utils.h b/mediapipe/examples/desktop/autoflip/quality/utils.h index ec1373ae4..7285e8265 100644 --- a/mediapipe/examples/desktop/autoflip/quality/utils.h +++ b/mediapipe/examples/desktop/autoflip/quality/utils.h @@ -29,31 +29,30 @@ namespace autoflip { // Packs detected features and timestamp (ms) into a KeyFrameInfo object. Scales // features back to the original frame size if features have been detected on a // different frame size. -::mediapipe::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, - const DetectionSet& detections, - const int original_frame_width, - const int original_frame_height, - const int feature_frame_width, - const int feature_frame_height, - KeyFrameInfo* key_frame_info); +absl::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, + const DetectionSet& detections, + const int original_frame_width, + const int original_frame_height, + const int feature_frame_width, + const int feature_frame_height, + KeyFrameInfo* key_frame_info); // Sorts required and non-required salient regions given a detection set. -::mediapipe::Status SortDetections( - const DetectionSet& detections, - std::vector* required_regions, - std::vector* non_required_regions); +absl::Status SortDetections(const DetectionSet& detections, + std::vector* required_regions, + std::vector* non_required_regions); // Sets the target crop size in KeyFrameCropOptions based on frame size and // target aspect ratio so that the target crop size covers the biggest area // possible in the frame. -::mediapipe::Status SetKeyFrameCropTarget(const int frame_width, - const int frame_height, - const double target_aspect_ratio, - KeyFrameCropOptions* crop_options); +absl::Status SetKeyFrameCropTarget(const int frame_width, + const int frame_height, + const double target_aspect_ratio, + KeyFrameCropOptions* crop_options); // Aggregates information from KeyFrameInfos and KeyFrameCropResults into // SceneKeyFrameCropSummary. -::mediapipe::Status AggregateKeyFrameResults( +absl::Status AggregateKeyFrameResults( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -61,7 +60,7 @@ namespace autoflip { // Computes the static top and border size across a scene given a vector of // StaticFeatures over frames. -::mediapipe::Status ComputeSceneStaticBordersSize( +absl::Status ComputeSceneStaticBordersSize( const std::vector& static_features, int* top_border_size, int* bottom_border_size); @@ -70,7 +69,7 @@ namespace autoflip { // background color exceeds given threshold, i.e., // min_fraction_solid_background_color. Builds the background color // interpolation functions in Lab space using input timestamps. -::mediapipe::Status FindSolidBackgroundColor( +absl::Status FindSolidBackgroundColor( const std::vector& static_features, const std::vector& static_features_timestamps, const double min_fraction_solid_background_color, @@ -93,13 +92,12 @@ void NormalizedRectToRect(const RectF& normalized_location, const int width, // Clamps a rectangle to lie within [x0, y0] and [x1, y1]. Returns true if the // rectangle has any overlapping with the target window. -::mediapipe::Status ClampRect(const int x0, const int y0, const int x1, - const int y1, Rect* location); +absl::Status ClampRect(const int x0, const int y0, const int x1, const int y1, + Rect* location); // Convenience function to clamp a rectangle to lie within [0, 0] and // [width, height]. -::mediapipe::Status ClampRect(const int width, const int height, - Rect* location); +absl::Status ClampRect(const int width, const int height, Rect* location); // Enlarges a given rectangle to cover a new rectangle to be added. void RectUnion(const Rect& rect_to_add, Rect* rect); @@ -107,10 +105,10 @@ void RectUnion(const Rect& rect_to_add, Rect* rect); // Performs an affine retarget on a list of input images. Output vector // cropped_frames must be filled with Mats of the same size as output_size and // type. -::mediapipe::Status AffineRetarget( - const cv::Size& output_size, const std::vector& frames, - const std::vector& affine_projection, - std::vector* cropped_frames); +absl::Status AffineRetarget(const cv::Size& output_size, + const std::vector& frames, + const std::vector& affine_projection, + std::vector* cropped_frames); } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc index ce73cf5bf..9ae612004 100644 --- a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc @@ -48,9 +48,9 @@ void CropRectToMat(const cv::Mat& image, cv::Rect* rect) { VisualScorer::VisualScorer(const VisualScorerOptions& options) : options_(options) {} -mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, - const SalientRegion& region, - float* score) const { +absl::Status VisualScorer::CalculateScore(const cv::Mat& image, + const SalientRegion& region, + float* score) const { const float weight_sum = options_.area_weight() + options_.sharpness_weight() + options_.colorfulness_weight(); @@ -67,14 +67,14 @@ mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, region.location_normalized().width() * image.cols, region.location_normalized().height() * image.rows); } else { - return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Unset region location."; } CropRectToMat(image, ®ion_rect); if (region_rect.area() == 0) { *score = 0; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Compute a score based on area covered by this region. @@ -89,7 +89,7 @@ mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, float sharpness_score_result = 0.0; if (options_.sharpness_weight() > kEpsilon) { // TODO: implement a sharpness score or remove this code block. - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "sharpness scorer is not yet implemented, please set weight to " "0.0"; } @@ -108,11 +108,11 @@ mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, if (*score > 1.0f || *score < 0.0f) { LOG(WARNING) << "Score of region outside expected range: " << *score; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VisualScorer::CalculateColorfulness( - const cv::Mat& image, float* colorfulness) const { +absl::Status VisualScorer::CalculateColorfulness(const cv::Mat& image, + float* colorfulness) const { // Convert the image to HSV. cv::Mat image_hsv; cv::cvtColor(image, image_hsv, CV_RGB2HSV); @@ -134,7 +134,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( // If the mask is empty, return. if (empty_mask) { *colorfulness = 0; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Generate a 2D histogram (hue/saturation). @@ -162,7 +162,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( } if (hue_sum == 0.0f) { *colorfulness = 0; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Compute the histogram entropy. @@ -175,7 +175,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( } *colorfulness /= std::log(2.0f); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h index aa2d2d20b..b2c6d3af7 100644 --- a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h +++ b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h @@ -30,13 +30,12 @@ class VisualScorer { explicit VisualScorer(const VisualScorerOptions& options); // Computes a score on a salientregion and returns a value [0...1]. - mediapipe::Status CalculateScore(const cv::Mat& image, - const SalientRegion& region, - float* score) const; + absl::Status CalculateScore(const cv::Mat& image, const SalientRegion& region, + float* score) const; private: - mediapipe::Status CalculateColorfulness(const cv::Mat& image, - float* colorfulness) const; + absl::Status CalculateColorfulness(const cv::Mat& image, + float* colorfulness) const; VisualScorerOptions options_; }; diff --git a/mediapipe/examples/desktop/autoflip/subgraph/BUILD b/mediapipe/examples/desktop/autoflip/subgraph/BUILD index 9af7e447b..6c3e2616c 100644 --- a/mediapipe/examples/desktop/autoflip/subgraph/BUILD +++ b/mediapipe/examples/desktop/autoflip/subgraph/BUILD @@ -18,14 +18,23 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +FACE_DETECTION_DEPS = [ + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator", +] + mediapipe_simple_subgraph( name = "autoflip_face_detection_subgraph", graph = "face_detection_subgraph.pbtxt", register_as = "AutoFlipFaceDetectionSubgraph", visibility = ["//visibility:public"], - deps = [ - "//mediapipe/graphs/face_detection:desktop_tflite_calculators", - ], + deps = FACE_DETECTION_DEPS, ) mediapipe_simple_subgraph( @@ -33,16 +42,7 @@ mediapipe_simple_subgraph( graph = "front_face_detection_subgraph.pbtxt", register_as = "AutoFlipFrontFaceDetectionSubgraph", visibility = ["//visibility:public"], - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", - ], + deps = FACE_DETECTION_DEPS, ) mediapipe_simple_subgraph( diff --git a/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt b/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt index b88ea0c75..3b2d410f5 100644 --- a/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt +++ b/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt @@ -1,5 +1,5 @@ # MediaPipe graph that performs face detection with TensorFlow Lite on CPU. Model paths setup for web use. -# TODO: parameterize input paths to support desktop use. +# TODO: parameterize input paths to support desktop use, for web only. input_stream: "VIDEO:input_video" output_stream: "DETECTIONS:output_detections" @@ -37,7 +37,7 @@ node { output_stream: "TENSORS:detection_tensors" options: { [mediapipe.TfLiteInferenceCalculatorOptions.ext] { - model_path: "mediapipe/models/face_detection_front.tflite" + model_path: "face_detection_front.tflite" } } } @@ -118,7 +118,7 @@ node { output_stream: "labeled_detections" options: { [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" + label_map_path: "face_detection_front_labelmap.txt" } } } diff --git a/mediapipe/examples/desktop/demo_run_graph_main.cc b/mediapipe/examples/desktop/demo_run_graph_main.cc index 25f4bb4f1..343460eac 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main.cc @@ -40,10 +40,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -56,16 +57,16 @@ DEFINE_string(output_video_path, "", LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (!save_video) { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); #if (CV_MAJOR_VERSION >= 3) && (CV_MINOR_VERSION >= 2) @@ -86,7 +87,14 @@ DEFINE_string(output_video_path, "", // Capture opencv camera or video frame. cv::Mat camera_frame_raw; capture >> camera_frame_raw; - if (camera_frame_raw.empty()) break; // End of video. + if (camera_frame_raw.empty()) { + if (!load_video) { + LOG(INFO) << "Ignore empty frames from camera."; + continue; + } + LOG(INFO) << "Empty frame, end of video reached."; + break; + } cv::Mat camera_frame; cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); if (!load_video) { @@ -118,7 +126,7 @@ DEFINE_string(output_video_path, "", if (save_video) { if (!writer.isOpened()) { LOG(INFO) << "Prepare video writer."; - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); RET_CHECK(writer.isOpened()); @@ -141,7 +149,7 @@ DEFINE_string(output_video_path, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc index b77d8d4a3..6942971f7 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc @@ -44,10 +44,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -66,16 +67,16 @@ DEFINE_string(output_video_path, "", LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (!save_video) { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); #if (CV_MAJOR_VERSION >= 3) && (CV_MINOR_VERSION >= 2) @@ -96,16 +97,23 @@ DEFINE_string(output_video_path, "", // Capture opencv camera or video frame. cv::Mat camera_frame_raw; capture >> camera_frame_raw; - if (camera_frame_raw.empty()) break; // End of video. + if (camera_frame_raw.empty()) { + if (!load_video) { + LOG(INFO) << "Ignore empty frames from camera."; + continue; + } + LOG(INFO) << "Empty frame, end of video reached."; + break; + } cv::Mat camera_frame; - cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); + cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGBA); if (!load_video) { cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1); } // Wrap Mat into an ImageFrame. auto input_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows, + mediapipe::ImageFormat::SRGBA, camera_frame.cols, camera_frame.rows, mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get()); camera_frame.copyTo(input_frame_mat); @@ -115,7 +123,7 @@ DEFINE_string(output_video_path, "", (double)cv::getTickCount() / (double)cv::getTickFrequency() * 1e6; MP_RETURN_IF_ERROR( gpu_helper.RunInGlContext([&input_frame, &frame_timestamp_us, &graph, - &gpu_helper]() -> ::mediapipe::Status { + &gpu_helper]() -> absl::Status { // Convert ImageFrame to GpuBuffer. auto texture = gpu_helper.CreateSourceTexture(*input_frame.get()); auto gpu_frame = texture.GetFrame(); @@ -125,7 +133,7 @@ DEFINE_string(output_video_path, "", MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( kInputStream, mediapipe::Adopt(gpu_frame.release()) .At(mediapipe::Timestamp(frame_timestamp_us)))); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); // Get the graph result packet, or stop if that fails. @@ -135,7 +143,7 @@ DEFINE_string(output_video_path, "", // Convert GpuBuffer to ImageFrame. MP_RETURN_IF_ERROR(gpu_helper.RunInGlContext( - [&packet, &output_frame, &gpu_helper]() -> ::mediapipe::Status { + [&packet, &output_frame, &gpu_helper]() -> absl::Status { auto& gpu_frame = packet.Get(); auto texture = gpu_helper.CreateSourceTexture(gpu_frame); output_frame = absl::make_unique( @@ -143,22 +151,25 @@ DEFINE_string(output_video_path, "", gpu_frame.width(), gpu_frame.height(), mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); gpu_helper.BindFramebuffer(texture); - const auto info = - mediapipe::GlTextureInfoForGpuBufferFormat(gpu_frame.format(), 0); + const auto info = mediapipe::GlTextureInfoForGpuBufferFormat( + gpu_frame.format(), 0, gpu_helper.GetGlVersion()); glReadPixels(0, 0, texture.width(), texture.height(), info.gl_format, info.gl_type, output_frame->MutablePixelData()); glFlush(); texture.Release(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); })); // Convert back to opencv for display or saving. cv::Mat output_frame_mat = mediapipe::formats::MatView(output_frame.get()); - cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); + if (output_frame_mat.channels() == 4) + cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGBA2BGR); + else + cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); if (save_video) { if (!writer.isOpened()) { LOG(INFO) << "Prepare video writer."; - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); RET_CHECK(writer.isOpened()); @@ -181,7 +192,7 @@ DEFINE_string(output_video_path, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/face_detection/BUILD b/mediapipe/examples/desktop/face_detection/BUILD index 55c9eb741..5743ae788 100644 --- a/mediapipe/examples/desktop/face_detection/BUILD +++ b/mediapipe/examples/desktop/face_detection/BUILD @@ -20,7 +20,7 @@ cc_binary( name = "face_detection_cpu", deps = [ "//mediapipe/examples/desktop:demo_run_graph_main", - "//mediapipe/graphs/face_detection:desktop_tflite_calculators", + "//mediapipe/graphs/face_detection:desktop_live_calculators", ], ) @@ -29,6 +29,6 @@ cc_binary( name = "face_detection_gpu", deps = [ "//mediapipe/examples/desktop:demo_run_graph_main_gpu", - "//mediapipe/graphs/face_detection:mobile_calculators", + "//mediapipe/graphs/face_detection:desktop_live_gpu_calculators", ], ) diff --git a/mediapipe/examples/desktop/hello_world/hello_world.cc b/mediapipe/examples/desktop/hello_world/hello_world.cc index b7dfa40c3..95c34146d 100644 --- a/mediapipe/examples/desktop/hello_world/hello_world.cc +++ b/mediapipe/examples/desktop/hello_world/hello_world.cc @@ -21,7 +21,7 @@ namespace mediapipe { -::mediapipe::Status PrintHelloWorld() { +absl::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie(R"( input_stream: "in" diff --git a/mediapipe/examples/desktop/multi_hand_tracking/BUILD b/mediapipe/examples/desktop/holistic_tracking/BUILD similarity index 63% rename from mediapipe/examples/desktop/multi_hand_tracking/BUILD rename to mediapipe/examples/desktop/holistic_tracking/BUILD index a7bd112ff..0f69c1e4f 100644 --- a/mediapipe/examples/desktop/multi_hand_tracking/BUILD +++ b/mediapipe/examples/desktop/holistic_tracking/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,26 +17,18 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe/examples:__subpackages__"]) cc_binary( - name = "multi_hand_tracking_tflite", - deps = [ - "//mediapipe/examples/desktop:simple_run_graph_main", - "//mediapipe/graphs/hand_tracking:multi_hand_desktop_tflite_calculators", - ], -) - -cc_binary( - name = "multi_hand_tracking_cpu", + name = "holistic_tracking_cpu", deps = [ "//mediapipe/examples/desktop:demo_run_graph_main", - "//mediapipe/graphs/hand_tracking:multi_hand_desktop_tflite_calculators", + "//mediapipe/graphs/holistic_tracking:holistic_tracking_cpu_graph_deps", ], ) # Linux only cc_binary( - name = "multi_hand_tracking_gpu", + name = "holistic_tracking_gpu", deps = [ "//mediapipe/examples/desktop:demo_run_graph_main_gpu", - "//mediapipe/graphs/hand_tracking:multi_hand_mobile_calculators", + "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu_deps", ], ) diff --git a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc index 4cfab621d..515ee37b0 100644 --- a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc +++ b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc @@ -47,25 +47,23 @@ DEFINE_string(output_image_path, "", namespace { -::mediapipe::StatusOr ReadFileToString( - const std::string& file_path) { +absl::StatusOr ReadFileToString(const std::string& file_path) { std::string contents; - MP_RETURN_IF_ERROR(::mediapipe::file::GetContents(file_path, &contents)); + MP_RETURN_IF_ERROR(mediapipe::file::GetContents(file_path, &contents)); return contents; } -::mediapipe::Status ProcessImage( - std::unique_ptr<::mediapipe::CalculatorGraph> graph) { +absl::Status ProcessImage(std::unique_ptr graph) { LOG(INFO) << "Load the image."; ASSIGN_OR_RETURN(const std::string raw_image, - ReadFileToString(FLAGS_input_image_path)); + ReadFileToString(absl::GetFlag(FLAGS_input_image_path))); LOG(INFO) << "Start running the calculator graph."; - ASSIGN_OR_RETURN(::mediapipe::OutputStreamPoller output_image_poller, + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller output_image_poller, graph->AddOutputStreamPoller(kOutputImageStream)); - ASSIGN_OR_RETURN(::mediapipe::OutputStreamPoller left_iris_depth_poller, + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller left_iris_depth_poller, graph->AddOutputStreamPoller(kLeftIrisDepthMmStream)); - ASSIGN_OR_RETURN(::mediapipe::OutputStreamPoller right_iris_depth_poller, + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller right_iris_depth_poller, graph->AddOutputStreamPoller(kRightIrisDepthMmStream)); MP_RETURN_IF_ERROR(graph->StartRun({})); @@ -74,22 +72,22 @@ namespace { (double)cv::getTickFrequency() * kMicrosPerSecond; MP_RETURN_IF_ERROR(graph->AddPacketToInputStream( - kInputStream, ::mediapipe::MakePacket(raw_image).At( - ::mediapipe::Timestamp(fake_timestamp_us)))); + kInputStream, mediapipe::MakePacket(raw_image).At( + mediapipe::Timestamp(fake_timestamp_us)))); // Get the graph result packets, or stop if that fails. - ::mediapipe::Packet left_iris_depth_packet; + mediapipe::Packet left_iris_depth_packet; if (!left_iris_depth_poller.Next(&left_iris_depth_packet)) { - return ::mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'left_iris_depth_mm'."); } const auto& left_iris_depth_mm = left_iris_depth_packet.Get(); const int left_iris_depth_cm = std::round(left_iris_depth_mm / 10); std::cout << "Left Iris Depth: " << left_iris_depth_cm << " cm." << std::endl; - ::mediapipe::Packet right_iris_depth_packet; + mediapipe::Packet right_iris_depth_packet; if (!right_iris_depth_poller.Next(&right_iris_depth_packet)) { - return ::mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'right_iris_depth_mm'."); } const auto& right_iris_depth_mm = right_iris_depth_packet.Get(); @@ -97,20 +95,20 @@ namespace { std::cout << "Right Iris Depth: " << right_iris_depth_cm << " cm." << std::endl; - ::mediapipe::Packet output_image_packet; + mediapipe::Packet output_image_packet; if (!output_image_poller.Next(&output_image_packet)) { - return ::mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'output_image'."); } - auto& output_frame = output_image_packet.Get<::mediapipe::ImageFrame>(); + auto& output_frame = output_image_packet.Get(); // Convert back to opencv for display or saving. - cv::Mat output_frame_mat = ::mediapipe::formats::MatView(&output_frame); + cv::Mat output_frame_mat = mediapipe::formats::MatView(&output_frame); cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); - const bool save_image = !FLAGS_output_image_path.empty(); + const bool save_image = !absl::GetFlag(FLAGS_output_image_path).empty(); if (save_image) { LOG(INFO) << "Saving image to file..."; - cv::imwrite(FLAGS_output_image_path, output_frame_mat); + cv::imwrite(absl::GetFlag(FLAGS_output_image_path), output_frame_mat); } else { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); cv::imshow(kWindowName, output_frame_mat); @@ -123,26 +121,26 @@ namespace { return graph->WaitUntilDone(); } -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; - MP_RETURN_IF_ERROR(::mediapipe::file::GetContents( + MP_RETURN_IF_ERROR(mediapipe::file::GetContents( kCalculatorGraphConfigFile, &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; - ::mediapipe::CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie<::mediapipe::CalculatorGraphConfig>( + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); LOG(INFO) << "Initialize the calculator graph."; - std::unique_ptr<::mediapipe::CalculatorGraph> graph = - absl::make_unique<::mediapipe::CalculatorGraph>(); + std::unique_ptr graph = + absl::make_unique(); MP_RETURN_IF_ERROR(graph->Initialize(config)); - const bool load_image = !FLAGS_input_image_path.empty(); + const bool load_image = !absl::GetFlag(FLAGS_input_image_path).empty(); if (load_image) { return ProcessImage(std::move(graph)); } else { - return ::mediapipe::InvalidArgumentError("Missing image file."); + return absl::InvalidArgumentError("Missing image file."); } } @@ -151,7 +149,7 @@ namespace { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc index a9a2456be..a15f599d1 100644 --- a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -38,27 +38,28 @@ DEFINE_string(output_side_packets, "", "side packets and paths to write to disk for the " "CalculatorGraph."); -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - std::map input_side_packets; + std::map input_side_packets; std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - RET_CHECK(!::mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0])); std::string input_side_packet_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( name_and_value[1], &input_side_packet_contents)); input_side_packets[name_and_value[0]] = - ::mediapipe::MakePacket(input_side_packet_contents); + mediapipe::MakePacket(input_side_packet_contents); } LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; @@ -66,26 +67,26 @@ DEFINE_string(output_side_packets, "", LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); LOG(INFO) << "Gathering output side packets."; - kv_pairs = absl::StrSplit(FLAGS_output_side_packets, ','); + kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - ::mediapipe::StatusOr<::mediapipe::Packet> output_packet = + absl::StatusOr output_packet = graph.GetOutputSidePacket(name_and_value[0]); RET_CHECK(output_packet.ok()) << "Packet " << name_and_value[0] << " was not available."; const std::string& serialized_string = - output_packet.ValueOrDie().Get(); + output_packet.value().Get(); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(name_and_value[1], serialized_string)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/object_detection_3d/BUILD b/mediapipe/examples/desktop/object_detection_3d/BUILD new file mode 100644 index 000000000..86e29a728 --- /dev/null +++ b/mediapipe/examples/desktop/object_detection_3d/BUILD @@ -0,0 +1,34 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +# bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_detection_3d:objectron_cpu +# To run 3D object detection for shoes, +# bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \ +# --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \ +# --input_side_packets="input_video_path=,box_landmark_model_path=mediapipe/models/object_detection_3d_sneakers.tflite,output_video_path=,allowed_labels=Footwear" +# To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly. +# Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair +# Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera +# Cup: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_cup.tflite,allowed_labels=Mug +cc_binary( + name = "objectron_cpu", + deps = [ + "//mediapipe/examples/desktop:simple_run_graph_main", + "//mediapipe/graphs/object_detection_3d:desktop_cpu_calculators", + ], +) diff --git a/mediapipe/examples/desktop/pose_tracking/BUILD b/mediapipe/examples/desktop/pose_tracking/BUILD new file mode 100644 index 000000000..447e2dfdc --- /dev/null +++ b/mediapipe/examples/desktop/pose_tracking/BUILD @@ -0,0 +1,34 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_binary( + name = "pose_tracking_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/pose_tracking:pose_tracking_cpu_deps", + ], +) + +# Linux only +cc_binary( + name = "pose_tracking_gpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main_gpu", + "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", + ], +) diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc index 2b76de6a5..5d33af66c 100644 --- a/mediapipe/examples/desktop/simple_run_graph_main.cc +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -58,31 +58,29 @@ DEFINE_string(output_side_packets_file, "", "The name of the local file to output all side packets specified " "with --output_side_packets. "); -::mediapipe::Status OutputStreamToLocalFile( - ::mediapipe::OutputStreamPoller& poller) { +absl::Status OutputStreamToLocalFile(mediapipe::OutputStreamPoller& poller) { std::ofstream file; - file.open(FLAGS_output_stream_file); - ::mediapipe::Packet packet; + file.open(absl::GetFlag(FLAGS_output_stream_file)); + mediapipe::Packet packet; while (poller.Next(&packet)) { std::string output_data; - if (!FLAGS_strip_timestamps) { + if (!absl::GetFlag(FLAGS_strip_timestamps)) { absl::StrAppend(&output_data, packet.Timestamp().Value(), ","); } absl::StrAppend(&output_data, packet.Get(), "\n"); file << output_data; } file.close(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status OutputSidePacketsToLocalFile( - ::mediapipe::CalculatorGraph& graph) { - if (!FLAGS_output_side_packets.empty() && - !FLAGS_output_side_packets_file.empty()) { +absl::Status OutputSidePacketsToLocalFile(mediapipe::CalculatorGraph& graph) { + if (!absl::GetFlag(FLAGS_output_side_packets).empty() && + !absl::GetFlag(FLAGS_output_side_packets_file).empty()) { std::ofstream file; - file.open(FLAGS_output_side_packets_file); + file.open(absl::GetFlag(FLAGS_output_side_packets_file)); std::vector side_packet_names = - absl::StrSplit(FLAGS_output_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& side_packet_name : side_packet_names) { ASSIGN_OR_RETURN(auto status_or_packet, graph.GetOutputSidePacket(side_packet_name)); @@ -91,47 +89,49 @@ DEFINE_string(output_side_packets_file, "", } file.close(); } else { - RET_CHECK(FLAGS_output_side_packets.empty() && - FLAGS_output_side_packets_file.empty()) + RET_CHECK(absl::GetFlag(FLAGS_output_side_packets).empty() && + absl::GetFlag(FLAGS_output_side_packets_file).empty()) << "--output_side_packets and --output_side_packets_file should be " "specified in pair."; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; - MP_RETURN_IF_ERROR(::mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + MP_RETURN_IF_ERROR(mediapipe::file::GetContents( + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; - ::mediapipe::CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie<::mediapipe::CalculatorGraphConfig>( + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - std::map input_side_packets; - if (!FLAGS_input_side_packets.empty()) { + std::map input_side_packets; + if (!absl::GetFlag(FLAGS_input_side_packets).empty()) { std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - RET_CHECK( - !::mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0])); input_side_packets[name_and_value[0]] = - ::mediapipe::MakePacket(name_and_value[1]); + mediapipe::MakePacket(name_and_value[1]); } } LOG(INFO) << "Initialize the calculator graph."; - ::mediapipe::CalculatorGraph graph; + mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); - if (!FLAGS_output_stream.empty() && !FLAGS_output_stream_file.empty()) { - ASSIGN_OR_RETURN(auto poller, - graph.AddOutputStreamPoller(FLAGS_output_stream)); + if (!absl::GetFlag(FLAGS_output_stream).empty() && + !absl::GetFlag(FLAGS_output_stream_file).empty()) { + ASSIGN_OR_RETURN(auto poller, graph.AddOutputStreamPoller( + absl::GetFlag(FLAGS_output_stream))); LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.StartRun({})); MP_RETURN_IF_ERROR(OutputStreamToLocalFile(poller)); } else { - RET_CHECK(FLAGS_output_stream.empty() && FLAGS_output_stream_file.empty()) + RET_CHECK(absl::GetFlag(FLAGS_output_stream).empty() && + absl::GetFlag(FLAGS_output_stream_file).empty()) << "--output_stream and --output_stream_file should be specified in " "pair."; LOG(INFO) << "Start running the calculator graph."; @@ -144,7 +144,7 @@ DEFINE_string(output_side_packets_file, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/youtube8m/BUILD b/mediapipe/examples/desktop/youtube8m/BUILD index af85e3113..e6347b243 100644 --- a/mediapipe/examples/desktop/youtube8m/BUILD +++ b/mediapipe/examples/desktop/youtube8m/BUILD @@ -21,7 +21,6 @@ cc_binary( "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:map_util", diff --git a/mediapipe/examples/desktop/youtube8m/README.md b/mediapipe/examples/desktop/youtube8m/README.md index 57a606fd6..94b36809d 100644 --- a/mediapipe/examples/desktop/youtube8m/README.md +++ b/mediapipe/examples/desktop/youtube8m/README.md @@ -1,7 +1,7 @@ ### Steps to run the YouTube-8M feature extraction graph 1. Checkout the repository and follow - [the installation instructions](https://github.com/google/mediapipe/blob/master/mediapipe/docs/install.md) + [the installation instructions](https://github.com/google/mediapipe/blob/master/docs/getting_started/install.md) to set up MediaPipe. ```bash diff --git a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc index 593fb187d..a303077cc 100644 --- a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc +++ b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc @@ -39,27 +39,28 @@ DEFINE_string(output_side_packets, "", "side packets and paths to write to disk for the " "CalculatorGraph."); -::mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - std::map input_side_packets; + std::map input_side_packets; std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - RET_CHECK(!::mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0])); std::string input_side_packet_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( name_and_value[1], &input_side_packet_contents)); input_side_packets[name_and_value[0]] = - ::mediapipe::MakePacket(input_side_packet_contents); + mediapipe::MakePacket(input_side_packet_contents); } mediapipe::MatrixData inc3_pca_mean_matrix_data, @@ -75,7 +76,7 @@ DEFINE_string(output_side_packets, "", mediapipe::MatrixFromMatrixDataProto(inc3_pca_mean_matrix_data, &inc3_pca_mean_matrix); input_side_packets["inception3_pca_mean_matrix"] = - ::mediapipe::MakePacket(inc3_pca_mean_matrix); + mediapipe::MakePacket(inc3_pca_mean_matrix); MP_RETURN_IF_ERROR(mediapipe::file::GetContents( "/tmp/mediapipe/inception3_projection_matrix_data.pb", &content)); @@ -83,7 +84,7 @@ DEFINE_string(output_side_packets, "", mediapipe::MatrixFromMatrixDataProto(inc3_pca_projection_matrix_data, &inc3_pca_projection_matrix); input_side_packets["inception3_pca_projection_matrix"] = - ::mediapipe::MakePacket(inc3_pca_projection_matrix); + mediapipe::MakePacket(inc3_pca_projection_matrix); MP_RETURN_IF_ERROR(mediapipe::file::GetContents( "/tmp/mediapipe/vggish_mean_matrix_data.pb", &content)); @@ -91,7 +92,7 @@ DEFINE_string(output_side_packets, "", mediapipe::MatrixFromMatrixDataProto(vggish_pca_mean_matrix_data, &vggish_pca_mean_matrix); input_side_packets["vggish_pca_mean_matrix"] = - ::mediapipe::MakePacket(vggish_pca_mean_matrix); + mediapipe::MakePacket(vggish_pca_mean_matrix); MP_RETURN_IF_ERROR(mediapipe::file::GetContents( "/tmp/mediapipe/vggish_projection_matrix_data.pb", &content)); @@ -99,7 +100,7 @@ DEFINE_string(output_side_packets, "", mediapipe::MatrixFromMatrixDataProto(vggish_pca_projection_matrix_data, &vggish_pca_projection_matrix); input_side_packets["vggish_pca_projection_matrix"] = - ::mediapipe::MakePacket(vggish_pca_projection_matrix); + mediapipe::MakePacket(vggish_pca_projection_matrix); LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; @@ -107,26 +108,26 @@ DEFINE_string(output_side_packets, "", LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); LOG(INFO) << "Gathering output side packets."; - kv_pairs = absl::StrSplit(FLAGS_output_side_packets, ','); + kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - ::mediapipe::StatusOr<::mediapipe::Packet> output_packet = + absl::StatusOr output_packet = graph.GetOutputSidePacket(name_and_value[0]); RET_CHECK(output_packet.ok()) << "Packet " << name_and_value[0] << " was not available."; const std::string& serialized_string = - output_packet.ValueOrDie().Get(); + output_packet.value().Get(); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(name_and_value[1], serialized_string)); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/ios/common/CommonViewController.h b/mediapipe/examples/ios/common/CommonViewController.h index b4650423b..d7cb1121a 100644 --- a/mediapipe/examples/ios/common/CommonViewController.h +++ b/mediapipe/examples/ios/common/CommonViewController.h @@ -18,6 +18,7 @@ #import "mediapipe/objc/MPPGraph.h" #import "mediapipe/objc/MPPLayerRenderer.h" #import "mediapipe/objc/MPPPlayerInputSource.h" +#import "mediapipe/objc/MPPTimestampConverter.h" typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) { MediaPipeDemoSourceCamera, @@ -36,6 +37,9 @@ typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) { // Provides data from a video. @property(nonatomic) MPPPlayerInputSource* videoSource; +// Helps to convert timestamp. +@property(nonatomic) MPPTimestampConverter* timestampConverter; + // The data source for the demo. @property(nonatomic) MediaPipeDemoSourceMode sourceMode; diff --git a/mediapipe/examples/ios/common/CommonViewController.mm b/mediapipe/examples/ios/common/CommonViewController.mm index aa7eb5d57..f6c47eacf 100644 --- a/mediapipe/examples/ios/common/CommonViewController.mm +++ b/mediapipe/examples/ios/common/CommonViewController.mm @@ -77,6 +77,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; [self.liveView.layer addSublayer:self.renderer.layer]; self.renderer.frameScaleMode = MPPFrameScaleModeFillAndCrop; + self.timestampConverter = [[MPPTimestampConverter alloc] init]; + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); self.videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); @@ -135,10 +137,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; [self.cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { if (granted) { - [self startGraphAndCamera]; dispatch_async(dispatch_get_main_queue(), ^{ self.noCameraLabel.hidden = YES; }); + [self startGraphAndCamera]; } }]; @@ -153,6 +155,9 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; if (![self.mediapipeGraph startWithError:&error]) { NSLog(@"Failed to start graph: %@", error); } + else if (![self.mediapipeGraph waitUntilIdleWithError:&error]) { + NSLog(@"Failed to complete graph initial run: %@", error); + } // Start fetching frames from the camera. dispatch_async(self.videoQueue, ^{ @@ -173,7 +178,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; [self.mediapipeGraph sendPixelBuffer:imageBuffer intoStream:self.graphInputStream - packetType:MPPPacketTypePixelBuffer]; + packetType:MPPPacketTypePixelBuffer + timestamp:[self.timestampConverter timestampForMediaTime:timestamp]]; } #pragma mark - MPPGraphDelegate methods diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index a4ae2cfca..43bff9b1e 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -54,9 +54,8 @@ ios_application( objc_library( name = "FaceDetectionCpuAppLibrary", data = [ - "//mediapipe/graphs/face_detection:mobile_cpu_binary_graph", - "//mediapipe/models:face_detection_front.tflite", - "//mediapipe/models:face_detection_front_labelmap.txt", + "//mediapipe/graphs/face_detection:face_detection_mobile_cpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/facedetectioncpu/Info.plist b/mediapipe/examples/ios/facedetectioncpu/Info.plist index d1738a5c7..34e1a7eee 100644 --- a/mediapipe/examples/ios/facedetectioncpu/Info.plist +++ b/mediapipe/examples/ios/facedetectioncpu/Info.plist @@ -9,6 +9,6 @@ GraphInputStream input_video GraphName - mobile_cpu + face_detection_mobile_cpu diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index 507ac45d8..51856a7f7 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -54,9 +54,8 @@ ios_application( objc_library( name = "FaceDetectionGpuAppLibrary", data = [ - "//mediapipe/graphs/face_detection:mobile_gpu_binary_graph", - "//mediapipe/models:face_detection_front.tflite", - "//mediapipe/models:face_detection_front_labelmap.txt", + "//mediapipe/graphs/face_detection:face_detection_mobile_gpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/facedetectiongpu/Info.plist b/mediapipe/examples/ios/facedetectiongpu/Info.plist index 6b4790734..45feefb45 100644 --- a/mediapipe/examples/ios/facedetectiongpu/Info.plist +++ b/mediapipe/examples/ios/facedetectiongpu/Info.plist @@ -9,6 +9,6 @@ GraphInputStream input_video GraphName - mobile_gpu + face_detection_mobile_gpu diff --git a/mediapipe/examples/ios/faceeffect/AppDelegate.h b/mediapipe/examples/ios/faceeffect/AppDelegate.h new file mode 100644 index 000000000..f7c48321c --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/faceeffect/AppDelegate.m b/mediapipe/examples/ios/faceeffect/AppDelegate.m new file mode 100644 index 000000000..42f9acd54 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD new file mode 100644 index 000000000..9e074ef2f --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -0,0 +1,119 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) + +MIN_IOS_VERSION = "10.0" + +alias( + name = "faceeffect", + actual = "FaceEffectApp", +) + +ios_application( + name = "FaceEffectApp", + app_icons = ["//mediapipe/examples/ios/common:AppIcon"], + bundle_id = BUNDLE_ID_PREFIX + ".FaceEffectGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = example_provisioning(), + deps = [ + ":FaceEffectAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "FaceEffectViewController", + srcs = [ + "FaceEffectViewController.mm", + ], + hdrs = [ + "FaceEffectViewController.h", + ], + copts = ["-std=c++17"], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/face_effect:face_effect_gpu.binarypb", + "//mediapipe/graphs/face_effect/data:axis.binarypb", + "//mediapipe/graphs/face_effect/data:axis.pngblob", + "//mediapipe/graphs/face_effect/data:facepaint.pngblob", + "//mediapipe/graphs/face_effect/data:glasses.binarypb", + "//mediapipe/graphs/face_effect/data:glasses.pngblob", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_detection.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_landmarks.binarypb", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/framework/formats:matrix_data_cc_proto", + "//mediapipe/graphs/face_effect:face_effect_gpu_deps", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", + ], + }), +) + +objc_library( + name = "FaceEffectAppLibrary", + srcs = [ + "AppDelegate.m", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/face_effect:face_effect_gpu.binarypb", + "//mediapipe/graphs/face_effect/data:facepaint.pngblob", + "//mediapipe/graphs/face_effect/data:glasses.binarypb", + "//mediapipe/graphs/face_effect/data:glasses.pngblob", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata.binarypb", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + ], + deps = [ + ":FaceEffectViewController", + ], +) diff --git a/mediapipe/examples/ios/faceeffect/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/faceeffect/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/faceeffect/Base.lproj/Main.storyboard b/mediapipe/examples/ios/faceeffect/Base.lproj/Main.storyboard new file mode 100644 index 000000000..d7a79730e --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/Base.lproj/Main.storyboard @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/faceeffect/FaceEffectViewController.h b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.h new file mode 100644 index 000000000..5d863cbf2 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.h @@ -0,0 +1,19 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface FaceEffectViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm new file mode 100644 index 000000000..56a895c69 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm @@ -0,0 +1,294 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "FaceEffectViewController.h" + +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +#include +#include +#include + +#include "mediapipe/framework/formats/matrix_data.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/modules/face_geometry/protos/face_geometry.pb.h" + +static NSString* const kGraphName = @"face_effect_gpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kMultiFaceGeometryStream = "multi_face_geometry"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; +static const char* kSelectedEffectIdInputStream = "selected_effect_id"; +static const char* kUseFaceDetectionInputSourceInputSidePacket = "use_face_detection_input_source"; + +static const BOOL kUseFaceDetectionInputSource = NO; +static const int kMatrixTranslationZIndex = 14; + +static const int kSelectedEffectIdAxis = 0; +static const int kSelectedEffectIdFacepaint = 1; +static const int kSelectedEffectIdGlasses = 2; + +@interface FaceEffectViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* graph; + +@end + +@implementation FaceEffectViewController { + /// Handle tap gestures. + UITapGestureRecognizer* _tapGestureRecognizer; + int _selectedEffectId; + + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Inform the user about how to switch between effects. + UILabel* _effectSwitchingHintLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.graph.delegate = nil; + [self.graph cancel]; + // Ignore errors since we're cleaning up. + [self.graph closeAllInputStreamsWithError:nil]; + [self.graph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Pass the kUseFaceDetectionInputSource flag value as an input side packet into the graph. + std::map side_packets; + side_packets[kUseFaceDetectionInputSourceInputSidePacket] = + mediapipe::MakePacket(kUseFaceDetectionInputSource); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addSidePackets:side_packets]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MPPPacketTypePixelBuffer]; + [newGraph addFrameOutputStream:kMultiFaceGeometryStream outputPacketType:MPPPacketTypeRaw]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _effectSwitchingHintLabel.hidden = YES; + _tapGestureRecognizer = [[UITapGestureRecognizer alloc] initWithTarget:self + action:@selector(handleTap)]; + [self.view addGestureRecognizer:_tapGestureRecognizer]; + + // By default, render the axis effect for the face detection input source and the glasses effect + // for the face landmark input source. + if (kUseFaceDetectionInputSource) { + _selectedEffectId = kSelectedEffectIdAxis; + } else { + _selectedEffectId = kSelectedEffectIdGlasses; + } + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer insertSublayer:_renderer.layer atIndex:0]; + _renderer.frameScaleMode = MPPFrameScaleModeFillAndCrop; + _renderer.mirrored = NO; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionFront; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + _cameraSource.videoMirrored = YES; + + self.graph = [[self class] loadGraphFromResource:kGraphName]; + self.graph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.graph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + _noCameraLabel.hidden = YES; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.graph. + NSError* error; + if (![self.graph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - UITapGestureRecognizer methods + +// We use the tap gesture recognizer to switch between face effects. This allows users to try +// multiple pre-bundled face effects without a need to recompile the app. +- (void)handleTap { + dispatch_async(_videoQueue, ^{ + // Avoid switching the Axis effect for the face detection input source. + if (kUseFaceDetectionInputSource) { + return; + } + + // Looped effect order: glasses -> facepaint -> axis -> glasses -> ... + switch (_selectedEffectId) { + case kSelectedEffectIdAxis: { + _selectedEffectId = kSelectedEffectIdGlasses; + break; + } + + case kSelectedEffectIdFacepaint: { + _selectedEffectId = kSelectedEffectIdAxis; + break; + } + + case kSelectedEffectIdGlasses: { + _selectedEffectId = kSelectedEffectIdFacepaint; + break; + } + } + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + _effectSwitchingHintLabel.hidden = kUseFaceDetectionInputSource; + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +// Receives a raw packet from the MediaPipe graph. Invoked on a MediaPipe worker thread. +// +// This callback demonstrates how the output face geometry packet can be obtained and used in an +// iOS app. As an example, the Z-translation component of the face pose transform matrix is logged +// for each face being equal to the approximate distance away from the camera in centimeters. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPacket:(const ::mediapipe::Packet&)packet + fromStream:(const std::string&)streamName { + if (streamName == kMultiFaceGeometryStream) { + if (packet.IsEmpty()) { + NSLog(@"[TS:%lld] No face geometry", packet.Timestamp().Value()); + return; + } + + const auto& multiFaceGeometry = + packet.Get>(); + NSLog(@"[TS:%lld] Number of face instances with geometry: %lu ", packet.Timestamp().Value(), + multiFaceGeometry.size()); + for (int faceIndex = 0; faceIndex < multiFaceGeometry.size(); ++faceIndex) { + const auto& faceGeometry = multiFaceGeometry[faceIndex]; + NSLog(@"\tApprox. distance away from camera for face[%d]: %.6f cm", faceIndex, + -faceGeometry.pose_transform_matrix().packed_data(kMatrixTranslationZIndex)); + } + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + + mediapipe::Timestamp graphTimestamp(static_cast( + mediapipe::Timestamp::kTimestampUnitsPerSecond * CMTimeGetSeconds(timestamp))); + + mediapipe::Packet selectedEffectIdPacket = + mediapipe::MakePacket(_selectedEffectId).At(graphTimestamp); + + [self.graph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MPPPacketTypePixelBuffer + timestamp:graphTimestamp]; + + // Alongside the input camera frame, we also send the `selected_effect_id` int packet to indicate + // which effect should be rendered on this frame. + [self.graph movePacket:std::move(selectedEffectIdPacket) + intoStream:kSelectedEffectIdInputStream + error:nil]; +} + +@end diff --git a/mediapipe/examples/ios/faceeffect/Info.plist b/mediapipe/examples/ios/faceeffect/Info.plist new file mode 100644 index 000000000..30db14c62 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/Info.plist @@ -0,0 +1,42 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/mediapipe/examples/ios/faceeffect/main.m b/mediapipe/examples/ios/faceeffect/main.m new file mode 100644 index 000000000..8848aeef4 --- /dev/null +++ b/mediapipe/examples/ios/faceeffect/main.m @@ -0,0 +1,22 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char* argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 11bd649bf..942a19659 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -59,8 +59,9 @@ objc_library( hdrs = [ "FaceMeshGpuViewController.h", ], + copts = ["-std=c++17"], data = [ - "//mediapipe/graphs/face_mesh:face_mesh_mobile_gpu_binary_graph", + "//mediapipe/graphs/face_mesh:face_mesh_mobile_gpu.binarypb", "//mediapipe/modules/face_detection:face_detection_front.tflite", "//mediapipe/modules/face_landmark:face_landmark.tflite", ], diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index e1fbb8bd6..1b5ed9820 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -55,8 +55,7 @@ objc_library( name = "HandDetectionGpuAppLibrary", data = [ "//mediapipe/graphs/hand_tracking:hand_detection_mobile_gpu_binary_graph", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/modules/palm_detection:palm_detection.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index b3ac999b6..0121150e1 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -59,12 +59,12 @@ objc_library( hdrs = [ "HandTrackingViewController.h", ], + copts = ["-std=c++17"], data = [ - "//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu_binary_graph", - "//mediapipe/models:hand_landmark.tflite", - "//mediapipe/models:handedness.txt", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb", + "//mediapipe/modules/hand_landmark:hand_landmark.tflite", + "//mediapipe/modules/hand_landmark:handedness.txt", + "//mediapipe/modules/palm_detection:palm_detection.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/examples/ios/handtrackinggpu/HandTrackingViewController.mm b/mediapipe/examples/ios/handtrackinggpu/HandTrackingViewController.mm index 491d65459..87e562d01 100644 --- a/mediapipe/examples/ios/handtrackinggpu/HandTrackingViewController.mm +++ b/mediapipe/examples/ios/handtrackinggpu/HandTrackingViewController.mm @@ -17,6 +17,10 @@ #include "mediapipe/framework/formats/landmark.pb.h" static const char* kLandmarksOutputStream = "hand_landmarks"; +static const char* kNumHandsInputSidePacket = "num_hands"; + +// Max number of hands to detect/process. +static const int kNumHands = 2; @implementation HandTrackingViewController @@ -25,6 +29,8 @@ static const char* kLandmarksOutputStream = "hand_landmarks"; - (void)viewDidLoad { [super viewDidLoad]; + [self.mediapipeGraph setSidePacket:(mediapipe::MakePacket(kNumHands)) + named:kNumHandsInputSidePacket]; [self.mediapipeGraph addFrameOutputStream:kLandmarksOutputStream outputPacketType:MPPPacketTypeRaw]; } @@ -40,12 +46,16 @@ static const char* kLandmarksOutputStream = "hand_landmarks"; NSLog(@"[TS:%lld] No hand landmarks", packet.Timestamp().Value()); return; } - const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>(); - NSLog(@"[TS:%lld] Number of landmarks on hand: %d", packet.Timestamp().Value(), - landmarks.landmark_size()); - for (int i = 0; i < landmarks.landmark_size(); ++i) { - NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(), - landmarks.landmark(i).y(), landmarks.landmark(i).z()); + const auto& multiHandLandmarks = packet.Get>(); + NSLog(@"[TS:%lld] Number of hand instances with landmarks: %lu", packet.Timestamp().Value(), + multiHandLandmarks.size()); + for (int handIndex = 0; handIndex < multiHandLandmarks.size(); ++handIndex) { + const auto& landmarks = multiHandLandmarks[handIndex]; + NSLog(@"\tNumber of landmarks for hand[%d]: %d", handIndex, landmarks.landmark_size()); + for (int i = 0; i < landmarks.landmark_size(); ++i) { + NSLog(@"\t\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(), + landmarks.landmark(i).y(), landmarks.landmark(i).z()); + } } } } diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD new file mode 100644 index 000000000..b8d6c00ab --- /dev/null +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -0,0 +1,76 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) + +MIN_IOS_VERSION = "10.0" + +alias( + name = "holistictrackinggpu", + actual = "HolisticTrackingGpuApp", +) + +ios_application( + name = "HolisticTrackingGpuApp", + app_icons = ["//mediapipe/examples/ios/common:AppIcon"], + bundle_id = BUNDLE_ID_PREFIX + ".HolisticTrackingGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = [ + "//mediapipe/examples/ios/common:Info.plist", + "Info.plist", + ], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = example_provisioning(), + deps = [ + ":HolisticTrackingGpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "HolisticTrackingGpuAppLibrary", + data = [ + "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + "//mediapipe/modules/hand_landmark:hand_landmark.tflite", + "//mediapipe/modules/hand_landmark:handedness.txt", + "//mediapipe/modules/holistic_landmark:hand_recrop.tflite", + "//mediapipe/modules/pose_detection:pose_detection.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", + ], + deps = [ + "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu_deps", + ], + }), +) diff --git a/mediapipe/examples/ios/holistictrackinggpu/Info.plist b/mediapipe/examples/ios/holistictrackinggpu/Info.plist new file mode 100644 index 000000000..ae92eb50f --- /dev/null +++ b/mediapipe/examples/ios/holistictrackinggpu/Info.plist @@ -0,0 +1,14 @@ + + + + + CameraPosition + back + GraphOutputStream + output_video + GraphInputStream + input_video + GraphName + holistic_tracking_gpu + + diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 3cf8d14f7..b58ecc104 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -59,6 +59,7 @@ objc_library( hdrs = [ "IrisTrackingViewController.h", ], + copts = ["-std=c++17"], data = [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu.binarypb", "//mediapipe/modules/face_detection:face_detection_front.tflite", diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD new file mode 100644 index 000000000..37e0b85e9 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD @@ -0,0 +1,70 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) + +MIN_IOS_VERSION = "10.0" + +alias( + name = "objectdetectiontrackinggpu", + actual = "ObjectDetectionTrackingGpuApp", +) + +ios_application( + name = "ObjectDetectionTrackingGpuApp", + app_icons = ["//mediapipe/examples/ios/common:AppIcon"], + bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionTrackingGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = [ + "//mediapipe/examples/ios/common:Info.plist", + "Info.plist", + ], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = example_provisioning(), + deps = [ + ":ObjectDetectionTrackingGpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "ObjectDetectionTrackingGpuAppLibrary", + data = [ + "//mediapipe/graphs/tracking:mobile_gpu_binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + deps = [ + "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/tracking:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist b/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist new file mode 100644 index 000000000..7e792c9b4 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist @@ -0,0 +1,14 @@ + + + + + CameraPosition + back + GraphName + mobile_gpu + GraphOutputStream + output_video + GraphInputStream + input_video + + diff --git a/mediapipe/examples/ios/multihandtrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD similarity index 66% rename from mediapipe/examples/ios/multihandtrackinggpu/BUILD rename to mediapipe/examples/ios/posetrackinggpu/BUILD index 5616f12b6..c78c6a674 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,14 +27,14 @@ licenses(["notice"]) MIN_IOS_VERSION = "10.0" alias( - name = "multihandtrackinggpu", - actual = "MultiHandTrackingGpuApp", + name = "posetrackinggpu", + actual = "PoseTrackingGpuApp", ) ios_application( - name = "MultiHandTrackingGpuApp", + name = "PoseTrackingGpuApp", app_icons = ["//mediapipe/examples/ios/common:AppIcon"], - bundle_id = BUNDLE_ID_PREFIX + ".MultiHandTrackingGpu", + bundle_id = BUNDLE_ID_PREFIX + ".PoseTrackingGpu", families = [ "iphone", "ipad", @@ -46,25 +46,24 @@ ios_application( minimum_os_version = MIN_IOS_VERSION, provisioning_profile = example_provisioning(), deps = [ - ":MultiHandTrackingGpuAppLibrary", + ":PoseTrackingGpuAppLibrary", "@ios_opencv//:OpencvFramework", ], ) objc_library( - name = "MultiHandTrackingGpuAppLibrary", + name = "PoseTrackingGpuAppLibrary", srcs = [ - "MultiHandTrackingViewController.mm", + "PoseTrackingViewController.mm", ], hdrs = [ - "MultiHandTrackingViewController.h", + "PoseTrackingViewController.h", ], + copts = ["-std=c++17"], data = [ - "//mediapipe/graphs/hand_tracking:multi_hand_tracking_mobile_gpu_binary_graph", - "//mediapipe/models:hand_landmark.tflite", - "//mediapipe/models:handedness.txt", - "//mediapipe/models:palm_detection.tflite", - "//mediapipe/models:palm_detection_labelmap.txt", + "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", + "//mediapipe/modules/pose_detection:pose_detection.tflite", + "//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", @@ -72,7 +71,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/graphs/hand_tracking:multi_hand_mobile_calculators", + "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", "//mediapipe/framework/formats:landmark_cc_proto", ], }), diff --git a/mediapipe/examples/ios/multihandtrackinggpu/Info.plist b/mediapipe/examples/ios/posetrackinggpu/Info.plist similarity index 75% rename from mediapipe/examples/ios/multihandtrackinggpu/Info.plist rename to mediapipe/examples/ios/posetrackinggpu/Info.plist index 46e3fbd3d..71e2e429e 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/Info.plist +++ b/mediapipe/examples/ios/posetrackinggpu/Info.plist @@ -3,14 +3,14 @@ CameraPosition - front + back MainViewController - MultiHandTrackingViewController + PoseTrackingViewController GraphOutputStream output_video GraphInputStream input_video GraphName - multi_hand_tracking_mobile_gpu + pose_tracking_gpu diff --git a/mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.h b/mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.h similarity index 85% rename from mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.h rename to mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.h index 17ea6feeb..f5dc4674a 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.h +++ b/mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,6 @@ #import "mediapipe/examples/ios/common/CommonViewController.h" -@interface MultiHandTrackingViewController : CommonViewController +@interface PoseTrackingViewController : CommonViewController @end diff --git a/mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.mm b/mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.mm similarity index 56% rename from mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.mm rename to mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.mm index 6c1deb7da..0f082031c 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/MultiHandTrackingViewController.mm +++ b/mediapipe/examples/ios/posetrackinggpu/PoseTrackingViewController.mm @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "MultiHandTrackingViewController.h" +#import "PoseTrackingViewController.h" #include "mediapipe/framework/formats/landmark.pb.h" -static const char* kLandmarksOutputStream = "multi_hand_landmarks"; +static const char* kLandmarksOutputStream = "pose_landmarks"; -@implementation MultiHandTrackingViewController +@implementation PoseTrackingViewController #pragma mark - UIViewController methods @@ -37,19 +37,15 @@ static const char* kLandmarksOutputStream = "multi_hand_landmarks"; fromStream:(const std::string&)streamName { if (streamName == kLandmarksOutputStream) { if (packet.IsEmpty()) { - NSLog(@"[TS:%lld] No hand landmarks", packet.Timestamp().Value()); + NSLog(@"[TS:%lld] No pose landmarks", packet.Timestamp().Value()); return; } - const auto& multi_hand_landmarks = packet.Get>(); - NSLog(@"[TS:%lld] Number of hand instances with landmarks: %lu", packet.Timestamp().Value(), - multi_hand_landmarks.size()); - for (int hand_index = 0; hand_index < multi_hand_landmarks.size(); ++hand_index) { - const auto& landmarks = multi_hand_landmarks[hand_index]; - NSLog(@"\tNumber of landmarks for hand[%d]: %d", hand_index, landmarks.landmark_size()); - for (int i = 0; i < landmarks.landmark_size(); ++i) { - NSLog(@"\t\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(), - landmarks.landmark(i).y(), landmarks.landmark(i).z()); - } + const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>(); + NSLog(@"[TS:%lld] Number of pose landmarks: %d", packet.Timestamp().Value(), + landmarks.landmark_size()); + for (int i = 0; i < landmarks.landmark_size(); ++i) { + NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(), + landmarks.landmark(i).y(), landmarks.landmark(i).z()); } } } diff --git a/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD b/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD index 0a2402857..3455fbbf8 100644 --- a/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD +++ b/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD @@ -59,6 +59,7 @@ objc_library( hdrs = [ "UpperBodyPoseTrackingViewController.h", ], + copts = ["-std=c++17"], data = [ "//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb", "//mediapipe/modules/pose_detection:pose_detection.tflite", diff --git a/mediapipe/examples/python/upper_body_pose_tracker.py b/mediapipe/examples/python/upper_body_pose_tracker.py deleted file mode 100644 index edb1dfe73..000000000 --- a/mediapipe/examples/python/upper_body_pose_tracker.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2020 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""MediaPipe upper body pose tracker. - -MediaPipe upper body pose tracker takes an RGB image as the input and returns -a pose landmark list and an annotated RGB image represented as a numpy ndarray. - -Usage examples: - pose_tracker = UpperBodyPoseTracker() - - pose_landmarks, _ = pose_tracker.run( - input_file='/tmp/input.png', - output_file='/tmp/output.png') - - input_image = cv2.imread('/tmp/input.png')[:, :, ::-1] - pose_landmarks, annotated_image = pose_tracker.run(input_image) - - pose_tracker.run_live() - - pose_tracker.close() -""" - -import os -import time -from typing import Tuple, Union - -import cv2 -import mediapipe.python as mp -import numpy as np -# resources dependency -from mediapipe.framework.formats import landmark_pb2 - -# Input and output stream names. -INPUT_VIDEO = 'input_video' -OUTPUT_VIDEO = 'output_video' -POSE_LANDMARKS = 'pose_landmarks' - - -class UpperBodyPoseTracker: - """MediaPipe upper body pose tracker.""" - - def __init__(self): - """The init method of MediaPipe upper body pose tracker. - - The method reads the upper body pose tracking cpu binary graph and - initializes a CalculatorGraph from it. The output packets of pose_landmarks - and output_video output streams will be observed by callbacks. The graph - will be started at the end of this method, waiting for input packets. - """ - # MediaPipe package root path - root_path = os.sep.join( os.path.abspath(__file__).split(os.sep)[:-4]) - mp.resource_util.set_resource_dir(root_path) - - self._graph = mp.CalculatorGraph( - binary_graph_path=os.path.join( - root_path, - 'mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.binarypb' - )) - self._outputs = {} - for stream_name in [POSE_LANDMARKS, OUTPUT_VIDEO]: - self._graph.observe_output_stream(stream_name, self._assign_packet) - self._graph.start_run() - - def run( - self, - input_frame: np.ndarray = None, - *, - input_file: str = None, - output_file: str = None - ) -> Tuple[Union[None, landmark_pb2.NormalizedLandmarkList], np.ndarray]: - """The run method of MediaPipe upper body pose tracker. - - MediaPipe upper body pose tracker takes either the path to an image file or - an RGB image represented as a numpy ndarray and it returns the pose - landmarks list and the annotated RGB image represented as a numpy ndarray. - - Args: - input_frame: An RGB image represented as a numpy ndarray. - input_file: The path to an image file. - output_file: The file path that the annotated image will be saved into. - - Returns: - pose_landmarks: The pose landmarks list. - annotated_image: The image with pose landmarks annotations. - - Raises: - RuntimeError: If the input frame doesn't contain 3 channels (RGB format) - or the input arg is not correctly provided. - - Examples - pose_tracker = UpperBodyPoseTracker() - pose_landmarks, _ = pose_tracker.run( - input_file='/tmp/input.png', - output_file='/tmp/output.png') - - # Read an image and convert the BGR image to RGB. - input_image = cv2.cvtColor(cv2.imread('/tmp/input.png'), COLOR_BGR2RGB) - pose_landmarks, annotated_image = pose_tracker.run(input_image) - pose_tracker.close() - """ - if input_file is None and input_frame is None: - raise RuntimeError( - 'Must provide either a path to an image file or an RGB image represented as a numpy.ndarray.' - ) - - if input_file: - if input_frame is not None: - raise RuntimeError( - 'Must only provide either \'input_file\' or \'input_frame\'.') - else: - input_frame = cv2.imread(input_file)[:, :, ::-1] - - pose_landmarks, annotated_image = self._run_graph(input_frame) - if output_file: - cv2.imwrite(output_file, annotated_image[:, :, ::-1]) - return pose_landmarks, annotated_image - - def run_live(self) -> None: - """Run MediaPipe upper body pose tracker with live camera input. - - The method will be self-terminated after 30 seconds. If you need to - terminate it earlier, press the Esc key to stop the run manually. Note that - you need to select the output image window rather than the terminal window - first and then press the key. - - Examples: - pose_tracker = UpperBodyPoseTracker() - pose_tracker.run_live() - pose_tracker.close() - """ - cap = cv2.VideoCapture(0) - start_time = time.time() - print( - 'Press Esc within the output image window to stop the run, or let it ' - 'self terminate after 30 seconds.') - while cap.isOpened() and time.time() - start_time < 30: - success, input_frame = cap.read() - if not success: - break - input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) - input_frame.flags.writeable = False - _, output_frame = self._run_graph(input_frame) - cv2.imshow('MediaPipe upper body pose tracker', - cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)) - if cv2.waitKey(5) & 0xFF == 27: - break - cap.release() - cv2.destroyAllWindows() - - def close(self) -> None: - self._graph.close() - self._graph = None - self._outputs = None - - def _run_graph( - self, - input_frame: np.ndarray = None, - ) -> Tuple[Union[None, landmark_pb2.NormalizedLandmarkList], np.ndarray]: - """The internal run graph method. - - Args: - input_frame: An RGB image represented as a numpy ndarray. - - Returns: - pose_landmarks: The pose landmarks list. - annotated_image: The image with pose landmarks annotations. - - Raises: - RuntimeError: If the input frame doesn't contain 3 channels representing - RGB. - """ - - if input_frame.shape[2] != 3: - raise RuntimeError('input frame must have 3 channels.') - - self._outputs.clear() - start_time = time.time() - self._graph.add_packet_to_input_stream( - stream=INPUT_VIDEO, - packet=mp.packet_creator.create_image_frame( - image_format=mp.ImageFormat.SRGB, data=input_frame), - timestamp=mp.Timestamp.from_seconds(start_time)) - self._graph.wait_until_idle() - - pose_landmarks = None - if POSE_LANDMARKS in self._outputs: - pose_landmarks = mp.packet_getter.get_proto(self._outputs[POSE_LANDMARKS]) - annotated_image = mp.packet_getter.get_image_frame( - self._outputs[OUTPUT_VIDEO]).numpy_view() - print('UpperBodyPoseTracker.Run() took', - time.time() - start_time, 'seconds') - return pose_landmarks, annotated_image - - def _assign_packet(self, stream_name: str, packet: mp.Packet) -> None: - self._outputs[stream_name] = packet diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index a61ee12df..d2ed6cf1e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -269,12 +269,6 @@ cc_library( "calculator_graph.h", "scheduler.h", ], - defines = select({ - "//conditions:default": [], - "//mediapipe/gpu:disable_gpu": [ - "MEDIAPIPE_DISABLE_GPU", - ], - }), visibility = [ ":mediapipe_internal", ], @@ -351,7 +345,6 @@ cc_library( ":calculator_base", ":calculator_context", ":calculator_context_manager", - ":calculator_registry_util", ":calculator_state", ":counter_factory", ":input_side_packet_handler", @@ -405,27 +398,6 @@ cc_library( ], ) -cc_library( - name = "calculator_registry_util", - srcs = ["calculator_registry_util.cc"], - hdrs = ["calculator_registry_util.h"], - visibility = [ - ":mediapipe_internal", - ], - deps = [ - ":calculator_base", - ":calculator_context", - ":calculator_state", - ":collection", - ":collection_item_id", - ":packet_set", - ":timestamp", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:tag_map", - ], -) - cc_library( name = "calculator_runner", testonly = 1, @@ -482,6 +454,7 @@ cc_library( ":type_map", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -939,7 +912,7 @@ cc_library( "//conditions:default": [], }) + select({ "//conditions:default": [], - "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU"], + "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU=1"], }) + select({ "//conditions:default": [], "//mediapipe/framework:disable_rtti_and_exceptions": [ @@ -947,8 +920,10 @@ cc_library( ], }), visibility = [ + "//mediapipe/calculators:__subpackages__", "//mediapipe/framework:__subpackages__", "//mediapipe/framework/port:__pkg__", + "//mediapipe/gpu:__pkg__", "//mediapipe/util:__subpackages__", ], ) @@ -1106,6 +1081,19 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "basic_types_registration", + srcs = ["basic_types_registration.cc"], + visibility = ["//visibility:public"], + deps = [ + ":type_map", + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_library( name = "validated_graph_config", srcs = ["validated_graph_config.cc"], @@ -1114,7 +1102,6 @@ cc_library( deps = [ ":calculator_base", ":calculator_contract", - ":calculator_registry_util", ":legacy_calculator_support", ":packet", ":packet_generator", diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD new file mode 100644 index 000000000..7c9a45e36 --- /dev/null +++ b/mediapipe/framework/api2/BUILD @@ -0,0 +1,227 @@ +package( + default_visibility = [":preview_users"], + features = ["-use_header_modules"], +) + +package_group( + name = "preview_users", + packages = [ + "//mediapipe/...", + ], +) + +licenses(["notice"]) + +cc_library( + name = "const_str", + hdrs = ["const_str.h"], +) + +cc_library( + name = "builder", + hdrs = ["builder.h"], + deps = [ + ":const_str", + ":contract", + ":node", + ":packet", + ":port", + "//mediapipe/framework:calculator_base", + "//mediapipe/framework:calculator_contract", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_test( + name = "builder_test", + srcs = ["builder_test.cc"], + deps = [ + ":builder", + ":node", + ":packet", + ":port", + ":tag", + ":test_contracts", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "contract", + hdrs = ["contract.h"], + deps = [ + ":const_str", + ":packet", + ":port", + ":tag", + ":tuple", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:output_side_packet", + "//mediapipe/framework/port:logging", + ], +) + +cc_test( + name = "contract_test", + srcs = ["contract_test.cc"], + deps = [ + ":contract", + ":port", + ":tag", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +cc_library( + name = "node", + srcs = ["node.cc"], + hdrs = ["node.h"], + deps = [ + ":const_str", + ":contract", + ":packet", + ":port", + "//mediapipe/framework:calculator_base", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/deps:no_destructor", + ], +) + +cc_library( + name = "test_contracts", + testonly = 1, + hdrs = ["test_contracts.h"], + deps = [ + ":node", + ], +) + +cc_test( + name = "node_test", + srcs = ["node_test.cc"], + deps = [ + ":node", + ":packet", + ":port", + ":test_contracts", + ":tuple", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "packet", + srcs = ["packet.cc"], + hdrs = ["packet.h"], + deps = [ + ":tuple", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/meta:type_traits", + ], +) + +cc_test( + name = "packet_test", + size = "small", + srcs = [ + "packet_test.cc", + ], + deps = [ + ":packet", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "port", + hdrs = ["port.h"], + deps = [ + ":const_str", + ":packet", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:output_side_packet", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "port_test", + size = "small", + srcs = [ + "port_test.cc", + ], + deps = [ + ":port", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "subgraph_test", + srcs = ["subgraph_test.cc"], + deps = [ + ":builder", + ":node", + ":packet", + ":port", + ":test_contracts", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:subgraph_expansion", + ], +) + +cc_library( + name = "tag", + hdrs = ["tag.h"], + deps = [":const_str"], +) + +cc_test( + name = "tag_test", + size = "small", + srcs = [ + "tag_test.cc", + ], + deps = [ + ":tag", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_library( + name = "tuple", + hdrs = ["tuple.h"], + deps = ["@com_google_absl//absl/meta:type_traits"], +) + +cc_test( + name = "tuple_test", + size = "small", + srcs = [ + "tuple_test.cc", + ], + deps = [ + ":tuple", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/framework/api2/README.md b/mediapipe/framework/api2/README.md new file mode 100644 index 000000000..32a699bd2 --- /dev/null +++ b/mediapipe/framework/api2/README.md @@ -0,0 +1,111 @@ +# Experimental new APIs + +This directory defines new APIs for MediaPipe: + +- Node API, an update to the Calculator API for defining MediaPipe components. +- Builder API, for assembling CalculatorGraphConfigs with C++, as an alternative + to using the proto API directly. + +The code is working, and the new APIs interoperate fully with the existing +framework code. They are considered a work in progress, but are being released +now so we can begin adopting them in our calculators. + +Developers are welcome to try out these APIs as early adopters, but should +expect breaking changes. The placement of this code under the `mediapipe::api2` +namespace is not final. + +## Node API + +This API can be used to define calculators. It is designed to be more type-safe +and less verbose than the original API. + +Input/output ports (streams and side packets) can now be declared as typed +constants, instead of using plain strings for access. + +For example, instead of + +``` +constexpr char kSelectTag[] = "SELECT"; +if (cc->Inputs().HasTag(kSelectTag)) { + cc->Inputs().Tag(kSelectTag).Set(); +} +``` + +you can write + +``` +static constexpr Input::Optional kSelect{"SELECT"}; +``` + +Instead of setting up the contract procedurally in `GetContract`, add ports to +the contract declaratively, as follows: + +``` +MEDIAPIPE_NODE_CONTRACT(kInput, kOutput); +``` + +To access an input in Process, instead of + +``` +int select = cc->Inputs().Tag(kSelectTag).Get(); +``` + +write + +``` +int select = kSelectTag(cc).Get(); // alternative: *kSelectTag(cc) +``` + +Sets of multiple ports can be declared with `::Multiple`. Note, also, that a tag +string must always be provided when declaring a port; use `""` for untagged +ports. For example: + + +``` +for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); +} +``` + +becomes + +``` +static constexpr Input::Multiple kIn{""}; +``` + +For output ports, the payload can be passed directly to the `Send` method. For +example, instead of + +``` +cc->Outputs().Index(0).Add( + new std::pair(cc->Inputs().Index(0).Value(), + cc->Inputs().Index(1).Value()), + cc->InputTimestamp()); +``` + +you can write + +``` +kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()}); +``` + +The input timestamp is propagated to the outputs by default. If your calculator +wants to alter timestamps, it must add a `TimestampChange` entry to its contract +declaration. For example: + +``` +MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop, + StreamHandler("ImmediateInputStreamHandler"), + TimestampChange::Arbitrary()); +``` + +Several calculators in +[`calculators/core`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core) and +[`calculators/tensor`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/tensor) +have been updated to use this API. Reference them for more examples. + +More complete documentation will be provided in the future. + +## Builder API + +Documentation will be provided in the future. diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h new file mode 100644 index 000000000..ae32c628a --- /dev/null +++ b/mediapipe/framework/api2/builder.h @@ -0,0 +1,575 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ +#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mediapipe/framework/api2/const_str.h" +#include "mediapipe/framework/api2/contract.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_contract.h" + +namespace mediapipe { +namespace api2 { +namespace builder { + +template +T& GetWithAutoGrow(std::vector>* vecp, int index) { + auto& vec = *vecp; + if (vec.size() <= index) { + vec.resize(index + 1); + } + if (vec[index] == nullptr) { + vec[index] = absl::make_unique(); + } + return *vec[index]; +} + +struct TagIndexLocation { + const std::string& tag; + std::size_t index; + std::size_t count; +}; + +template +class TagIndexMap { + public: + std::vector>& operator[](const std::string& tag) { + return map_[tag]; + } + + void Visit(std::function fun) const { + for (const auto& tagged : map_) { + TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; + for (const auto& item : tagged.second) { + fun(loc, *item); + ++loc.index; + } + } + } + + void Visit(std::function fun) { + for (auto& tagged : map_) { + TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; + for (auto& item : tagged.second) { + fun(loc, item.get()); + ++loc.index; + } + } + } + + // Note: entries are held by a unique_ptr to ensure pointers remain valid. + // Should use absl::flat_hash_map but ordering keys for now. + std::map>> map_; +}; + +// These structs are used internally to store information about the endpoints +// of a connection. +struct SourceBase; +struct DestinationBase { + SourceBase* source = nullptr; +}; +struct SourceBase { + std::vector dests_; + std::string name_; +}; + +// Following existing GraphConfig usage, we allow using a multiport as a single +// port as well. This is necessary for generic nodes, since we have no +// information about which ports are meant to be multiports or not, but it is +// also convenient with typed nodes. +template +class MultiPort : public Single { + public: + using Base = typename Single::Base; + + explicit MultiPort(std::vector>* vec) + : Single(vec), vec_(*vec) {} + + Single operator[](int index) { + CHECK_GE(index, 0); + return Single{&GetWithAutoGrow(&vec_, index)}; + } + + private: + std::vector>& vec_; +}; + +// These classes wrap references to the underlying source/destination +// endpoints, adding type information and the user-visible API. +template +class DestinationImpl { + public: + using Base = DestinationBase; + + explicit DestinationImpl(std::vector>* vec) + : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} + explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + DestinationBase& base_; +}; + +template +class DestinationImpl + : public MultiPort> { + public: + using MultiPort>::MultiPort; +}; + +template +class SourceImpl { + public: + using Base = SourceBase; + + // Src is used as the return type of fluent methods below. Since these are + // single-port methods, it is desirable to always decay to a reference to the + // single-port superclass, even if they are called on a multiport. + using Src = SourceImpl; + template + using Dst = DestinationImpl; + + // clang-format off + template + struct AllowConnection : public std::integral_constant{} || std::is_same{} || + std::is_same{}> {}; + // clang-format on + + explicit SourceImpl(std::vector>* vec) + : SourceImpl(&GetWithAutoGrow(vec, 0)) {} + explicit SourceImpl(SourceBase* base) : base_(*base) {} + + template {}, int>::type = 0> + Src& AddTarget(const Dst& dest) { + CHECK(dest.base_.source == nullptr); + dest.base_.source = &base_; + base_.dests_.emplace_back(&dest.base_); + return *this; + } + Src& SetName(std::string name) { + base_.name_ = std::move(name); + return *this; + } + template + Src& operator>>(const Dst& dest) { + return AddTarget(dest); + } + + private: + SourceBase& base_; +}; + +template +class SourceImpl + : public MultiPort> { + public: + using MultiPort>::MultiPort; +}; + +// A source and a destination correspond to an output/input stream on a node, +// and a side source and side destination correspond to an output/input side +// packet. +// For graph inputs/outputs, however, the inputs are sources, and the outputs +// are destinations. This is because graph ports are connected "from inside" +// when building the graph. +template +using Source = SourceImpl; +template +using SideSource = SourceImpl; +template +using Destination = DestinationImpl; +template +using SideDestination = DestinationImpl; + +class NodeBase { + public: + // TODO: right now access to an indexed port is made directly by + // specifying both a tag and an index. It would be better to represent this + // as a two-step lookup, first getting a multi-port, and then accessing one + // of its entries by index. However, for nodes without visible contracts we + // can't know whether a tag is indexable or not, so we would need the + // multi-port to also be usable as a port directly (representing index 0). + Source Out(const std::string& tag) { + return Source(&out_streams_[tag]); + } + + Destination In(const std::string& tag) { + return Destination(&in_streams_[tag]); + } + + SideSource SideOut(const std::string& tag) { + return SideSource(&out_sides_[tag]); + } + + SideDestination SideIn(const std::string& tag) { + return SideDestination(&in_sides_[tag]); + } + + // Convenience methods for accessing purely index-based ports. + Source Out(int index) { return Out("")[index]; } + + Destination In(int index) { return In("")[index]; } + + SideSource SideOut(int index) { return SideOut("")[index]; } + + SideDestination SideIn(int index) { return SideIn("")[index]; } + + template + T& GetOptions() { + options_used_ = true; + return *options_.MutableExtension(T::ext); + } + + protected: + NodeBase(std::string type) : type_(std::move(type)) {} + + std::string type_; + TagIndexMap in_streams_; + TagIndexMap out_streams_; + TagIndexMap in_sides_; + TagIndexMap out_sides_; + CalculatorOptions options_; + // ideally we'd just check if any extensions are set on options_ + bool options_used_ = false; + friend class Graph; +}; + +template +class Node; +#if __cplusplus >= 201703L +// Deduction guide to silence -Wctad-maybe-unsupported. +explicit Node()->Node; +#endif // C++17 + +template <> +class Node : public NodeBase { + public: + Node(std::string type) : NodeBase(std::move(type)) {} +}; + +using GenericNode = Node; + +template