From 7c331ad58b2cca0dca468e342768900041d65adc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 25 Mar 2021 15:01:44 -0700 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 6e4aff1cc351be3ae4537b677f36d139ee50ce09 --- .../.bazelversion => .bazelversion | 0 Dockerfile | 2 +- MANIFEST.in | 4 + README.md | 2 +- WORKSPACE | 38 +- docs/framework_concepts/packets.md | 23 +- docs/getting_started/android.md | 2 +- docs/getting_started/install.md | 85 +-- docs/index.md | 2 +- docs/solutions/face_detection.md | 4 +- docs/solutions/objectron.md | 51 +- docs/solutions/pose_classification.md | 25 +- docs/solutions/solutions.md | 2 +- docs/tools/tracing_and_profiling.md | 17 +- mediapipe/calculators/audio/BUILD | 24 +- .../audio/audio_decoder_calculator_test.cc | 2 +- mediapipe/calculators/core/BUILD | 24 +- .../core/flow_limiter_calculator.cc | 2 +- .../calculators/core/nonzero_calculator.cc | 42 ++ .../core/packet_resampler_calculator.cc | 708 +++++++++++++----- .../core/packet_resampler_calculator.h | 259 ++++++- .../core/packet_resampler_calculator.proto | 15 + .../core/packet_resampler_calculator_test.cc | 72 ++ mediapipe/calculators/tensor/BUILD | 2 +- .../tensor/image_to_tensor_calculator_test.cc | 2 +- .../tensor/inference_calculator.cc | 2 +- .../calculators/tensor/inference_calculator.h | 4 +- .../tensor/inference_calculator.proto | 4 +- .../tensor/inference_calculator_cpu.cc | 2 +- ...nference_calculator_face_detection_test.cc | 2 +- .../tensor/inference_calculator_gl.cc | 6 +- .../tensor/inference_calculator_metal.cc | 2 +- mediapipe/calculators/tensorflow/BUILD | 10 +- .../tensorflow_inference_calculator_test.cc | 2 +- ...ssion_from_frozen_graph_calculator_test.cc | 2 +- ...ession_from_frozen_graph_generator_test.cc | 2 +- ...ession_from_saved_model_calculator_test.cc | 2 +- ...session_from_saved_model_generator_test.cc | 2 +- .../unpack_media_sequence_calculator.cc | 1 + .../unpack_media_sequence_calculator.proto | 7 - mediapipe/calculators/tflite/BUILD | 2 +- .../tflite/ssd_anchors_calculator_test.cc | 2 +- .../tflite/tflite_inference_calculator.cc | 35 +- .../tflite/tflite_inference_calculator.proto | 4 +- mediapipe/calculators/util/BUILD | 60 +- ...ction_classifications_merger_calculator.cc | 149 ++++ ..._classifications_merger_calculator_test.cc | 320 ++++++++ .../util/landmarks_smoothing_calculator.cc | 256 +++++-- .../util/landmarks_smoothing_calculator.proto | 30 + .../landmarks_to_render_data_calculator.cc | 137 ++-- .../landmarks_to_render_data_calculator.proto | 6 + .../util/visibility_copy_calculator.cc | 194 +++++ .../util/visibility_copy_calculator.proto | 29 + .../util/visibility_smoothing_calculator.cc | 243 ++++++ .../visibility_smoothing_calculator.proto | 40 + .../world_landmark_projection_calculator.cc | 108 +++ mediapipe/calculators/video/BUILD | 3 + mediapipe/examples/coral/BUILD | 3 +- mediapipe/examples/coral/Dockerfile | 2 +- .../examples/coral/demo_run_graph_main.cc | 22 +- mediapipe/examples/desktop/BUILD | 9 +- .../desktop/autoflip/calculators/BUILD | 20 +- .../calculators/content_zooming_calculator.cc | 326 ++++++-- .../content_zooming_calculator.proto | 12 +- .../content_zooming_calculator_state.h | 38 + .../content_zooming_calculator_test.cc | 273 ++++++- .../shot_boundary_calculator_test.cc | 2 +- .../examples/desktop/autoflip/quality/BUILD | 10 +- .../quality/padding_effect_generator_test.cc | 7 +- .../scene_camera_motion_analyzer_test.cc | 2 +- .../autoflip/quality/scene_cropping_viz.cc | 2 +- .../examples/desktop/demo_run_graph_main.cc | 22 +- .../desktop/demo_run_graph_main_gpu.cc | 22 +- .../examples/desktop/iris_tracking/BUILD | 3 +- .../iris_depth_from_image_desktop.cc | 17 +- .../examples/desktop/media_sequence/BUILD | 3 +- .../media_sequence/run_graph_file_io_main.cc | 28 +- .../desktop/object_detection_3d/BUILD | 2 +- .../examples/desktop/simple_run_graph_main.cc | 44 +- mediapipe/examples/desktop/youtube8m/BUILD | 3 +- .../youtube8m/extract_yt8m_features.cc | 28 +- mediapipe/framework/BUILD | 89 ++- mediapipe/framework/calculator_context.cc | 4 +- mediapipe/framework/calculator_context.h | 21 +- mediapipe/framework/calculator_graph.cc | 56 +- mediapipe/framework/calculator_graph.h | 18 +- .../framework/calculator_graph_bounds_test.cc | 32 + mediapipe/framework/calculator_graph_test.cc | 251 ------- mediapipe/framework/calculator_node.cc | 8 +- mediapipe/framework/calculator_state.cc | 8 +- mediapipe/framework/calculator_state.h | 18 +- mediapipe/framework/deps/ret_check.h | 2 +- mediapipe/framework/deps/status.cc | 2 +- mediapipe/framework/deps/status_macros.h | 4 +- mediapipe/framework/encode_binary_proto.bzl | 22 +- mediapipe/framework/formats/BUILD | 4 +- .../framework/formats/classification.proto | 2 + mediapipe/framework/formats/image.h | 13 +- mediapipe/framework/formats/landmark.proto | 4 +- mediapipe/framework/formats/motion/BUILD | 2 +- .../formats/motion/optical_flow_field_test.cc | 2 +- mediapipe/framework/graph_service.h | 13 + mediapipe/framework/graph_service_manager.cc | 21 + mediapipe/framework/graph_service_manager.h | 42 ++ .../framework/graph_service_manager_test.cc | 53 ++ mediapipe/framework/input_stream_handler.cc | 11 +- mediapipe/framework/port/BUILD | 12 - mediapipe/framework/profiler/BUILD | 2 +- .../framework/profiler/graph_tracer_test.cc | 2 +- mediapipe/framework/profiler/reporter/BUILD | 2 +- mediapipe/framework/subgraph.cc | 12 +- mediapipe/framework/subgraph.h | 67 +- mediapipe/framework/subgraph_test.cc | 58 ++ mediapipe/framework/tool/BUILD | 29 +- .../framework/tool/subgraph_expansion.cc | 11 +- mediapipe/framework/tool/subgraph_expansion.h | 7 +- .../framework/tool/subgraph_expansion_test.cc | 38 + .../framework/tool/text_to_binary_graph.cc | 16 +- mediapipe/framework/validated_graph_config.cc | 27 +- mediapipe/framework/validated_graph_config.h | 10 +- .../framework/validated_graph_config_test.cc | 165 ++++ mediapipe/gpu/gl_context.cc | 2 +- mediapipe/gpu/gl_texture_buffer_pool.cc | 22 +- mediapipe/gpu/gl_texture_buffer_pool.h | 2 +- .../instant_motion_tracking/calculators/BUILD | 2 +- .../object_detection_3d/calculators/BUILD | 2 +- .../objectron_desktop_cpu.pbtxt | 2 +- .../com/google/mediapipe/framework/Graph.java | 30 +- .../framework/PacketListCallback.java} | 22 +- .../com/google/mediapipe/framework/jni/BUILD | 2 +- .../mediapipe/framework/jni/class_registry.h | 2 + .../google/mediapipe/framework/jni/graph.cc | 64 ++ .../google/mediapipe/framework/jni/graph.h | 7 + .../mediapipe/framework/jni/graph_jni.cc | 30 + .../mediapipe/framework/jni/graph_jni.h | 4 + .../mediapipe/framework/jni/jni_util.cc | 16 + .../google/mediapipe/framework/jni/jni_util.h | 2 + .../framework/jni/register_natives.cc | 7 + .../com/google/mediapipe/mediapipe_aar.bzl | 7 + .../face_detection_front_by_roi_cpu.pbtxt | 2 +- mediapipe/modules/face_geometry/libs/BUILD | 4 +- mediapipe/modules/objectron/calculators/BUILD | 12 +- mediapipe/modules/pose_landmark/BUILD | 1 + .../pose_landmark_filtering.pbtxt | 64 +- mediapipe/python/solution_base.py | 5 +- mediapipe/python/solutions/drawing_utils.py | 30 +- .../python/solutions/drawing_utils_test.py | 43 ++ mediapipe/python/solutions/objectron.py | 30 +- mediapipe/python/solutions/pose.py | 1 + mediapipe/python/solutions/pose_test.py | 73 +- mediapipe/util/BUILD | 70 +- mediapipe/util/audio_decoder.cc | 16 +- mediapipe/util/audio_decoder.h | 2 +- mediapipe/util/cpu_util.cc | 15 +- mediapipe/util/filtering/BUILD | 12 + mediapipe/util/filtering/one_euro_filter.cc | 84 +++ mediapipe/util/filtering/one_euro_filter.h | 40 + mediapipe/util/resource_util.cc | 26 +- mediapipe/util/resource_util.h | 1 - mediapipe/util/resource_util_android.cc | 54 +- mediapipe/util/resource_util_apple.cc | 32 +- mediapipe/util/resource_util_custom.h | 18 + mediapipe/util/resource_util_default.cc | 43 ++ mediapipe/util/resource_util_internal.h | 19 + mediapipe/util/tflite/cpu_op_resolver.h | 3 +- mediapipe/util/tflite/op_resolver.h | 3 +- mediapipe/util/tflite/tflite_model_loader.cc | 1 - mediapipe/util/tracking/BUILD | 11 +- .../tracking/region_flow_computation_test.cc | 4 +- requirements.txt | 6 +- setup.py | 6 +- setup_opencv.sh | 4 +- third_party/BUILD | 8 + .../org_tensorflow_compatibility_fixes.diff | 22 +- third_party/org_tensorflow_objc_cxx17.diff | 2 +- 175 files changed, 4804 insertions(+), 1325 deletions(-) rename mediapipe/opensource_only/.bazelversion => .bazelversion (100%) create mode 100644 mediapipe/calculators/core/nonzero_calculator.cc create mode 100644 mediapipe/calculators/util/detection_classifications_merger_calculator.cc create mode 100644 mediapipe/calculators/util/detection_classifications_merger_calculator_test.cc create mode 100644 mediapipe/calculators/util/visibility_copy_calculator.cc create mode 100644 mediapipe/calculators/util/visibility_copy_calculator.proto create mode 100644 mediapipe/calculators/util/visibility_smoothing_calculator.cc create mode 100644 mediapipe/calculators/util/visibility_smoothing_calculator.proto create mode 100644 mediapipe/calculators/util/world_landmark_projection_calculator.cc create mode 100644 mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h create mode 100644 mediapipe/framework/graph_service_manager.cc create mode 100644 mediapipe/framework/graph_service_manager.h create mode 100644 mediapipe/framework/graph_service_manager_test.cc create mode 100644 mediapipe/framework/validated_graph_config_test.cc rename mediapipe/{framework/port/commandlineflags.h => java/com/google/mediapipe/framework/PacketListCallback.java} (60%) create mode 100644 mediapipe/util/filtering/one_euro_filter.cc create mode 100644 mediapipe/util/filtering/one_euro_filter.h create mode 100644 mediapipe/util/resource_util_custom.h create mode 100644 mediapipe/util/resource_util_default.cc create mode 100644 mediapipe/util/resource_util_internal.h diff --git a/mediapipe/opensource_only/.bazelversion b/.bazelversion similarity index 100% rename from mediapipe/opensource_only/.bazelversion rename to .bazelversion diff --git a/Dockerfile b/Dockerfile index dc3b034a2..1b46ccdc4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -54,7 +54,7 @@ RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=3.4.1 +ARG BAZEL_VERSION=3.7.2 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/MANIFEST.in b/MANIFEST.in index 1994721f3..ba8014db8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,3 +10,7 @@ include requirements.txt recursive-include mediapipe/modules *.tflite *.txt *.binarypb exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite +exclude mediapipe/modules/objectron/object_detection_3d_sneakers.tflite +exclude mediapipe/modules/objectron/object_detection_3d_chair.tflite +exclude mediapipe/modules/objectron/object_detection_3d_camera.tflite +exclude mediapipe/modules/objectron/object_detection_3d_cup.tflite diff --git a/README.md b/README.md index 06fa39b5e..8c75978a4 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/WORKSPACE b/WORKSPACE index 32b466e6c..3932b19c5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,16 +2,19 @@ workspace(name = "mediapipe") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -skylib_version = "0.9.0" http_archive( name = "bazel_skylib", type = "tar.gz", - url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel_skylib-{}.tar.gz".format (skylib_version, skylib_version), - sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", + urls = [ + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", + ], + sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c", ) +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") +bazel_skylib_workspace() load("@bazel_skylib//lib:versions.bzl", "versions") -versions.check(minimum_bazel_version = "3.4.0") - +versions.check(minimum_bazel_version = "3.7.2") # ABSL cpp library lts_2020_09_23 http_archive( @@ -38,8 +41,8 @@ http_archive( http_archive( name = "rules_foreign_cc", - strip_prefix = "rules_foreign_cc-main", - url = "https://github.com/bazelbuild/rules_foreign_cc/archive/main.zip", + strip_prefix = "rules_foreign_cc-0.1.0", + url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip", ) load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") @@ -304,8 +307,8 @@ http_archive( # Maven dependencies. -RULES_JVM_EXTERNAL_TAG = "3.2" -RULES_JVM_EXTERNAL_SHA = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af" +RULES_JVM_EXTERNAL_TAG = "4.0" +RULES_JVM_EXTERNAL_SHA = "31701ad93dbfe544d597dbe62c9a1fdd76d81d8a9150c2bf1ecf928ecdf97169" http_archive( name = "rules_jvm_external", @@ -318,7 +321,6 @@ load("@rules_jvm_external//:defs.bzl", "maven_install") # Important: there can only be one maven_install rule. Add new maven deps here. maven_install( - name = "maven", artifacts = [ "androidx.concurrent:concurrent-futures:1.0.0-alpha03", "androidx.lifecycle:lifecycle-common:2.2.0", @@ -343,10 +345,10 @@ maven_install( "org.hamcrest:hamcrest-library:1.3", ], repositories = [ - "https://jcenter.bintray.com", "https://maven.google.com", "https://dl.google.com/dl/android/maven2", "https://repo1.maven.org/maven2", + "https://jcenter.bintray.com", ], fetch_sources = True, version_conflict_policy = "pinned", @@ -363,10 +365,10 @@ http_archive( ], ) -#Tensorflow repo should always go after the other external dependencies. -# 2020-12-09 -_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a" -_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0" +# Tensorflow repo should always go after the other external dependencies. +# 2021-03-25 +_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887" +_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f" http_archive( name = "org_tensorflow", urls = [ @@ -383,5 +385,7 @@ http_archive( sha256 = _TENSORFLOW_SHA256, ) -load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") -tf_workspace(tf_repo_name = "org_tensorflow") +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") +tf_workspace3() +load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") +tf_workspace2() diff --git a/docs/framework_concepts/packets.md b/docs/framework_concepts/packets.md index bdf11c69f..bf5cd4ea6 100644 --- a/docs/framework_concepts/packets.md +++ b/docs/framework_concepts/packets.md @@ -12,19 +12,30 @@ nav_order: 3 {:toc} --- -Each calculator is a node of of a graph. We describe how to create a new calculator, how to initialize a calculator, how to perform its calculations, input and output streams, timestamps, and options +Calculators communicate by sending and receiving packets. Typically a single +packet is sent along each input stream at each input timestamp. A packet can +contain any kind of data, such as a single frame of video or a single integer +detection count. ## Creating a packet -Packets are generally created with `MediaPipe::Adopt()` (from packet.h). +Packets are generally created with `mediapipe::MakePacket()` or +`mediapipe::Adopt()` (from packet.h). ```c++ -// Create some data. -auto data = absl::make_unique("constructor_argument"); -// Create a packet to own the data. -Packet p = Adopt(data.release()); +// Create a packet containing some new data. +Packet p = MakePacket("constructor_argument"); // Make a new packet with the same data and a different timestamp. Packet p2 = p.At(Timestamp::PostStream()); ``` +or: + +```c++ +// Create some new data. +auto data = absl::make_unique("constructor_argument"); +// Create a packet to own the data. +Packet p = Adopt(data.release()).At(Timestamp::PostStream()); +``` + Data within a packet is accessed with `Packet::Get()` diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index 855f5fa29..ee83116dd 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -28,7 +28,7 @@ Gradle. * Install MediaPipe following these [instructions](./install.md). * Setup Java Runtime. * Setup Android SDK release 28.0.3 and above. -* Setup Android NDK r18b and above. +* Setup Android NDK version between 18 and 21. MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see below for Android Studio setup). However, if you prefer using MediaPipe without diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 7a02def53..c0a240ae8 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -25,25 +25,11 @@ install --user six`. ## Installing on Debian and Ubuntu -1. Install Bazel. +1. Install Bazelisk. Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) - to install Bazel 3.4 or higher. - - For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs - to be built from source: - - ```bash - # For Bazel 3.4.1 - mkdir $HOME/bazel-3.4.1 - cd $HOME/bazel-3.4.1 - wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip - sudo apt-get install build-essential openjdk-8-jdk python zip unzip - unzip bazel-3.4.1-dist.zip - env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh - sudo cp output/bazel /usr/local/bin/ - ``` + [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) + to install Bazelisk. 2. Checkout MediaPipe repository. @@ -207,11 +193,11 @@ build issues. **Disclaimer**: Running MediaPipe on CentOS is experimental. -1. Install Bazel. +1. Install Bazelisk. Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html) - to install Bazel 3.4 or higher. + [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) + to install Bazelisk. 2. Checkout MediaPipe repository. @@ -336,11 +322,11 @@ build issues. * Install [Xcode](https://developer.apple.com/xcode/) and its Command Line Tools by `xcode-select --install`. -2. Install Bazel. +2. Install Bazelisk. Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) - to install Bazel 3.4 or higher. + [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) + to install Bazelisk. 3. Checkout MediaPipe repository. @@ -353,7 +339,7 @@ build issues. 4. Install OpenCV and FFmpeg. Option 1. Use HomeBrew package manager tool to install the pre-compiled - OpenCV 3.4.5 libraries. FFmpeg will be installed via OpenCV. + OpenCV 3 libraries. FFmpeg will be installed via OpenCV. ```bash $ brew install opencv@3 @@ -484,29 +470,36 @@ next section. 4. Install Visual C++ Build Tools 2019 and WinSDK - Go to https://visualstudio.microsoft.com/visual-cpp-build-tools, download - build tools, and install Microsoft Visual C++ 2019 Redistributable and - Microsoft Build Tools 2019. + Go to + [the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools), + download build tools, and install Microsoft Visual C++ 2019 Redistributable + and Microsoft Build Tools 2019. Download the WinSDK from - https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/ and - install. + [the official MicroSoft website](https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/) + and install. -5. Install Bazel and add the location of the Bazel executable to the `%PATH%` - environment variable. +5. Install Bazel or Bazelisk and add the location of the Bazel executable to + the `%PATH%` environment variable. - Follow the official - [Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) - to install Bazel 3.4 or higher. + Option 1. Follow + [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) + to install Bazel 3.7.2 or higher. -6. Set Bazel variables. + Option 2. Follow the official + [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) + to install Bazelisk. + +6. Set Bazel variables. Learn more details about + ["Build on Windows"](https://docs.bazel.build/versions/master/windows.html#build-c-with-msvc) + in the Bazel official documentation. ``` - # Find the exact paths and version numbers from your local version. + # Please find the exact paths and version numbers from your local version. C:\> set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools C:\> set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC - C:\> set BAZEL_VC_FULL_VERSION=14.25.28610 - C:\> set BAZEL_WINSDK_FULL_VERSION=10.1.18362.1 + C:\> set BAZEL_VC_FULL_VERSION= + C:\> set BAZEL_WINSDK_FULL_VERSION= ``` 7. Checkout MediaPipe repository. @@ -593,19 +586,11 @@ cameras. Alternatively, you use a video file as input. username@DESKTOP-TMVLBJ1:~$ sudo apt-get update && sudo apt-get install -y build-essential git python zip adb openjdk-8-jdk ``` -5. Install Bazel. +5. Install Bazelisk. - ```bash - username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ - https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \ - sudo mkdir -p /usr/local/bazel/3.4.1 && \ - chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \ - sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \ - source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash - - username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \ - alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel' - ``` + Follow the official + [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) + to install Bazelisk. 6. Checkout MediaPipe repository. diff --git a/docs/index.md b/docs/index.md index d3db8892d..9035bf106 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,7 +44,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index f04af27d7..de2f5d4a5 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -183,8 +183,8 @@ function onResults(results) { canvasCtx.restore(); } -const faceDetection = new Objectron({locateFile: (file) => { - return `https://cdn.jsdelivr.net/npm/@mediapipe/objectron@0.0/${file}`; +const faceDetection = new FaceDetection({locateFile: (file) => { + return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`; }}); faceDetection.setOptions({ minDetectionConfidence: 0.5 diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index c689f9c40..0164e23b3 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -358,15 +358,17 @@ cap.release() ## Example Apps Please first see general instructions for -[Android](../getting_started/android.md) and [iOS](../getting_started/ios.md) on -how to build MediaPipe examples. +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see [visualizer documentation](../tools/visualizer.md). -### Two-stage Objectron +### Mobile + +#### Two-stage Objectron * Graph: [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) @@ -404,7 +406,7 @@ to visualize its associated subgraphs, please see * iOS target: Not available -### Single-stage Objectron +#### Single-stage Objectron * Graph: [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) @@ -428,7 +430,7 @@ to visualize its associated subgraphs, please see * iOS target: Not available -### Assets +#### Assets Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file format into a custom .uuu format. This can be done for user assets as follows: @@ -449,9 +451,35 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http > single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as > absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. -### Coordinate Systems -#### Object Coordinate +### Desktop + +To build the application, run: + +```bash +bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_detection_3d:objectron_cpu +``` + +To run the application, replace `` and `` +in the command below with your own paths, and `` and +`` with the following: + +Category | `` | `` +:------- | :-------------------------------------------------------------------------- | :----------------- +Shoe | mediapipe/modules/objectron/object_detection_3d_sneakers.tflite | Footwear +Chair | mediapipe/modules/objectron/object_detection_3d_chair.tflite | Chair +Cup | mediapipe/modules/objectron/object_detection_3d_cup.tflite | Mug +Camera | mediapipe/modules/objectron/object_detection_3d_camera.tflite | Camera + +``` +GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \ + --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \ + --input_side_packets=input_video_path=,output_video_path=,box_landmark_model_path=,allowed_labels= +``` + +## Coordinate Systems + +### Object Coordinate Each object has its object coordinate frame. We use the below object coordinate definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front, @@ -459,7 +487,7 @@ origin is at the center of the 3D bounding box. ![box_coordinate.svg](../images/box_coordinate.svg) -#### Camera Coordinate +### Camera Coordinate A 3D object is parameterized by its `scale` and `rotation`, `translation` with regard to the camera coordinate frame. In this API we use the below camera @@ -476,7 +504,7 @@ camera frame by applying `rotation` and `translation`: landmarks_3d = rotation * scale * unit_box + translation ``` -#### NDC Space +### NDC Space In this API we use [NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html) @@ -495,7 +523,7 @@ y_ndc = -fy * Y / Z + py z_ndc = 1 / Z ``` -#### Pixel Space +### Pixel Space In this API we set upper-left coner of an image as the origin of pixel coordinate. One can convert from NDC to pixel space as follows: @@ -532,10 +560,11 @@ py = -py_pixel * 2.0 / image_height + 1.0 [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) * Google AI Blog: [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) +* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021 * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak Shape Supervision](https://arxiv.org/abs/2003.03522) * Paper: [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) - ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)) + ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020 * [Models and model cards](./models.md#objectron) * [Python Colab](https://mediapipe.page.link/objectron_py_colab) diff --git a/docs/solutions/pose_classification.md b/docs/solutions/pose_classification.md index 9595dc7d1..21f87a95d 100644 --- a/docs/solutions/pose_classification.md +++ b/docs/solutions/pose_classification.md @@ -25,10 +25,11 @@ One of the applications [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) can enable is fitness. More specifically - pose classification and repetition counting. In this section we'll provide basic guidance on building a custom pose -classifier with the help of [Colabs](#colabs) and wrap it in a simple -[fitness app](https://mediapipe.page.link/mlkit-pose-classification-demo-app) -powered by [ML Kit](https://developers.google.com/ml-kit). Push-ups and squats -are used for demonstration purposes as the most common exercises. +classifier with the help of [Colabs](#colabs) and wrap it in a simple fitness +demo within +[ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app). +Push-ups and squats are used for demonstration purposes as the most common +exercises. ![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) | :--------------------------------------------------------------------------------------------------------: | @@ -47,7 +48,7 @@ determines the object's class based on the closest samples in the training set. classifier and form a training set using these [Colabs](#colabs), 3. Perform the classification itself followed by repetition counting (e.g., in the - [ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app)). + [ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app)). ## Training Set @@ -76,7 +77,7 @@ video right in the Colab. Code of the classifier is available both in the [`Pose Classification Colab (Extended)`] and in the -[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app). +[ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app). Please refer to them for details of the approach described below. The k-NN algorithm used for pose classification requires a feature vector @@ -127,11 +128,13 @@ where the pose class and the counter can't be changed. ## Future Work -We are actively working on improving BlazePose GHUM 3D's Z prediction. It will -allow us to use joint angles in the feature vectors, which are more natural and -easier to configure (although distances can still be useful to detect touches -between body parts) and to perform rotation normalization of poses and reduce -the number of camera angles required for accurate k-NN classification. +We are actively working on improving +[BlazePose GHUM 3D](./pose.md#pose-landmark-model-blazepose-ghum-3d)'s Z +prediction. It will allow us to use joint angles in the feature vectors, which +are more natural and easier to configure (although distances can still be useful +to detect touches between body parts) and to perform rotation normalization of +poses and reduce the number of camera angles required for accurate k-NN +classification. ## Colabs diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index c78dffea0..a95f0c032 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -28,7 +28,7 @@ has_toc: false [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index ed58eb61b..30b4bd993 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -41,6 +41,7 @@ profiler_config { trace_enabled: true enable_profiler: true trace_log_interval_count: 200 + trace_log_path: "/sdcard/Download/" } ``` @@ -64,7 +65,7 @@ MediaPipe will emit data into a pre-specified directory: * On the desktop, this will be the `/tmp` directory. -* On Android, this will be the `/sdcard` directory. +* On Android, this will be the external storage directory (e.g., `/storage/emulated/0/`). * On iOS, this can be reached through XCode. Select "Window/Devices and Simulators" and select the "Devices" tab. @@ -103,7 +104,7 @@ we record ten intervals of half a second each. This can be overridden by adding * Include the line below in your `AndroidManifest.xml` file. ```xml - + ``` * Grant the permission either upon first app launch, or by going into @@ -130,8 +131,8 @@ we record ten intervals of half a second each. This can be overridden by adding events to a trace log files at: ```bash - /sdcard/mediapipe_trace_0.binarypb - /sdcard/mediapipe_trace_1.binarypb + /storage/emulated/0/Download/mediapipe_trace_0.binarypb + /storage/emulated/0/Download/mediapipe_trace_1.binarypb ``` After every 5 sec, writing shifts to a successive trace log file, such that @@ -139,10 +140,10 @@ we record ten intervals of half a second each. This can be overridden by adding trace files have been written to the device using adb shell. ```bash - adb shell "ls -la /sdcard/" + adb shell "ls -la /storage/emulated/0/Download" ``` - On android, MediaPipe selects the external storage directory `/sdcard` for + On android, MediaPipe selects the external storage (e.g., `/storage/emulated/0/`) for trace logs. This directory can be overridden using the setting `trace_log_path`, like: @@ -150,7 +151,7 @@ we record ten intervals of half a second each. This can be overridden by adding profiler_config { trace_enabled: true enable_profiler: true - trace_log_path: "/sdcard/profiles/" + trace_log_path: "/sdcard/Download/profiles/" } ``` @@ -161,7 +162,7 @@ we record ten intervals of half a second each. This can be overridden by adding ```bash # from your terminal - adb pull /sdcard/mediapipe_trace_0.binarypb + adb pull /storage/emulated/0/Download/mediapipe_trace_0.binarypb # if successful you should see something like # /sdcard/mediapipe_trace_0.binarypb: 1 file pulled. 0.1 MB/s (6766 bytes in 0.045s) ``` diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index 9667e11d5..ed6a509dc 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -128,7 +128,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/util:time_series_util", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -147,7 +147,7 @@ cc_library( "//mediapipe/util:time_series_util", "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp/mfcc", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -168,7 +168,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp:resampler", "@com_google_audio_tools//audio/dsp:resampler_q", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -208,7 +208,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp/spectrogram", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -228,7 +228,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/util:time_series_util", "@com_google_audio_tools//audio/dsp:window_functions", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -242,9 +242,9 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:time_series_header_cc_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/flags:flag", ], ) @@ -261,7 +261,7 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/util:time_series_test_util", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -276,7 +276,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -296,7 +296,7 @@ cc_test( "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", "@com_google_audio_tools//audio/dsp:number_util", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -314,7 +314,7 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -333,7 +333,7 @@ cc_test( "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", "@com_google_audio_tools//audio/dsp:window_functions", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -352,6 +352,6 @@ cc_test( "//mediapipe/framework/tool:validate_type", "//mediapipe/util:time_series_test_util", "@com_google_audio_tools//audio/dsp:signal_vector_util", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) diff --git a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc index 33ab9e04f..8107a56f9 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/time_series_header.pb.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 61d402f74..f319aef5b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -414,7 +414,7 @@ cc_library( "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -430,7 +430,7 @@ cc_library( "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -450,6 +450,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "nonzero_calculator", + srcs = ["nonzero_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + cc_test( name = "mux_calculator_test", srcs = ["mux_calculator_test.cc"], @@ -776,7 +790,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:validate_type", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -793,7 +807,7 @@ cc_test( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:validate_type", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -1024,7 +1038,7 @@ cc_library( "//mediapipe/framework/tool:status_util", "//mediapipe/util:time_series_util", "@com_google_absl//absl/memory", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index 4fbfced96..eba621ce3 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -57,7 +57,7 @@ namespace mediapipe { // // The "ALLOW" stream indicates the transition between accepting frames and // dropping frames. "ALLOW = true" indicates the start of accepting frames -// including the current timestamp, and "ALLOW = true" indicates the start of +// including the current timestamp, and "ALLOW = false" indicates the start of // dropping frames including the current timestamp. // // FlowLimiterCalculator provides limited support for multiple input streams. diff --git a/mediapipe/calculators/core/nonzero_calculator.cc b/mediapipe/calculators/core/nonzero_calculator.cc new file mode 100644 index 000000000..9a5928231 --- /dev/null +++ b/mediapipe/calculators/core/nonzero_calculator.cc @@ -0,0 +1,42 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +// A Calculator that returns 0 if INPUT is 0, and 1 otherwise. +class NonZeroCalculator : public Node { + public: + static constexpr Input::SideFallback kIn{"INPUT"}; + static constexpr Output kOut{"OUTPUT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { + if (!kIn(cc).IsEmpty()) { + auto output = std::make_unique((*kIn(cc) != 0) ? 1 : 0); + kOut(cc).Send(std::move(output)); + } + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(NonZeroCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 32b1c850a..43253520a 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -87,7 +87,6 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { flush_last_packet_ = resampler_options.flush_last_packet(); jitter_ = resampler_options.jitter(); - jitter_with_reflection_ = resampler_options.jitter_with_reflection(); input_data_id_ = cc->Inputs().GetId("DATA", 0); if (!input_data_id_.IsValid()) { @@ -98,11 +97,7 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { output_data_id_ = cc->Outputs().GetId("", 0); } - period_count_ = 0; frame_rate_ = resampler_options.frame_rate(); - base_timestamp_ = resampler_options.has_base_timestamp() - ? Timestamp(resampler_options.base_timestamp()) - : Timestamp::Unset(); start_time_ = resampler_options.has_start_time() ? Timestamp(resampler_options.start_time()) : Timestamp::Min(); @@ -141,30 +136,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { } } - if (jitter_ != 0.0) { - if (resampler_options.output_header() != - PacketResamplerCalculatorOptions::NONE) { - LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " - "the actual value."; - } - if (flush_last_packet_) { - flush_last_packet_ = false; - LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " - "ignored, because we are adding jitter."; - } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); - random_ = CreateSecureRandom(seed); - if (random_ == nullptr) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - "SecureRandom is not available. With \"jitter\" specified, " - "PacketResamplerCalculator processing cannot proceed."); - } - packet_reservoir_random_ = CreateSecureRandom(seed); - } - packet_reservoir_ = - std::make_unique(packet_reservoir_random_.get()); - return absl::OkStatus(); + strategy_ = GetSamplingStrategy(resampler_options); + + return strategy_->Open(cc); } absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { @@ -177,171 +151,13 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } } - if (jitter_ != 0.0 && random_ != nullptr) { - // Packet reservior is used to make sure there's an output for every period, - // e.g. partial period at the end of the stream. - if (packet_reservoir_->IsEnabled() && - (first_timestamp_ == Timestamp::Unset() || - (cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) { - auto curr_packet = cc->Inputs().Get(input_data_id_).Value(); - packet_reservoir_->AddSample(curr_packet); - } - MP_RETURN_IF_ERROR(ProcessWithJitter(cc)); - } else { - MP_RETURN_IF_ERROR(ProcessWithoutJitter(cc)); + + if (absl::Status status = strategy_->Process(cc); !status.ok()) { + return status; // Avoid MP_RETURN_IF_ERROR macro for external release. } + last_packet_ = cc->Inputs().Get(input_data_id_).Value(); - return absl::OkStatus(); -} -void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() { - next_output_timestamp_min_ = first_timestamp_; - if (jitter_with_reflection_) { - next_output_timestamp_ = - first_timestamp_ + random_->UnbiasedUniform64(frame_time_usec_); - return; - } - next_output_timestamp_ = - first_timestamp_ + frame_time_usec_ * random_->RandFloat(); -} - -void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { - packet_reservoir_->Clear(); - if (jitter_with_reflection_) { - next_output_timestamp_min_ += frame_time_usec_; - Timestamp next_output_timestamp_max_ = - next_output_timestamp_min_ + frame_time_usec_; - - next_output_timestamp_ += frame_time_usec_ + - random_->UnbiasedUniform64(2 * jitter_usec_ + 1) - - jitter_usec_; - next_output_timestamp_ = Timestamp(ReflectBetween( - next_output_timestamp_.Value(), next_output_timestamp_min_.Value(), - next_output_timestamp_max_.Value())); - CHECK_GE(next_output_timestamp_, next_output_timestamp_min_); - CHECK_LT(next_output_timestamp_, next_output_timestamp_max_); - return; - } - packet_reservoir_->Disable(); - next_output_timestamp_ += - frame_time_usec_ * - ((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat()); -} - -absl::Status PacketResamplerCalculator::ProcessWithJitter( - CalculatorContext* cc) { - RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); - RET_CHECK_NE(jitter_, 0.0); - - if (first_timestamp_ == Timestamp::Unset()) { - first_timestamp_ = cc->InputTimestamp(); - InitializeNextOutputTimestampWithJitter(); - if (first_timestamp_ == next_output_timestamp_) { - OutputWithinLimits( - cc, - cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_)); - UpdateNextOutputTimestampWithJitter(); - } - return absl::OkStatus(); - } - - if (frame_time_usec_ < - (cc->InputTimestamp() - last_packet_.Timestamp()).Value()) { - LOG_FIRST_N(WARNING, 2) - << "Adding jitter is not very useful when upsampling."; - } - - while (true) { - const int64 last_diff = - (next_output_timestamp_ - last_packet_.Timestamp()).Value(); - RET_CHECK_GT(last_diff, 0); - const int64 curr_diff = - (next_output_timestamp_ - cc->InputTimestamp()).Value(); - if (curr_diff > 0) { - break; - } - OutputWithinLimits(cc, (std::abs(curr_diff) > last_diff - ? last_packet_ - : cc->Inputs().Get(input_data_id_).Value()) - .At(next_output_timestamp_)); - UpdateNextOutputTimestampWithJitter(); - // From now on every time a packet is emitted the timestamp of the next - // packet becomes known; that timestamp is stored in next_output_timestamp_. - // The only exception to this rule is the packet emitted from Close() which - // can only happen when jitter_with_reflection is enabled but in this case - // next_output_timestamp_min_ is a non-decreasing lower bound of any - // subsequent packet. - const Timestamp timestamp_bound = jitter_with_reflection_ - ? next_output_timestamp_min_ - : next_output_timestamp_; - cc->Outputs().Get(output_data_id_).SetNextTimestampBound(timestamp_bound); - } - return absl::OkStatus(); -} - -absl::Status PacketResamplerCalculator::ProcessWithoutJitter( - CalculatorContext* cc) { - RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); - RET_CHECK_EQ(jitter_, 0.0); - - if (first_timestamp_ == Timestamp::Unset()) { - // This is the first packet, initialize the first_timestamp_. - if (base_timestamp_ == Timestamp::Unset()) { - // Initialize first_timestamp_ with exactly the first packet timestamp. - first_timestamp_ = cc->InputTimestamp(); - } else { - // Initialize first_timestamp_ with the first packet timestamp - // aligned to the base_timestamp_. - int64 first_index = MathUtil::SafeRound( - (cc->InputTimestamp() - base_timestamp_).Seconds() * frame_rate_); - first_timestamp_ = - base_timestamp_ + TimestampDiffFromSeconds(first_index / frame_rate_); - } - if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { - cc->Outputs() - .Tag("VIDEO_HEADER") - .Add(new VideoHeader(video_header_), Timestamp::PreStream()); - } - } - const Timestamp received_timestamp = cc->InputTimestamp(); - const int64 received_timestamp_idx = - TimestampToPeriodIndex(received_timestamp); - // Only consider the received packet if it belongs to the current period - // (== period_count_) or to a newer one (> period_count_). - if (received_timestamp_idx >= period_count_) { - // Fill the empty periods until we are in the same index as the received - // packet. - while (received_timestamp_idx > period_count_) { - OutputWithinLimits( - cc, last_packet_.At(PeriodIndexToTimestamp(period_count_))); - ++period_count_; - } - // Now, if the received packet has a timestamp larger than the middle of - // the current period, we can send a packet without waiting. We send the - // one closer to the middle. - Timestamp target_timestamp = PeriodIndexToTimestamp(period_count_); - if (received_timestamp >= target_timestamp) { - bool have_last_packet = (last_packet_.Timestamp() != Timestamp::Unset()); - bool send_current = - !have_last_packet || (received_timestamp - target_timestamp <= - target_timestamp - last_packet_.Timestamp()); - if (send_current) { - OutputWithinLimits( - cc, cc->Inputs().Get(input_data_id_).Value().At(target_timestamp)); - } else { - OutputWithinLimits(cc, last_packet_.At(target_timestamp)); - } - ++period_count_; - } - // TODO: Add a mechanism to the framework to allow these packets - // to be output earlier (without waiting for a much later packet to - // arrive) - - // Update the bound for the next packet. - cc->Outputs() - .Get(output_data_id_) - .SetNextTimestampBound(PeriodIndexToTimestamp(period_count_)); - } return absl::OkStatus(); } @@ -349,17 +165,34 @@ absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { if (!cc->GraphStatus().ok()) { return absl::OkStatus(); } - // Emit the last packet received if we have at least one packet, but - // haven't sent anything for its period. - if (first_timestamp_ != Timestamp::Unset() && flush_last_packet_ && - TimestampToPeriodIndex(last_packet_.Timestamp()) == period_count_) { - OutputWithinLimits(cc, - last_packet_.At(PeriodIndexToTimestamp(period_count_))); + + return strategy_->Close(cc); +} + +std::unique_ptr +PacketResamplerCalculator::GetSamplingStrategy( + const PacketResamplerCalculatorOptions& options) { + if (options.reproducible_sampling()) { + if (!options.jitter_with_reflection()) { + LOG(WARNING) + << "reproducible_sampling enabled w/ jitter_with_reflection " + "disabled. " + << "reproducible_sampling always uses jitter with reflection, " + << "Ignoring jitter_with_reflection setting."; + } + return absl::make_unique(this); } - if (!packet_reservoir_->IsEmpty()) { - OutputWithinLimits(cc, packet_reservoir_->GetSample()); + + if (options.jitter() == 0) { + return absl::make_unique(this); } - return absl::OkStatus(); + + if (options.jitter_with_reflection()) { + return absl::make_unique(this); + } + + // With jitter and no reflection. + return absl::make_unique(this); } Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const { @@ -385,4 +218,479 @@ void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc, } } +absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + + if (resampler_options.output_header() != + PacketResamplerCalculatorOptions::NONE) { + LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; + } + + if (calculator_->flush_last_packet_) { + LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; + } + + const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + random_ = CreateSecureRandom(seed); + if (random_ == nullptr) { + return absl::InvalidArgumentError( + "SecureRandom is not available. With \"jitter\" specified, " + "PacketResamplerCalculator processing cannot proceed."); + } + + packet_reservoir_random_ = CreateSecureRandom(seed); + packet_reservoir_ = + std::make_unique(packet_reservoir_random_.get()); + + return absl::OkStatus(); +} +absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) { + if (!packet_reservoir_->IsEmpty()) { + LOG(INFO) << "Emitting pack from reservoir."; + calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample()); + } + return absl::OkStatus(); +} +absl::Status LegacyJitterWithReflectionStrategy::Process( + CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + + if (packet_reservoir_->IsEnabled() && + (first_timestamp_ == Timestamp::Unset() || + (cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) { + auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value(); + packet_reservoir_->AddSample(curr_packet); + } + + if (first_timestamp_ == Timestamp::Unset()) { + first_timestamp_ = cc->InputTimestamp(); + InitializeNextOutputTimestampWithJitter(); + if (first_timestamp_ == next_output_timestamp_) { + calculator_->OutputWithinLimits(cc, cc->Inputs() + .Get(calculator_->input_data_id_) + .Value() + .At(next_output_timestamp_)); + UpdateNextOutputTimestampWithJitter(); + } + return absl::OkStatus(); + } + + if (calculator_->frame_time_usec_ < + (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { + LOG_FIRST_N(WARNING, 2) + << "Adding jitter is not very useful when upsampling."; + } + + while (true) { + const int64 last_diff = + (next_output_timestamp_ - calculator_->last_packet_.Timestamp()) + .Value(); + RET_CHECK_GT(last_diff, 0); + const int64 curr_diff = + (next_output_timestamp_ - cc->InputTimestamp()).Value(); + if (curr_diff > 0) { + break; + } + calculator_->OutputWithinLimits( + cc, (std::abs(curr_diff) > last_diff + ? calculator_->last_packet_ + : cc->Inputs().Get(calculator_->input_data_id_).Value()) + .At(next_output_timestamp_)); + UpdateNextOutputTimestampWithJitter(); + // From now on every time a packet is emitted the timestamp of the next + // packet becomes known; that timestamp is stored in next_output_timestamp_. + // The only exception to this rule is the packet emitted from Close() which + // can only happen when jitter_with_reflection is enabled but in this case + // next_output_timestamp_min_ is a non-decreasing lower bound of any + // subsequent packet. + const Timestamp timestamp_bound = next_output_timestamp_min_; + cc->Outputs() + .Get(calculator_->output_data_id_) + .SetNextTimestampBound(timestamp_bound); + } + return absl::OkStatus(); +} + +void LegacyJitterWithReflectionStrategy:: + InitializeNextOutputTimestampWithJitter() { + next_output_timestamp_min_ = first_timestamp_; + next_output_timestamp_ = + first_timestamp_ + + random_->UnbiasedUniform64(calculator_->frame_time_usec_); +} + +void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() { + packet_reservoir_->Clear(); + next_output_timestamp_min_ += calculator_->frame_time_usec_; + Timestamp next_output_timestamp_max_ = + next_output_timestamp_min_ + calculator_->frame_time_usec_; + + next_output_timestamp_ += + calculator_->frame_time_usec_ + + random_->UnbiasedUniform64(2 * calculator_->jitter_usec_ + 1) - + calculator_->jitter_usec_; + next_output_timestamp_ = Timestamp(ReflectBetween( + next_output_timestamp_.Value(), next_output_timestamp_min_.Value(), + next_output_timestamp_max_.Value())); + CHECK_GE(next_output_timestamp_, next_output_timestamp_min_); + CHECK_LT(next_output_timestamp_, next_output_timestamp_max_); +} + +absl::Status ReproducibleJitterWithReflectionStrategy::Open( + CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + + if (resampler_options.output_header() != + PacketResamplerCalculatorOptions::NONE) { + LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; + } + + if (calculator_->flush_last_packet_) { + LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; + } + + const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + random_ = CreateSecureRandom(seed); + if (random_ == nullptr) { + return absl::InvalidArgumentError( + "SecureRandom is not available. With \"jitter\" specified, " + "PacketResamplerCalculator processing cannot proceed."); + } + + return absl::OkStatus(); +} +absl::Status ReproducibleJitterWithReflectionStrategy::Close( + CalculatorContext* cc) { + // If last packet is non-empty and a packet hasn't been emitted for this + // period, emit the last packet. + if (!calculator_->last_packet_.IsEmpty() && !packet_emitted_this_period_) { + calculator_->OutputWithinLimits( + cc, calculator_->last_packet_.At(next_output_timestamp_)); + } + return absl::OkStatus(); +} +absl::Status ReproducibleJitterWithReflectionStrategy::Process( + CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + + Packet current_packet = cc->Inputs().Get(calculator_->input_data_id_).Value(); + + if (calculator_->last_packet_.IsEmpty()) { + // last_packet is empty, this is the first packet of the stream. + + InitializeNextOutputTimestamp(current_packet.Timestamp()); + + // If next_output_timestamp_ happens to fall before current_packet, emit + // current packet. Only a single packet can be emitted at the beginning + // of the stream. + if (next_output_timestamp_ < current_packet.Timestamp()) { + calculator_->OutputWithinLimits( + cc, current_packet.At(next_output_timestamp_)); + packet_emitted_this_period_ = true; + } + + return absl::OkStatus(); + } + + // Last packet is set, so we are mid-stream. + if (calculator_->frame_time_usec_ < + (current_packet.Timestamp() - calculator_->last_packet_.Timestamp()) + .Value()) { + // Note, if the stream is upsampling, this could lead to the same packet + // being emitted twice. Upsampling and jitter doesn't make much sense + // but does technically work. + LOG_FIRST_N(WARNING, 2) + << "Adding jitter is not very useful when upsampling."; + } + + // Since we may be upsampling, we need to iteratively advance the + // next_output_timestamp_ one period at a time until it reaches the period + // current_packet is in. During this process, last_packet and/or + // current_packet may be repeatly emitted. + + UpdateNextOutputTimestamp(current_packet.Timestamp()); + + while (!packet_emitted_this_period_ && + next_output_timestamp_ <= current_packet.Timestamp()) { + // last_packet < next_output_timestamp_ <= current_packet, + // so emit the closest packet. + Packet packet_to_emit = + current_packet.Timestamp() - next_output_timestamp_ < + next_output_timestamp_ - calculator_->last_packet_.Timestamp() + ? current_packet + : calculator_->last_packet_; + calculator_->OutputWithinLimits(cc, + packet_to_emit.At(next_output_timestamp_)); + + packet_emitted_this_period_ = true; + + // If we are upsampling, packet_emitted_this_period_ can be reset by + // the following UpdateNext and the loop will iterate. + UpdateNextOutputTimestamp(current_packet.Timestamp()); + } + + // Set the bounds on the output stream. Note, if we emitted a packet + // above, it will already be set at next_output_timestamp_ + 1, in which + // case we have to skip setting it. + if (cc->Outputs().Get(calculator_->output_data_id_).NextTimestampBound() < + next_output_timestamp_) { + cc->Outputs() + .Get(calculator_->output_data_id_) + .SetNextTimestampBound(next_output_timestamp_); + } + return absl::OkStatus(); +} + +void ReproducibleJitterWithReflectionStrategy::InitializeNextOutputTimestamp( + Timestamp current_timestamp) { + if (next_output_timestamp_min_ != Timestamp::Unset()) { + return; + } + + next_output_timestamp_min_ = Timestamp(0); + next_output_timestamp_ = + Timestamp(GetNextRandom(calculator_->frame_time_usec_)); + + // While the current timestamp is ahead of the max (i.e. min + frame_time), + // fast-forward. + while (current_timestamp >= + next_output_timestamp_min_ + calculator_->frame_time_usec_) { + packet_emitted_this_period_ = true; // Force update... + UpdateNextOutputTimestamp(current_timestamp); + } +} + +void ReproducibleJitterWithReflectionStrategy::UpdateNextOutputTimestamp( + Timestamp current_timestamp) { + if (packet_emitted_this_period_ && + current_timestamp >= + next_output_timestamp_min_ + calculator_->frame_time_usec_) { + next_output_timestamp_min_ += calculator_->frame_time_usec_; + Timestamp next_output_timestamp_max_ = + next_output_timestamp_min_ + calculator_->frame_time_usec_; + + next_output_timestamp_ += calculator_->frame_time_usec_ + + GetNextRandom(2 * calculator_->jitter_usec_ + 1) - + calculator_->jitter_usec_; + next_output_timestamp_ = Timestamp(ReflectBetween( + next_output_timestamp_.Value(), next_output_timestamp_min_.Value(), + next_output_timestamp_max_.Value())); + + packet_emitted_this_period_ = false; + } +} + +absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + + if (resampler_options.output_header() != + PacketResamplerCalculatorOptions::NONE) { + LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; + } + + if (calculator_->flush_last_packet_) { + LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; + } + + const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + random_ = CreateSecureRandom(seed); + if (random_ == nullptr) { + return absl::InvalidArgumentError( + "SecureRandom is not available. With \"jitter\" specified, " + "PacketResamplerCalculator processing cannot proceed."); + } + + packet_reservoir_random_ = CreateSecureRandom(seed); + packet_reservoir_ = + absl::make_unique(packet_reservoir_random_.get()); + + return absl::OkStatus(); +} +absl::Status JitterWithoutReflectionStrategy::Close(CalculatorContext* cc) { + if (!packet_reservoir_->IsEmpty()) { + calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample()); + } + return absl::OkStatus(); +} +absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + + // Packet reservior is used to make sure there's an output for every period, + // e.g. partial period at the end of the stream. + if (packet_reservoir_->IsEnabled() && + (calculator_->first_timestamp_ == Timestamp::Unset() || + (cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) { + auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value(); + packet_reservoir_->AddSample(curr_packet); + } + + if (calculator_->first_timestamp_ == Timestamp::Unset()) { + calculator_->first_timestamp_ = cc->InputTimestamp(); + InitializeNextOutputTimestamp(); + if (calculator_->first_timestamp_ == next_output_timestamp_) { + calculator_->OutputWithinLimits(cc, cc->Inputs() + .Get(calculator_->input_data_id_) + .Value() + .At(next_output_timestamp_)); + UpdateNextOutputTimestamp(); + } + return absl::OkStatus(); + } + + if (calculator_->frame_time_usec_ < + (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { + LOG_FIRST_N(WARNING, 2) + << "Adding jitter is not very useful when upsampling."; + } + + while (true) { + const int64 last_diff = + (next_output_timestamp_ - calculator_->last_packet_.Timestamp()) + .Value(); + RET_CHECK_GT(last_diff, 0); + const int64 curr_diff = + (next_output_timestamp_ - cc->InputTimestamp()).Value(); + if (curr_diff > 0) { + break; + } + calculator_->OutputWithinLimits( + cc, (std::abs(curr_diff) > last_diff + ? calculator_->last_packet_ + : cc->Inputs().Get(calculator_->input_data_id_).Value()) + .At(next_output_timestamp_)); + UpdateNextOutputTimestamp(); + cc->Outputs() + .Get(calculator_->output_data_id_) + .SetNextTimestampBound(next_output_timestamp_); + } + return absl::OkStatus(); +} + +void JitterWithoutReflectionStrategy::InitializeNextOutputTimestamp() { + next_output_timestamp_min_ = calculator_->first_timestamp_; + next_output_timestamp_ = calculator_->first_timestamp_ + + calculator_->frame_time_usec_ * random_->RandFloat(); +} + +void JitterWithoutReflectionStrategy::UpdateNextOutputTimestamp() { + packet_reservoir_->Clear(); + packet_reservoir_->Disable(); + next_output_timestamp_ += calculator_->frame_time_usec_ * + ((1.0 - calculator_->jitter_) + + 2.0 * calculator_->jitter_ * random_->RandFloat()); +} + +absl::Status NoJitterStrategy::Open(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + base_timestamp_ = resampler_options.has_base_timestamp() + ? Timestamp(resampler_options.base_timestamp()) + : Timestamp::Unset(); + + period_count_ = 0; + + return absl::OkStatus(); +} +absl::Status NoJitterStrategy::Close(CalculatorContext* cc) { + // Emit the last packet received if we have at least one packet, but + // haven't sent anything for its period. + if (calculator_->first_timestamp_ != Timestamp::Unset() && + calculator_->flush_last_packet_ && + calculator_->TimestampToPeriodIndex( + calculator_->last_packet_.Timestamp()) == period_count_) { + calculator_->OutputWithinLimits( + cc, calculator_->last_packet_.At( + calculator_->PeriodIndexToTimestamp(period_count_))); + } + return absl::OkStatus(); +} +absl::Status NoJitterStrategy::Process(CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + + if (calculator_->first_timestamp_ == Timestamp::Unset()) { + // This is the first packet, initialize the first_timestamp_. + if (base_timestamp_ == Timestamp::Unset()) { + // Initialize first_timestamp_ with exactly the first packet timestamp. + calculator_->first_timestamp_ = cc->InputTimestamp(); + } else { + // Initialize first_timestamp_ with the first packet timestamp + // aligned to the base_timestamp_. + int64 first_index = MathUtil::SafeRound( + (cc->InputTimestamp() - base_timestamp_).Seconds() * + calculator_->frame_rate_); + calculator_->first_timestamp_ = + base_timestamp_ + + TimestampDiffFromSeconds(first_index / calculator_->frame_rate_); + } + if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { + cc->Outputs() + .Tag("VIDEO_HEADER") + .Add(new VideoHeader(calculator_->video_header_), + Timestamp::PreStream()); + } + } + const Timestamp received_timestamp = cc->InputTimestamp(); + const int64 received_timestamp_idx = + calculator_->TimestampToPeriodIndex(received_timestamp); + // Only consider the received packet if it belongs to the current period + // (== period_count_) or to a newer one (> period_count_). + if (received_timestamp_idx >= period_count_) { + // Fill the empty periods until we are in the same index as the received + // packet. + while (received_timestamp_idx > period_count_) { + calculator_->OutputWithinLimits( + cc, calculator_->last_packet_.At( + calculator_->PeriodIndexToTimestamp(period_count_))); + ++period_count_; + } + // Now, if the received packet has a timestamp larger than the middle of + // the current period, we can send a packet without waiting. We send the + // one closer to the middle. + Timestamp target_timestamp = + calculator_->PeriodIndexToTimestamp(period_count_); + if (received_timestamp >= target_timestamp) { + bool have_last_packet = + (calculator_->last_packet_.Timestamp() != Timestamp::Unset()); + bool send_current = + !have_last_packet || + (received_timestamp - target_timestamp <= + target_timestamp - calculator_->last_packet_.Timestamp()); + if (send_current) { + calculator_->OutputWithinLimits(cc, + cc->Inputs() + .Get(calculator_->input_data_id_) + .Value() + .At(target_timestamp)); + } else { + calculator_->OutputWithinLimits( + cc, calculator_->last_packet_.At(target_timestamp)); + } + ++period_count_; + } + // TODO: Add a mechanism to the framework to allow these packets + // to be output earlier (without waiting for a much later packet to + // arrive) + + // Update the bound for the next packet. + cc->Outputs() + .Get(calculator_->output_data_id_) + .SetNextTimestampBound( + calculator_->PeriodIndexToTimestamp(period_count_)); + } + return absl::OkStatus(); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_resampler_calculator.h b/mediapipe/calculators/core/packet_resampler_calculator.h index 4a1a3ffaa..fbecdb0e7 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.h +++ b/mediapipe/calculators/core/packet_resampler_calculator.h @@ -55,7 +55,7 @@ class PacketReservoir { // correspond to timestamp t. // - The next packet is chosen randomly (uniform distribution) among frames // that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate]. -// - if jitter_with_reflection_ is true, the timestamp will be reflected +// - if jitter_with_reflection is true, the timestamp will be reflected // against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate) // so that its marginal distribution is uniform within this interval. // In the formula, t_0 is the timestamp of the first sampled @@ -66,6 +66,17 @@ class PacketReservoir { // the resampling. For Cloud ML Video Intelligence API, the hash of the // input video should serve this purpose. For YouTube, either video ID or // content hex ID of the input video should do. +// - If reproducible_samping is true, care is taken to allow reproducible +// "mid-stream" sampling. The calculator can be executed on a stream that +// doesn't start at the first period. For instance, if the calculator +// is run on a 10 second stream it will produce the same set of samples +// as two runs of the calculator, the first with 3 seconds of input starting +// at time 0 and the second with 7 seconds of input starting at time +3s. +// - In order to guarantee the exact same samples, 1) the inputs must be +// aligned with the sampling period. For instance, if the sampling rate +// is 2 frames per second, streams should be aligned on 0.5 second +// boundaries, and 2) the stream must include at least one extra packet +// before and after the second aligned sampling period. // // If jitter_ is not specified: // - The first packet defines the first_timestamp of the output stream, @@ -105,19 +116,6 @@ class PacketResamplerCalculator : public CalculatorBase { absl::Status Close(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; - private: - // Calculates the first sampled timestamp that incorporates a jittering - // offset. - void InitializeNextOutputTimestampWithJitter(); - // Calculates the next sampled timestamp that incorporates a jittering offset. - void UpdateNextOutputTimestampWithJitter(); - - // Logic for Process() when jitter_ != 0.0. - absl::Status ProcessWithJitter(CalculatorContext* cc); - - // Logic for Process() when jitter_ == 0.0. - absl::Status ProcessWithoutJitter(CalculatorContext* cc); - // Given the current count of periods that have passed, this returns // the next valid timestamp of the middle point of the next period: // if count is 0, it returns the first_timestamp_. @@ -141,6 +139,16 @@ class PacketResamplerCalculator : public CalculatorBase { // Outputs a packet if it is in range (start_time_, end_time_). void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const; + protected: + // Returns Sampling Strategy to use. + // + // Virtual to allow injection of testing strategies. + virtual std::unique_ptr GetSamplingStrategy( + const mediapipe::PacketResamplerCalculatorOptions& options); + + private: + std::unique_ptr strategy_; + // The timestamp of the first packet received. Timestamp first_timestamp_; @@ -150,14 +158,6 @@ class PacketResamplerCalculator : public CalculatorBase { // Inverse of frame_rate_. int64 frame_time_usec_; - // Number of periods that have passed (= #packets sent to the output). - // - // Can only be used if jitter_ equals zero. - int64 period_count_; - - // The last packet that was received. - Packet last_packet_; - VideoHeader video_header_; // The "DATA" input stream. CollectionItemId input_data_id_; @@ -165,23 +165,15 @@ class PacketResamplerCalculator : public CalculatorBase { CollectionItemId output_data_id_; // Indicator whether to flush last packet even if its timestamp is greater - // than the final stream timestamp. Set to false when jitter_ is non-zero. + // than the final stream timestamp. bool flush_last_packet_; - // Jitter-related variables. - std::unique_ptr random_; double jitter_ = 0.0; - bool jitter_with_reflection_; - int64 jitter_usec_; - Timestamp next_output_timestamp_; - // If jittering_with_reflection_ is true, next_output_timestamp_ will be - // kept within the interval - // [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_) - Timestamp next_output_timestamp_min_; - // If specified, output timestamps are aligned with base_timestamp. - // Otherwise, they are aligned with the first input timestamp. - Timestamp base_timestamp_; + int64 jitter_usec_; + + // The last packet that was received. + Packet last_packet_; // If specified, only outputs at/after start_time are included. Timestamp start_time_; @@ -191,15 +183,210 @@ class PacketResamplerCalculator : public CalculatorBase { // If set, the output timestamps nearest to start_time and end_time // are included in the output, even if the nearest timestamp is not - // between start_time and end_time.W + // between start_time and end_time. bool round_limits_; + // Allow strategies access to all internal calculator state. + // + // The calculator and strategies are intimiately tied together so this should + // not break encapsulation. + friend class LegacyJitterWithReflectionStrategy; + friend class ReproducibleJitterWithReflectionStrategy; + friend class JitterWithoutReflectionStrategy; + friend class NoJitterStrategy; +}; + +// Abstract class encapsulating sampling stategy. +// +// These are used solely by PacketResamplerCalculator, but are exposed here +// to facilitate tests. +class PacketResamplerStrategy { + public: + PacketResamplerStrategy(PacketResamplerCalculator* calculator) + : calculator_(calculator) {} + virtual ~PacketResamplerStrategy() = default; + + // Delegate for CalculatorBase::Open. See CalculatorBase for relevant + // implementation considerations. + virtual absl::Status Open(CalculatorContext* cc) = 0; + // Delegate for CalculatorBase::Close. See CalculatorBase for relevant + // implementation considerations. + virtual absl::Status Close(CalculatorContext* cc) = 0; + // Delegate for CalculatorBase::Process. See CalculatorBase for relevant + // implementation considerations. + virtual absl::Status Process(CalculatorContext* cc) = 0; + + protected: + // Calculator running strategy. + PacketResamplerCalculator* calculator_; +}; + +// Strategy that applies Jitter with reflection based sampling. +// +// Used by PacketResamplerCalculator when both Jitter and reflection are +// enabled. +// +// This applies the legacy jitter with reflection which doesn't allow +// for reproducibility of sampling when starting mid-stream. This is maintained +// for backward compatibility. +class LegacyJitterWithReflectionStrategy : public PacketResamplerStrategy { + public: + LegacyJitterWithReflectionStrategy(PacketResamplerCalculator* calculator) + : PacketResamplerStrategy(calculator) {} + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + void InitializeNextOutputTimestampWithJitter(); + void UpdateNextOutputTimestampWithJitter(); + + // Jitter-related variables. + std::unique_ptr random_; + + // The timestamp of the first packet received. + Timestamp first_timestamp_; + + // Next packet to be emitted. Since packets may not align perfectly with + // next_output_timestamp_, the closest packet will be emitted. + Timestamp next_output_timestamp_; + + // Lower bound for next timestamp. + // + // next_output_timestamp_ will be kept within the interval + // [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_) + Timestamp next_output_timestamp_min_ = Timestamp::Unset(); + // packet reservior used for sampling random packet out of partial // period when jitter is enabled std::unique_ptr packet_reservoir_; + // random number generator used in packet_reservior_. std::unique_ptr packet_reservoir_random_; }; +// Strategy that applies reproducible jitter with reflection based sampling. +// +// Used by PacketResamplerCalculator when both Jitter and reflection are +// enabled. +class ReproducibleJitterWithReflectionStrategy + : public PacketResamplerStrategy { + public: + ReproducibleJitterWithReflectionStrategy( + PacketResamplerCalculator* calculator) + : PacketResamplerStrategy(calculator) {} + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + protected: + // Returns next random in range (0,n]. + // + // Exposed as virtual function for testing Jitter with reflection. + // This is the only way random_ is accessed. + virtual uint64 GetNextRandom(uint64 n) { + return random_->UnbiasedUniform64(n); + } + + private: + // Initializes Jitter with reflection. + // + // This will fast-forward to the period containing current_timestamp. + // next_output_timestamp_ is guarnateed to be current_timestamp's period + // and packet_emitted_this_period_ will be set to false. + void InitializeNextOutputTimestamp(Timestamp current_timestamp); + + // Potentially advances next_output_timestamp_ a single period. + // + // next_output_timestamp_ will only be advanced if packet_emitted_this_period_ + // is false. next_output_timestamp_ will never be advanced beyond + // current_timestamp's period. + // + // However, next_output_timestamp_ could fall before current_timestamp's + // period since only a single period can be advanced at a time. + void UpdateNextOutputTimestamp(Timestamp current_timestamp); + + // Jitter-related variables. + std::unique_ptr random_; + + // Next packet to be emitted. Since packets may not align perfectly with + // next_output_timestamp_, the closest packet will be emitted. + Timestamp next_output_timestamp_; + + // Lower bound for next timestamp. + // + // next_output_timestamp_ will be kept within the interval + // [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_) + Timestamp next_output_timestamp_min_ = Timestamp::Unset(); + + // Indicates packet was emitted for current period (i.e. the period + // next_output_timestamp_ falls in. + bool packet_emitted_this_period_ = false; +}; + +// Strategy that applies Jitter without reflection based sampling. +// +// Used by PacketResamplerCalculator when Jitter is enabled and reflection is +// not enabled. +class JitterWithoutReflectionStrategy : public PacketResamplerStrategy { + public: + JitterWithoutReflectionStrategy(PacketResamplerCalculator* calculator) + : PacketResamplerStrategy(calculator) {} + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + // Calculates the first sampled timestamp that incorporates a jittering + // offset. + void InitializeNextOutputTimestamp(); + + // Calculates the next sampled timestamp that incorporates a jittering offset. + void UpdateNextOutputTimestamp(); + + // Jitter-related variables. + std::unique_ptr random_; + + // Next packet to be emitted. Since packets may not align perfectly with + // next_output_timestamp_, the closest packet will be emitted. + Timestamp next_output_timestamp_; + + // Lower bound for next timestamp. + // + // next_output_timestamp_ will be kept within the interval + // [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_) + Timestamp next_output_timestamp_min_ = Timestamp::Unset(); + + // packet reservior used for sampling random packet out of partial period. + std::unique_ptr packet_reservoir_; + + // random number generator used in packet_reservior_. + std::unique_ptr packet_reservoir_random_; +}; + +// Strategy that applies sampling without any jitter. +// +// Used by PacketResamplerCalculator when jitter is not enabled. +class NoJitterStrategy : public PacketResamplerStrategy { + public: + NoJitterStrategy(PacketResamplerCalculator* calculator) + : PacketResamplerStrategy(calculator) {} + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + // Number of periods that have passed (= #packets sent to the output). + int64 period_count_; + + // If specified, output timestamps are aligned with base_timestamp. + // Otherwise, they are aligned with the first input timestamp. + Timestamp base_timestamp_; +}; + } // namespace mediapipe #endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/packet_resampler_calculator.proto b/mediapipe/calculators/core/packet_resampler_calculator.proto index d037ee9de..f7ca47023 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.proto +++ b/mediapipe/calculators/core/packet_resampler_calculator.proto @@ -68,8 +68,23 @@ message PacketResamplerCalculatorOptions { // pseudo-random number generator does its job and the number of frames is // sufficiently large, the average frame rate will be close to this value. optional double jitter = 4; + + // Enables reflection when applying jitter. + // + // This option is ignored when reproducible_sampling is true, in which case + // reflection will be used. + // + // New use cases should use reproducible_sampling = true, as + // jitter_with_reflection is deprecated and will be removed at some point. optional bool jitter_with_reflection = 9 [default = false]; + // If set, enabled reproducible sampling, allowing frames to be sampled + // without regards to where the stream starts. See + // packet_resampler_calculator.h for details. + // + // This enables reflection (ignoring jitter_with_reflection setting). + optional bool reproducible_sampling = 10 [default = false]; + // If specified, output timestamps are aligned with base_timestamp. // Otherwise, they are aligned with the first input timestamp. // diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index ffd23684e..58d58767e 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -30,6 +30,7 @@ namespace mediapipe { +using ::testing::ElementsAre; namespace { // A simple version of CalculatorRunner with built-in convenience // methods for setting inputs from a vector and checking outputs @@ -96,6 +97,77 @@ class SimpleRunner : public CalculatorRunner { static int static_count_; }; +// Matcher for Packets with uint64 payload, comparing arg packet's +// timestamp and uint64 payload. +MATCHER_P2(PacketAtTimestamp, payload, timestamp, + absl::StrCat(negation ? "isn't" : "is", " a packet with payload ", + payload, " @ time ", timestamp)) { + if (timestamp != arg.Timestamp().Value()) { + *result_listener << "at incorrect timestamp = " << arg.Timestamp().Value(); + return false; + } + int64 actual_payload = arg.template Get(); + if (actual_payload != payload) { + *result_listener << "with incorrect payload = " << actual_payload; + return false; + } + return true; +} + +// JitterWithReflectionStrategy child class which injects a specified stream +// of "random" numbers. +// +// Calculators are created through factory methods, making testing and injection +// tricky. This class utilizes a static variable, random_sequence, to pass +// the desired random sequence into the calculator. +class ReproducibleJitterWithReflectionStrategyForTesting + : public ReproducibleJitterWithReflectionStrategy { + public: + ReproducibleJitterWithReflectionStrategyForTesting( + PacketResamplerCalculator* calculator) + : ReproducibleJitterWithReflectionStrategy(calculator) {} + + // Statically accessed random sequence to use for jitter with reflection. + // + // An EXPECT will fail if sequence is less than the number requested during + // processing. + static std::vector random_sequence; + + protected: + virtual uint64 GetNextRandom(uint64 n) { + EXPECT_LT(sequence_index_, random_sequence.size()); + return random_sequence[sequence_index_++] % n; + } + + private: + int32 sequence_index_ = 0; +}; +std::vector + ReproducibleJitterWithReflectionStrategyForTesting::random_sequence; + +// PacketResamplerCalculator child class which injects a specified stream +// of "random" numbers. +// +// Calculators are created through factory methods, making testing and injection +// tricky. This class utilizes a static variable, random_sequence, to pass +// the desired random sequence into the calculator. +class ReproducibleResamplerCalculatorForTesting + : public PacketResamplerCalculator { + public: + static absl::Status GetContract(CalculatorContract* cc) { + return PacketResamplerCalculator::GetContract(cc); + } + + protected: + std::unique_ptr GetSamplingStrategy( + const mediapipe::PacketResamplerCalculatorOptions& Options) { + return absl::make_unique< + ReproducibleJitterWithReflectionStrategyForTesting>(this); + } +}; + +REGISTER_CALCULATOR(ReproducibleResamplerCalculatorForTesting); + int SimpleRunner::static_count_ = 0; TEST(PacketResamplerCalculatorTest, NoPacketsInStream) { diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 5a0631007..e27347a7e 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -561,13 +561,13 @@ cc_test( "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 233424720..275c33559 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/flags/flag.h" #include "absl/memory/memory.h" #include "absl/strings/substitute.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" @@ -28,7 +29,6 @@ #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/opencv_core_inc.h" diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 89a02b713..fb05fa5cd 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -41,7 +41,7 @@ class InferenceCalculatorSelectorImpl (options.has_delegate() && options.delegate().has_gpu()); if (should_use_gpu) { impls.emplace_back("Metal"); - impls.emplace_back("MlDrift"); + impls.emplace_back("MlDriftWebGl"); impls.emplace_back("Gl"); } impls.emplace_back("Cpu"); diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index a746684ff..f2a0b4360 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -118,8 +118,8 @@ struct InferenceCalculatorGl : public InferenceCalculator { static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; }; -struct InferenceCalculatorMlDrift : public InferenceCalculator { - static constexpr char kCalculatorName[] = "InferenceCalculatorMlDrift"; +struct InferenceCalculatorMlDriftWebGl : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorMlDriftWebGl"; }; struct InferenceCalculatorMetal : public InferenceCalculator { diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 0efb61d4a..e0b538a91 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -51,12 +51,12 @@ message InferenceCalculatorOptions { // This option is valid for TFLite GPU delegate API2 only, // Choose any of available APIs to force running inference using it. - enum API { + enum Api { ANY = 0; OPENGL = 1; OPENCL = 2; } - optional API api = 4 [default = ANY]; + optional Api api = 4 [default = ANY]; // This option is valid for TFLite GPU delegate API2 only, // Set to true to use 16-bit float precision. If max precision is needed, diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index d931b93fa..e93ad4a3a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -136,7 +136,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { const auto& model = *model_packet_.Get(); tflite::ops::builtin::BuiltinOpResolver op_resolver = kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); RET_CHECK(interpreter_); diff --git a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc index 5fb7c974a..f1dc8c8fe 100644 --- a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc @@ -59,7 +59,7 @@ const std::vector& GetParams() { p.back().delegate.mutable_gpu(); #endif // TARGET_IPHONE_SIMULATOR #if __EMSCRIPTEN__ - p.push_back({"MlDrift", "MlDrift"}); + p.push_back({"MlDriftWebGl", "MlDriftWebGl"}); p.back().delegate.mutable_gpu(); #endif // __EMSCRIPTEN__ #if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3 diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 081b12d3c..d2cabfcac 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -63,7 +63,7 @@ class InferenceCalculatorGlImpl mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr tflite_gpu_runner_; bool allow_precision_loss_ = false; - mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE @@ -244,7 +244,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( const auto& model = *model_packet_.Get(); tflite::ops::builtin::BuiltinOpResolver op_resolver = kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); // Create runner tflite::gpu::InferenceOptions options; @@ -294,7 +294,7 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { const auto& model = *model_packet_.Get(); tflite::ops::builtin::BuiltinOpResolver op_resolver = kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); RET_CHECK(interpreter_); diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 490189aec..a81d0d460 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -200,7 +200,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { const auto& model = *model_packet_.Get(); tflite::ops::builtin::BuiltinOpResolver op_resolver = kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); RET_CHECK(interpreter_); diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 4c2b90a59..0dbbd57da 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -892,13 +892,13 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:framework", @@ -923,13 +923,13 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:framework", @@ -954,11 +954,11 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:direct_session", @@ -981,11 +981,11 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:direct_session", @@ -1144,8 +1144,8 @@ cc_test( ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", ":tensorflow_session_from_frozen_graph_generator_cc_proto", + "@com_google_absl//absl/flags:flag", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:integral_types", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index 20e80bf33..6a931679d 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -16,12 +16,12 @@ #include #include +#include "absl/flags/flag.h" #include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index 8d3d3fdff..bdf90dcbb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "absl/strings/substitute.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" @@ -19,7 +20,6 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index c7f06bbc4..34d7e8828 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "absl/strings/substitute.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" @@ -19,7 +20,6 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index 912d71600..7016f14bb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "absl/strings/str_replace.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" @@ -20,7 +21,6 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index 92d0d5de4..aca506f0b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "absl/strings/str_replace.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" @@ -19,7 +20,6 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 1f4cda359..6f52f09ef 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" #include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto index 444e46ffd..51cc870c7 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto @@ -56,11 +56,4 @@ message UnpackMediaSequenceCalculatorOptions { // the clip start and end times and outputs these for the // AudioDecoderCalculator to consume. optional AudioDecoderOptions base_audio_decoder_options = 9; - - optional string keypoint_names = 10 [ - default = - "NOSE,LEFT_EAR,RIGHT_EAR,LEFT_SHOULDER,RIGHT_SHOULDER,LEFT_FORE_PAW,RIGHT_FORE_PAW,LEFT_HIP,RIGHT_HIP,LEFT_HIND_PAW,RIGHT_HIND_PAW,ROOT_TAIL" - ]; - // When the keypoint doesn't exists, output this default value. - optional float default_keypoint_location = 11 [default = -1.0]; } diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 6bc4636b1..2d1037d20 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -147,11 +147,11 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats/object_detection:anchor_cc_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/flags:flag", ], ) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc index 906eeed21..851d26b3d 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index a2fc7ec3a..bc05f51b5 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -278,7 +278,7 @@ class TfLiteInferenceCalculator : public CalculatorBase { bool use_advanced_gpu_api_ = false; bool allow_precision_loss_ = false; - mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::API + mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; bool use_kernel_caching_ = false; @@ -702,11 +702,16 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( #if MEDIAPIPE_TFLITE_GL_INFERENCE ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver; + + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates + default_op_resolver; + auto op_resolver_ptr = + static_cast( + &default_op_resolver); if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver = cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get(); + op_resolver_ptr = &(cc->InputSidePackets() + .Tag("CUSTOM_OP_RESOLVER") + .Get()); } // Create runner @@ -733,7 +738,7 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( } } MP_RETURN_IF_ERROR( - tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + tflite_gpu_runner_->InitializeWithModel(model, *op_resolver_ptr)); // Allocate interpreter memory for cpu output. if (!gpu_output_) { @@ -786,18 +791,24 @@ absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver; + + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates + default_op_resolver; + auto op_resolver_ptr = + static_cast( + &default_op_resolver); + if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver = cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get(); + op_resolver_ptr = &(cc->InputSidePackets() + .Tag("CUSTOM_OP_RESOLVER") + .Get()); } #if defined(MEDIAPIPE_EDGE_TPU) interpreter_ = - BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get()); + BuildEdgeTpuInterpreter(model, op_resolver_ptr, edgetpu_context_.get()); #else - tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_); #endif // MEDIAPIPE_EDGE_TPU RET_CHECK(interpreter_); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index 862de8b0b..02dc20831 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -51,12 +51,12 @@ message TfLiteInferenceCalculatorOptions { // This option is valid for TFLite GPU delegate API2 only, // Choose any of available APIs to force running inference using it. - enum API { + enum Api { ANY = 0; OPENGL = 1; OPENCL = 2; } - optional API api = 4 [default = ANY]; + optional Api api = 4 [default = ANY]; // This option is valid for TFLite GPU delegate API2 only, // Set to true to use 16-bit float precision. If max precision is needed, diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index df6d5c6d6..bc24e5994 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -841,12 +841,39 @@ cc_library( "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", + "//mediapipe/util/filtering:one_euro_filter", "//mediapipe/util/filtering:relative_velocity_filter", "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) +mediapipe_proto_library( + name = "visibility_smoothing_calculator_proto", + srcs = ["visibility_smoothing_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "visibility_smoothing_calculator", + srcs = ["visibility_smoothing_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":visibility_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util/filtering:low_pass_filter", + "@com_google_absl//absl/algorithm:container", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], @@ -858,7 +885,7 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -1194,3 +1221,34 @@ cc_library( }), alwayslink = 1, ) + +cc_library( + name = "detection_classifications_merger_calculator", + srcs = ["detection_classifications_merger_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "detection_classifications_merger_calculator_test", + srcs = ["detection_classifications_merger_calculator_test.cc"], + deps = [ + ":detection_classifications_merger_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) diff --git a/mediapipe/calculators/util/detection_classifications_merger_calculator.cc b/mediapipe/calculators/util/detection_classifications_merger_calculator.cc new file mode 100644 index 000000000..86f26b0dc --- /dev/null +++ b/mediapipe/calculators/util/detection_classifications_merger_calculator.cc @@ -0,0 +1,149 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/substitute.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { +namespace api2 { + +namespace {} // namespace + +// Replaces the classification labels and scores from the input `Detection` with +// the ones provided into the input `ClassificationList`. Namely: +// * `label_id[i]` becomes `classification[i].index` +// * `score[i]` becomes `classification[i].score` +// * `label[i]` becomes `classification[i].label` (if present) +// +// In case the input `ClassificationList` contains no results (i.e. +// `classification` is empty, which may happen if the classifier uses a score +// threshold and no confident enough result were returned), the input +// `Detection` is returned unchanged. +// +// This is specifically designed for two-stage detection cascades where the +// detections returned by a standalone detector (typically a class-agnostic +// localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through +// the optional "RECT" or "NORM_RECT" input, e.g: +// +// node { +// calculator: "DetectionsToRectsCalculator" +// # Output of an upstream object detector. +// input_stream: "DETECTION:detection" +// output_stream: "NORM_RECT:norm_rect" +// } +// node { +// calculator: "TfLiteTaskImageClassifierCalculator" +// input_stream: "IMAGE:image" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "CLASSIFICATION_RESULT:classification_result" +// } +// node { +// calculator: "TfLiteTaskClassificationResultToClassificationsCalculator" +// input_stream: "CLASSIFICATION_RESULT:classification_result" +// output_stream: "CLASSIFICATION_LIST:classification_list" +// } +// node { +// calculator: "DetectionClassificationsMergerCalculator" +// input_stream: "INPUT_DETECTION:detection" +// input_stream: "CLASSIFICATION_LIST:classification_list" +// # Final output. +// output_stream: "OUTPUT_DETECTION:classified_detection" +// } +// +// Inputs: +// INPUT_DETECTION: `Detection` proto. +// CLASSIFICATION_LIST: `ClassificationList` proto. +// +// Output: +// OUTPUT_DETECTION: modified `Detection` proto. +class DetectionClassificationsMergerCalculator : public Node { + public: + static constexpr Input kInputDetection{"INPUT_DETECTION"}; + static constexpr Input kClassificationList{ + "CLASSIFICATION_LIST"}; + static constexpr Output kOutputDetection{"OUTPUT_DETECTION"}; + + MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList, + kOutputDetection); + + absl::Status Process(CalculatorContext* cc) override; +}; +MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator); + +absl::Status DetectionClassificationsMergerCalculator::Process( + CalculatorContext* cc) { + if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) { + return absl::OkStatus(); + } + RET_CHECK(!kInputDetection(cc).IsEmpty()); + RET_CHECK(!kClassificationList(cc).IsEmpty()); + + Detection detection = *kInputDetection(cc); + const ClassificationList& classification_list = *kClassificationList(cc); + + // Update input detection only if classification did return results. + if (classification_list.classification_size() != 0) { + detection.clear_label_id(); + detection.clear_score(); + detection.clear_label(); + detection.clear_display_name(); + for (const auto& classification : classification_list.classification()) { + if (!classification.has_index()) { + return absl::InvalidArgumentError( + "Missing required 'index' field in Classification proto."); + } + detection.add_label_id(classification.index()); + if (!classification.has_score()) { + return absl::InvalidArgumentError( + "Missing required 'score' field in Classification proto."); + } + detection.add_score(classification.score()); + if (classification.has_label()) { + detection.add_label(classification.label()); + } + if (classification.has_display_name()) { + detection.add_display_name(classification.display_name()); + } + } + // Post-conversion sanity checks. + if (detection.label_size() != 0 && + detection.label_size() != detection.label_id_size()) { + return absl::InvalidArgumentError(absl::Substitute( + "Each input Classification is expected to either always or never " + "provide a 'label' field. Found $0 'label' fields for $1 " + "'Classification' objects.", + /*$0=*/detection.label_size(), /*$1=*/detection.label_id_size())); + } + if (detection.display_name_size() != 0 && + detection.display_name_size() != detection.label_id_size()) { + return absl::InvalidArgumentError(absl::Substitute( + "Each input Classification is expected to either always or never " + "provide a 'display_name' field. Found $0 'display_name' fields for " + "$1 'Classification' objects.", + /*$0=*/detection.display_name_size(), + /*$1=*/detection.label_id_size())); + } + } + kOutputDetection(cc).Send(detection); + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_classifications_merger_calculator_test.cc b/mediapipe/calculators/util/detection_classifications_merger_calculator_test.cc new file mode 100644 index 000000000..926f13e14 --- /dev/null +++ b/mediapipe/calculators/util/detection_classifications_merger_calculator_test.cc @@ -0,0 +1,320 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +constexpr char kGraphConfig[] = R"( + input_stream: "input_detection" + input_stream: "classification_list" + output_stream: "output_detection" + node { + calculator: "DetectionClassificationsMergerCalculator" + input_stream: "INPUT_DETECTION:input_detection" + input_stream: "CLASSIFICATION_LIST:classification_list" + output_stream: "OUTPUT_DETECTION:output_detection" + } + )"; + +constexpr char kInputDetection[] = R"( + label: "entity" + label_id: 1 + score: 0.9 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 } + } + display_name: "Entity" + )"; + +// Checks that the input Detection is returned unchanged if the input +// ClassificationList does not contain any result. +TEST(DetectionClassificationsMergerCalculator, SucceedsWithNoClassification) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(""); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Get and validate output. + EXPECT_THAT(output_packets, testing::SizeIs(1)); + const Detection& output_detection = output_packets[0].Get(); + EXPECT_THAT(output_detection, mediapipe::EqualsProto(input_detection)); +} + +// Checks that merging succeeds when the input ClassificationList includes +// labels and display names. +TEST(DetectionClassificationsMergerCalculator, + SucceedsWithLabelsAndDisplayNames) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" } + classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Get and validate output. + EXPECT_THAT(output_packets, testing::SizeIs(1)); + const Detection& output_detection = output_packets[0].Get(); + EXPECT_THAT(output_detection, + mediapipe::EqualsProto(ParseTextProtoOrDie(R"( + label: "dog" + label: "fox" + label_id: 11 + label_id: 12 + score: 0.5 + score: 0.4 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 } + } + display_name: "Dog" + display_name: "Fox" + )"))); +} + +// Checks that merging succeeds when the input ClassificationList doesn't +// include labels and display names. +TEST(DetectionClassificationsMergerCalculator, + SucceedsWithoutLabelsAndDisplayNames) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { index: 11 score: 0.5 } + classification { index: 12 score: 0.4 } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Get and validate output. + EXPECT_THAT(output_packets, testing::SizeIs(1)); + const Detection& output_detection = output_packets[0].Get(); + EXPECT_THAT(output_detection, + mediapipe::EqualsProto(ParseTextProtoOrDie(R"( + label_id: 11 + label_id: 12 + score: 0.5 + score: 0.4 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 } + } + )"))); +} + +// Checks that merging fails if the input ClassificationList misses mandatory +// "index" field. +TEST(DetectionClassificationsMergerCalculator, FailsWithMissingIndex) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { score: 0.5 label: "dog" } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument); +} + +// Checks that merging fails if the input ClassificationList misses mandatory +// "score" field. +TEST(DetectionClassificationsMergerCalculator, FailsWithMissingScore) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { index: 11 label: "dog" } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument); +} + +// Checks that merging fails if the input ClassificationList has an +// inconsistent number of labels. +TEST(DetectionClassificationsMergerCalculator, + FailsWithInconsistentNumberOfLabels) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" } + classification { index: 12 score: 0.4 display_name: "Fox" } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument); +} + +// Checks that merging fails if the input ClassificationList has an +// inconsistent number of display names. +TEST(DetectionClassificationsMergerCalculator, + FailsWithInconsistentNumberOfDisplayNames) { + auto graph_config = ParseTextProtoOrDie(kGraphConfig); + + // Prepare input packets. + const Detection& input_detection = + ParseTextProtoOrDie(kInputDetection); + Packet input_detection_packet = + MakePacket(input_detection).At(Timestamp(0)); + const ClassificationList& classification_list = + ParseTextProtoOrDie(R"( + classification { index: 11 score: 0.5 label: "dog" } + classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" } + )"); + Packet classification_list_packet = + MakePacket(classification_list).At(Timestamp(0)); + + // Catch output. + std::vector output_packets; + tool::AddVectorSink("output_detection", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_detection", input_detection_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list", + classification_list_packet)); + ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 4f1d4a608..38bdb9d04 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/algorithm/container.h" #include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/filtering/one_euro_filter.h" #include "mediapipe/util/filtering/relative_velocity_filter.h" namespace mediapipe { @@ -25,19 +28,54 @@ namespace mediapipe { namespace { constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; +constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using mediapipe::OneEuroFilter; using mediapipe::RelativeVelocityFilter; +void NormalizedLandmarksToLandmarks( + const NormalizedLandmarkList& norm_landmarks, const int image_width, + const int image_height, LandmarkList* landmarks) { + for (int i = 0; i < norm_landmarks.landmark_size(); ++i) { + const auto& norm_landmark = norm_landmarks.landmark(i); + + auto* landmark = landmarks->add_landmark(); + landmark->set_x(norm_landmark.x() * image_width); + landmark->set_y(norm_landmark.y() * image_height); + // Scale Z the same way as X (using image width). + landmark->set_z(norm_landmark.z() * image_width); + landmark->set_visibility(norm_landmark.visibility()); + landmark->set_presence(norm_landmark.presence()); + } +} + +void LandmarksToNormalizedLandmarks(const LandmarkList& landmarks, + const int image_width, + const int image_height, + NormalizedLandmarkList* norm_landmarks) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& landmark = landmarks.landmark(i); + + auto* norm_landmark = norm_landmarks->add_landmark(); + norm_landmark->set_x(landmark.x() / image_width); + norm_landmark->set_y(landmark.y() / image_height); + // Scale Z the same way as X (using image width). + norm_landmark->set_z(landmark.z() / image_width); + norm_landmark->set_visibility(landmark.visibility()); + norm_landmark->set_presence(landmark.presence()); + } +} + // Estimate object scale to use its inverse value as velocity scale for // RelativeVelocityFilter. If value will be too small (less than // `options_.min_allowed_object_scale`) smoothing will be disabled and // landmarks will be returned as is. // Object scale is calculated as average between bounding box width and height // with sides parallel to axis. -float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width, - int image_height) { +float GetObjectScale(const LandmarkList& landmarks) { const auto& lm_minmax_x = absl::c_minmax_element( landmarks.landmark(), [](const auto& a, const auto& b) { return a.x() < b.x(); }); @@ -50,8 +88,8 @@ float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width, const float y_min = lm_minmax_y.first->y(); const float y_max = lm_minmax_y.second->y(); - const float object_width = (x_max - x_min) * image_width; - const float object_height = (y_max - y_min) * image_height; + const float object_width = x_max - x_min; + const float object_height = y_max - y_min; return (object_width + object_height) / 2.0f; } @@ -63,19 +101,17 @@ class LandmarksFilter { virtual absl::Status Reset() { return absl::OkStatus(); } - virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, + virtual absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) = 0; + LandmarkList* out_landmarks) = 0; }; // Returns landmarks as is without smoothing. class NoFilter : public LandmarksFilter { public: - absl::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, + absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { + LandmarkList* out_landmarks) override { *out_landmarks = in_landmarks; return absl::OkStatus(); } @@ -85,10 +121,11 @@ class NoFilter : public LandmarksFilter { class VelocityFilter : public LandmarksFilter { public: VelocityFilter(int window_size, float velocity_scale, - float min_allowed_object_scale) + float min_allowed_object_scale, bool disable_value_scaling) : window_size_(window_size), velocity_scale_(velocity_scale), - min_allowed_object_scale_(min_allowed_object_scale) {} + min_allowed_object_scale_(min_allowed_object_scale), + disable_value_scaling_(disable_value_scaling) {} absl::Status Reset() override { x_filters_.clear(); @@ -97,45 +134,37 @@ class VelocityFilter : public LandmarksFilter { return absl::OkStatus(); } - absl::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, + absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { - // Get image size. - int image_width; - int image_height; - std::tie(image_width, image_height) = image_size; - + LandmarkList* out_landmarks) override { // Get value scale as inverse value of the object scale. // If value is too small smoothing will be disabled and landmarks will be // returned as is. - const float object_scale = - GetObjectScale(in_landmarks, image_width, image_height); - if (object_scale < min_allowed_object_scale_) { - *out_landmarks = in_landmarks; - return absl::OkStatus(); + float value_scale = 1.0f; + if (!disable_value_scaling_) { + const float object_scale = GetObjectScale(in_landmarks); + if (object_scale < min_allowed_object_scale_) { + *out_landmarks = in_landmarks; + return absl::OkStatus(); + } + value_scale = 1.0f / object_scale; } - const float value_scale = 1.0f / object_scale; // Initialize filters once. MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); // Filter landmarks. Every axis of every landmark is filtered separately. for (int i = 0; i < in_landmarks.landmark_size(); ++i) { - const NormalizedLandmark& in_landmark = in_landmarks.landmark(i); + const auto& in_landmark = in_landmarks.landmark(i); - NormalizedLandmark* out_landmark = out_landmarks->add_landmark(); + auto* out_landmark = out_landmarks->add_landmark(); *out_landmark = in_landmark; - out_landmark->set_x(x_filters_[i].Apply(timestamp, value_scale, - in_landmark.x() * image_width) / - image_width); - out_landmark->set_y(y_filters_[i].Apply(timestamp, value_scale, - in_landmark.y() * image_height) / - image_height); - // Scale Z the save was as X (using image width). - out_landmark->set_z(z_filters_[i].Apply(timestamp, value_scale, - in_landmark.z() * image_width) / - image_width); + out_landmark->set_x( + x_filters_[i].Apply(timestamp, value_scale, in_landmark.x())); + out_landmark->set_y( + y_filters_[i].Apply(timestamp, value_scale, in_landmark.y())); + out_landmark->set_z( + z_filters_[i].Apply(timestamp, value_scale, in_landmark.z())); } return absl::OkStatus(); @@ -165,12 +194,83 @@ class VelocityFilter : public LandmarksFilter { int window_size_; float velocity_scale_; float min_allowed_object_scale_; + bool disable_value_scaling_; std::vector x_filters_; std::vector y_filters_; std::vector z_filters_; }; +// Please check OneEuroFilter documentation for details. +class OneEuroFilterImpl : public LandmarksFilter { + public: + OneEuroFilterImpl(double frequency, double min_cutoff, double beta, + double derivate_cutoff) + : frequency_(frequency), + min_cutoff_(min_cutoff), + beta_(beta), + derivate_cutoff_(derivate_cutoff) {} + + absl::Status Reset() override { + x_filters_.clear(); + y_filters_.clear(); + z_filters_.clear(); + return absl::OkStatus(); + } + + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + LandmarkList* out_landmarks) override { + // Initialize filters once. + MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); + + // Filter landmarks. Every axis of every landmark is filtered separately. + for (int i = 0; i < in_landmarks.landmark_size(); ++i) { + const auto& in_landmark = in_landmarks.landmark(i); + + auto* out_landmark = out_landmarks->add_landmark(); + *out_landmark = in_landmark; + out_landmark->set_x(x_filters_[i].Apply(timestamp, in_landmark.x())); + out_landmark->set_y(y_filters_[i].Apply(timestamp, in_landmark.y())); + out_landmark->set_z(z_filters_[i].Apply(timestamp, in_landmark.z())); + } + + return absl::OkStatus(); + } + + private: + // Initializes filters for the first time or after Reset. If initialized then + // check the size. + absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { + if (!x_filters_.empty()) { + RET_CHECK_EQ(x_filters_.size(), n_landmarks); + RET_CHECK_EQ(y_filters_.size(), n_landmarks); + RET_CHECK_EQ(z_filters_.size(), n_landmarks); + return absl::OkStatus(); + } + + for (int i = 0; i < n_landmarks; ++i) { + x_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + y_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + z_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + } + + return absl::OkStatus(); + } + + double frequency_; + double min_cutoff_; + double beta_; + double derivate_cutoff_; + + std::vector x_filters_; + std::vector y_filters_; + std::vector z_filters_; +}; + } // namespace // A calculator to smooth landmarks over time. @@ -207,16 +307,21 @@ class LandmarksSmoothingCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: - LandmarksFilter* landmarks_filter_; + std::unique_ptr landmarks_filter_; }; REGISTER_CALCULATOR(LandmarksSmoothingCalculator); absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { - cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); - cc->Inputs().Tag(kImageSizeTag).Set>(); - cc->Outputs() - .Tag(kNormalizedFilteredLandmarksTag) - .Set(); + if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { + cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); + cc->Inputs().Tag(kImageSizeTag).Set>(); + cc->Outputs() + .Tag(kNormalizedFilteredLandmarksTag) + .Set(); + } else { + cc->Inputs().Tag(kLandmarksTag).Set(); + cc->Outputs().Tag(kFilteredLandmarksTag).Set(); + } return absl::OkStatus(); } @@ -227,12 +332,19 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { // Pick landmarks filter. const auto& options = cc->Options(); if (options.has_no_filter()) { - landmarks_filter_ = new NoFilter(); + landmarks_filter_ = absl::make_unique(); } else if (options.has_velocity_filter()) { - landmarks_filter_ = new VelocityFilter( + landmarks_filter_ = absl::make_unique( options.velocity_filter().window_size(), options.velocity_filter().velocity_scale(), - options.velocity_filter().min_allowed_object_scale()); + options.velocity_filter().min_allowed_object_scale(), + options.velocity_filter().disable_value_scaling()); + } else if (options.has_one_euro_filter()) { + landmarks_filter_ = absl::make_unique( + options.one_euro_filter().frequency(), + options.one_euro_filter().min_cutoff(), + options.one_euro_filter().beta(), + options.one_euro_filter().derivate_cutoff()); } else { RET_CHECK_FAIL() << "Landmarks filter is either not specified or not supported"; @@ -244,25 +356,53 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { // Check that landmarks are not empty and reset the filter if so. // Don't emit an empty packet for this timestamp. - if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { + if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) && + cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) || + (cc->Inputs().HasTag(kLandmarksTag) && + cc->Inputs().Tag(kLandmarksTag).IsEmpty())) { MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); return absl::OkStatus(); } - const auto& in_landmarks = - cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); - const auto& image_size = - cc->Inputs().Tag(kImageSizeTag).Get>(); const auto& timestamp = absl::Microseconds(cc->InputTimestamp().Microseconds()); - auto out_landmarks = absl::make_unique(); - MP_RETURN_IF_ERROR(landmarks_filter_->Apply(in_landmarks, image_size, - timestamp, out_landmarks.get())); + if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { + const auto& in_norm_landmarks = + cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); - cc->Outputs() - .Tag(kNormalizedFilteredLandmarksTag) - .Add(out_landmarks.release(), cc->InputTimestamp()); + int image_width; + int image_height; + std::tie(image_width, image_height) = + cc->Inputs().Tag(kImageSizeTag).Get>(); + + auto in_landmarks = absl::make_unique(); + NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, image_height, + in_landmarks.get()); + + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(landmarks_filter_->Apply(*in_landmarks, timestamp, + out_landmarks.get())); + + auto out_norm_landmarks = absl::make_unique(); + LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height, + out_norm_landmarks.get()); + + cc->Outputs() + .Tag(kNormalizedFilteredLandmarksTag) + .Add(out_norm_landmarks.release(), cc->InputTimestamp()); + } else { + const auto& in_landmarks = + cc->Inputs().Tag(kLandmarksTag).Get(); + + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR( + landmarks_filter_->Apply(in_landmarks, timestamp, out_landmarks.get())); + + cc->Outputs() + .Tag(kFilteredLandmarksTag) + .Add(out_landmarks.release(), cc->InputTimestamp()); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.proto b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto index aca539cab..2466fafe6 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.proto +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto @@ -39,10 +39,40 @@ message LandmarksSmoothingCalculatorOptions { // If calculated object scale is less than given value smoothing will be // disabled and landmarks will be returned as is. optional float min_allowed_object_scale = 3 [default = 1e-6]; + + // Disable value scaling based on object size and use `1.0` instead. + // Value scale is calculated as inverse value of object size. Object size is + // calculated as maximum side of rectangular bounding box of the object in + // XY plane. + optional bool disable_value_scaling = 4 [default = false]; + } + + // For the details of the filter implementation and the procedure of its + // configuration please check http://cristal.univ-lille.fr/~casiez/1euro/ + message OneEuroFilter { + // Frequency of incomming frames defined in seconds. Used only if can't be + // calculated from provided events (e.g. on the very first frame). + optional float frequency = 1 [default = 0.033]; + + // Minimum cutoff frequency. Start by tuning this parameter while keeping + // `beta = 0` to reduce jittering to the desired level. 1Hz (the default + // value) is a good starting point. + optional float min_cutoff = 2 [default = 1.0]; + + // Cutoff slope. After `min_cutoff` is configured, start increasing `beta` + // value to reduce the lag introduced by the `min_cutoff`. Find the desired + // balance between jittering and lag. + optional float beta = 3 [default = 0.0]; + + // Cutoff frequency for derivate. It is set to 1Hz in the original + // algorithm, but can be tuned to further smooth the speed (i.e. derivate) + // on the object. + optional float derivate_cutoff = 4 [default = 1.0]; } oneof filter_options { NoFilter no_filter = 1; VelocityFilter velocity_filter = 2; + OneEuroFilter one_euro_filter = 3; } } diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index 7818ad8cd..f2cec3ae3 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -34,6 +34,34 @@ constexpr char kRenderScaleTag[] = "RENDER_SCALE"; constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kLandmarkLabel[] = "KEYPOINT"; +inline Color DefaultMinDepthLineColor() { + Color color; + color.set_r(0); + color.set_g(0); + color.set_b(0); + return color; +} + +inline Color DefaultMaxDepthLineColor() { + Color color; + color.set_r(255); + color.set_g(255); + color.set_b(255); + return color; +} + +inline Color MixColors(const Color& color1, const Color& color2, + float color1_weight) { + Color color; + color.set_r(static_cast(color1.r() * color1_weight + + color2.r() * (1.f - color1_weight))); + color.set_g(static_cast(color1.g() * color1_weight + + color2.g() * (1.f - color1_weight))); + color.set_b(static_cast(color1.b() * color1_weight + + color2.b() * (1.f - color1_weight))); + return color; +} + inline void SetColor(RenderAnnotation* annotation, const Color& color) { annotation->mutable_color()->set_r(color.r()); annotation->mutable_color()->set_g(color.g()); @@ -57,6 +85,23 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min, } } +template +bool IsLandmarkVisibileAndPresent(const LandmarkType& landmark, + bool utilize_visibility, + float visibility_threshold, + bool utilize_presence, + float presence_threshold) { + if (utilize_visibility && landmark.has_visibility() && + landmark.visibility() < visibility_threshold) { + return false; + } + if (utilize_presence && landmark.has_presence() && + landmark.presence() < presence_threshold) { + return false; + } + return true; +} + void SetColorSizeValueFromZ(float z, float z_min, float z_max, RenderAnnotation* render_annotation, float min_depth_circle_thickness, @@ -75,8 +120,9 @@ void SetColorSizeValueFromZ(float z, float z_min, float z_max, template void AddConnectionToRenderData(const LandmarkType& start, - const LandmarkType& end, int gray_val1, - int gray_val2, float thickness, bool normalized, + const LandmarkType& end, + const Color& color_start, const Color& color_end, + float thickness, bool normalized, RenderData* render_data) { auto* connection_annotation = render_data->add_render_annotations(); RenderAnnotation::GradientLine* line = @@ -86,12 +132,13 @@ void AddConnectionToRenderData(const LandmarkType& start, line->set_x_end(end.x()); line->set_y_end(end.y()); line->set_normalized(normalized); - line->mutable_color1()->set_r(gray_val1); - line->mutable_color1()->set_g(gray_val1); - line->mutable_color1()->set_b(gray_val1); - line->mutable_color2()->set_r(gray_val2); - line->mutable_color2()->set_g(gray_val2); - line->mutable_color2()->set_b(gray_val2); + line->mutable_color1()->set_r(color_start.r()); + line->mutable_color1()->set_g(color_start.g()); + line->mutable_color1()->set_b(color_start.b()); + line->mutable_color2()->set_r(color_end.r()); + line->mutable_color2()->set_g(color_end.g()); + line->mutable_color2()->set_b(color_end.b()); + connection_annotation->set_thickness(thickness); } @@ -102,26 +149,26 @@ void AddConnectionsWithDepth(const LandmarkListType& landmarks, float visibility_threshold, bool utilize_presence, float presence_threshold, float thickness, bool normalized, float min_z, float max_z, + const Color& min_depth_line_color, + const Color& max_depth_line_color, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (utilize_visibility && - ((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || - (ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { + if (!IsLandmarkVisibileAndPresent( + ld0, utilize_visibility, visibility_threshold, utilize_presence, + presence_threshold) || + !IsLandmarkVisibileAndPresent( + ld1, utilize_visibility, visibility_threshold, utilize_presence, + presence_threshold)) { continue; } - if (utilize_presence && - ((ld0.has_presence() && ld0.presence() < presence_threshold) || - (ld1.has_presence() && ld1.presence() < presence_threshold))) { - continue; - } - const int gray_val1 = - 255 - static_cast(Remap(ld0.z(), min_z, max_z, 255)); - const int gray_val2 = - 255 - static_cast(Remap(ld1.z(), min_z, max_z, 255)); - AddConnectionToRenderData(ld0, ld1, gray_val1, gray_val2, - thickness, normalized, render_data); + const Color color0 = MixColors(min_depth_line_color, max_depth_line_color, + Remap(ld0.z(), min_z, max_z, 1.f)); + const Color color1 = MixColors(min_depth_line_color, max_depth_line_color, + Remap(ld1.z(), min_z, max_z, 1.f)); + AddConnectionToRenderData(ld0, ld1, color0, color1, thickness, + normalized, render_data); } } @@ -151,14 +198,12 @@ void AddConnections(const LandmarkListType& landmarks, for (int i = 0; i < landmark_connections.size(); i += 2) { const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (utilize_visibility && - ((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || - (ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { - continue; - } - if (utilize_presence && - ((ld0.has_presence() && ld0.presence() < presence_threshold) || - (ld1.has_presence() && ld1.presence() < presence_threshold))) { + if (!IsLandmarkVisibileAndPresent( + ld0, utilize_visibility, visibility_threshold, utilize_presence, + presence_threshold) || + !IsLandmarkVisibileAndPresent( + ld1, utilize_visibility, visibility_threshold, utilize_presence, + presence_threshold)) { continue; } AddConnectionToRenderData(ld0, ld1, connection_color, @@ -232,6 +277,13 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { float z_min = 0.f; float z_max = 0.f; + const Color min_depth_line_color = options_.has_min_depth_line_color() + ? options_.min_depth_line_color() + : DefaultMinDepthLineColor(); + const Color max_depth_line_color = options_.has_max_depth_line_color() + ? options_.max_depth_line_color() + : DefaultMaxDepthLineColor(); + // Apply scale to `thickness` of rendered landmarks and connections to make // them bigger when object (e.g. pose, hand or face) is closer/bigger and // snaller when object is further/smaller. @@ -254,7 +306,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { landmarks, landmark_connections_, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold(), thickness, /*normalized=*/false, z_min, - z_max, render_data.get()); + z_max, min_depth_line_color, max_depth_line_color, render_data.get()); } else { AddConnections( landmarks, landmark_connections_, options_.utilize_visibility(), @@ -265,13 +317,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const Landmark& landmark = landmarks.landmark(i); - if (options_.utilize_visibility() && landmark.has_visibility() && - landmark.visibility() < options_.visibility_threshold()) { - continue; - } - - if (options_.utilize_presence() && landmark.has_presence() && - landmark.presence() < options_.presence_threshold()) { + if (!IsLandmarkVisibileAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { continue; } @@ -303,7 +352,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { landmarks, landmark_connections_, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold(), thickness, /*normalized=*/true, z_min, - z_max, render_data.get()); + z_max, min_depth_line_color, max_depth_line_color, render_data.get()); } else { AddConnections( landmarks, landmark_connections_, options_.utilize_visibility(), @@ -314,12 +363,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const NormalizedLandmark& landmark = landmarks.landmark(i); - if (options_.utilize_visibility() && landmark.has_visibility() && - landmark.visibility() < options_.visibility_threshold()) { - continue; - } - if (options_.utilize_presence() && landmark.has_presence() && - landmark.presence() < options_.presence_threshold()) { + if (!IsLandmarkVisibileAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { continue; } diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto index 34f073f26..990919540 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto @@ -64,4 +64,10 @@ message LandmarksToRenderDataCalculatorOptions { // Max thickness of the drawing for landmark circle. optional double max_depth_circle_thickness = 11 [default = 18.0]; + + // Gradient color for the lines connecting landmarks at the minimum depth. + optional Color min_depth_line_color = 12; + + // Gradient color for the lines connecting landmarks at the maximum depth. + optional Color max_depth_line_color = 13; } diff --git a/mediapipe/calculators/util/visibility_copy_calculator.cc b/mediapipe/calculators/util/visibility_copy_calculator.cc new file mode 100644 index 000000000..f85ff9ea2 --- /dev/null +++ b/mediapipe/calculators/util/visibility_copy_calculator.cc @@ -0,0 +1,194 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/algorithm/container.h" +#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +namespace { + +constexpr char kLandmarksFromTag[] = "LANDMARKS_FROM"; +constexpr char kNormalizedLandmarksFromTag[] = "NORM_LANDMARKS_FROM"; +constexpr char kLandmarksToTag[] = "LANDMARKS_TO"; +constexpr char kNormalizedLandmarksToTag[] = "NORM_LANDMARKS_TO"; + +} // namespace + +// A calculator to copy visibility and presence between landmarks. +// +// Landmarks to copy from and to copy to can be of different type (normalized or +// non-normalized), but ladnmarks to copy to and output landmarks should be of +// the same type. Exactly one stream to copy landmarks from, to copy to and to +// output should be provided. +// +// Inputs: +// LANDMARKS_FROM (optional): A LandmarkList of landmarks to copy from. +// NORM_LANDMARKS_FROM (optional): A NormalizedLandmarkList of landmarks to +// copy from. +// LANDMARKS_TO (optional): A LandmarkList of landmarks to copy to. +// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy +// to. +// +// Outputs: +// LANDMARKS_TO (optional): A LandmarkList of landmarks from LANDMARKS_TO and +// visibility/presence from LANDMARKS_FROM or NORM_LANDMARKS_FROM. +// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy +// to. +// +// Example config: +// node { +// calculator: "VisibilityCopyCalculator" +// input_stream: "NORM_LANDMARKS_FROM:pose_landmarks" +// input_stream: "LANDMARKS_TO:pose_world_landmarks" +// output_stream: "LANDMARKS_TO:pose_world_landmarks_with_visibility" +// options: { +// [mediapipe.VisibilityCopyCalculatorOptions.ext] { +// copy_visibility: true +// copy_presence: true +// } +// } +// } +// +class VisibilityCopyCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + template + absl::Status CopyVisibility(CalculatorContext* cc, + const std::string& landmarks_from_tag, + const std::string& landmarks_to_tag); + + bool copy_visibility_; + bool copy_presence_; +}; +REGISTER_CALCULATOR(VisibilityCopyCalculator); + +absl::Status VisibilityCopyCalculator::GetContract(CalculatorContract* cc) { + // Landmarks to copy from. + RET_CHECK(cc->Inputs().HasTag(kLandmarksFromTag) ^ + cc->Inputs().HasTag(kNormalizedLandmarksFromTag)) + << "Exatly one landmarks stream to copy from should be provided"; + if (cc->Inputs().HasTag(kLandmarksFromTag)) { + cc->Inputs().Tag(kLandmarksFromTag).Set(); + } else { + cc->Inputs().Tag(kNormalizedLandmarksFromTag).Set(); + } + + // Landmarks to copy to and corresponding output landmarks. + RET_CHECK(cc->Inputs().HasTag(kLandmarksToTag) ^ + cc->Inputs().HasTag(kNormalizedLandmarksToTag)) + << "Exatly one landmarks stream to copy to should be provided"; + if (cc->Inputs().HasTag(kLandmarksToTag)) { + cc->Inputs().Tag(kLandmarksToTag).Set(); + + RET_CHECK(cc->Outputs().HasTag(kLandmarksToTag)) + << "Landmarks to copy to and output stream types should be the same"; + cc->Outputs().Tag(kLandmarksToTag).Set(); + } else { + cc->Inputs().Tag(kNormalizedLandmarksToTag).Set(); + + RET_CHECK(cc->Outputs().HasTag(kNormalizedLandmarksToTag)) + << "Landmarks to copy to and output stream types should be the same"; + cc->Outputs().Tag(kNormalizedLandmarksToTag).Set(); + } + + return absl::OkStatus(); +} + +absl::Status VisibilityCopyCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + const auto& options = cc->Options(); + copy_visibility_ = options.copy_visibility(); + copy_presence_ = options.copy_presence(); + + return absl::OkStatus(); +} + +absl::Status VisibilityCopyCalculator::Process(CalculatorContext* cc) { + // Switch between all four possible combinations of landmarks from and + // landmarks to types (normalized and non-normalized). + auto status = absl::OkStatus(); + if (cc->Inputs().HasTag(kLandmarksFromTag)) { + if (cc->Inputs().HasTag(kLandmarksToTag)) { + status = CopyVisibility(cc, kLandmarksFromTag, + kLandmarksToTag); + } else { + status = CopyVisibility( + cc, kLandmarksFromTag, kNormalizedLandmarksToTag); + } + } else { + if (cc->Inputs().HasTag(kLandmarksToTag)) { + status = CopyVisibility( + cc, kNormalizedLandmarksFromTag, kLandmarksToTag); + } else { + status = CopyVisibility( + cc, kNormalizedLandmarksFromTag, kNormalizedLandmarksToTag); + } + } + + return status; +} + +template +absl::Status VisibilityCopyCalculator::CopyVisibility( + CalculatorContext* cc, const std::string& landmarks_from_tag, + const std::string& landmarks_to_tag) { + // Check that both landmarks to copy from and to copy to are non empty. + if (cc->Inputs().Tag(landmarks_from_tag).IsEmpty() || + cc->Inputs().Tag(landmarks_to_tag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto landmarks_from = + cc->Inputs().Tag(landmarks_from_tag).Get(); + const auto landmarks_to = + cc->Inputs().Tag(landmarks_to_tag).Get(); + auto landmarks_out = absl::make_unique(); + + for (int i = 0; i < landmarks_from.landmark_size(); ++i) { + const auto& landmark_from = landmarks_from.landmark(i); + const auto& landmark_to = landmarks_to.landmark(i); + + // Create output landmark and copy all fields from the `to` landmark. + const auto& landmark_out = landmarks_out->add_landmark(); + *landmark_out = landmark_to; + + // Copy visibility and presence from the `from` landmark. + if (copy_visibility_) { + landmark_out->set_visibility(landmark_from.visibility()); + } + if (copy_presence_) { + landmark_out->set_presence(landmark_from.presence()); + } + } + + cc->Outputs() + .Tag(landmarks_to_tag) + .Add(landmarks_out.release(), cc->InputTimestamp()); + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/visibility_copy_calculator.proto b/mediapipe/calculators/util/visibility_copy_calculator.proto new file mode 100644 index 000000000..df25937b8 --- /dev/null +++ b/mediapipe/calculators/util/visibility_copy_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator_options.proto"; + +message VisibilityCopyCalculatorOptions { + extend CalculatorOptions { + optional VisibilityCopyCalculatorOptions ext = 363728421; + } + + optional bool copy_visibility = 1 [default = true]; + + optional bool copy_presence = 2 [default = true]; +} diff --git a/mediapipe/calculators/util/visibility_smoothing_calculator.cc b/mediapipe/calculators/util/visibility_smoothing_calculator.cc new file mode 100644 index 000000000..cd6ce5f0d --- /dev/null +++ b/mediapipe/calculators/util/visibility_smoothing_calculator.cc @@ -0,0 +1,243 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/algorithm/container.h" +#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/filtering/low_pass_filter.h" + +namespace mediapipe { + +namespace { + +constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; +constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; + +using mediapipe::LowPassFilter; + +// Abstract class for various visibility filters. +class VisibilityFilter { + public: + virtual ~VisibilityFilter() = default; + + virtual absl::Status Reset() { return absl::OkStatus(); } + + virtual absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + LandmarkList* out_landmarks) = 0; + + virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) = 0; +}; + +// Returns visibility as is without smoothing. +class NoFilter : public VisibilityFilter { + public: + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { + *out_landmarks = in_landmarks; + return absl::OkStatus(); + } + + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + LandmarkList* out_landmarks) override { + *out_landmarks = in_landmarks; + return absl::OkStatus(); + } +}; + +// Please check LowPassFilter documentation for details. +class LowPassVisibilityFilter : public VisibilityFilter { + public: + LowPassVisibilityFilter(float alpha) : alpha_(alpha) {} + + absl::Status Reset() override { + visibility_filters_.clear(); + return absl::OkStatus(); + } + + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + LandmarkList* out_landmarks) override { + return ApplyImpl(in_landmarks, timestamp, out_landmarks); + } + + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { + return ApplyImpl(in_landmarks, timestamp, + out_landmarks); + } + + private: + template + absl::Status ApplyImpl(const LandmarksType& in_landmarks, + const absl::Duration& timestamp, + LandmarksType* out_landmarks) { + // Initializes filters for the first time or after Reset. If initialized + // then check the size. + int n_landmarks = in_landmarks.landmark_size(); + if (!visibility_filters_.empty()) { + RET_CHECK_EQ(visibility_filters_.size(), n_landmarks); + } else { + visibility_filters_.resize(n_landmarks, LowPassFilter(alpha_)); + } + + // Filter visibilities. + for (int i = 0; i < in_landmarks.landmark_size(); ++i) { + const auto& in_landmark = in_landmarks.landmark(i); + + auto* out_landmark = out_landmarks->add_landmark(); + *out_landmark = in_landmark; + out_landmark->set_visibility( + visibility_filters_[i].Apply(in_landmark.visibility())); + } + + return absl::OkStatus(); + } + + float alpha_; + std::vector visibility_filters_; +}; + +} // namespace + +// A calculator to smooth landmark visibilities over time. +// +// Exactly one landmarks input stream is expected. Output stream type should be +// the same as the input one. +// +// Inputs: +// LANDMARKS (optional): A LandmarkList of landmarks you want to smooth. +// NORM_LANDMARKS (optional): A NormalizedLandmarkList of landmarks you want +// to smooth. +// +// Outputs: +// FILTERED_LANDMARKS (optional): A LandmarkList of smoothed landmarks. +// NORM_FILTERED_LANDMARKS (optional): A NormalizedLandmarkList of smoothed +// landmarks. +// +// Example config: +// node { +// calculator: "VisibilitySmoothingCalculator" +// input_stream: "NORM_LANDMARKS:pose_landmarks" +// output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered" +// options: { +// [mediapipe.VisibilitySmoothingCalculatorOptions.ext] { +// low_pass_filter: { +// alpha: 0.1 +// } +// } +// } +// } +// +class VisibilitySmoothingCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::unique_ptr visibility_filter_; +}; +REGISTER_CALCULATOR(VisibilitySmoothingCalculator); + +absl::Status VisibilitySmoothingCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag) ^ + cc->Inputs().HasTag(kLandmarksTag)) + << "Exactly one landmarks input stream is expected"; + if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { + cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); + RET_CHECK(cc->Outputs().HasTag(kNormalizedFilteredLandmarksTag)) + << "Landmarks output stream should of the same type as input one"; + cc->Outputs() + .Tag(kNormalizedFilteredLandmarksTag) + .Set(); + } else { + cc->Inputs().Tag(kLandmarksTag).Set(); + RET_CHECK(cc->Outputs().HasTag(kFilteredLandmarksTag)) + << "Landmarks output stream should of the same type as input one"; + cc->Outputs().Tag(kFilteredLandmarksTag).Set(); + } + + return absl::OkStatus(); +} + +absl::Status VisibilitySmoothingCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + // Pick visibility filter. + const auto& options = cc->Options(); + if (options.has_no_filter()) { + visibility_filter_ = absl::make_unique(); + } else if (options.has_low_pass_filter()) { + visibility_filter_ = absl::make_unique( + options.low_pass_filter().alpha()); + } else { + RET_CHECK_FAIL() + << "Visibility filter is either not specified or not supported"; + } + + return absl::OkStatus(); +} + +absl::Status VisibilitySmoothingCalculator::Process(CalculatorContext* cc) { + // Check that landmarks are not empty and reset the filter if so. + // Don't emit an empty packet for this timestamp. + if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) && + cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) || + (cc->Inputs().HasTag(kLandmarksTag) && + cc->Inputs().Tag(kLandmarksTag).IsEmpty())) { + MP_RETURN_IF_ERROR(visibility_filter_->Reset()); + return absl::OkStatus(); + } + + const auto& timestamp = + absl::Microseconds(cc->InputTimestamp().Microseconds()); + + if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { + const auto& in_landmarks = + cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp, + out_landmarks.get())); + cc->Outputs() + .Tag(kNormalizedFilteredLandmarksTag) + .Add(out_landmarks.release(), cc->InputTimestamp()); + } else { + const auto& in_landmarks = + cc->Inputs().Tag(kLandmarksTag).Get(); + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp, + out_landmarks.get())); + cc->Outputs() + .Tag(kFilteredLandmarksTag) + .Add(out_landmarks.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/visibility_smoothing_calculator.proto b/mediapipe/calculators/util/visibility_smoothing_calculator.proto new file mode 100644 index 000000000..3b991923c --- /dev/null +++ b/mediapipe/calculators/util/visibility_smoothing_calculator.proto @@ -0,0 +1,40 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator_options.proto"; + +message VisibilitySmoothingCalculatorOptions { + extend CalculatorOptions { + optional VisibilitySmoothingCalculatorOptions ext = 360207350; + } + + // Default behaviour and fast way to disable smoothing. + message NoFilter {} + + message LowPassFilter { + // Coefficient applied to a new value, whilte `1 - alpha` is applied to a + // stored value. Should be in [0, 1] range. The smaller the value - the + // smoother result and the bigger lag. + optional float alpha = 1 [default = 0.1]; + } + + oneof filter_options { + NoFilter no_filter = 1; + LowPassFilter low_pass_filter = 2; + } +} diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc new file mode 100644 index 000000000..28cf9498d --- /dev/null +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -0,0 +1,108 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kRectTag[] = "NORM_RECT"; + +} // namespace + +// Projects world landmarks from the rectangle to original coordinates. +// +// World landmarks are predicted in meters rather than in pixels of the image +// and have origin in the middle of the hips rather than in the corner of the +// pose image (cropped with given rectangle). Thus only rotation (but not scale +// and translation) is applied to the landmarks to transform them back to +// original coordinates. +// +// Input: +// LANDMARKS: A LandmarkList representing world landmarks in the rectangle. +// NORM_RECT: An NormalizedRect representing a normalized rectangle in image +// coordinates. +// +// Output: +// LANDMARKS: A LandmarkList representing world landmarks projected (rotated +// but not scaled or translated) from the rectangle to original +// coordinates. +// +// Usage example: +// node { +// calculator: "WorldLandmarkProjectionCalculator" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "NORM_RECT:rect" +// output_stream: "LANDMARKS:projected_landmarks" +// } +// +class WorldLandmarkProjectionCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag(kLandmarksTag).Set(); + cc->Inputs().Tag(kRectTag).Set(); + cc->Outputs().Tag(kLandmarksTag).Set(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // Check that landmarks and rect are not empty. + if (cc->Inputs().Tag(kLandmarksTag).IsEmpty() || + cc->Inputs().Tag(kRectTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& in_landmarks = + cc->Inputs().Tag(kLandmarksTag).Get(); + const auto& in_rect = cc->Inputs().Tag(kRectTag).Get(); + + auto out_landmarks = absl::make_unique(); + for (int i = 0; i < in_landmarks.landmark_size(); ++i) { + const auto& in_landmark = in_landmarks.landmark(i); + + Landmark* out_landmark = out_landmarks->add_landmark(); + *out_landmark = in_landmark; + + const float angle = in_rect.rotation(); + out_landmark->set_x(std::cos(angle) * in_landmark.x() - + std::sin(angle) * in_landmark.y()); + out_landmark->set_y(std::sin(angle) * in_landmark.x() + + std::cos(angle) * in_landmark.y()); + } + + cc->Outputs() + .Tag(kLandmarksTag) + .Add(out_landmarks.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } +}; +REGISTER_CALCULATOR(WorldLandmarkProjectionCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index af526044a..806b9f1fa 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -426,6 +426,7 @@ cc_test( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/flags:flag", ], ) @@ -450,6 +451,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/flags:flag", ], ) @@ -534,6 +536,7 @@ cc_test( "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto", + "@com_google_absl//absl/flags:flag", ], ) diff --git a/mediapipe/examples/coral/BUILD b/mediapipe/examples/coral/BUILD index ec747573b..50f0b38c7 100644 --- a/mediapipe/examples/coral/BUILD +++ b/mediapipe/examples/coral/BUILD @@ -27,13 +27,14 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", ], ) diff --git a/mediapipe/examples/coral/Dockerfile b/mediapipe/examples/coral/Dockerfile index bc655c580..ea99e5b08 100644 --- a/mediapipe/examples/coral/Dockerfile +++ b/mediapipe/examples/coral/Dockerfile @@ -62,7 +62,7 @@ COPY . /mediapipe/ # Install bazel # Please match the current MediaPipe Bazel requirements according to docs. -ARG BAZEL_VERSION=3.4.1 +ARG BAZEL_VERSION=3.7.2 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ diff --git a/mediapipe/examples/coral/demo_run_graph_main.cc b/mediapipe/examples/coral/demo_run_graph_main.cc index 698955472..6f1c56268 100644 --- a/mediapipe/examples/coral/demo_run_graph_main.cc +++ b/mediapipe/examples/coral/demo_run_graph_main.cc @@ -15,10 +15,11 @@ // An example of sending OpenCV webcam frames into a MediaPipe graph. #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -30,15 +31,14 @@ constexpr char kInputStream[] = "input_video"; constexpr char kOutputStream[] = "output_video"; constexpr char kWindowName[] = "MediaPipe"; -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_video_path, "", - "Full path of video to load. " - "If not provided, attempt to use a webcam."); -DEFINE_string(output_video_path, "", - "Full path of where to save result (.mp4 only). " - "If not provided, show result in a window."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, input_video_path, "", + "Full path of video to load. " + "If not provided, attempt to use a webcam."); +ABSL_FLAG(std::string, output_video_path, "", + "Full path of where to save result (.mp4 only). " + "If not provided, show result in a window."); absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; @@ -143,7 +143,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/BUILD b/mediapipe/examples/desktop/BUILD index 7772e21da..80cb7ad81 100644 --- a/mediapipe/examples/desktop/BUILD +++ b/mediapipe/examples/desktop/BUILD @@ -23,13 +23,14 @@ cc_library( srcs = ["simple_run_graph_main.cc"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", ], ) @@ -41,13 +42,14 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", ], ) @@ -62,7 +64,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_imgproc", @@ -72,5 +73,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", ], ) diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 99b9d6fff..7ab845cbb 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -54,18 +54,27 @@ mediapipe_cc_proto_library( deps = [":border_detection_calculator_proto"], ) +cc_library( + name = "content_zooming_calculator_state", + hdrs = ["content_zooming_calculator_state.h"], + deps = [ + "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:rect_cc_proto", + ], +) + cc_library( name = "content_zooming_calculator", srcs = ["content_zooming_calculator.cc"], deps = [ ":content_zooming_calculator_cc_proto", + ":content_zooming_calculator_state", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", - "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:location_data_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], @@ -88,7 +97,9 @@ mediapipe_cc_proto_library( "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto", "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//mediapipe/examples:__subpackages__"], + visibility = [ + "//mediapipe/examples:__subpackages__", + ], deps = [ ":content_zooming_calculator_proto", ], @@ -127,6 +138,7 @@ cc_test( deps = [ ":content_zooming_calculator", ":content_zooming_calculator_cc_proto", + ":content_zooming_calculator_state", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/framework:calculator_framework", @@ -368,7 +380,6 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgcodecs", @@ -376,6 +387,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index c2ee6b0ff..28c34b2b5 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -17,12 +17,11 @@ #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" -#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" +#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/location_data.pb.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" @@ -33,12 +32,18 @@ constexpr char kSalientRegions[] = "SALIENT_REGIONS"; constexpr char kDetections[] = "DETECTIONS"; constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kCropRect[] = "CROP_RECT"; +constexpr char kFirstCropRect[] = "FIRST_CROP_RECT"; // Field-of-view (degrees) of the camera's x-axis (width). // TODO: Parameterize FOV based on camera specs. constexpr float kFieldOfView = 60; +// A pointer to a ContentZoomingCalculatorStateCacheType in a side packet. +// Used to save state on Close and load state on Open in a new graph. +// Can be used to preserve state between graphs. +constexpr char kStateCache[] = "STATE_CACHE"; namespace mediapipe { namespace autoflip { +using StateCacheType = ContentZoomingCalculatorStateCacheType; // Content zooming calculator zooms in on content when a detection has // "only_required" set true or any raw detection input. It does this by @@ -49,8 +54,7 @@ namespace autoflip { // include mobile makeover and autofliplive face reframing. class ContentZoomingCalculator : public CalculatorBase { public: - ContentZoomingCalculator() - : initialized_(false), last_only_required_detection_(0) {} + ContentZoomingCalculator() : initialized_(false) {} ~ContentZoomingCalculator() override {} ContentZoomingCalculator(const ContentZoomingCalculator&) = delete; ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete; @@ -58,8 +62,25 @@ class ContentZoomingCalculator : public CalculatorBase { static absl::Status GetContract(mediapipe::CalculatorContract* cc); absl::Status Open(mediapipe::CalculatorContext* cc) override; absl::Status Process(mediapipe::CalculatorContext* cc) override; + absl::Status Close(mediapipe::CalculatorContext* cc) override; private: + // Tries to load state from a state-cache, if provided. Fallsback to + // initializing state if no cache or no value in the cache are available. + absl::Status MaybeLoadState(mediapipe::CalculatorContext* cc, int frame_width, + int frame_height); + // Saves state to a state-cache, if provided. + absl::Status SaveState(mediapipe::CalculatorContext* cc) const; + // Initializes the calculator for the given frame size, creating path solvers + // and resetting history like last measured values. + absl::Status InitializeState(int frame_width, int frame_height); + // Adjusts state to work with an updated frame size. + absl::Status UpdateForResolutionChange(int frame_width, int frame_height); + // Returns true if we are zooming to the initial rect. + bool IsZoomingToInitialRect(const Timestamp& timestamp) const; + // Builds the output rectangle when zooming to the initial rect. + absl::StatusOr GetInitialZoomingRect( + int frame_width, int frame_height, const Timestamp& timestamp) const; // Converts bounds to tilt offset, pan offset and height. absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, float ymax, int* tilt_offset, @@ -76,6 +97,10 @@ class ContentZoomingCalculator : public CalculatorBase { std::unique_ptr path_solver_tilt_; // Are parameters initialized. bool initialized_; + // Stores the time of the first crop rectangle. + Timestamp first_rect_timestamp_; + // Stores the first crop rectangle. + mediapipe::NormalizedRect first_rect_; // Stores the time of the last "only_required" input. int64 last_only_required_detection_; // Rect values of last message with detection(s). @@ -116,6 +141,12 @@ absl::Status ContentZoomingCalculator::GetContract( if (cc->Outputs().HasTag(kCropRect)) { cc->Outputs().Tag(kCropRect).Set(); } + if (cc->Outputs().HasTag(kFirstCropRect)) { + cc->Outputs().Tag(kFirstCropRect).Set(); + } + if (cc->InputSidePackets().HasTag(kStateCache)) { + cc->InputSidePackets().Tag(kStateCache).Set(); + } return absl::OkStatus(); } @@ -135,6 +166,13 @@ absl::Status ContentZoomingCalculator::Open(mediapipe::CalculatorContext* cc) { return absl::OkStatus(); } +absl::Status ContentZoomingCalculator::Close(mediapipe::CalculatorContext* cc) { + if (initialized_) { + MP_RETURN_IF_ERROR(SaveState(cc)); + } + return absl::OkStatus(); +} + absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom( float xmin, float xmax, float ymin, float ymax, int* tilt_offset, int* pan_offset, int* height) { @@ -275,39 +313,89 @@ absl::Status ContentZoomingCalculator::UpdateAspectAndMax() { return absl::OkStatus(); } -absl::Status ContentZoomingCalculator::Process( - mediapipe::CalculatorContext* cc) { - // For async subgraph support, return on empty video size packets. - if (cc->Inputs().HasTag(kVideoSize) && - cc->Inputs().Tag(kVideoSize).IsEmpty()) { +absl::Status ContentZoomingCalculator::MaybeLoadState( + mediapipe::CalculatorContext* cc, int frame_width, int frame_height) { + const auto* state_cache = + cc->InputSidePackets().HasTag(kStateCache) + ? cc->InputSidePackets().Tag(kStateCache).Get() + : nullptr; + if (!state_cache || !state_cache->has_value()) { + return InitializeState(frame_width, frame_height); + } + + const ContentZoomingCalculatorState& state = state_cache->value(); + frame_width_ = state.frame_width; + frame_height_ = state.frame_height; + path_solver_pan_ = + std::make_unique(state.path_solver_pan); + path_solver_tilt_ = + std::make_unique(state.path_solver_tilt); + path_solver_zoom_ = + std::make_unique(state.path_solver_zoom); + first_rect_timestamp_ = state.first_rect_timestamp; + first_rect_ = state.first_rect; + last_only_required_detection_ = state.last_only_required_detection; + last_measured_height_ = state.last_measured_height; + last_measured_x_offset_ = state.last_measured_x_offset; + last_measured_y_offset_ = state.last_measured_y_offset; + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + + return UpdateForResolutionChange(frame_width, frame_height); +} + +absl::Status ContentZoomingCalculator::SaveState( + mediapipe::CalculatorContext* cc) const { + auto* state_cache = + cc->InputSidePackets().HasTag(kStateCache) + ? cc->InputSidePackets().Tag(kStateCache).Get() + : nullptr; + if (!state_cache) { return absl::OkStatus(); } - int frame_width, frame_height; - MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height)); - // Init on first call. - if (!initialized_) { - frame_width_ = frame_width; - frame_height_ = frame_height; - path_solver_pan_ = std::make_unique( - options_.kinematic_options_pan(), 0, frame_width_, - static_cast(frame_width_) / kFieldOfView); - path_solver_tilt_ = std::make_unique( - options_.kinematic_options_tilt(), 0, frame_height_, - static_cast(frame_height_) / kFieldOfView); - MP_RETURN_IF_ERROR(UpdateAspectAndMax()); - int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / - static_cast(kFieldOfView)); - path_solver_zoom_ = std::make_unique( - options_.kinematic_options_zoom(), min_zoom_size, - max_frame_value_ * frame_height_, - static_cast(frame_height_) / kFieldOfView); - last_measured_height_ = max_frame_value_ * frame_height_; - last_measured_x_offset_ = target_aspect_ * frame_width_; - last_measured_y_offset_ = frame_width_ / 2; - initialized_ = true; - } + *state_cache = ContentZoomingCalculatorState{ + .frame_height = frame_height_, + .frame_width = frame_width_, + .path_solver_zoom = *path_solver_zoom_, + .path_solver_pan = *path_solver_pan_, + .path_solver_tilt = *path_solver_tilt_, + .first_rect_timestamp = first_rect_timestamp_, + .first_rect = first_rect_, + .last_only_required_detection = last_only_required_detection_, + .last_measured_height = last_measured_height_, + .last_measured_x_offset = last_measured_x_offset_, + .last_measured_y_offset = last_measured_y_offset_, + }; + return absl::OkStatus(); +} +absl::Status ContentZoomingCalculator::InitializeState(int frame_width, + int frame_height) { + frame_width_ = frame_width; + frame_height_ = frame_height; + path_solver_pan_ = std::make_unique( + options_.kinematic_options_pan(), 0, frame_width_, + static_cast(frame_width_) / kFieldOfView); + path_solver_tilt_ = std::make_unique( + options_.kinematic_options_tilt(), 0, frame_height_, + static_cast(frame_height_) / kFieldOfView); + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / + static_cast(kFieldOfView)); + path_solver_zoom_ = std::make_unique( + options_.kinematic_options_zoom(), min_zoom_size, + max_frame_value_ * frame_height_, + static_cast(frame_height_) / kFieldOfView); + first_rect_timestamp_ = Timestamp::Unset(); + last_only_required_detection_ = 0; + last_measured_height_ = max_frame_value_ * frame_height_; + last_measured_x_offset_ = target_aspect_ * frame_width_; + last_measured_y_offset_ = frame_width_ / 2; + return absl::OkStatus(); +} + +absl::Status ContentZoomingCalculator::UpdateForResolutionChange( + int frame_width, int frame_height) { // Update state for change in input resolution. if (frame_width_ != frame_width || frame_height_ != frame_height) { double width_scale = frame_width / static_cast(frame_width_); @@ -328,6 +416,74 @@ absl::Status ContentZoomingCalculator::Process( MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree( static_cast(frame_height_) / kFieldOfView)); } + return absl::OkStatus(); +} + +bool ContentZoomingCalculator::IsZoomingToInitialRect( + const Timestamp& timestamp) const { + if (options_.us_to_first_rect() == 0 || + first_rect_timestamp_ == Timestamp::Unset()) { + return false; + } + + const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); + return (0 <= delta_us && delta_us <= options_.us_to_first_rect()); +} + +namespace { +double easeInQuad(double t) { return t * t; } +double easeOutQuad(double t) { return -1 * t * (t - 2); } +double easeInOutQuad(double t) { + if (t < 0.5) { + return easeInQuad(t * 2) * 0.5; + } else { + return easeOutQuad(t * 2 - 1) * 0.5 + 0.5; + } +} +double lerp(double a, double b, double i) { return a * (1 - i) + b * i; } +} // namespace + +absl::StatusOr ContentZoomingCalculator::GetInitialZoomingRect( + int frame_width, int frame_height, const Timestamp& timestamp) const { + RET_CHECK(IsZoomingToInitialRect(timestamp)) + << "Must only be called if zooming to initial rect."; + + const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); + const int64 delay = options_.us_to_first_rect_delay(); + const double interpolation = easeInOutQuad(std::max( + 0.0, (delta_us - delay) / + static_cast(options_.us_to_first_rect() - delay))); + + const double x_center = lerp(0.5, first_rect_.x_center(), interpolation); + const double y_center = lerp(0.5, first_rect_.y_center(), interpolation); + const double width = lerp(1.0, first_rect_.width(), interpolation); + const double height = lerp(1.0, first_rect_.height(), interpolation); + + mediapipe::Rect gpu_rect; + gpu_rect.set_x_center(x_center * frame_width); + gpu_rect.set_width(width * frame_width); + gpu_rect.set_y_center(y_center * frame_height); + gpu_rect.set_height(height * frame_height); + return gpu_rect; +} + +absl::Status ContentZoomingCalculator::Process( + mediapipe::CalculatorContext* cc) { + // For async subgraph support, return on empty video size packets. + if (cc->Inputs().HasTag(kVideoSize) && + cc->Inputs().Tag(kVideoSize).IsEmpty()) { + return absl::OkStatus(); + } + int frame_width, frame_height; + MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height)); + + // Init on first call or re-init always if configured to be stateless. + if (!initialized_) { + MP_RETURN_IF_ERROR(MaybeLoadState(cc, frame_width, frame_height)); + initialized_ = !options_.is_stateless(); + } else { + MP_RETURN_IF_ERROR(UpdateForResolutionChange(frame_width, frame_height)); + } bool only_required_found = false; @@ -348,31 +504,52 @@ absl::Status ContentZoomingCalculator::Process( if (cc->Inputs().HasTag(kDetections)) { if (cc->Inputs().Tag(kDetections).IsEmpty()) { - auto default_rect = absl::make_unique(); - default_rect->set_x_center(frame_width_ / 2); - default_rect->set_y_center(frame_height_ / 2); - default_rect->set_width(frame_width_); - default_rect->set_height(frame_height_); - cc->Outputs().Tag(kCropRect).Add(default_rect.release(), - Timestamp(cc->InputTimestamp())); - return absl::OkStatus(); - } - auto raw_detections = - cc->Inputs().Tag(kDetections).Get>(); - for (const auto& detection : raw_detections) { - only_required_found = true; - MP_RETURN_IF_ERROR(UpdateRanges( - detection, options_.detection_shift_vertical(), - options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax)); + if (last_only_required_detection_ == 0) { + // If no detections are available and we never had any, + // simply return the full-image rectangle as crop-rect. + if (cc->Outputs().HasTag(kCropRect)) { + auto default_rect = absl::make_unique(); + default_rect->set_x_center(frame_width_ / 2); + default_rect->set_y_center(frame_height_ / 2); + default_rect->set_width(frame_width_); + default_rect->set_height(frame_height_); + cc->Outputs().Tag(kCropRect).Add(default_rect.release(), + Timestamp(cc->InputTimestamp())); + } + // Also provide a first crop rect: in this case a zero-sized one. + if (cc->Outputs().HasTag(kFirstCropRect)) { + cc->Outputs() + .Tag(kFirstCropRect) + .Add(new mediapipe::NormalizedRect(), + Timestamp(cc->InputTimestamp())); + } + return absl::OkStatus(); + } + } else { + auto raw_detections = cc->Inputs() + .Tag(kDetections) + .Get>(); + for (const auto& detection : raw_detections) { + only_required_found = true; + MP_RETURN_IF_ERROR(UpdateRanges( + detection, options_.detection_shift_vertical(), + options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax)); + } } } - // Convert bounds to tilt/zoom and in pixel coordinates. - int offset_y, height, offset_x; - MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, - &offset_x, &height)); + bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); - if (only_required_found) { + int offset_y, height, offset_x; + if (zooming_to_initial_rect) { + // If we are zooming to the first rect, ignore any new incoming detections. + height = last_measured_height_; + offset_x = last_measured_x_offset_; + offset_y = last_measured_y_offset_; + } else if (only_required_found) { + // Convert bounds to tilt/zoom and in pixel coordinates. + MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, + &offset_x, &height)); // A only required detection was found. last_only_required_detection_ = cc->InputTimestamp().Microseconds(); last_measured_height_ = height; @@ -383,7 +560,9 @@ absl::Status ContentZoomingCalculator::Process( options_.us_before_zoomout()) { // No only_require detections found within salient regions packets // arriving since us_before_zoomout duration. - height = max_frame_value_ * frame_height_; + height = max_frame_value_ * frame_height_ + + (options_.kinematic_options_zoom().min_motion_to_reframe() * + (static_cast(frame_height_) / kFieldOfView)); offset_x = (target_aspect_ * height) / 2; offset_y = frame_height_ / 2; } else { @@ -463,17 +642,44 @@ absl::Status ContentZoomingCalculator::Process( .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); } + if (first_rect_timestamp_ == Timestamp::Unset() && + options_.us_to_first_rect() != 0) { + first_rect_timestamp_ = cc->InputTimestamp(); + first_rect_.set_x_center(path_offset_x / static_cast(frame_width_)); + first_rect_.set_width(path_height * target_aspect_ / + static_cast(frame_width_)); + first_rect_.set_y_center(path_offset_y / static_cast(frame_height_)); + first_rect_.set_height(path_height / static_cast(frame_height_)); + // After setting the first rectangle, check whether we should zoom to it. + zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); + } + // Transmit downstream to glcroppingcalculator. if (cc->Outputs().HasTag(kCropRect)) { - auto gpu_rect = absl::make_unique(); - gpu_rect->set_x_center(path_offset_x); - gpu_rect->set_width(path_height * target_aspect_); - gpu_rect->set_y_center(path_offset_y); - gpu_rect->set_height(path_height); + std::unique_ptr gpu_rect; + if (zooming_to_initial_rect) { + auto rect = GetInitialZoomingRect(frame_width, frame_height, + cc->InputTimestamp()); + MP_RETURN_IF_ERROR(rect.status()); + gpu_rect = absl::make_unique(*rect); + } else { + gpu_rect = absl::make_unique(); + gpu_rect->set_x_center(path_offset_x); + gpu_rect->set_width(path_height * target_aspect_); + gpu_rect->set_y_center(path_offset_y); + gpu_rect->set_height(path_height); + } cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(), Timestamp(cc->InputTimestamp())); } + if (cc->Outputs().HasTag(kFirstCropRect)) { + cc->Outputs() + .Tag(kFirstCropRect) + .Add(new mediapipe::NormalizedRect(first_rect_), + Timestamp(cc->InputTimestamp())); + } + return absl::OkStatus(); } diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index c0d4dd78b..4564b88be 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 14 +// NextTag: 17 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -55,6 +55,16 @@ message ContentZoomingCalculatorOptions { // Defines the smallest value in degrees the camera is permitted to zoom. optional float max_zoom_value_deg = 13 [default = 35]; + // Whether to keep state between frames or to compute the final crop rect. + optional bool is_stateless = 14 [default = false]; + + // Duration (in MicroSeconds) for moving to the first crop rect. + optional int64 us_to_first_rect = 15 [default = 0]; + // Duration (in MicroSeconds) to delay moving to the first crop rect. + // Used only if us_to_first_rect is set and is interpreted as part of the + // us_to_first_rect time budget. + optional int64 us_to_first_rect_delay = 16 [default = 0]; + // Deprecated parameters optional KinematicOptions kinematic_options = 2 [deprecated = true]; optional int64 min_motion_to_reframe = 4 [deprecated = true]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h new file mode 100644 index 000000000..c01a6fbb5 --- /dev/null +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h @@ -0,0 +1,38 @@ +#ifndef MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_ +#define MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_ + +#include + +#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace autoflip { + +struct ContentZoomingCalculatorState { + int frame_height = -1; + int frame_width = -1; + // Path solver used to smooth top/bottom border crop values. + KinematicPathSolver path_solver_zoom; + KinematicPathSolver path_solver_pan; + KinematicPathSolver path_solver_tilt; + // Stores the time of the first crop rectangle. + Timestamp first_rect_timestamp; + // Stores the first crop rectangle. + mediapipe::NormalizedRect first_rect; + // Stores the time of the last "only_required" input. + int64 last_only_required_detection = 0; + // Rect values of last message with detection(s). + int last_measured_height = 0; + int last_measured_x_offset = 0; + int last_measured_y_offset = 0; +}; + +using ContentZoomingCalculatorStateCacheType = + std::optional; + +} // namespace autoflip +} // namespace mediapipe + +#endif // MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_ diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 0db252fec..6859da11f 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" +#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h" #include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -109,6 +110,7 @@ const char kConfigD[] = R"( input_stream: "VIDEO_SIZE:size" input_stream: "DETECTIONS:detections" output_stream: "CROP_RECT:rect" + output_stream: "FIRST_CROP_RECT:first_rect" options: { [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { max_zoom_value_deg: 0 @@ -147,19 +149,24 @@ void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, const int width, const int height, CalculatorRunner* runner) { auto detections = std::make_unique>(); - mediapipe::Detection detection; - detection.mutable_location_data()->set_format( - mediapipe::LocationData::RELATIVE_BOUNDING_BOX); - detection.mutable_location_data() - ->mutable_relative_bounding_box() - ->set_height(position.height); - detection.mutable_location_data()->mutable_relative_bounding_box()->set_width( - position.width); - detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin( - position.x); - detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin( - position.y); - detections->push_back(detection); + if (position.width > 0 && position.height > 0) { + mediapipe::Detection detection; + detection.mutable_location_data()->set_format( + mediapipe::LocationData::RELATIVE_BOUNDING_BOX); + detection.mutable_location_data() + ->mutable_relative_bounding_box() + ->set_height(position.height); + detection.mutable_location_data() + ->mutable_relative_bounding_box() + ->set_width(position.width); + detection.mutable_location_data() + ->mutable_relative_bounding_box() + ->set_xmin(position.x); + detection.mutable_location_data() + ->mutable_relative_bounding_box() + ->set_ymin(position.y); + detections->push_back(detection); + } runner->MutableInputs() ->Tag("DETECTIONS") .packets.push_back(Adopt(detections.release()).At(Timestamp(time))); @@ -185,7 +192,6 @@ void CheckCropRect(const int x_center, const int y_center, const int width, EXPECT_EQ(rect.width(), width); EXPECT_EQ(rect.height(), height); } - TEST(ContentZoomingCalculatorTest, ZoomTest) { auto runner = ::absl::make_unique( ParseTextProtoOrDie(kConfigA)); @@ -244,6 +250,46 @@ TEST(ContentZoomingCalculatorTest, PanConfig) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, PanConfigWithCache) { + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache; + auto config = ParseTextProtoOrDie(kConfigD); + config.add_input_side_packet("STATE_CACHE:state_cache"); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0); + options->mutable_kinematic_options_pan()->set_update_rate_seconds(2); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(50.0); + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(483, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + // Now repeat the last frame for a new runner without the cache to see a reset + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr); + AddDetection(cv::Rect_(.45, .55, .15, .15), 2000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost. + runner->Outputs().Tag("CROP_RECT").packets); + } +} + TEST(ContentZoomingCalculatorTest, TiltConfig) { auto config = ParseTextProtoOrDie(kConfigD); auto* options = config.mutable_options()->MutableExtension( @@ -280,6 +326,46 @@ TEST(ContentZoomingCalculatorTest, ZoomConfig) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, ZoomConfigWithCache) { + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache; + auto config = ParseTextProtoOrDie(kConfigD); + config.add_input_side_packet("STATE_CACHE:state_cache"); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(50.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0); + options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2); + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 139, 139, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + // Now repeat the last frame for a new runner without the cache to see a reset + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr); + AddDetection(cv::Rect_(.45, .55, .15, .15), 2000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost. + runner->Outputs().Tag("CROP_RECT").packets); + } +} + TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) { auto runner = ::absl::make_unique( ParseTextProtoOrDie(kConfigB)); @@ -509,6 +595,32 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeStationary) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, ResolutionChangeStationaryWithCache) { + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache; + auto config = ParseTextProtoOrDie(kConfigD); + config.add_input_side_packet("STATE_CACHE:state_cache"); + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 222, 222, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500 * 0.5, 500 * 0.5, 222 * 0.5, 222 * 0.5, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } +} + TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) { auto config = ParseTextProtoOrDie(kConfigD); auto runner = ::absl::make_unique(config); @@ -527,6 +639,37 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) { runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, ResolutionChangeZoomingWithCache) { + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache; + auto config = ParseTextProtoOrDie(kConfigD); + config.add_input_side_packet("STATE_CACHE:state_cache"); + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetectionFrameSize(cv::Rect_(.1, .1, .8, .8), 0, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 888, 888, 0, + runner->Outputs().Tag("CROP_RECT").packets); + } + // The second runner should just resume based on state from the first runner. + { + auto runner = ::absl::make_unique(config); + runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< + mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 2000000, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 588, 588, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500 * 0.5, 500 * 0.5, 288 * 0.5, 288 * 0.5, 1, + runner->Outputs().Tag("CROP_RECT").packets); + } +} + TEST(ContentZoomingCalculatorTest, MaxZoomValue) { auto config = ParseTextProtoOrDie(kConfigD); auto* options = config.mutable_options()->MutableExtension( @@ -540,6 +683,108 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValue) { CheckCropRect(500, 500, 916, 916, 0, runner->Outputs().Tag("CROP_RECT").packets); } +TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_scale_factor(1.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.025, .025, .95, .95), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(0, 0, -1, -1), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(0, 0, -1, -1), 2000000, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + // 55/60 * 1000 = 916 + CheckCropRect(500, 500, 950, 950, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 1000, 1000, 2, + runner->Outputs().Tag("CROP_RECT").packets); +} +TEST(ContentZoomingCalculatorTest, StartZoomedOut) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_us_to_first_rect(1000000); + options->set_us_to_first_rect_delay(500000); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 400000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 800000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1500000, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 1000, 1000, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 470, 470, 2, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 222, 222, 3, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 222, 222, 4, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + + auto input_size = ::absl::make_unique>(1000, 1000); + runner->MutableInputs() + ->Tag("VIDEO_SIZE") + .packets.push_back(Adopt(input_size.release()).At(Timestamp(0))); + + MP_ASSERT_OK(runner->Run()); + + const std::vector& output_packets = + runner->Outputs().Tag("FIRST_CROP_RECT").packets; + ASSERT_EQ(output_packets.size(), 1); + const auto& rect = output_packets[0].Get(); + EXPECT_EQ(rect.x_center(), 0); + EXPECT_EQ(rect.y_center(), 0); + EXPECT_EQ(rect.width(), 0); + EXPECT_EQ(rect.height(), 0); +} + +TEST(ContentZoomingCalculatorTest, ProvidesConstantFirstRect) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_us_to_first_rect(500000); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 500000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1500000, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + const std::vector& output_packets = + runner->Outputs().Tag("FIRST_CROP_RECT").packets; + ASSERT_EQ(output_packets.size(), 4); + const auto& first_rect = output_packets[0].Get(); + EXPECT_NEAR(first_rect.x_center(), 0.5, 0.05); + EXPECT_NEAR(first_rect.y_center(), 0.5, 0.05); + EXPECT_NEAR(first_rect.width(), 0.222, 0.05); + EXPECT_NEAR(first_rect.height(), 0.222, 0.05); + for (int i = 1; i < 4; ++i) { + const auto& rect = output_packets[i].Get(); + EXPECT_EQ(first_rect.x_center(), rect.x_center()); + EXPECT_EQ(first_rect.y_center(), rect.y_center()); + EXPECT_EQ(first_rect.width(), rect.width()); + EXPECT_EQ(first_rect.height(), rect.height()); + } +} } // namespace } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc index e2b4f659d..a26c7e44c 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/flags/flag.h" #include "absl/strings/string_view.h" #include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -19,7 +20,6 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_core_inc.h" diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index 4a5ac3b7a..307953c19 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -46,7 +46,9 @@ proto_library( mediapipe_cc_proto_library( name = "kinematic_path_solver_cc_proto", srcs = ["kinematic_path_solver.proto"], - visibility = ["//mediapipe/examples:__subpackages__"], + visibility = [ + "//mediapipe/examples:__subpackages__", + ], deps = [":kinematic_path_solver_proto"], ) @@ -96,11 +98,11 @@ cc_library( deps = [ "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", ], ) @@ -249,10 +251,10 @@ cc_test( ":scene_camera_motion_analyzer", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", ], ) @@ -280,13 +282,13 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index fcdcf4b09..787baa370 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -14,11 +14,11 @@ #include "mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h" +#include "absl/flags/flag.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -28,8 +28,9 @@ #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_matchers.h" -DEFINE_string(input_image, "", "The path to an input image."); -DEFINE_string(output_folder, "", "The folder to output test result images."); +ABSL_FLAG(std::string, input_image, "", "The path to an input image."); +ABSL_FLAG(std::string, output_folder, "", + "The folder to output test result images."); namespace mediapipe { namespace autoflip { diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc index 1e8805b09..703d17534 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc @@ -19,12 +19,12 @@ #include #include +#include "absl/flags/flag.h" #include "absl/strings/str_split.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h" #include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc index d99292fa3..d27423c9a 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc @@ -202,7 +202,7 @@ absl::Status DrawFocusPointAndCropWindow( const auto& point = focus_point_frames[i].point(j); const int x = point.norm_point_x() * scene_frame.cols; const int y = point.norm_point_y() * scene_frame.rows; - cv::circle(viz_mat, cv::Point(x, y), 3, kRed, CV_FILLED); + cv::circle(viz_mat, cv::Point(x, y), 3, kRed, cv::FILLED); center_x += x; center_y += y; } diff --git a/mediapipe/examples/desktop/demo_run_graph_main.cc b/mediapipe/examples/desktop/demo_run_graph_main.cc index 343460eac..0d26aa0d3 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main.cc @@ -15,10 +15,11 @@ // An example of sending OpenCV webcam frames into a MediaPipe graph. #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -30,15 +31,14 @@ constexpr char kInputStream[] = "input_video"; constexpr char kOutputStream[] = "output_video"; constexpr char kWindowName[] = "MediaPipe"; -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_video_path, "", - "Full path of video to load. " - "If not provided, attempt to use a webcam."); -DEFINE_string(output_video_path, "", - "Full path of where to save result (.mp4 only). " - "If not provided, show result in a window."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, input_video_path, "", + "Full path of video to load. " + "If not provided, attempt to use a webcam."); +ABSL_FLAG(std::string, output_video_path, "", + "Full path of where to save result (.mp4 only). " + "If not provided, show result in a window."); absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; @@ -148,7 +148,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc index 6942971f7..586565db4 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc @@ -16,10 +16,11 @@ // This example requires a linux computer and a GPU with EGL support drivers. #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -34,15 +35,14 @@ constexpr char kInputStream[] = "input_video"; constexpr char kOutputStream[] = "output_video"; constexpr char kWindowName[] = "MediaPipe"; -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_video_path, "", - "Full path of video to load. " - "If not provided, attempt to use a webcam."); -DEFINE_string(output_video_path, "", - "Full path of where to save result (.mp4 only). " - "If not provided, show result in a window."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, input_video_path, "", + "Full path of video to load. " + "If not provided, attempt to use a webcam."); +ABSL_FLAG(std::string, output_video_path, "", + "Full path of where to save result (.mp4 only). " + "If not provided, show result in a window."); absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; @@ -191,7 +191,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/iris_tracking/BUILD b/mediapipe/examples/desktop/iris_tracking/BUILD index 29812d21c..c6596de0b 100644 --- a/mediapipe/examples/desktop/iris_tracking/BUILD +++ b/mediapipe/examples/desktop/iris_tracking/BUILD @@ -23,7 +23,6 @@ cc_binary( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_imgproc", @@ -31,6 +30,8 @@ cc_binary( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/graphs/iris_tracking:iris_depth_cpu_deps", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", ], ) diff --git a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc index 515ee37b0..928ebb207 100644 --- a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc +++ b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc @@ -17,11 +17,12 @@ #include #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -38,12 +39,12 @@ constexpr char kCalculatorGraphConfigFile[] = "mediapipe/graphs/iris_tracking/iris_depth_cpu.pbtxt"; constexpr float kMicrosPerSecond = 1e6; -DEFINE_string(input_image_path, "", - "Full path of image to load. " - "If not provided, nothing will run."); -DEFINE_string(output_image_path, "", - "Full path of where to save image result (.jpg only). " - "If not provided, show result in a window."); +ABSL_FLAG(std::string, input_image_path, "", + "Full path of image to load. " + "If not provided, nothing will run."); +ABSL_FLAG(std::string, output_image_path, "", + "Full path of where to save image result (.jpg only). " + "If not provided, show result in a window."); namespace { @@ -148,7 +149,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/media_sequence/BUILD b/mediapipe/examples/desktop/media_sequence/BUILD index 4e94ebe53..1a88aa109 100644 --- a/mediapipe/examples/desktop/media_sequence/BUILD +++ b/mediapipe/examples/desktop/media_sequence/BUILD @@ -21,11 +21,12 @@ cc_library( srcs = ["run_graph_file_io_main.cc"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc index a15f599d1..06212b013 100644 --- a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -17,26 +17,26 @@ // to disk. #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "absl/strings/str_split.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_side_packets, "", - "Comma-separated list of key=value pairs specifying side packets " - "and corresponding file paths for the CalculatorGraph. The side " - "packets are read from the files and fed to the graph as strings " - "even if they represent doubles, floats, etc."); -DEFINE_string(output_side_packets, "", - "Comma-separated list of key=value pairs specifying the output " - "side packets and paths to write to disk for the " - "CalculatorGraph."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, input_side_packets, "", + "Comma-separated list of key=value pairs specifying side packets " + "and corresponding file paths for the CalculatorGraph. The side " + "packets are read from the files and fed to the graph as strings " + "even if they represent doubles, floats, etc."); +ABSL_FLAG(std::string, output_side_packets, "", + "Comma-separated list of key=value pairs specifying the output " + "side packets and paths to write to disk for the " + "CalculatorGraph."); absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; @@ -85,7 +85,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/object_detection_3d/BUILD b/mediapipe/examples/desktop/object_detection_3d/BUILD index 86e29a728..8a58e1129 100644 --- a/mediapipe/examples/desktop/object_detection_3d/BUILD +++ b/mediapipe/examples/desktop/object_detection_3d/BUILD @@ -20,7 +20,7 @@ package(default_visibility = ["//mediapipe/examples:__subpackages__"]) # To run 3D object detection for shoes, # bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \ # --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \ -# --input_side_packets="input_video_path=,box_landmark_model_path=mediapipe/models/object_detection_3d_sneakers.tflite,output_video_path=,allowed_labels=Footwear" +# --input_side_packets="input_video_path=,box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_sneakers.tflite,output_video_path=,allowed_labels=Footwear" # To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly. # Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair # Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc index 5d33af66c..96d9839a8 100644 --- a/mediapipe/examples/desktop/simple_run_graph_main.cc +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -20,11 +20,12 @@ #include #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -32,31 +33,30 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_side_packets, "", - "Comma-separated list of key=value pairs specifying side packets " - "for the CalculatorGraph. All values will be treated as the " - "string type even if they represent doubles, floats, etc."); +ABSL_FLAG(std::string, input_side_packets, "", + "Comma-separated list of key=value pairs specifying side packets " + "for the CalculatorGraph. All values will be treated as the " + "string type even if they represent doubles, floats, etc."); // Local file output flags. // Output stream -DEFINE_string(output_stream, "", - "The output stream to output to the local file in csv format."); -DEFINE_string(output_stream_file, "", - "The name of the local file to output all packets sent to " - "the stream specified with --output_stream. "); -DEFINE_bool(strip_timestamps, false, - "If true, only the packet contents (without timestamps) will be " - "written into the local file."); +ABSL_FLAG(std::string, output_stream, "", + "The output stream to output to the local file in csv format."); +ABSL_FLAG(std::string, output_stream_file, "", + "The name of the local file to output all packets sent to " + "the stream specified with --output_stream. "); +ABSL_FLAG(bool, strip_timestamps, false, + "If true, only the packet contents (without timestamps) will be " + "written into the local file."); // Output side packets -DEFINE_string(output_side_packets, "", - "A CSV of output side packets to output to local file."); -DEFINE_string(output_side_packets_file, "", - "The name of the local file to output all side packets specified " - "with --output_side_packets. "); +ABSL_FLAG(std::string, output_side_packets, "", + "A CSV of output side packets to output to local file."); +ABSL_FLAG(std::string, output_side_packets_file, "", + "The name of the local file to output all side packets specified " + "with --output_side_packets. "); absl::Status OutputStreamToLocalFile(mediapipe::OutputStreamPoller& poller) { std::ofstream file; @@ -143,7 +143,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/examples/desktop/youtube8m/BUILD b/mediapipe/examples/desktop/youtube8m/BUILD index e6347b243..e0e44c4d9 100644 --- a/mediapipe/examples/desktop/youtube8m/BUILD +++ b/mediapipe/examples/desktop/youtube8m/BUILD @@ -18,10 +18,11 @@ cc_binary( name = "extract_yt8m_features", srcs = ["extract_yt8m_features.cc"], deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:parse_text_proto", diff --git a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc index a303077cc..9030e9255 100644 --- a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc +++ b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc @@ -17,27 +17,27 @@ // to disk. #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "absl/strings/str_split.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" -DEFINE_string( - calculator_graph_config_file, "", - "Name of file containing text format CalculatorGraphConfig proto."); -DEFINE_string(input_side_packets, "", - "Comma-separated list of key=value pairs specifying side packets " - "and corresponding file paths for the CalculatorGraph. The side " - "packets are read from the files and fed to the graph as strings " - "even if they represent doubles, floats, etc."); -DEFINE_string(output_side_packets, "", - "Comma-separated list of key=value pairs specifying the output " - "side packets and paths to write to disk for the " - "CalculatorGraph."); +ABSL_FLAG(std::string, calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +ABSL_FLAG(std::string, input_side_packets, "", + "Comma-separated list of key=value pairs specifying side packets " + "and corresponding file paths for the CalculatorGraph. The side " + "packets are read from the files and fed to the graph as strings " + "even if they represent doubles, floats, etc."); +ABSL_FLAG(std::string, output_side_packets, "", + "Comma-separated list of key=value pairs specifying the output " + "side packets and paths to write to disk for the " + "CalculatorGraph."); absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; @@ -126,7 +126,7 @@ absl::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index d2ed6cf1e..2124ca580 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -23,7 +23,6 @@ package(default_visibility = ["//visibility:private"]) package_group( name = "mediapipe_internal", packages = [ - "//java/com/google/mediapipe/framework/...", "//mediapipe/...", ], ) @@ -78,21 +77,19 @@ mediapipe_proto_library( mediapipe_proto_library( name = "mediapipe_options_proto", srcs = ["mediapipe_options.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [":mediapipe_internal"], ) mediapipe_proto_library( name = "packet_factory_proto", srcs = ["packet_factory.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [":mediapipe_internal"], ) mediapipe_proto_library( name = "packet_generator_proto", srcs = ["packet_generator.proto"], - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = [":mediapipe_internal"], ) mediapipe_proto_library( @@ -105,7 +102,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [":mediapipe_internal"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -274,14 +271,17 @@ cc_library( ], deps = [ ":calculator_base", + ":calculator_node", ":counter_factory", ":delegating_executor", ":mediapipe_profiling", ":executor", ":graph_output_stream", + ":graph_service", + ":graph_service_manager", ":input_stream_manager", ":input_stream_shard", - ":graph_service", + ":output_side_packet_impl", ":output_stream", ":output_stream_manager", ":output_stream_poller", @@ -303,29 +303,27 @@ cc_library( "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework:thread_pool_executor_cc_proto", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "//mediapipe/gpu:graph_support", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - ":calculator_node", - ":output_side_packet_impl", - "//mediapipe/framework/profiler:graph_profiler", - "//mediapipe/framework/tool:fill_packet_set", - "//mediapipe/framework/tool:status_util", - "//mediapipe/framework/tool:tag_map", - "//mediapipe/framework/tool:validate", - "//mediapipe/framework/tool:validate_name", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", + "//mediapipe/framework/profiler:graph_profiler", + "//mediapipe/framework/tool:fill_packet_set", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:validate", + "//mediapipe/framework/tool:validate_name", + "//mediapipe/gpu:graph_support", "//mediapipe/util:cpu_util", ] + select({ "//conditions:default": [ @@ -336,6 +334,28 @@ cc_library( }), ) +cc_library( + name = "graph_service_manager", + srcs = ["graph_service_manager.cc"], + hdrs = ["graph_service_manager.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":graph_service", + "//mediapipe/framework:packet", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "graph_service_manager_test", + srcs = ["graph_service_manager_test.cc"], + deps = [ + ":graph_service_manager", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:gtest_main", + ], +) + cc_library( name = "calculator_node", srcs = ["calculator_node.cc"], @@ -425,6 +445,7 @@ cc_library( ":counter", ":counter_factory", ":graph_service", + ":graph_service_manager", ":input_stream", ":output_stream", ":packet", @@ -977,6 +998,8 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":graph_service", + ":graph_service_manager", ":port", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto", @@ -989,6 +1012,8 @@ cc_library( "//mediapipe/framework/tool:template_expander", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", ], ) @@ -1008,7 +1033,7 @@ cc_library( "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -1102,6 +1127,7 @@ cc_library( deps = [ ":calculator_base", ":calculator_contract", + ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", @@ -1136,6 +1162,24 @@ cc_library( ], ) +cc_test( + name = "validated_graph_config_test", + srcs = ["validated_graph_config_test.cc"], + deps = [ + ":calculator_framework", + ":graph_service", + ":graph_service_manager", + ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "graph_validation", hdrs = ["graph_validation.h"], @@ -1591,13 +1635,16 @@ cc_test( srcs = ["subgraph_test.cc"], deps = [ ":calculator_framework", + ":graph_service_manager", ":subgraph", ":test_calculators", + "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:status", + "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/mediapipe/framework/calculator_context.cc b/mediapipe/framework/calculator_context.cc index 0d1c05b1e..4452f45e3 100644 --- a/mediapipe/framework/calculator_context.cc +++ b/mediapipe/framework/calculator_context.cc @@ -41,9 +41,9 @@ Counter* CalculatorContext::GetCounter(const std::string& name) { return calculator_state_->GetCounter(name); } -CounterSet* CalculatorContext::GetCounterSet() { +CounterFactory* CalculatorContext::GetCounterFactory() { CHECK(calculator_state_); - return calculator_state_->GetCounterSet(); + return calculator_state_->GetCounterFactory(); } const PacketSet& CalculatorContext::InputSidePackets() const { diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h index e73dd66ce..34c2c2425 100644 --- a/mediapipe/framework/calculator_context.h +++ b/mediapipe/framework/calculator_context.h @@ -76,7 +76,7 @@ class CalculatorContext { // Returns the counter set, which can be used to create new counters. // No prefix is added to counters created in this way. - CounterSet* GetCounterSet(); + CounterFactory* GetCounterFactory(); // Returns the current input timestamp, or Timestamp::Unset if there are // no input packets. @@ -113,26 +113,9 @@ class CalculatorContext { return calculator_state_->GetSharedProfilingContext().get(); } - template - class ServiceBinding { - public: - bool IsAvailable() { - return calculator_state_->IsServiceAvailable(service_); - } - T& GetObject() { return calculator_state_->GetServiceObject(service_); } - - ServiceBinding(CalculatorState* calculator_state, - const GraphService& service) - : calculator_state_(calculator_state), service_(service) {} - - private: - CalculatorState* calculator_state_; - const GraphService& service_; - }; - template ServiceBinding Service(const GraphService& service) { - return ServiceBinding(calculator_state_, service); + return ServiceBinding(calculator_state_->GetServiceObject(service)); } private: diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index 961861abe..ccbde4381 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -36,6 +36,7 @@ #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/delegating_executor.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/input_stream_manager.h" #include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/packet_generator.h" @@ -392,7 +393,8 @@ absl::Status CalculatorGraph::Initialize( const CalculatorGraphConfig& input_config, const std::map& side_packets) { auto validated_graph = absl::make_unique(); - MP_RETURN_IF_ERROR(validated_graph->Initialize(input_config)); + MP_RETURN_IF_ERROR(validated_graph->Initialize( + input_config, /*graph_registry=*/nullptr, &service_manager_)); return Initialize(std::move(validated_graph), side_packets); } @@ -402,8 +404,8 @@ absl::Status CalculatorGraph::Initialize( const std::map& side_packets, const std::string& graph_type, const Subgraph::SubgraphOptions* options) { auto validated_graph = absl::make_unique(); - MP_RETURN_IF_ERROR(validated_graph->Initialize(input_configs, input_templates, - graph_type, options)); + MP_RETURN_IF_ERROR(validated_graph->Initialize( + input_configs, input_templates, graph_type, options, &service_manager_)); return Initialize(std::move(validated_graph), side_packets); } @@ -509,19 +511,15 @@ absl::Status CalculatorGraph::StartRun( #if !MEDIAPIPE_DISABLE_GPU absl::Status CalculatorGraph::SetGpuResources( std::shared_ptr<::mediapipe::GpuResources> resources) { - RET_CHECK(!ContainsKey(service_packets_, kGpuService.key)) + auto gpu_service = service_manager_.GetServiceObject(kGpuService); + RET_CHECK_EQ(gpu_service, nullptr) << "The GPU resources have already been configured."; - service_packets_[kGpuService.key] = - MakePacket>( - std::move(resources)); - return absl::OkStatus(); + return service_manager_.SetServiceObject(kGpuService, std::move(resources)); } std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() const { - auto service_iter = service_packets_.find(kGpuService.key); - if (service_iter == service_packets_.end()) return nullptr; - return service_iter->second.Get>(); + return service_manager_.GetServiceObject(kGpuService); } absl::StatusOr> CalculatorGraph::PrepareGpu( @@ -536,8 +534,7 @@ absl::StatusOr> CalculatorGraph::PrepareGpu( } } if (uses_gpu) { - auto service_iter = service_packets_.find(kGpuService.key); - bool has_service = service_iter != service_packets_.end(); + auto gpu_resources = service_manager_.GetServiceObject(kGpuService); auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); // Workaround for b/116875321: CalculatorRunner provides an empty packet, @@ -545,15 +542,12 @@ absl::StatusOr> CalculatorGraph::PrepareGpu( bool has_legacy_sp = legacy_sp_iter != side_packets.end() && !legacy_sp_iter->second.IsEmpty(); - std::shared_ptr<::mediapipe::GpuResources> gpu_resources; - if (has_service) { + if (gpu_resources) { if (has_legacy_sp) { LOG(WARNING) << "::mediapipe::GpuSharedData provided as a side packet while the " << "graph already had one; ignoring side packet"; } - gpu_resources = service_iter->second - .Get>(); update_sp = true; } else { if (has_legacy_sp) { @@ -564,8 +558,8 @@ absl::StatusOr> CalculatorGraph::PrepareGpu( ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create()); update_sp = true; } - service_packets_[kGpuService.key] = - MakePacket>(gpu_resources); + MP_RETURN_IF_ERROR( + service_manager_.SetServiceObject(kGpuService, gpu_resources)); } // Create or replace the legacy side packet if needed. @@ -682,8 +676,10 @@ absl::Status CalculatorGraph::PrepareForRun( std::placeholders::_1, std::placeholders::_2); node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); scheduler_.AssignNodeToSchedulerQueue(&node); + // TODO: update calculator node to use GraphServiceManager + // instead of service packets? const absl::Status result = node.PrepareForRun( - current_run_side_packets_, service_packets_, + current_run_side_packets_, service_manager_.ServicePackets(), std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_, &node), std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_, @@ -811,6 +807,11 @@ absl::Status CalculatorGraph::AddPacketToInputStreamInternal( CHECK_GE(node_id, validated_graph_->CalculatorInfos().size()); { absl::MutexLock lock(&full_input_streams_mutex_); + if (full_input_streams_.empty()) { + return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + << "CalculatorGraph::AddPacketToInputStream() is called before " + "StartRun()"; + } if (graph_input_stream_add_mode_ == GraphInputStreamAddMode::ADD_IF_NOT_FULL) { if (has_error_) { @@ -1170,21 +1171,6 @@ void CalculatorGraph::Pause() { scheduler_.Pause(); } void CalculatorGraph::Resume() { scheduler_.Resume(); } -absl::Status CalculatorGraph::SetServicePacket(const GraphServiceBase& service, - Packet p) { - // TODO: check that the graph has not been started! - service_packets_[service.key] = std::move(p); - return absl::OkStatus(); -} - -Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) { - auto it = service_packets_.find(service.key); - if (it == service_packets_.end()) { - return {}; - } - return it->second; -} - absl::Status CalculatorGraph::SetExecutorInternal( const std::string& name, std::shared_ptr executor) { if (!executors_.emplace(name, executor).second) { diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index a70da438b..4c9079f0a 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -38,6 +38,7 @@ #include "mediapipe/framework/executor.h" #include "mediapipe/framework/graph_output_stream.h" #include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/output_side_packet_impl.h" #include "mediapipe/framework/output_stream.h" @@ -377,19 +378,20 @@ class CalculatorGraph { template absl::Status SetServiceObject(const GraphService& service, std::shared_ptr object) { - return SetServicePacket(service, - MakePacket>(std::move(object))); + // TODO: check that the graph has not been started! + return service_manager_.SetServiceObject(service, object); } template std::shared_ptr GetServiceObject(const GraphService& service) { - Packet p = GetServicePacket(service); - if (p.IsEmpty()) return nullptr; - return p.Get>(); + return service_manager_.GetServiceObject(service); } // Only the Java API should call this directly. - absl::Status SetServicePacket(const GraphServiceBase& service, Packet p); + absl::Status SetServicePacket(const GraphServiceBase& service, Packet p) { + // TODO: check that the graph has not been started! + return service_manager_.SetServicePacket(service, p); + } private: // GraphRunState is used as a parameter in the function CallStatusHandlers. @@ -523,7 +525,6 @@ class CalculatorGraph { // status before taking any action. void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full); - Packet GetServicePacket(const GraphServiceBase& service); #if !MEDIAPIPE_DISABLE_GPU // Owns the legacy GpuSharedData if we need to create one for backwards // compatibility. @@ -598,7 +599,8 @@ class CalculatorGraph { // The processed input side packet map for this run. std::map current_run_side_packets_; - std::map service_packets_; + // Object to manage graph services. + GraphServiceManager service_manager_; // Vector of errors encountered while running graph. Always use RecordError() // to add an error to this vector. diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index 2c71f8cb3..44f3c9a43 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -1361,6 +1361,38 @@ TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(CalculatorGraphBoundsTest, PostStreamPacketToSetProcessTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_0", MakePacket(0).At(Timestamp::PostStream()))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(output_0_packets.size(), 1); + EXPECT_EQ(output_0_packets[0].Timestamp(), Timestamp::PostStream()); + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + // A Calculator that sends a timestamp bound for every other input. class OccasionalBoundCalculator : public CalculatorBase { public: diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 9464f5c32..2ea02e041 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -4356,256 +4356,5 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) { ASSERT_EQ(5, packet_dump.size()); } -// Returns the first packet of the input stream. -class FirstPacketFilterCalculator : public CalculatorBase { - public: - FirstPacketFilterCalculator() {} - ~FirstPacketFilterCalculator() override {} - - static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return absl::OkStatus(); - } - - absl::Status Process(CalculatorContext* cc) override { - if (!seen_first_packet_) { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - cc->Outputs().Index(0).Close(); - seen_first_packet_ = true; - } - return absl::OkStatus(); - } - - private: - bool seen_first_packet_ = false; -}; -REGISTER_CALCULATOR(FirstPacketFilterCalculator); -constexpr int kDefaultMaxCount = 1000; - -TEST(CalculatorGraph, TestPollPacket) { - CalculatorGraphConfig config; - CalculatorGraphConfig::Node* node = config.add_node(); - node->set_calculator("CountingSourceCalculator"); - node->add_output_stream("output"); - node->add_input_side_packet("MAX_COUNT:max_count"); - - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - auto status_or_poller = graph.AddOutputStreamPoller("output"); - ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.value()); - MP_ASSERT_OK( - graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); - Packet packet; - int num_packets = 0; - while (poller.Next(&packet)) { - EXPECT_EQ(num_packets, packet.Get()); - ++num_packets; - } - MP_ASSERT_OK(graph.CloseAllPacketSources()); - MP_ASSERT_OK(graph.WaitUntilDone()); - EXPECT_FALSE(poller.Next(&packet)); - EXPECT_EQ(kDefaultMaxCount, num_packets); -} - -TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) { - CalculatorGraphConfig config; - CalculatorGraphConfig::Node* node = config.add_node(); - node->set_calculator("CountingSourceCalculator"); - node->add_output_stream("output"); - node->add_input_side_packet("MAX_COUNT:max_count"); - - for (int queue_size = 1; queue_size < 10; ++queue_size) { - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - auto status_or_poller = graph.AddOutputStreamPoller("output"); - ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.value()); - poller.SetMaxQueueSize(queue_size); - MP_ASSERT_OK( - graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); - Packet packet; - int num_packets = 0; - while (poller.Next(&packet)) { - EXPECT_EQ(num_packets, packet.Get()); - ++num_packets; - } - MP_ASSERT_OK(graph.CloseAllPacketSources()); - MP_ASSERT_OK(graph.WaitUntilDone()); - EXPECT_FALSE(poller.Next(&packet)); - EXPECT_EQ(kDefaultMaxCount, num_packets); - } -} - -TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) { - CalculatorGraphConfig config; - CalculatorGraphConfig::Node* node1 = config.add_node(); - node1->set_calculator("CountingSourceCalculator"); - node1->add_output_stream("stream1"); - node1->add_input_side_packet("MAX_COUNT:max_count"); - CalculatorGraphConfig::Node* node2 = config.add_node(); - node2->set_calculator("PassThroughCalculator"); - node2->add_input_stream("stream1"); - node2->add_output_stream("stream2"); - - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - auto status_or_poller1 = graph.AddOutputStreamPoller("stream1"); - ASSERT_TRUE(status_or_poller1.ok()); - OutputStreamPoller poller1 = std::move(status_or_poller1.value()); - auto status_or_poller2 = graph.AddOutputStreamPoller("stream2"); - ASSERT_TRUE(status_or_poller2.ok()); - OutputStreamPoller poller2 = std::move(status_or_poller2.value()); - MP_ASSERT_OK( - graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); - Packet packet1; - Packet packet2; - int num_packets1 = 0; - int num_packets2 = 0; - int running_pollers = 2; - while (running_pollers > 0) { - if (poller1.Next(&packet1)) { - EXPECT_EQ(num_packets1++, packet1.Get()); - } else { - --running_pollers; - } - if (poller2.Next(&packet2)) { - EXPECT_EQ(num_packets2++, packet2.Get()); - } else { - --running_pollers; - } - } - MP_ASSERT_OK(graph.CloseAllPacketSources()); - MP_ASSERT_OK(graph.WaitUntilDone()); - EXPECT_FALSE(poller1.Next(&packet1)); - EXPECT_FALSE(poller2.Next(&packet2)); - EXPECT_EQ(kDefaultMaxCount, num_packets1); - EXPECT_EQ(kDefaultMaxCount, num_packets2); -} - -// Ensure that when a custom input stream handler is used to handle packets from -// input streams, an error message is outputted with the appropriate link to -// resolve the issue when the calculator doesn't handle inputs in monotonically -// increasing order of timestamps. -TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) { - CalculatorGraph graph; - CalculatorGraphConfig config = - mediapipe::ParseTextProtoOrDie(R"( - input_stream: 'input0' - input_stream: 'input1' - node { - calculator: 'SimpleMuxCalculator' - input_stream: 'input0' - input_stream: 'input1' - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } - output_stream: 'output' - } - )"); - std::vector packet_dump; - tool::AddVectorSink("output", &config, &packet_dump); - - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.StartRun({})); - - // Send packets to input stream "input0" at timestamps 0 and 1 consecutively. - Timestamp input0_timestamp = Timestamp(0); - MP_EXPECT_OK(graph.AddPacketToInputStream( - "input0", MakePacket(1).At(input0_timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, packet_dump.size()); - EXPECT_EQ(1, packet_dump[0].Get()); - - ++input0_timestamp; - MP_EXPECT_OK(graph.AddPacketToInputStream( - "input0", MakePacket(3).At(input0_timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(2, packet_dump.size()); - EXPECT_EQ(3, packet_dump[1].Get()); - - // Send a packet to input stream "input1" at timestamp 0 after sending two - // packets at timestamps 0 and 1 to input stream "input0". This will result - // in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle - // inputs from all streams in monotonically increasing order of timestamps. - Timestamp input1_timestamp = Timestamp(0); - MP_EXPECT_OK(graph.AddPacketToInputStream( - "input1", MakePacket(2).At(input1_timestamp))); - absl::Status run_status = graph.WaitUntilIdle(); - EXPECT_THAT( - run_status.ToString(), - testing::AllOf( - // The core problem. - testing::HasSubstr("timestamp mismatch on a calculator"), - testing::HasSubstr( - "timestamps that are not strictly monotonically increasing"), - // Link to the possible solution. - testing::HasSubstr("ImmediateInputStreamHandler class comment"))); -} - -void DoTestMultipleGraphRuns(absl::string_view input_stream_handler, - bool select_packet) { - std::string graph_proto = absl::StrFormat(R"( - input_stream: 'input' - input_stream: 'select' - node { - calculator: 'PassThroughCalculator' - input_stream: 'input' - input_stream: 'select' - input_stream_handler { - input_stream_handler: "%s" - } - output_stream: 'output' - output_stream: 'select_out' - } - )", - input_stream_handler.data()); - CalculatorGraphConfig config = - mediapipe::ParseTextProtoOrDie(graph_proto); - std::vector packet_dump; - tool::AddVectorSink("output", &config, &packet_dump); - - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - - struct Run { - Timestamp timestamp; - int value; - }; - std::vector runs = {{.timestamp = Timestamp(2000), .value = 2}, - {.timestamp = Timestamp(1000), .value = 1}}; - for (const Run& run : runs) { - MP_ASSERT_OK(graph.StartRun({})); - - if (select_packet) { - MP_EXPECT_OK(graph.AddPacketToInputStream( - "select", MakePacket(0).At(run.timestamp))); - } - MP_EXPECT_OK(graph.AddPacketToInputStream( - "input", MakePacket(run.value).At(run.timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, packet_dump.size()); - EXPECT_EQ(run.value, packet_dump[0].Get()); - EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp()); - - MP_ASSERT_OK(graph.CloseAllPacketSources()); - MP_ASSERT_OK(graph.WaitUntilDone()); - - packet_dump.clear(); - } -} - -TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) { - DoTestMultipleGraphRuns("BarrierInputStreamHandler", true); - DoTestMultipleGraphRuns("DefaultInputStreamHandler", true); - DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true); - DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true); - DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false); - DoTestMultipleGraphRuns("MuxInputStreamHandler", true); - DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true); - DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true); -} - } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc index 8d538486b..de9b29914 100644 --- a/mediapipe/framework/calculator_node.cc +++ b/mediapipe/framework/calculator_node.cc @@ -408,13 +408,13 @@ absl::Status CalculatorNode::PrepareForRun( validated_graph_->CalculatorInfos()[node_id_].Contract(); for (const auto& svc_req : contract.ServiceRequests()) { const auto& req = svc_req.second; - std::string key{req.Service().key}; - auto it = service_packets.find(key); + auto it = service_packets.find(req.Service().key); if (it == service_packets.end()) { RET_CHECK(req.IsOptional()) - << "required service '" << key << "' was not provided"; + << "required service '" << req.Service().key << "' was not provided"; } else { - calculator_state_->SetServicePacket(key, it->second); + MP_RETURN_IF_ERROR( + calculator_state_->SetServicePacket(req.Service(), it->second)); } } diff --git a/mediapipe/framework/calculator_state.cc b/mediapipe/framework/calculator_state.cc index fcd20c2a1..3b0264e97 100644 --- a/mediapipe/framework/calculator_state.cc +++ b/mediapipe/framework/calculator_state.cc @@ -61,13 +61,9 @@ Counter* CalculatorState::GetCounter(const std::string& name) { return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name)); } -CounterSet* CalculatorState::GetCounterSet() { +CounterFactory* CalculatorState::GetCounterFactory() { CHECK(counter_factory_); - return counter_factory_->GetCounterSet(); -} - -void CalculatorState::SetServicePacket(const std::string& key, Packet packet) { - service_packets_[key] = std::move(packet); + return counter_factory_; } } // namespace mediapipe diff --git a/mediapipe/framework/calculator_state.h b/mediapipe/framework/calculator_state.h index 42d1f1d4a..8a50f5d8e 100644 --- a/mediapipe/framework/calculator_state.h +++ b/mediapipe/framework/calculator_state.h @@ -27,6 +27,7 @@ #include "mediapipe/framework/counter.h" #include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port.h" @@ -81,7 +82,7 @@ class CalculatorState { // Returns a counter set, which can be passed to other classes, to generate // counters. NOTE: This differs from GetCounter, in that the counters // created by this counter set do not have the NodeName prefix. - CounterSet* GetCounterSet(); + CounterFactory* GetCounterFactory(); std::shared_ptr GetSharedProfilingContext() const { return profiling_context_; @@ -99,17 +100,14 @@ class CalculatorState { counter_factory_ = counter_factory; } - void SetServicePacket(const std::string& key, Packet packet); - - bool IsServiceAvailable(const GraphServiceBase& service) { - return ContainsKey(service_packets_, service.key); + absl::Status SetServicePacket(const GraphServiceBase& service, + Packet packet) { + return graph_service_manager_.SetServicePacket(service, packet); } template - T& GetServiceObject(const GraphService& service) { - auto it = service_packets_.find(service.key); - CHECK(it != service_packets_.end()); - return *it->second.template Get>(); + std::shared_ptr GetServiceObject(const GraphService& service) { + return graph_service_manager_.GetServiceObject(service); } private: @@ -129,7 +127,7 @@ class CalculatorState { // The graph tracing and profiling interface. std::shared_ptr profiling_context_; - std::map service_packets_; + GraphServiceManager graph_service_manager_; //////////////////////////////////////// // Variables which ARE cleared by ResetBetweenRuns(). diff --git a/mediapipe/framework/deps/ret_check.h b/mediapipe/framework/deps/ret_check.h index fec7a0318..3b1c3674a 100644 --- a/mediapipe/framework/deps/ret_check.h +++ b/mediapipe/framework/deps/ret_check.h @@ -37,7 +37,7 @@ inline StatusBuilder RetCheckImpl(const absl::Status& status, const char* condition, mediapipe::source_location location) { if (ABSL_PREDICT_TRUE(status.ok())) - return mediapipe::StatusBuilder(OkStatus(), location); + return mediapipe::StatusBuilder(absl::OkStatus(), location); return RetCheckFailSlowPath(location, condition, status); } diff --git a/mediapipe/framework/deps/status.cc b/mediapipe/framework/deps/status.cc index b51c9f9db..c6e7b68b5 100644 --- a/mediapipe/framework/deps/status.cc +++ b/mediapipe/framework/deps/status.cc @@ -18,7 +18,7 @@ namespace mediapipe { -std::ostream& operator<<(std::ostream& os, const Status& x) { +std::ostream& operator<<(std::ostream& os, const absl::Status& x) { os << x.ToString(); return os; } diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index c9df245f2..d31c81c2d 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -194,10 +194,10 @@ namespace status_macro_internal { // that declares a variable. class StatusAdaptorForMacros { public: - StatusAdaptorForMacros(const Status& status, const char* file, int line) + StatusAdaptorForMacros(const absl::Status& status, const char* file, int line) : builder_(status, file, line) {} - StatusAdaptorForMacros(Status&& status, const char* file, int line) + StatusAdaptorForMacros(absl::Status&& status, const char* file, int line) : builder_(std::move(status), file, line) {} StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, diff --git a/mediapipe/framework/encode_binary_proto.bzl b/mediapipe/framework/encode_binary_proto.bzl index 7ab235beb..3af435f75 100644 --- a/mediapipe/framework/encode_binary_proto.bzl +++ b/mediapipe/framework/encode_binary_proto.bzl @@ -79,13 +79,10 @@ def _get_proto_provider(dep): def _encode_binary_proto_impl(ctx): """Implementation of the encode_binary_proto rule.""" - all_protos = depset() - for dep in ctx.attr.deps: - provider = _get_proto_provider(dep) - all_protos = depset( - direct = [], - transitive = [all_protos, provider.transitive_sources], - ) + all_protos = depset( + direct = [], + transitive = [_get_proto_provider(dep).transitive_sources for dep in ctx.attr.deps], + ) textpb = ctx.file.input binarypb = ctx.outputs.output or ctx.actions.declare_file( @@ -120,7 +117,7 @@ def _encode_binary_proto_impl(ctx): data_runfiles = ctx.runfiles(transitive_files = output_depset), )] -encode_binary_proto = rule( +_encode_binary_proto = rule( implementation = _encode_binary_proto_impl, attrs = { "_proto_compiler": attr.label( @@ -142,6 +139,15 @@ encode_binary_proto = rule( }, ) +def encode_binary_proto(name, input, message_type, deps, **kwargs): + _encode_binary_proto( + name = name, + input = input, + message_type = message_type, + deps = deps, + **kwargs + ) + def _generate_proto_descriptor_set_impl(ctx): """Implementation of the generate_proto_descriptor_set rule.""" all_protos = depset(transitive = [ diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index e404e7218..3067eb246 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -114,7 +114,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -260,9 +260,11 @@ mediapipe_register_type( include_headers = ["mediapipe/framework/formats/landmark.pb.h"], types = [ "::mediapipe::Landmark", + "::mediapipe::LandmarkList", "::mediapipe::NormalizedLandmark", "::mediapipe::NormalizedLandmarkList", "::std::vector<::mediapipe::Landmark>", + "::std::vector<::mediapipe::LandmarkList>", "::std::vector<::mediapipe::NormalizedLandmark>", "::std::vector<::mediapipe::NormalizedLandmarkList>", ], diff --git a/mediapipe/framework/formats/classification.proto b/mediapipe/framework/formats/classification.proto index 8a777f105..dbe079f7a 100644 --- a/mediapipe/framework/formats/classification.proto +++ b/mediapipe/framework/formats/classification.proto @@ -31,6 +31,8 @@ message Classification { optional float score = 2; // Label or name of the class. optional string label = 3; + // Optional human-readable string for display purposes. + optional string display_name = 4; } // Group of Classification protos. diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index 9f36471c8..58184afca 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -78,6 +78,12 @@ class Image { pixel_mutex_ = std::make_shared(); } + // CPU getters. + const ImageFrameSharedPtr& GetImageFrameSharedPtr() const { + if (use_gpu_ == true) ConvertToCpu(); + return image_frame_; + } + // Creates an Image representing the same image content as the input GPU // buffer in platform-specific representations. #if !MEDIAPIPE_DISABLE_GPU @@ -95,13 +101,8 @@ class Image { gpu_buffer_ = gpu_buffer; pixel_mutex_ = std::make_shared(); } -#endif // !MEDIAPIPE_DISABLE_GPU - const ImageFrameSharedPtr& GetImageFrameSharedPtr() const { - if (use_gpu_ == true) ConvertToCpu(); - return image_frame_; - } -#if !MEDIAPIPE_DISABLE_GPU + // GPU getters. #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef() const { if (use_gpu_ == false) ConvertToGpu(); diff --git a/mediapipe/framework/formats/landmark.proto b/mediapipe/framework/formats/landmark.proto index 3cb77e148..eb6a454ed 100644 --- a/mediapipe/framework/formats/landmark.proto +++ b/mediapipe/framework/formats/landmark.proto @@ -47,8 +47,8 @@ message LandmarkList { repeated Landmark landmark = 1; } -// A normalized version of above Landmark proto. All coordiates should be within -// [0, 1]. +// A normalized version of above Landmark proto. All coordinates should be +// within [0, 1]. message NormalizedLandmark { optional float x = 1; optional float y = 2; diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 5beeb5703..3bc3a1394 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -67,11 +67,11 @@ cc_test( deps = [ ":optical_flow_field", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/flags:flag", "@org_tensorflow//tensorflow/core:framework", ], ) diff --git a/mediapipe/framework/formats/motion/optical_flow_field_test.cc b/mediapipe/framework/formats/motion/optical_flow_field_test.cc index 44474120f..5eb92a806 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field_test.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field_test.cc @@ -18,8 +18,8 @@ #include #include +#include "absl/flags/flag.h" #include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" diff --git a/mediapipe/framework/graph_service.h b/mediapipe/framework/graph_service.h index 983815410..920603929 100644 --- a/mediapipe/framework/graph_service.h +++ b/mediapipe/framework/graph_service.h @@ -41,6 +41,19 @@ struct GraphService : public GraphServiceBase { constexpr GraphService(const char* key) : GraphServiceBase(key) {} }; +template +class ServiceBinding { + public: + bool IsAvailable() { return service_ != nullptr; } + T& GetObject() { return *service_; } + + ServiceBinding() {} + explicit ServiceBinding(std::shared_ptr service) : service_(service) {} + + private: + std::shared_ptr service_; +}; + } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_ diff --git a/mediapipe/framework/graph_service_manager.cc b/mediapipe/framework/graph_service_manager.cc new file mode 100644 index 000000000..dae84afa7 --- /dev/null +++ b/mediapipe/framework/graph_service_manager.cc @@ -0,0 +1,21 @@ +#include "mediapipe/framework/graph_service_manager.h" + +namespace mediapipe { + +absl::Status GraphServiceManager::SetServicePacket( + const GraphServiceBase& service, Packet p) { + // TODO: check service is already set? + service_packets_[service.key] = std::move(p); + return absl::OkStatus(); +} + +Packet GraphServiceManager::GetServicePacket( + const GraphServiceBase& service) const { + auto it = service_packets_.find(service.key); + if (it == service_packets_.end()) { + return {}; + } + return it->second; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/graph_service_manager.h b/mediapipe/framework/graph_service_manager.h new file mode 100644 index 000000000..a8b9cc1fb --- /dev/null +++ b/mediapipe/framework/graph_service_manager.h @@ -0,0 +1,42 @@ +#ifndef MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_MANAGER_H_ +#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_MANAGER_H_ + +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/packet.h" + +namespace mediapipe { + +class GraphServiceManager { + public: + template + absl::Status SetServiceObject(const GraphService& service, + std::shared_ptr object) { + return SetServicePacket(service, + MakePacket>(std::move(object))); + } + + absl::Status SetServicePacket(const GraphServiceBase& service, Packet p); + + template + std::shared_ptr GetServiceObject(const GraphService& service) const { + Packet p = GetServicePacket(service); + if (p.IsEmpty()) return nullptr; + return p.Get>(); + } + + const std::map& ServicePackets() { + return service_packets_; + } + + private: + Packet GetServicePacket(const GraphServiceBase& service) const; + + std::map service_packets_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_MANAGER_H_ diff --git a/mediapipe/framework/graph_service_manager_test.cc b/mediapipe/framework/graph_service_manager_test.cc new file mode 100644 index 000000000..f38148006 --- /dev/null +++ b/mediapipe/framework/graph_service_manager_test.cc @@ -0,0 +1,53 @@ +#include "mediapipe/framework/graph_service_manager.h" + +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +TEST(GraphServiceManager, SetGetServiceObject) { + GraphServiceManager service_manager; + + constexpr GraphService kIntService("mediapipe::IntService"); + EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr); + + MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, + std::make_shared(100))); + ASSERT_NE(service_manager.GetServiceObject(kIntService), nullptr); + EXPECT_EQ(*service_manager.GetServiceObject(kIntService), 100); +} + +TEST(GraphServiceManager, SetServicePacket) { + GraphServiceManager service_manager; + + constexpr GraphService kIntService("mediapipe::IntService"); + + MP_EXPECT_OK(service_manager.SetServicePacket( + kIntService, + mediapipe::MakePacket>(std::make_shared(100)))); + ASSERT_NE(service_manager.GetServiceObject(kIntService), nullptr); + EXPECT_EQ(*service_manager.GetServiceObject(kIntService), 100); +} + +TEST(GraphServiceManager, ServicePackets) { + GraphServiceManager service_manager; + + EXPECT_TRUE(service_manager.ServicePackets().empty()); + + constexpr GraphService kIntService("mediapipe::IntService"); + + MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, + std::make_shared(100))); + + EXPECT_EQ(service_manager.ServicePackets().size(), 1); + ASSERT_NE(service_manager.ServicePackets().find(kIntService.key), + service_manager.ServicePackets().end()); + EXPECT_EQ(*service_manager.ServicePackets() + .at(kIntService.key) + .Get>(), + 100); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index 66d9bfede..4af4c7370 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -365,9 +365,14 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } else { // Any unprocessed input_ts can be processed. - // Note that (min_bound - 1) is the highest fully settled timestamp. - Timestamp input_timestamp = - std::min(min_packet, min_bound.PreviousAllowedInStream()); + // The settled timestamp is the highest timestamp at which no future packets + // can arrive. Timestamp::PostStream is treated specially because it is + // omitted by Timestamp::PreviousAllowedInStream. + Timestamp settled = + (min_packet == Timestamp::PostStream() && min_bound > min_packet) + ? min_packet + : min_bound.PreviousAllowedInStream(); + Timestamp input_timestamp = std::min(min_packet, settled); if (input_timestamp > std::max(last_processed_ts_, Timestamp::Unstarted())) { *min_stream_timestamp = input_timestamp; diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index cc15572d6..a6827d2ef 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -89,18 +89,6 @@ cc_library( ], ) -cc_library( - name = "commandlineflags", - hdrs = [ - "commandlineflags.h", - ], - visibility = ["//visibility:public"], - deps = [ - "//third_party:glog", - "@com_google_absl//absl/flags:flag", - ], -) - cc_library( name = "core_proto", hdrs = [ diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index cabc980d2..ade6bf9e3 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -235,7 +235,6 @@ cc_test( "//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:message_matchers", "//mediapipe/framework/port:advanced_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", @@ -247,6 +246,7 @@ cc_test( "//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", ], ) diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index 4c50a6c91..6731e4d15 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/flags/flag.h" #include "absl/time/time.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -28,7 +29,6 @@ #include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/message_matchers.h" #include "mediapipe/framework/port/advanced_proto_inc.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/framework/profiler/reporter/BUILD b/mediapipe/framework/profiler/reporter/BUILD index 3d92efd8d..511929724 100644 --- a/mediapipe/framework/profiler/reporter/BUILD +++ b/mediapipe/framework/profiler/reporter/BUILD @@ -31,7 +31,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_profile_cc_proto", "//mediapipe/framework/port:advanced_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:re2", @@ -39,6 +38,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index 5ca99aec3..d0f018e1a 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -93,17 +93,17 @@ bool GraphRegistry::IsRegistered(const std::string& ns, absl::StatusOr GraphRegistry::CreateByName( const std::string& ns, const std::string& type_name, - const Subgraph::SubgraphOptions* options) const { - Subgraph::SubgraphOptions graph_options; - if (options) { - graph_options = *options; - } + SubgraphContext* context) const { absl::StatusOr> maker = local_factories_.IsRegistered(ns, type_name) ? local_factories_.Invoke(ns, type_name) : global_factories_->Invoke(ns, type_name); MP_RETURN_IF_ERROR(maker.status()); - return maker.value()->GetConfig(graph_options); + if (context != nullptr) { + return maker.value()->GetConfig(context); + } + SubgraphContext default_context; + return maker.value()->GetConfig(&default_context); } } // namespace mediapipe diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index 64ebc313c..f40def24c 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -19,8 +19,12 @@ #include "absl/base/macros.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/registration.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/calculator_graph_template.pb.h" @@ -28,6 +32,51 @@ namespace mediapipe { +class SubgraphContext { + public: + SubgraphContext() : SubgraphContext(nullptr, nullptr) {} + // @node and/or @service_manager can be nullptr. + SubgraphContext(const CalculatorGraphConfig::Node* node, + const GraphServiceManager* service_manager) + : default_node_(node ? absl::nullopt + : absl::optional( + CalculatorGraphConfig::Node())), + original_node_(node ? *node : default_node_.value()), + default_service_manager_( + service_manager + ? absl::nullopt + : absl::optional(GraphServiceManager())), + service_manager_(service_manager ? *service_manager + : default_service_manager_.value()), + options_map_(std::move(tool::OptionsMap().Initialize(original_node_))) { + } + + template + const T& Options() { + return options_map_.Get(); + } + + const CalculatorGraphConfig::Node& OriginalNode() { return original_node_; } + + template + ServiceBinding Service(const GraphService& service) const { + return ServiceBinding(service_manager_.GetServiceObject(service)); + } + + private: + // Populated if node is not provided during construction. + const absl::optional default_node_; + + const CalculatorGraphConfig::Node& original_node_; + + // Populated if service manager is not provided during construction. + const absl::optional default_service_manager_; + + const GraphServiceManager& service_manager_; + + tool::OptionsMap options_map_; +}; + // Instances of this class are responsible for providing a subgraph config. // They are only used during graph construction. They do not stay alive once // the graph is running. @@ -36,13 +85,25 @@ class Subgraph { using SubgraphOptions = CalculatorGraphConfig::Node; Subgraph(); virtual ~Subgraph(); + // Returns the config to use for one instantiation of the subgraph. The // nodes and generators in this config will replace the subgraph node in // the parent graph. - // Subclasses may use the options argument to parameterize the config. + // Subclasses may use `SubgraphContext*` param to parameterize the config. // TODO: make this static? + virtual absl::StatusOr GetConfig(SubgraphContext* sc) { + if (sc == nullptr) { + return GetConfig(SubgraphOptions{}); + } + return GetConfig(sc->OriginalNode()); + } + + // Kept for backward compatibility - please override `GetConfig` taking + // `SubgraphContext*` param. virtual absl::StatusOr GetConfig( - const SubgraphOptions& options) = 0; + const SubgraphOptions& options) { + return absl::UnimplementedError("Not implemented."); + } // Returns options of a specific type. template @@ -120,7 +181,7 @@ class GraphRegistry { // Returns the specified graph config. absl::StatusOr CreateByName( const std::string& ns, const std::string& type_name, - const Subgraph::SubgraphOptions* options = nullptr) const; + SubgraphContext* context = nullptr) const; static GraphRegistry global_graph_registry; diff --git a/mediapipe/framework/subgraph_test.cc b/mediapipe/framework/subgraph_test.cc index f112d20fb..0428db789 100644 --- a/mediapipe/framework/subgraph_test.cc +++ b/mediapipe/framework/subgraph_test.cc @@ -14,11 +14,16 @@ #include "mediapipe/framework/subgraph.h" +#include + +#include "absl/strings/str_format.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" // Because of portability issues, we include this directly. +#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" // NOLINT(build/deprecated) namespace mediapipe { @@ -75,5 +80,58 @@ TEST_F(SubgraphTest, LinkedSubgraph) { TestGraphEnclosing("DubQuadTestSubgraph"); } +const mediapipe::GraphService kStringTestService{ + "mediapipe::StringTestService"}; +class EmitSideServiceStringTestSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + auto string_service = sc->Service(kStringTestService); + RET_CHECK(string_service.IsAvailable()) << "Service not available"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + absl::StrFormat(R"( + output_side_packet: "string" + node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:string" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "%s" } + } + } + } + )", + string_service.GetObject())); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(EmitSideServiceStringTestSubgraph); + +TEST(SubgraphServicesTest, EmitStringFromTestService) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"( + output_side_packet: "str" + node { + calculator: "EmitSideServiceStringTestSubgraph" + output_side_packet: "str" + } + )"); + + Packet side_string; + tool::AddSidePacketSink("str", &config, &side_string); + + CalculatorGraph graph; + // It's important that service object is set before Initialize() + MP_ASSERT_OK(graph.SetServiceObject( + kStringTestService, std::make_shared("Expected STRING"))); + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(side_string.Get(), "Expected STRING"); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 991814515..fdf35b591 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -19,6 +19,7 @@ load( "data_as_c_string", "mediapipe_binary_graph", ) +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") licenses(["notice"]) @@ -35,9 +36,10 @@ cc_library( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:advanced_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", ], ) @@ -150,6 +152,28 @@ cc_library( ], ) +mediapipe_cc_test( + name = "options_util_test", + size = "small", + srcs = ["options_util_test.cc"], + data = [":node_chain_subgraph.proto"], + requires_full_emulation = False, + deps = [ + ":options_util", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:basic_types_registration", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", + "//mediapipe/util:header_util", + ], +) + cc_library( name = "packet_util", hdrs = ["packet_util.h"], @@ -227,6 +251,7 @@ cc_library( ":name_util", ":tag_map", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:graph_service_manager", "//mediapipe/framework:packet_generator", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:port", @@ -560,7 +585,9 @@ cc_test( ":subgraph_expansion", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:graph_service_manager", "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index ab5a7c464..dbba25ac4 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -25,6 +25,7 @@ #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/core_proto_inc.h" @@ -273,7 +274,8 @@ absl::Status ConnectSubgraphStreams( } absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, - const GraphRegistry* graph_registry) { + const GraphRegistry* graph_registry, + const GraphServiceManager* service_manager) { graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; RET_CHECK(config); @@ -292,9 +294,10 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, int node_id = it - nodes->begin(); std::string node_name = CanonicalNodeName(*config, node_id); MP_RETURN_IF_ERROR(ValidateSubgraphFields(node)); - ASSIGN_OR_RETURN(auto subgraph, - graph_registry->CreateByName(config->package(), - node.calculator(), &node)); + SubgraphContext subgraph_context(&node, service_manager); + ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName( + config->package(), node.calculator(), + &subgraph_context)); MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); subgraphs.push_back(subgraph); diff --git a/mediapipe/framework/tool/subgraph_expansion.h b/mediapipe/framework/tool/subgraph_expansion.h index 5c4e1c5cf..f8eabc27c 100644 --- a/mediapipe/framework/tool/subgraph_expansion.h +++ b/mediapipe/framework/tool/subgraph_expansion.h @@ -19,6 +19,7 @@ #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/subgraph.h" @@ -68,8 +69,10 @@ absl::Status ConnectSubgraphStreams( // Replaces subgraph nodes in the given config with the contents of the // corresponding subgraphs. Nested subgraphs are retrieved from the // graph registry and expanded recursively. -absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, - const GraphRegistry* graph_registry = nullptr); +absl::Status ExpandSubgraphs( + CalculatorGraphConfig* config, + const GraphRegistry* graph_registry = nullptr, + const GraphServiceManager* service_manager = nullptr); // Creates a graph wrapping the provided node and exposing all of its // connections diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index 07e0b512d..b56c08435 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -19,6 +19,8 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/graph_service_manager.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/gmock.h" @@ -526,5 +528,41 @@ TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) { EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); } +const mediapipe::GraphService kStringTestService{ + "mediapipe::StringTestService"}; +class GraphServicesClientTestSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + auto string_service = sc->Service(kStringTestService); + RET_CHECK(string_service.IsAvailable()) << "Service not available"; + CalculatorGraphConfig config; + config.add_node()->set_calculator(string_service.GetObject()); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(GraphServicesClientTestSubgraph); + +TEST(SubgraphExpansionTest, GraphServicesUsage) { + CalculatorGraphConfig supergraph = + mediapipe::ParseTextProtoOrDie(R"( + node { calculator: "GraphServicesClientTestSubgraph" } + )"); + + CalculatorGraphConfig expected_graph = + mediapipe::ParseTextProtoOrDie(R"( + node { + name: "graphservicesclienttestsubgraph__ExpectedNode" + calculator: "ExpectedNode" + } + )"); + GraphServiceManager service_manager; + MP_ASSERT_OK(service_manager.SetServiceObject( + kStringTestService, std::make_shared("ExpectedNode"))); + MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr, + &service_manager)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/text_to_binary_graph.cc b/mediapipe/framework/tool/text_to_binary_graph.cc index 4282f748e..b6b38dea7 100644 --- a/mediapipe/framework/tool/text_to_binary_graph.cc +++ b/mediapipe/framework/tool/text_to_binary_graph.cc @@ -19,19 +19,19 @@ #include #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -DEFINE_string(proto_source, "", - "The template source file containing CalculatorGraphConfig " - "protobuf text with inline template params."); -DEFINE_string( - proto_output, "", - "An output template file in binary CalculatorGraphTemplate form."); +ABSL_FLAG(std::string, proto_source, "", + "The template source file containing CalculatorGraphConfig " + "protobuf text with inline template params."); +ABSL_FLAG(std::string, proto_output, "", + "An output template file in binary CalculatorGraphTemplate form."); #define EXIT_IF_ERROR(status) \ if (!status.ok()) { \ @@ -92,7 +92,7 @@ absl::Status WriteFile(const std::string& proto_output, bool write_text, int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); + absl::ParseCommandLine(argc, argv); // Validate command line options. absl::Status status; diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 559a4a53c..b4173a9c2 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -21,6 +21,7 @@ #include "absl/strings/substitute.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/legacy_calculator_support.h" #include "mediapipe/framework/packet_generator.h" #include "mediapipe/framework/packet_generator.pb.h" @@ -142,10 +143,11 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) { absl::Status PerformBasicTransforms( const CalculatorGraphConfig& input_graph_config, const GraphRegistry* graph_registry, + const GraphServiceManager* service_manager, CalculatorGraphConfig* output_graph_config) { *output_graph_config = input_graph_config; - MP_RETURN_IF_ERROR( - tool::ExpandSubgraphs(output_graph_config, graph_registry)); + MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry, + service_manager)); MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config)); @@ -344,7 +346,8 @@ absl::Status NodeTypeInfo::Initialize( absl::Status ValidatedGraphConfig::Initialize( const CalculatorGraphConfig& input_config, - const GraphRegistry* graph_registry) { + const GraphRegistry* graph_registry, + const GraphServiceManager* service_manager) { RET_CHECK(!initialized_) << "ValidatedGraphConfig can be initialized only once."; @@ -353,8 +356,8 @@ absl::Status ValidatedGraphConfig::Initialize( << input_config.DebugString(); #endif - MP_RETURN_IF_ERROR( - PerformBasicTransforms(input_config, graph_registry, &config_)); + MP_RETURN_IF_ERROR(PerformBasicTransforms(input_config, graph_registry, + service_manager, &config_)); // Initialize the basic node information. MP_RETURN_IF_ERROR(InitializeGeneratorInfo()); @@ -429,18 +432,22 @@ absl::Status ValidatedGraphConfig::Initialize( absl::Status ValidatedGraphConfig::Initialize( const std::string& graph_type, const Subgraph::SubgraphOptions* options, - const GraphRegistry* graph_registry) { + const GraphRegistry* graph_registry, + const GraphServiceManager* service_manager) { graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; - auto status_or_config = graph_registry->CreateByName("", graph_type, options); + SubgraphContext subgraph_context(options, service_manager); + auto status_or_config = + graph_registry->CreateByName("", graph_type, &subgraph_context); MP_RETURN_IF_ERROR(status_or_config.status()); - return Initialize(status_or_config.value(), graph_registry); + return Initialize(status_or_config.value(), graph_registry, service_manager); } absl::Status ValidatedGraphConfig::Initialize( const std::vector& input_configs, const std::vector& input_templates, - const std::string& graph_type, const Subgraph::SubgraphOptions* options) { + const std::string& graph_type, const Subgraph::SubgraphOptions* arguments, + const GraphServiceManager* service_manager) { GraphRegistry graph_registry; for (auto& config : input_configs) { graph_registry.Register(config.type(), config); @@ -448,7 +455,7 @@ absl::Status ValidatedGraphConfig::Initialize( for (auto& templ : input_templates) { graph_registry.Register(templ.config().type(), templ); } - return Initialize(graph_type, options, &graph_registry); + return Initialize(graph_type, arguments, &graph_registry, service_manager); } absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() { diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index f509707f5..0bdaac251 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_contract.h" +#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/map_util.h" @@ -195,7 +196,8 @@ class ValidatedGraphConfig { // before any other functions. Subgraphs are specified through the // global graph registry or an optional local graph registry. absl::Status Initialize(const CalculatorGraphConfig& input_config, - const GraphRegistry* graph_registry = nullptr); + const GraphRegistry* graph_registry = nullptr, + const GraphServiceManager* service_manager = nullptr); // Initializes the ValidatedGraphConfig from registered graph and subgraph // configs. Subgraphs are retrieved from the specified graph registry or from @@ -203,7 +205,8 @@ class ValidatedGraphConfig { // specifying its type in |graph_type|. absl::Status Initialize(const std::string& graph_type, const Subgraph::SubgraphOptions* options = nullptr, - const GraphRegistry* graph_registry = nullptr); + const GraphRegistry* graph_registry = nullptr, + const GraphServiceManager* service_manager = nullptr); // Initializes the ValidatedGraphConfig from the specified graph and subgraph // configs. Template graph and subgraph configs can be specified through @@ -215,7 +218,8 @@ class ValidatedGraphConfig { const std::vector& input_configs, const std::vector& input_templates, const std::string& graph_type = "", - const Subgraph::SubgraphOptions* arguments = nullptr); + const Subgraph::SubgraphOptions* arguments = nullptr, + const GraphServiceManager* service_manager = nullptr); // Returns true if the ValidatedGraphConfig has been initialized. bool Initialized() const { return initialized_; } diff --git a/mediapipe/framework/validated_graph_config_test.cc b/mediapipe/framework/validated_graph_config_test.cc new file mode 100644 index 000000000..e293461b9 --- /dev/null +++ b/mediapipe/framework/validated_graph_config_test.cc @@ -0,0 +1,165 @@ +#include "mediapipe/framework/validated_graph_config.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +class NoOp : public mediapipe::api2::Node { + public: + static constexpr mediapipe::api2::Input::Optional kInputNotNeeded{"NN"}; + static constexpr mediapipe::api2::Output::Optional kOutputNotNeeded{ + "NN"}; + MEDIAPIPE_NODE_CONTRACT(kInputNotNeeded, kOutputNotNeeded); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); + } +}; + +using CalculatorA = NoOp; +MEDIAPIPE_REGISTER_NODE(CalculatorA); +using CalculatorB = NoOp; +MEDIAPIPE_REGISTER_NODE(CalculatorB); +using CalculatorC = NoOp; +MEDIAPIPE_REGISTER_NODE(CalculatorC); + +CalculatorGraphConfig ExpectedConfig(const std::string& node_name) { + CalculatorGraphConfig config; + config.add_node()->set_calculator(node_name); + config.add_executor(); + return config; +} + +CalculatorGraphConfig ExpectedConfigExpandedFromGraph( + const std::string& graph_name, const std::string& node_name) { + CalculatorGraphConfig config; + auto* node = config.add_node(); + node->set_calculator(node_name); + node->set_name( + absl::StrCat(absl::AsciiStrToLower(graph_name), "__", node_name)); + config.add_executor(); + return config; +} + +class AlwaysCalculatorALegacySubgraph : public Subgraph { + absl::StatusOr GetConfig( + const SubgraphOptions& options) override { + return ExpectedConfig("CalculatorA"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(AlwaysCalculatorALegacySubgraph); + +TEST(ValidatedGraphConfigTest, InitializeByTypeLegacySubgraphHardcoded) { + ValidatedGraphConfig config; + MP_EXPECT_OK(config.Initialize("AlwaysCalculatorALegacySubgraph", + /*options=*/nullptr, + /*graph_registry=*/nullptr, + /*service_manager=*/nullptr)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfig("CalculatorA"))); +} + +TEST(ValidatedGraphConfigTest, InitializeLegacySubgraphHardcoded) { + CalculatorGraphConfig graph; + graph.add_node()->set_calculator("AlwaysCalculatorALegacySubgraph"); + + ValidatedGraphConfig config; + MP_EXPECT_OK(config.Initialize(graph, + /*graph_registry=*/nullptr, + /*service_manager=*/nullptr)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), + EqualsProto(ExpectedConfigExpandedFromGraph( + "AlwaysCalculatorALegacySubgraph", "CalculatorA"))); +} + +class AlwaysCalculatorASubgraph : public Subgraph { + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + return ExpectedConfig("CalculatorA"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(AlwaysCalculatorASubgraph); + +TEST(ValidatedGraphConfigTest, InitializeByTypeSubgraphHardcoded) { + ValidatedGraphConfig config; + MP_EXPECT_OK(config.Initialize("AlwaysCalculatorASubgraph", + /*options=*/nullptr, + /*graph_registry=*/nullptr, + /*service_manager=*/nullptr)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfig("CalculatorA"))); +} + +TEST(ValidatedGraphConfigTest, InitializeSubgraphHardcoded) { + CalculatorGraphConfig graph; + graph.add_node()->set_calculator("AlwaysCalculatorASubgraph"); + + ValidatedGraphConfig config; + MP_EXPECT_OK(config.Initialize(graph, + /*graph_registry=*/nullptr, + /*service_manager=*/nullptr)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), + EqualsProto(ExpectedConfigExpandedFromGraph( + "AlwaysCalculatorASubgraph", "CalculatorA"))); +} + +const mediapipe::GraphService kStringTestService{ + "mediapipe::StringTestService"}; + +class TestServiceSubgraph : public Subgraph { + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + return ExpectedConfig(sc->Service(kStringTestService).GetObject()); + } +}; +REGISTER_MEDIAPIPE_GRAPH(TestServiceSubgraph); + +TEST(ValidatedGraphConfigTest, InitializeByTypeSubgraphWithServiceCalculatorB) { + for (const std::string& calculator_name : + {"CalculatorA", "CalculatorB", "CalculatorC"}) { + ValidatedGraphConfig config; + GraphServiceManager service_manager; + MP_ASSERT_OK(service_manager.SetServiceObject( + kStringTestService, std::make_shared(calculator_name))); + MP_EXPECT_OK(config.Initialize("TestServiceSubgraph", + /*options=*/nullptr, + /*graph_registry=*/nullptr, + /*service_manager=*/&service_manager)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfig(calculator_name))); + } +} + +TEST(ValidatedGraphConfigTest, InitializeSubgraphWithServiceCalculatorB) { + for (const std::string& calculator_name : + {"CalculatorA", "CalculatorB", "CalculatorC"}) { + CalculatorGraphConfig graph; + graph.add_node()->set_calculator("TestServiceSubgraph"); + + ValidatedGraphConfig config; + GraphServiceManager service_manager; + MP_ASSERT_OK(service_manager.SetServiceObject( + kStringTestService, std::make_shared(calculator_name))); + MP_EXPECT_OK(config.Initialize(graph, + /*graph_registry=*/nullptr, + /*service_manager=*/&service_manager)); + ASSERT_TRUE(config.Initialized()); + EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph( + "TestServiceSubgraph", calculator_name))); + } +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 92c67035a..05a18bf4c 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -782,7 +782,7 @@ bool GlContext::CheckForGlErrors() { return CheckForGlErrors(false); } bool GlContext::CheckForGlErrors(bool force) { #if UNSAFE_EMSCRIPTEN_SKIP_GL_ERROR_HANDLING if (!force) { - LOG_FIRST_N(WARNING, 1) << "MediaPipe OpenGL error checking is disabled"; + LOG_FIRST_N(WARNING, 1) << "OpenGL error checking is disabled"; return false; } #endif diff --git a/mediapipe/gpu/gl_texture_buffer_pool.cc b/mediapipe/gpu/gl_texture_buffer_pool.cc index 76902256d..3d5a8cdaa 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.cc +++ b/mediapipe/gpu/gl_texture_buffer_pool.cc @@ -52,15 +52,15 @@ GlTextureBufferSharedPtr GlTextureBufferPool::GetBuffer() { // Return a shared_ptr with a custom deleter that adds the buffer back // to our available list. std::weak_ptr weak_pool(shared_from_this()); - return std::shared_ptr(buffer.release(), - [weak_pool](GlTextureBuffer* buf) { - auto pool = weak_pool.lock(); - if (pool) { - pool->Return(buf); - } else { - delete buf; - } - }); + return std::shared_ptr( + buffer.release(), [weak_pool](GlTextureBuffer* buf) { + auto pool = weak_pool.lock(); + if (pool) { + pool->Return(absl::WrapUnique(buf)); + } else { + delete buf; + } + }); } std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { @@ -68,12 +68,12 @@ std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { return {in_use_count_, available_.size()}; } -void GlTextureBufferPool::Return(GlTextureBuffer* buf) { +void GlTextureBufferPool::Return(std::unique_ptr buf) { std::vector> trimmed; { absl::MutexLock lock(&mutex_); --in_use_count_; - available_.emplace_back(buf); + available_.emplace_back(std::move(buf)); TrimAvailable(&trimmed); } // The trimmed buffers will be released without holding the lock. diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index d2e8fc39f..4dcad305e 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -56,7 +56,7 @@ class GlTextureBufferPool int keep_count); // Return a buffer to the pool. - void Return(GlTextureBuffer* buf); + void Return(std::unique_ptr buf); // If the total number of buffers is greater than keep_count, destroys any // surplus buffers that are no longer in use. diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/BUILD b/mediapipe/graphs/instant_motion_tracking/calculators/BUILD index b8242cfdd..93af68c21 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/BUILD +++ b/mediapipe/graphs/instant_motion_tracking/calculators/BUILD @@ -65,7 +65,7 @@ cc_library( "//mediapipe/modules/objectron/calculators:box", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 5550128af..8f803124a 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -89,7 +89,7 @@ cc_library( "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) diff --git a/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt b/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt index bc8b78b34..0a962d7d0 100644 --- a/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt +++ b/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt @@ -35,7 +35,7 @@ node { } # Subgraph that renders annotations and overlays them on top of the input -# images (see renderer_gpu.pbtxt). +# images (see renderer_cpu.pbtxt). node { calculator: "RendererSubgraph" input_stream: "IMAGE:input_video" diff --git a/mediapipe/java/com/google/mediapipe/framework/Graph.java b/mediapipe/java/com/google/mediapipe/framework/Graph.java index ede4ed3db..b90e51d8a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/Graph.java +++ b/mediapipe/java/com/google/mediapipe/framework/Graph.java @@ -33,8 +33,8 @@ public class Graph { private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private static final int MAX_BUFFER_SIZE = 20; private long nativeGraphHandle; - // Hold the references to callbacks. - private final List packetCallbacks = new ArrayList<>(); + // Hold the references to callbacks (PacketCallback and PacketListCallback). + private final List callbacks = new ArrayList<>(); // Side packets used for running the graph. private Map sidePackets = new HashMap<>(); // Stream headers used for running the graph. @@ -151,10 +151,29 @@ public class Graph { Preconditions.checkNotNull(streamName); Preconditions.checkNotNull(callback); Preconditions.checkState(!graphRunning && !startRunningGraphCalled); - packetCallbacks.add(callback); + callbacks.add(callback); nativeAddPacketCallback(nativeGraphHandle, streamName, callback); } + /** + * Adds a {@link PacketListCallback} to the context for callback during graph running. + * + * @param streamNames The output stream names in the graph for callback. + * @param callback The callback for handling the call when all output streams listed in + * streamNames get {@link Packet}. + * @throws MediaPipeException for any error status. + */ + public synchronized void addMultiStreamCallback( + List streamNames, PacketListCallback callback) { + Preconditions.checkState( + nativeGraphHandle != 0, "Invalid context, tearDown() might have been called already."); + Preconditions.checkNotNull(streamNames); + Preconditions.checkNotNull(callback); + Preconditions.checkState(!graphRunning && !startRunningGraphCalled); + callbacks.add(callback); + nativeAddMultiStreamCallback(nativeGraphHandle, streamNames, callback); + } + /** * Adds a {@link SurfaceOutput} for a stream producing GpuBuffers. * @@ -443,7 +462,7 @@ public class Graph { nativeGraphHandle = 0; } } - packetCallbacks.clear(); + callbacks.clear(); } /** @@ -580,6 +599,9 @@ public class Graph { private native void nativeAddPacketCallback( long context, String streamName, PacketCallback callback); + private native void nativeAddMultiStreamCallback( + long context, List streamName, PacketListCallback callback); + private native long nativeAddSurfaceOutput(long context, String streamName); private native void nativeLoadBinaryGraph(long context, String path); diff --git a/mediapipe/framework/port/commandlineflags.h b/mediapipe/java/com/google/mediapipe/framework/PacketListCallback.java similarity index 60% rename from mediapipe/framework/port/commandlineflags.h rename to mediapipe/java/com/google/mediapipe/framework/PacketListCallback.java index a3d17c71e..f256e41bd 100644 --- a/mediapipe/framework/port/commandlineflags.h +++ b/mediapipe/java/com/google/mediapipe/framework/PacketListCallback.java @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,19 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ -#define MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ +package com.google.mediapipe.framework; -#include "gflags/gflags.h" -namespace absl { -template -T GetFlag(const T& f) { - return f; -} -template -void SetFlag(T* f, const U& u) { - *f = u; -} -} // namespace absl +import java.util.List; -#endif // MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ +/** Interface for MediaPipe callback with packets from multiple output streams. */ +public interface PacketListCallback { + public void process(List packets); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 3fe9efd1f..cd98e4595 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -93,7 +93,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", "//mediapipe/framework:camera_intrinsics", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:matrix", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h b/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h index 998c070ee..1cf0fb2ce 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h @@ -49,6 +49,8 @@ class ClassRegistry { "com/google/mediapipe/framework/MediaPipeException"; static constexpr char const* kPacketCallbackClassName = "com/google/mediapipe/framework/PacketCallback"; + static constexpr char const* kPacketListCallbackClassName = + "com/google/mediapipe/framework/PacketListCallback"; static constexpr char const* kPacketCreatorClassName = "com/google/mediapipe/framework/PacketCreator"; static constexpr char const* kPacketGetterClassName = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index dde43f567..e244c1186 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -89,11 +89,21 @@ class CallbackHandler { packet, header); } + void PacketListCallback(const std::vector& packets) { + context_->CallbackToJava(mediapipe::java::GetJNIEnv(), java_callback_, + packets); + } + std::function CreateCallback() { return std::bind(&CallbackHandler::PacketCallback, this, std::placeholders::_1); } + std::function&)> CreatePacketListCallback() { + return std::bind(&CallbackHandler::PacketListCallback, this, + std::placeholders::_1); + } + std::function CreateCallbackWithHeader() { return std::bind(&CallbackHandler::PacketWithHeaderCallback, this, std::placeholders::_1, std::placeholders::_2); @@ -191,6 +201,23 @@ absl::Status Graph::AddCallbackHandler(std::string output_stream_name, return absl::OkStatus(); } +absl::Status Graph::AddMultiStreamCallbackHandler( + std::vector output_stream_names, jobject java_callback) { + if (!graph_config()) { + return absl::InternalError("Graph is not loaded!"); + } + auto handler = + absl::make_unique(this, java_callback); + std::pair side_packet_pair; + tool::AddMultiStreamCallback(output_stream_names, + handler->CreatePacketListCallback(), + graph_config(), &side_packet_pair); + side_packets_[side_packet_pair.first] = side_packet_pair.second; + EnsureMinimumExecutorStackSizeForJava(); + callback_handlers_.emplace_back(std::move(handler)); + return absl::OkStatus(); +} + int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { if (!graph_config()) { LOG(ERROR) << "Graph is not loaded!"; @@ -333,6 +360,43 @@ void Graph::CallbackToJava(JNIEnv* env, jobject java_callback_obj, env->DeleteLocalRef(java_header_packet); } +void Graph::CallbackToJava(JNIEnv* env, jobject java_callback_obj, + const std::vector& packets) { + jclass callback_cls = env->GetObjectClass(java_callback_obj); + + auto& class_registry = mediapipe::android::ClassRegistry::GetInstance(); + const std::string process_method_name = class_registry.GetMethodName( + mediapipe::android::ClassRegistry::kPacketListCallbackClassName, + "process"); + jmethodID processMethod = env->GetMethodID( + callback_cls, process_method_name.c_str(), "(Ljava/util/List;)V"); + + jclass list_cls = env->FindClass("java/util/ArrayList"); + jobject java_list = + env->NewObject(list_cls, env->GetMethodID(list_cls, "", "()V")); + jmethodID add_method = + env->GetMethodID(list_cls, "add", "(Ljava/lang/Object;)Z"); + std::vector packet_handles; + for (const Packet& packet : packets) { + int64_t packet_handle = WrapPacketIntoContext(packet); + packet_handles.push_back(packet_handle); + jobject java_packet = + CreateJavaPacket(env, global_java_packet_cls_, packet_handle); + env->CallBooleanMethod(java_list, add_method, java_packet); + env->DeleteLocalRef(java_packet); + } + + VLOG(2) << "Calling java callback."; + env->CallVoidMethod(java_callback_obj, processMethod, java_list); + // release the packet after callback. + for (int64_t packet_handle : packet_handles) { + RemovePacket(packet_handle); + } + env->DeleteLocalRef(callback_cls); + env->DeleteLocalRef(java_list); + VLOG(2) << "Returned from java callback."; +} + void Graph::SetPacketJavaClass(JNIEnv* env) { if (global_java_packet_cls_ == nullptr) { auto& class_registry = ClassRegistry::GetInstance(); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 488920f8e..2a29c04cb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -51,6 +51,9 @@ class Graph { // Adds a callback for a given stream name. absl::Status AddCallbackHandler(std::string output_stream_name, jobject java_callback); + // Adds a callback for multiple output streams. + absl::Status AddMultiStreamCallbackHandler( + std::vector output_stream_names, jobject java_callback); // Loads a binary graph from a file. absl::Status LoadBinaryGraph(std::string path_to_graph); @@ -158,6 +161,10 @@ class Graph { void CallbackToJava(JNIEnv* env, jobject java_callback_obj, const Packet& packet, const Packet& header_packet); + // Invokes a Java packet list callback. + void CallbackToJava(JNIEnv* env, jobject java_callback_obj, + const std::vector& packets); + ProfilingContext* GetProfilingContext(); private: diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index a9ed0ccc8..ec8cc3efd 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -15,6 +15,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h" #include +#include #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -23,6 +24,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" +using mediapipe::android::JavaListToStdStringVector; using mediapipe::android::JStringToStdString; using mediapipe::android::ThrowIfError; @@ -187,6 +189,34 @@ GRAPH_METHOD(nativeAddPacketCallback)(JNIEnv* env, jobject thiz, jlong context, global_callback_ref)); } +JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( + JNIEnv* env, jobject thiz, jlong context, jobject stream_names, + jobject callback) { + mediapipe::android::Graph* mediapipe_graph = + reinterpret_cast(context); + std::vector output_stream_names = + JavaListToStdStringVector(env, stream_names); + for (const std::string& s : output_stream_names) { + if (s.empty()) { + ThrowIfError(env, + absl::InternalError("streamNames is not correctly parsed or " + "it contains empty std::string.")); + return; + } + } + + // Create a global reference to the callback object, so that it can + // be accessed later. + jobject global_callback_ref = env->NewGlobalRef(callback); + if (!global_callback_ref) { + ThrowIfError(env, + absl::InternalError("Failed to allocate packets callback")); + return; + } + ThrowIfError(env, mediapipe_graph->AddMultiStreamCallbackHandler( + output_stream_names, global_callback_ref)); +} + JNIEXPORT jlong JNICALL GRAPH_METHOD(nativeAddSurfaceOutput)( JNIEnv* env, jobject thiz, jlong context, jstring stream_name) { mediapipe::android::Graph* mediapipe_graph = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h index fdd63f508..c7c321171 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h @@ -62,6 +62,10 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddPacketCallback)( JNIEnv* env, jobject thiz, jlong context, jstring stream_name, jobject callback); +JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( + JNIEnv* env, jobject thiz, jlong context, jobject stream_names, + jobject callback); + JNIEXPORT jlong JNICALL GRAPH_METHOD(nativeAddSurfaceOutput)( JNIEnv* env, jobject thiz, jlong context, jstring stream_name); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc index 08b340495..fa33db570 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc @@ -111,6 +111,22 @@ std::string JStringToStdString(JNIEnv* env, jstring jstr) { return str; } +// Converts a `java.util.List` to a `std::vector`. +std::vector JavaListToStdStringVector(JNIEnv* env, jobject from) { + jclass cls = env->FindClass("java/util/List"); + int size = env->CallIntMethod(from, env->GetMethodID(cls, "size", "()I")); + std::vector result; + result.reserve(size); + for (int i = 0; i < size; i++) { + jobject element = env->CallObjectMethod( + from, env->GetMethodID(cls, "get", "(I)Ljava/lang/Object;"), i); + result.push_back(JStringToStdString(env, static_cast(element))); + env->DeleteLocalRef(element); + } + env->DeleteLocalRef(cls); + return result; +} + jthrowable CreateMediaPipeException(JNIEnv* env, absl::Status status) { auto& class_registry = mediapipe::android::ClassRegistry::GetInstance(); std::string mpe_class_name = class_registry.GetClassName( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h index 2524467ff..3e5639ef1 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h @@ -27,6 +27,8 @@ namespace android { std::string JStringToStdString(JNIEnv* env, jstring jstr); +std::vector JavaListToStdStringVector(JNIEnv* env, jobject from); + // Creates a java MediaPipeException object for a absl::Status. jthrowable CreateMediaPipeException(JNIEnv* env, absl::Status status); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index d687fbecd..63e59c011 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -103,6 +103,13 @@ void RegisterGraphNatives(JNIEnv *env) { AddJNINativeMethod(&graph_methods, graph, "nativeAddPacketCallback", native_add_packet_callback_signature.c_str(), (void *)&GRAPH_METHOD(nativeAddPacketCallback)); + std::string packet_list_callback_name = class_registry.GetClassName( + mediapipe::android::ClassRegistry::kPacketListCallbackClassName); + std::string native_add_multi_stream_callback_signature = + absl::StrFormat("(JLjava/lang/List;L%s;)V", packet_list_callback_name); + AddJNINativeMethod(&graph_methods, graph, "nativeAddMultiStreamCallback", + native_add_multi_stream_callback_signature.c_str(), + (void *)&GRAPH_METHOD(nativeAddMultiStreamCallback)); AddJNINativeMethod(&graph_methods, graph, "nativeMovePacketToInputStream", "(JLjava/lang/String;JJ)V", (void *)&GRAPH_METHOD(nativeMovePacketToInputStream)); diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 5aa4fa20e..7f8ee7079 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -139,6 +139,12 @@ cat > $(OUTS) < $(OUTS) <': _PacketDataType.PROTO_LIST, + '::std::vector<::mediapipe::LandmarkList>': + _PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::NormalizedLandmark>': _PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::NormalizedLandmarkList>': diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index 06936741a..40259153e 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -28,6 +28,8 @@ from mediapipe.framework.formats import landmark_pb2 PRESENCE_THRESHOLD = 0.5 RGB_CHANNELS = 3 RED_COLOR = (0, 0, 255) +GREEN_COLOR = (0, 128, 0) +BLUE_COLOR = (255, 0, 0) VISIBILITY_THRESHOLD = 0.5 @@ -178,9 +180,7 @@ def draw_axis( focal_length: Tuple[float, float] = (1.0, 1.0), principal_point: Tuple[float, float] = (0.0, 0.0), axis_length: float = 0.1, - x_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(0, 0, 255)), - y_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(0, 128, 0)), - z_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(255, 0, 0))): + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -190,12 +190,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - x_axis_drawing_spec: A DrawingSpec object that specifies the x axis - drawing settings such as color, line thickness. - y_axis_drawing_spec: A DrawingSpec object that specifies the y axis - drawing settings such as color, line thickness. - z_axis_drawing_spec: A DrawingSpec object that specifies the z axis - drawing settings such as color, line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis + drawing settings such as line thickness. Raises: ValueError: If one of the followings: @@ -213,8 +209,8 @@ def draw_axis( # Project 3D points to NDC space. fx, fy = focal_length px, py = principal_point - x_ndc = -fx * x / z + px - y_ndc = -fy * y / z + py + x_ndc = np.clip(-fx * x / (z + 1e-5) + px, -1., 1.) + y_ndc = np.clip(-fy * y / (z + 1e-5) + py, -1., 1.) # Convert from NDC space to image space. x_im = np.int32((1 + x_ndc) * 0.5 * image_cols) y_im = np.int32((1 - y_ndc) * 0.5 * image_rows) @@ -223,9 +219,9 @@ def draw_axis( x_axis = (x_im[1], y_im[1]) y_axis = (x_im[2], y_im[2]) z_axis = (x_im[3], y_im[3]) - image = cv2.arrowedLine(image, origin, x_axis, x_axis_drawing_spec.color, - x_axis_drawing_spec.thickness) - image = cv2.arrowedLine(image, origin, y_axis, y_axis_drawing_spec.color, - y_axis_drawing_spec.thickness) - image = cv2.arrowedLine(image, origin, z_axis, z_axis_drawing_spec.color, - z_axis_drawing_spec.thickness) + cv2.arrowedLine(image, origin, x_axis, RED_COLOR, + axis_drawing_spec.thickness) + cv2.arrowedLine(image, origin, y_axis, GREEN_COLOR, + axis_drawing_spec.thickness) + cv2.arrowedLine(image, origin, z_axis, BLUE_COLOR, + axis_drawing_spec.thickness) diff --git a/mediapipe/python/solutions/drawing_utils_test.py b/mediapipe/python/solutions/drawing_utils_test.py index 6391aca80..e8e0bdfcc 100644 --- a/mediapipe/python/solutions/drawing_utils_test.py +++ b/mediapipe/python/solutions/drawing_utils_test.py @@ -28,6 +28,7 @@ from mediapipe.python.solutions import drawing_utils DEFAULT_BBOX_DRAWING_SPEC = drawing_utils.DrawingSpec() DEFAULT_CONNECTION_DRAWING_SPEC = drawing_utils.DrawingSpec() DEFAULT_CIRCLE_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255)) +DEFAULT_AXIS_DRAWING_SPEC = drawing_utils.DrawingSpec() class DrawingUtilTest(parameterized.TestCase): @@ -40,6 +41,11 @@ class DrawingUtilTest(parameterized.TestCase): with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): drawing_utils.draw_detection(image, detection_pb2.Detection()) + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + rotation = np.eye(3, dtype=np.float32) + translation = np.array([0., 0., 1.]) + drawing_utils.draw_axis(image, rotation, translation) def test_invalid_connection(self): landmark_list = text_format.Parse( @@ -133,6 +139,43 @@ class DrawingUtilTest(parameterized.TestCase): image=image, landmark_list=landmark_list, connections=[(0, 1)]) np.testing.assert_array_equal(image, expected_result) + def test_draw_axis(self): + image = np.zeros((100, 100, 3), np.uint8) + expected_result = np.copy(image) + origin = (50, 50) + x_axis = (75, 50) + y_axis = (50, 22) + z_axis = (50, 77) + cv2.arrowedLine(expected_result, origin, x_axis, drawing_utils.RED_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + cv2.arrowedLine(expected_result, origin, y_axis, drawing_utils.GREEN_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + cv2.arrowedLine(expected_result, origin, z_axis, drawing_utils.BLUE_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + r = np.sqrt(2.) / 2. + rotation = np.array([[1., 0., 0.], [0., r, -r], [0., r, r]]) + translation = np.array([0, 0, -0.2]) + drawing_utils.draw_axis(image, rotation, translation) + np.testing.assert_array_equal(image, expected_result) + + def test_draw_axis_zero_translation(self): + image = np.zeros((100, 100, 3), np.uint8) + expected_result = np.copy(image) + origin = (50, 50) + x_axis = (0, 50) + y_axis = (50, 100) + z_axis = (50, 50) + cv2.arrowedLine(expected_result, origin, x_axis, drawing_utils.RED_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + cv2.arrowedLine(expected_result, origin, y_axis, drawing_utils.GREEN_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + cv2.arrowedLine(expected_result, origin, z_axis, drawing_utils.BLUE_COLOR, + DEFAULT_AXIS_DRAWING_SPEC.thickness) + rotation = np.eye(3, dtype=np.float32) + translation = np.zeros((3,), dtype=np.float32) + drawing_utils.draw_axis(image, rotation, translation) + np.testing.assert_array_equal(image, expected_result) + def test_min_and_max_coordinate_values(self): landmark_list = text_format.Parse( 'landmark {x: 0.0 y: 1.0}' diff --git a/mediapipe/python/solutions/objectron.py b/mediapipe/python/solutions/objectron.py index 7802c814b..9681b645a 100644 --- a/mediapipe/python/solutions/objectron.py +++ b/mediapipe/python/solutions/objectron.py @@ -15,7 +15,10 @@ """MediaPipe Objectron.""" import enum +import os +import shutil from typing import List, Tuple, NamedTuple, Optional +import urllib.request import attr import numpy as np @@ -89,6 +92,23 @@ BOX_CONNECTIONS = frozenset([ (BoxLandmark.FRONT_BOTTOM_RIGHT, BoxLandmark.FRONT_TOP_RIGHT), (BoxLandmark.BACK_TOP_RIGHT, BoxLandmark.FRONT_TOP_RIGHT), ]) +_OSS_URL_PREFIX = 'https://github.com/google/mediapipe/raw/master/' + + +def _download_oss_model(model_path: str): + """Download the objectron oss model from GitHub if it doesn't exist in the package.""" + + mp_root_path = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-4]) + model_abspath = os.path.join(mp_root_path, model_path) + if os.path.exists(model_abspath): + return + model_url = _OSS_URL_PREFIX + model_path + with urllib.request.urlopen(model_url) as response, open(model_abspath, + 'wb') as out_file: + if response.code != 200: + raise ConnectionError('Cannot download ' + model_path + + ' from the MediaPipe Github repo.') + shutil.copyfileobj(response, out_file) @attr.s(auto_attribs=True) @@ -132,9 +152,10 @@ _MODEL_DICT = { } -def GetModelByName(name: str) -> ObjectronModel: +def get_model_by_name(name: str) -> ObjectronModel: if name not in _MODEL_DICT: raise ValueError(f'{name} is not a valid model name for Objectron.') + _download_oss_model(_MODEL_DICT[name].model_path) return _MODEL_DICT[name] @@ -186,6 +207,10 @@ class Objectron(SolutionBase): conversions inside the API. image_size (Optional): size (image_width, image_height) of the input image , ONLY needed when use focal_length and principal_point in pixel space. + + Raises: + ConnectionError: If the objectron open source model can't be downloaded + from the MediaPipe Github repo. """ # Get Camera parameters. fx, fy = focal_length @@ -199,7 +224,7 @@ class Objectron(SolutionBase): py = - (py - half_height) / half_height # Create and init model. - model = GetModelByName(model_name) + model = get_model_by_name(model_name) super().__init__( binary_graph_path=BINARYPB_FILE_PATH, side_inputs={ @@ -275,4 +300,3 @@ class Objectron(SolutionBase): new_outputs.append(ObjectronOutputs(landmarks_2d, landmarks_3d, rotation, translation, scale=scale)) return new_outputs - diff --git a/mediapipe/python/solutions/pose.py b/mediapipe/python/solutions/pose.py index c295bbe7a..74d8166af 100644 --- a/mediapipe/python/solutions/pose.py +++ b/mediapipe/python/solutions/pose.py @@ -36,6 +36,7 @@ from mediapipe.calculators.util import logic_calculator_pb2 from mediapipe.calculators.util import non_max_suppression_calculator_pb2 from mediapipe.calculators.util import rect_transformation_calculator_pb2 from mediapipe.calculators.util import thresholding_calculator_pb2 +from mediapipe.calculators.util import visibility_smoothing_calculator_pb2 # pylint: enable=unused-import from mediapipe.python.solution_base import SolutionBase diff --git a/mediapipe/python/solutions/pose_test.py b/mediapipe/python/solutions/pose_test.py index b15408b39..b5d108460 100644 --- a/mediapipe/python/solutions/pose_test.py +++ b/mediapipe/python/solutions/pose_test.py @@ -13,7 +13,9 @@ # limitations under the License. """Tests for mediapipe.python.solutions.pose.""" +import json import os +import tempfile from absl.testing import absltest from absl.testing import parameterized @@ -52,7 +54,7 @@ class PoseTest(parameterized.TestCase): def _landmarks_list_to_array(self, landmark_list, image_shape): rows, cols, _ = image_shape - return np.asarray([(lmk.x * cols, lmk.y * rows) + return np.asarray([(lmk.x * cols, lmk.y * rows, lmk.z * cols) for lmk in landmark_list.landmark]) def _assert_diff_less(self, array1, array2, threshold): @@ -81,9 +83,9 @@ class PoseTest(parameterized.TestCase): for _ in range(num_frames): results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) self._assert_diff_less( - self._landmarks_list_to_array(results.pose_landmarks, image.shape), - EXPECTED_UPPER_BODY_LANDMARKS, - DIFF_THRESHOLD) + self._landmarks_list_to_array(results.pose_landmarks, + image.shape)[:, :2], + EXPECTED_UPPER_BODY_LANDMARKS, DIFF_THRESHOLD) @parameterized.named_parameters(('static_image_mode', True, 3), ('video_mode', False, 3)) @@ -95,9 +97,66 @@ class PoseTest(parameterized.TestCase): for _ in range(num_frames): results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) self._assert_diff_less( - self._landmarks_list_to_array(results.pose_landmarks, image.shape), - EXPECTED_FULL_BODY_LANDMARKS, - DIFF_THRESHOLD) + self._landmarks_list_to_array(results.pose_landmarks, + image.shape)[:, :2], + EXPECTED_FULL_BODY_LANDMARKS, DIFF_THRESHOLD) + + @parameterized.named_parameters( + ('full_body', False, 'pose_squats.full_body.npz'), + ('upper_body', True, 'pose_squats.upper_body.npz')) + def test_on_video(self, upper_body_only, expected_name): + """Tests pose models on a video.""" + # If set to `True` will dump actual predictions to .npz and JSON files. + dump_predictions = False + + # Set threshold for comparing actual and expected predictions in pixels. + diff_threshold = 50 + + video_path = os.path.join(os.path.dirname(__file__), + 'testdata/pose_squats.mp4') + expected_path = os.path.join(os.path.dirname(__file__), + 'testdata/{}'.format(expected_name)) + + # Predict pose landmarks for each frame. + video_cap = cv2.VideoCapture(video_path) + actual_per_frame = [] + with mp_pose.Pose( + static_image_mode=False, upper_body_only=upper_body_only) as pose: + while True: + # Get next frame of the video. + success, input_frame = video_cap.read() + if not success: + break + + # Run pose tracker. + input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) + result = pose.process(image=input_frame) + pose_landmarks = self._landmarks_list_to_array(result.pose_landmarks, + input_frame.shape) + + actual_per_frame.append(pose_landmarks) + actual = np.asarray(actual_per_frame) + + if dump_predictions: + # Dump .npz + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + np.savez(tmp_file, predictions=np.array(actual)) + print('Predictions saved as .npz to {}'.format(tmp_file.name)) + + # Dump JSON + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + with open(tmp_file.name, 'w') as fl: + dump_data = {'predictions': np.around(actual, 3).tolist()} + fl.write(json.dumps(dump_data, indent=2, separators=(',', ': '))) + print('Predictions saved as JSON to {}'.format(tmp_file.name)) + + # Validate actual vs. expected predictions. + expected = np.load(expected_path)['predictions'] + assert actual.shape == expected.shape, ( + 'Unexpected shape of predictions: {} instead of {}'.format( + actual.shape, expected.shape)) + self._assert_diff_less( + actual[..., :2], expected[..., :2], threshold=diff_threshold) if __name__ == '__main__': diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index d115dd087..1fa71c6e7 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") licenses(["notice"]) @@ -45,7 +46,7 @@ cc_library( name = "audio_decoder", srcs = ["audio_decoder.cc"], hdrs = ["audio_decoder.h"], - visibility = ["//mediapipe:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":audio_decoder_cc_proto", "//mediapipe/framework:packet", @@ -53,7 +54,6 @@ cc_library( "//mediapipe/framework/deps:cleanup", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:time_series_header_cc_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", @@ -61,9 +61,10 @@ cc_library( "//mediapipe/framework/tool:status_util", "//third_party:libffmpeg", "@com_google_absl//absl/base:endian", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -71,11 +72,10 @@ cc_library( name = "cpu_util", srcs = ["cpu_util.cc"], hdrs = ["cpu_util.h"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:status", @@ -92,9 +92,7 @@ cc_library( name = "header_util", srcs = ["header_util.cc"], hdrs = ["header_util.h"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet_set", @@ -107,9 +105,7 @@ cc_library( name = "image_frame_util", srcs = ["image_frame_util.cc"], hdrs = ["image_frame_util.h"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:image_format_cc_proto", @@ -132,9 +128,7 @@ cc_library( name = "annotation_renderer", srcs = ["annotation_renderer.cc"], hdrs = ["annotation_renderer.h"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":render_data_cc_proto", "//mediapipe/framework/port:logging", @@ -145,15 +139,31 @@ cc_library( ], ) +# Prefer to use ":resource_util", Customization of the resource util is being restricted +# while we explore how it should best be implemented. +cc_library( + name = "resource_util_custom", + hdrs = ["resource_util_custom.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:status", + ], +) + cc_library( name = "resource_util", - srcs = select({ - "//conditions:default": ["resource_util.cc"], + srcs = [ + "resource_util.cc", + "resource_util_internal.h", + ] + select({ + "//conditions:default": ["resource_util_default.cc"], "//mediapipe:android": ["resource_util_android.cc"], "//mediapipe:ios": ["resource_util_apple.cc"], - "//mediapipe:macos": ["resource_util.cc"], + "//mediapipe:macos": ["resource_util_default.cc"], }), - hdrs = ["resource_util.h"], + hdrs = [ + "resource_util.h", + ], # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], @@ -162,10 +172,10 @@ cc_library( ], "//mediapipe:macos": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:public"], deps = [ + ":resource_util_custom", + "@com_google_absl//absl/container:flat_hash_map", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", @@ -197,9 +207,7 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", @@ -221,9 +229,7 @@ cc_library( name = "time_series_util", srcs = ["time_series_util.cc"], hdrs = ["time_series_util.h"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", @@ -239,9 +245,7 @@ cc_library( name = "time_series_test_util", testonly = 1, hdrs = ["time_series_test_util.h"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":time_series_util", "//mediapipe/framework:calculator_framework", @@ -255,7 +259,7 @@ cc_library( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -269,6 +273,6 @@ cc_test( "//mediapipe/framework/deps:message_matchers", "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:gtest_main", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) diff --git a/mediapipe/util/audio_decoder.cc b/mediapipe/util/audio_decoder.cc index 9ebc79aad..7f964a63d 100644 --- a/mediapipe/util/audio_decoder.cc +++ b/mediapipe/util/audio_decoder.cc @@ -41,14 +41,14 @@ extern "C" { #include "libavutil/samplefmt.h" } -DEFINE_int64(media_decoder_allowed_audio_gap_merge, 5, - "The time gap forwards or backwards in the audio to ignore. " - "Timestamps in media files are restricted by the container format " - "and stream codec and are invariably not accurate to exact sample " - "numbers. If the discrepency between time based on counting " - "samples and based on the container timestamps grows beyond this " - "value it will be reset to the value in the audio stream and " - "counting based on samples will resume."); +ABSL_FLAG(int64_t, media_decoder_allowed_audio_gap_merge, 5, + "The time gap forwards or backwards in the audio to ignore. " + "Timestamps in media files are restricted by the container format " + "and stream codec and are invariably not accurate to exact sample " + "numbers. If the discrepency between time based on counting " + "samples and based on the container timestamps grows beyond this " + "value it will be reset to the value in the audio stream and " + "counting based on samples will resume."); namespace mediapipe { diff --git a/mediapipe/util/audio_decoder.h b/mediapipe/util/audio_decoder.h index ae2ab3b33..f95bf316b 100644 --- a/mediapipe/util/audio_decoder.h +++ b/mediapipe/util/audio_decoder.h @@ -20,10 +20,10 @@ #include #include +#include "absl/flags/flag.h" #include "absl/time/time.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" diff --git a/mediapipe/util/cpu_util.cc b/mediapipe/util/cpu_util.cc index 33e0dacde..7c6f982eb 100644 --- a/mediapipe/util/cpu_util.cc +++ b/mediapipe/util/cpu_util.cc @@ -26,6 +26,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/flags/flag.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" @@ -33,14 +34,24 @@ #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/statusor.h" +ABSL_FLAG(std::string, system_cpu_max_freq_file, + "/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq", + "The file pattern for CPU max frequencies, where $0 will be replaced " + "with the CPU id."); + namespace mediapipe { namespace { constexpr uint32 kBufferLength = 64; absl::StatusOr GetFilePath(int cpu) { - return absl::Substitute( - "/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq", cpu); + if (absl::GetFlag(FLAGS_system_cpu_max_freq_file).find("$0") == + std::string::npos) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid frequency file: ", + absl::GetFlag(FLAGS_system_cpu_max_freq_file))); + } + return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu); } absl::StatusOr GetCpuMaxFrequency(int cpu) { diff --git a/mediapipe/util/filtering/BUILD b/mediapipe/util/filtering/BUILD index e02842186..e167a3333 100644 --- a/mediapipe/util/filtering/BUILD +++ b/mediapipe/util/filtering/BUILD @@ -38,6 +38,18 @@ cc_test( ], ) +cc_library( + name = "one_euro_filter", + srcs = ["one_euro_filter.cc"], + hdrs = ["one_euro_filter.h"], + deps = [ + ":low_pass_filter", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/time", + ], +) + cc_library( name = "relative_velocity_filter", srcs = ["relative_velocity_filter.cc"], diff --git a/mediapipe/util/filtering/one_euro_filter.cc b/mediapipe/util/filtering/one_euro_filter.cc new file mode 100644 index 000000000..c2451c6dc --- /dev/null +++ b/mediapipe/util/filtering/one_euro_filter.cc @@ -0,0 +1,84 @@ +#include "mediapipe/util/filtering/one_euro_filter.h" + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/util/filtering/low_pass_filter.h" + +namespace mediapipe { + +static const double kEpsilon = 0.000001; + +OneEuroFilter::OneEuroFilter(double frequency, double min_cutoff, double beta, + double derivate_cutoff) { + SetFrequency(frequency); + SetMinCutoff(min_cutoff); + SetBeta(beta); + SetDerivateCutoff(derivate_cutoff); + x_ = absl::make_unique(GetAlpha(min_cutoff)); + dx_ = absl::make_unique(GetAlpha(derivate_cutoff)); + last_time_ = 0; +} + +double OneEuroFilter::Apply(absl::Duration timestamp, double value) { + int64_t new_timestamp = absl::ToInt64Nanoseconds(timestamp); + if (last_time_ >= new_timestamp) { + // Results are unpredictable in this case, so nothing to do but + // return same value + LOG(WARNING) << "New timestamp is equal or less than the last one."; + return value; + } + + // update the sampling frequency based on timestamps + if (last_time_ != 0 && new_timestamp != 0) { + static constexpr double kNanoSecondsToSecond = 1e-9; + frequency_ = 1.0 / ((new_timestamp - last_time_) * kNanoSecondsToSecond); + } + last_time_ = new_timestamp; + + // estimate the current variation per second + double dvalue = x_->HasLastRawValue() + ? (value - x_->LastRawValue()) * frequency_ + : 0.0; // FIXME: 0.0 or value? + double edvalue = dx_->ApplyWithAlpha(dvalue, GetAlpha(derivate_cutoff_)); + // use it to update the cutoff frequency + double cutoff = min_cutoff_ + beta_ * std::fabs(edvalue); + + // filter the given value + return x_->ApplyWithAlpha(value, GetAlpha(cutoff)); +} + +double OneEuroFilter::GetAlpha(double cutoff) { + double te = 1.0 / frequency_; + double tau = 1.0 / (2 * M_PI * cutoff); + return 1.0 / (1.0 + tau / te); +} + +void OneEuroFilter::SetFrequency(double frequency) { + if (frequency <= kEpsilon) { + LOG(ERROR) << "frequency should be > 0"; + return; + } + frequency_ = frequency; +} + +void OneEuroFilter::SetMinCutoff(double min_cutoff) { + if (min_cutoff <= kEpsilon) { + LOG(ERROR) << "min_cutoff should be > 0"; + return; + } + min_cutoff_ = min_cutoff; +} + +void OneEuroFilter::SetBeta(double beta) { beta_ = beta; } + +void OneEuroFilter::SetDerivateCutoff(double derivate_cutoff) { + if (derivate_cutoff <= kEpsilon) { + LOG(ERROR) << "derivate_cutoff should be > 0"; + return; + } + derivate_cutoff_ = derivate_cutoff; +} + +} // namespace mediapipe diff --git a/mediapipe/util/filtering/one_euro_filter.h b/mediapipe/util/filtering/one_euro_filter.h new file mode 100644 index 000000000..0d4dd2916 --- /dev/null +++ b/mediapipe/util/filtering/one_euro_filter.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_UTIL_FILTERING_ONE_EURO_FILTER_H_ +#define MEDIAPIPE_UTIL_FILTERING_ONE_EURO_FILTER_H_ + +#include + +#include "absl/time/time.h" +#include "mediapipe/util/filtering/low_pass_filter.h" + +namespace mediapipe { + +class OneEuroFilter { + public: + OneEuroFilter(double frequency, double min_cutoff, double beta, + double derivate_cutoff); + + double Apply(absl::Duration timestamp, double value); + + private: + double GetAlpha(double cutoff); + + void SetFrequency(double frequency); + + void SetMinCutoff(double min_cutoff); + + void SetBeta(double beta); + + void SetDerivateCutoff(double derivate_cutoff); + + double frequency_; + double min_cutoff_; + double beta_; + double derivate_cutoff_; + std::unique_ptr x_; + std::unique_ptr dx_; + int64_t last_time_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FILTERING_ONE_EURO_FILTER_H_ diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 87659b7f0..042d1e810 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -14,27 +14,31 @@ #include "mediapipe/util/resource_util.h" -#include "absl/flags/flag.h" +#include + #include "absl/strings/str_split.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" - -ABSL_FLAG( - std::string, resource_root_dir, "", - "The absolute path to the resource directory." - "If specified, resource_root_dir will be prepended to the original path."); +#include "mediapipe/util/resource_util_custom.h" +#include "mediapipe/util/resource_util_internal.h" namespace mediapipe { -absl::StatusOr PathToResourceAsFile(const std::string& path) { - return mediapipe::file::JoinPath(absl::GetFlag(FLAGS_resource_root_dir), - path); -} +namespace { +ResourceProviderFn resource_provider_ = nullptr; +} // namespace absl::Status GetResourceContents(const std::string& path, std::string* output, bool read_as_binary) { - return mediapipe::file::GetContents(path, output, read_as_binary); + if (resource_provider_ == nullptr || !resource_provider_(path, output).ok()) { + return internal::DefaultGetResourceContents(path, output, read_as_binary); + } + return absl::OkStatus(); +} + +void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { + resource_provider_ = std::move(fn); } } // namespace mediapipe diff --git a/mediapipe/util/resource_util.h b/mediapipe/util/resource_util.h index c870900e2..a0c26c505 100644 --- a/mediapipe/util/resource_util.h +++ b/mediapipe/util/resource_util.h @@ -17,7 +17,6 @@ #include -#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" namespace mediapipe { diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index 323c31c02..b7589d8fe 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -18,9 +18,9 @@ #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/singleton.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/util/android/asset_manager_util.h" #include "mediapipe/util/android/file/base/helpers.h" -#include "mediapipe/util/resource_util.h" namespace mediapipe { @@ -31,6 +31,37 @@ absl::StatusOr PathToResourceAsFileInternal( } } // namespace +namespace internal { +absl::Status DefaultGetResourceContents(const std::string& path, + std::string* output, + bool read_as_binary) { + if (!read_as_binary) { + LOG(WARNING) + << "Setting \"read_as_binary\" to false is a no-op on Android."; + } + if (absl::StartsWith(path, "/")) { + return file::GetContents(path, output, file::Defaults()); + } + + if (absl::StartsWith(path, "content://")) { + MP_RETURN_IF_ERROR( + Singleton::get()->ReadContentUri(path, output)); + return absl::OkStatus(); + } + + // Try the test environment. + absl::string_view workspace = "mediapipe"; + auto test_path = file::JoinPath(std::getenv("TEST_SRCDIR"), workspace, path); + if (file::Exists(test_path).ok()) { + return file::GetContents(path, output, file::Defaults()); + } + + RET_CHECK(Singleton::get()->ReadFile(path, output)) + << "could not read asset: " << path; + return absl::OkStatus(); +} +} // namespace internal + absl::StatusOr PathToResourceAsFile(const std::string& path) { // Return full path. if (absl::StartsWith(path, "/")) { @@ -68,25 +99,4 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { return path; } -absl::Status GetResourceContents(const std::string& path, std::string* output, - bool read_as_binary) { - if (!read_as_binary) { - LOG(WARNING) - << "Setting \"read_as_binary\" to false is a no-op on Android."; - } - if (absl::StartsWith(path, "/")) { - return file::GetContents(path, output, file::Defaults()); - } - - if (absl::StartsWith(path, "content://")) { - MP_RETURN_IF_ERROR( - Singleton::get()->ReadContentUri(path, output)); - return absl::OkStatus(); - } - - RET_CHECK(Singleton::get()->ReadFile(path, output)) - << "could not read asset: " << path; - return absl::OkStatus(); -} - } // namespace mediapipe diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index 1750c67e4..428018ee4 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -20,6 +20,7 @@ #include "absl/strings/match.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/util/resource_util.h" namespace mediapipe { @@ -40,6 +41,23 @@ absl::StatusOr PathToResourceAsFileInternal( } } // namespace +namespace internal { +absl::Status DefaultGetResourceContents(const std::string& path, + std::string* output, + bool read_as_binary) { + if (!read_as_binary) { + LOG(WARNING) << "Setting \"read_as_binary\" to false is a no-op on ios."; + } + ASSIGN_OR_RETURN(std::string full_path, PathToResourceAsFile(path)); + + std::ifstream input_file(full_path); + std::stringstream buffer; + buffer << input_file.rdbuf(); + buffer.str().swap(*output); + return absl::OkStatus(); +} +} // namespace internal + absl::StatusOr PathToResourceAsFile(const std::string& path) { // Return full path. if (absl::StartsWith(path, "/")) { @@ -83,18 +101,4 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { return path; } -absl::Status GetResourceContents(const std::string& path, std::string* output, - bool read_as_binary) { - if (!read_as_binary) { - LOG(WARNING) << "Setting \"read_as_binary\" to false is a no-op on ios."; - } - ASSIGN_OR_RETURN(std::string full_path, PathToResourceAsFile(path)); - - std::ifstream input_file(full_path); - std::stringstream buffer; - buffer << input_file.rdbuf(); - buffer.str().swap(*output); - return absl::OkStatus(); -} - } // namespace mediapipe diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h new file mode 100644 index 000000000..6bc1513c6 --- /dev/null +++ b/mediapipe/util/resource_util_custom.h @@ -0,0 +1,18 @@ +#ifndef MEDIAPIPE_UTIL_RESOURCE_UTIL_CUSTOM_H_ +#define MEDIAPIPE_UTIL_RESOURCE_UTIL_CUSTOM_H_ + +#include + +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +typedef std::function + ResourceProviderFn; + +// Overrides the behavior of GetResourceContents. +void SetCustomGlobalResourceProvider(ResourceProviderFn fn); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_RESOURCE_UTIL_CUSTOM_H_ diff --git a/mediapipe/util/resource_util_default.cc b/mediapipe/util/resource_util_default.cc new file mode 100644 index 000000000..8eaae6738 --- /dev/null +++ b/mediapipe/util/resource_util_default.cc @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/flag.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/statusor.h" + +ABSL_FLAG( + std::string, resource_root_dir, "", + "The absolute path to the resource directory." + "If specified, resource_root_dir will be prepended to the original path."); + +namespace mediapipe { + +using mediapipe::file::GetContents; +using mediapipe::file::JoinPath; + +namespace internal { + +absl::Status DefaultGetResourceContents(const std::string& path, + std::string* output, + bool read_as_binary) { + return GetContents(path, output, read_as_binary); +} +} // namespace internal + +absl::StatusOr PathToResourceAsFile(const std::string& path) { + return JoinPath(absl::GetFlag(FLAGS_resource_root_dir), path); +} + +} // namespace mediapipe diff --git a/mediapipe/util/resource_util_internal.h b/mediapipe/util/resource_util_internal.h new file mode 100644 index 000000000..8ae127a28 --- /dev/null +++ b/mediapipe/util/resource_util_internal.h @@ -0,0 +1,19 @@ +#ifndef MEDIAPIPE_UTIL_RESOURCE_UTIL_INTERNAL_H_ +#define MEDIAPIPE_UTIL_RESOURCE_UTIL_INTERNAL_H_ + +#include + +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace internal { + +// Tries to return the contents of a file given the path. Implementation is +// platform-dependent. +absl::Status DefaultGetResourceContents(const std::string& path, + std::string* output, + bool read_as_binary); + +} // namespace internal +} // namespace mediapipe +#endif // MEDIAPIPE_UTIL_RESOURCE_UTIL_INTERNAL_H_ diff --git a/mediapipe/util/tflite/cpu_op_resolver.h b/mediapipe/util/tflite/cpu_op_resolver.h index 9754fbfc8..887683013 100644 --- a/mediapipe/util/tflite/cpu_op_resolver.h +++ b/mediapipe/util/tflite/cpu_op_resolver.h @@ -27,7 +27,8 @@ extern "C" void MediaPipe_RegisterTfLiteOpResolver(tflite::MutableOpResolver*); // This resolver is used for the custom ops introduced by // `MediaPipe_RegisterTfLiteOpResolver` (see above). -class CpuOpResolver : public tflite::ops::builtin::BuiltinOpResolver { +class CpuOpResolver + : public tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates { public: CpuOpResolver() { MediaPipe_RegisterTfLiteOpResolver(this); } }; diff --git a/mediapipe/util/tflite/op_resolver.h b/mediapipe/util/tflite/op_resolver.h index c84fbc0d6..4ca179ef1 100644 --- a/mediapipe/util/tflite/op_resolver.h +++ b/mediapipe/util/tflite/op_resolver.h @@ -20,7 +20,8 @@ namespace mediapipe { // This OpResolver is used for supporting "Convolution2DTransposeBias" on GPU. -class OpResolver : public tflite::ops::builtin::BuiltinOpResolver { +class OpResolver + : public tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates { public: OpResolver(); }; diff --git a/mediapipe/util/tflite/tflite_model_loader.cc b/mediapipe/util/tflite/tflite_model_loader.cc index 7a27b1ea3..a87d94bd6 100644 --- a/mediapipe/util/tflite/tflite_model_loader.cc +++ b/mediapipe/util/tflite/tflite_model_loader.cc @@ -24,7 +24,6 @@ absl::StatusOr> TfLiteModelLoader::LoadFromPath( std::string model_path = path; ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path)); - auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str()); RET_CHECK(model) << "Failed to load model from path " << model_path; return api2::MakePacket( diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index db9b004f9..7d5156049 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -290,7 +290,7 @@ cc_library( "//mediapipe/framework/port:vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -429,7 +429,7 @@ cc_library( "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -526,7 +526,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/memory", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], ) @@ -624,7 +624,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", - "@eigen_archive//:eigen", + "@eigen_archive//:eigen3", ], alwayslink = 1, ) @@ -735,7 +735,6 @@ cc_test( ":region_flow_cc_proto", ":region_flow_computation", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", @@ -744,6 +743,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", ], ) @@ -757,6 +757,7 @@ cc_test( ":box_tracker", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/flags:flag", ], ) diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 91437681a..0ac6dc2a5 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -21,9 +21,9 @@ #include #include +#include "absl/flags/flag.h" #include "absl/time/clock.h" #include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" @@ -37,7 +37,7 @@ // To ensure that the selected thresholds are robust, it is recommend // to run this test mutiple times with time seed, if changes are made. -DEFINE_bool(time_seed, false, "Activate to test thresholds"); +ABSL_FLAG(bool, time_seed, false, "Activate to test thresholds"); namespace mediapipe { namespace { diff --git a/requirements.txt b/requirements.txt index 37cad28fe..709a31d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ absl-py -attrs +attrs>=19.1.0 dataclasses -numpy == 1.19.3 -opencv-python +numpy +opencv-contrib-python protobuf>=3.11.4 six wheel diff --git a/setup.py b/setup.py index 5acad7bdb..8a4c71274 100644 --- a/setup.py +++ b/setup.py @@ -180,13 +180,13 @@ class GeneratePyProtos(setuptools.Command): 'mediapipe/util/**/*.proto' ]: for proto_file in glob.glob(pattern, recursive=True): + proto_dir = os.path.dirname(os.path.abspath(proto_file)) # Ignore test protos. if proto_file.endswith('test.proto'): continue - # Ignore tensorflow protos. - if 'tensorflow' in proto_file: + # Ignore tensorflow protos in mediapipe/calculators/tensorflow. + if 'tensorflow' in proto_dir: continue - proto_dir = os.path.dirname(os.path.abspath(proto_file)) # Ignore testdata dir. if proto_dir.endswith('testdata'): continue diff --git a/setup_opencv.sh b/setup_opencv.sh index 983ccc7cd..b09068e65 100644 --- a/setup_opencv.sh +++ b/setup_opencv.sh @@ -64,7 +64,9 @@ if [ -z "$1" ] -DBUILD_opencv_plot=OFF -DBUILD_opencv_quality=OFF -DBUILD_opencv_reg=OFF \ -DBUILD_opencv_rgbd=OFF -DBUILD_opencv_saliency=OFF -DBUILD_opencv_shape=OFF \ -DBUILD_opencv_structured_light=OFF -DBUILD_opencv_surface_matching=OFF \ - -DBUILD_opencv_world=OFF -DBUILD_opencv_xobjdetect=OFF -DBUILD_opencv_xphoto=OFF + -DBUILD_opencv_world=OFF -DBUILD_opencv_xobjdetect=OFF -DBUILD_opencv_xphoto=OFF \ + -DCV_ENABLE_INTRINSICS=ON -DWITH_EIGEN=ON -DWITH_PTHREADS=ON -DWITH_PTHREADS_PF=ON \ + -DWITH_JPEG=ON -DWITH_PNG=ON -DWITH_TIFF=ON make -j 16 sudo make install rm -rf /tmp/build_opencv diff --git a/third_party/BUILD b/third_party/BUILD index ef408e4a2..654f0cb72 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -109,7 +109,15 @@ cmake_external( "BUILD_SHARED_LIBS": "ON" if OPENCV_SHARED_LIBS else "OFF", "WITH_ITT": "OFF", "WITH_JASPER": "OFF", + "WITH_JPEG": "ON", + "WITH_PNG": "ON", + "WITH_TIFF": "ON", "WITH_WEBP": "OFF", + # Optimization flags + "CV_ENABLE_INTRINSICS": "ON", + "WITH_EIGEN": "ON", + "WITH_PTHREADS": "ON", + "WITH_PTHREADS_PF": "ON", # When building tests, by default Bazel builds them in dynamic mode. # See https://docs.bazel.build/versions/master/be/c-cpp.html#cc_binary.linkstatic # For example, when building //mediapipe/calculators/video:opencv_video_encoder_calculator_test, diff --git a/third_party/org_tensorflow_compatibility_fixes.diff b/third_party/org_tensorflow_compatibility_fixes.diff index 502a994e8..2846fcc80 100644 --- a/third_party/org_tensorflow_compatibility_fixes.diff +++ b/third_party/org_tensorflow_compatibility_fixes.diff @@ -42,9 +42,27 @@ index 67bd587162e..2a3c6bd30dc 100644 @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. - + -include "tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs"; +include "../common/task/serialization_base.fbs"; - + namespace tflite.gpu.cl.data; +diff --git a/third_party/eigen3/eigen_archive.BUILD b/third_party/eigen3/eigen_archive.BUILD +index dad592bec48..670017c2c0f 100644 +--- a/third_party/eigen3/eigen_archive.BUILD ++++ b/third_party/eigen3/eigen_archive.BUILD +@@ -49,6 +49,13 @@ cc_library( + visibility = ["//visibility:public"], + ) + ++# For backward compatibility. ++alias( ++ name = "eigen", ++ actual=":eigen3", ++ visibility = ["//visibility:public"], ++) ++ + filegroup( + name = "eigen_header_files", + srcs = EIGEN_MPL2_HEADER_FILES, diff --git a/third_party/org_tensorflow_objc_cxx17.diff b/third_party/org_tensorflow_objc_cxx17.diff index 3242f65bd..a9da53fdf 100644 --- a/third_party/org_tensorflow_objc_cxx17.diff +++ b/third_party/org_tensorflow_objc_cxx17.diff @@ -21,4 +21,4 @@ index 6dcde34a62f..1adfc28aad9 100644 + "-std=c++17", ] - cc_library( + objc_library(