diff --git a/.github/ISSUE_TEMPLATE/00-build-installation-issue.md b/.github/ISSUE_TEMPLATE/00-build-installation-issue.md new file mode 100644 index 000000000..f4300e42a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/00-build-installation-issue.md @@ -0,0 +1,27 @@ +--- +name: "Build/Installation Issue" +about: Use this template for build/installation issues +labels: type:build/install + +--- +Please make sure that this is a build/installation issue and also refer to the [troubleshooting](https://google.github.io/mediapipe/getting_started/troubleshooting.html) documentation before raising any issues. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g. Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Compiler version (e.g. gcc/g++ 8 /Apple clang version 12.0.0): +- Programming Language and version ( e.g. C++ 14, Python 3.6, Java ): +- Installed using virtualenv? pip? Conda? (if python): +- [MediaPipe version](https://github.com/google/mediapipe/releases): +- Bazel version: +- XCode and Tulsi versions (if iOS): +- Android SDK and NDK versions (if android): +- Android [AAR](https://google.github.io/mediapipe/getting_started/android_archive_library.html) ( if android): +- OpenCV version (if running on desktop): + +**Describe the problem**: + + +**[Provide the exact sequence of commands / steps that you executed before running into the problem](https://google.github.io/mediapipe/getting_started/getting_started.html):** + +**Complete Logs:** +Include Complete Log information or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/10-solution-issue.md new file mode 100644 index 000000000..a5332cb36 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/10-solution-issue.md @@ -0,0 +1,26 @@ +--- +name: "Solution Issue" +about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +labels: type:support + +--- +Please make sure that this is a [solution](https://google.github.io/mediapipe/solutions/solutions.html) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in Mediapipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- [MediaPipe version](https://github.com/google/mediapipe/releases): +- Bazel version: +- Solution (e.g. FaceMesh, Pose, Holistic): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/repo link /any notebook: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/20-documentation-issue.md b/.github/ISSUE_TEMPLATE/20-documentation-issue.md new file mode 100644 index 000000000..2918e03b4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/20-documentation-issue.md @@ -0,0 +1,51 @@ +--- +name: "Documentation Issue" +about: Use this template for documentation related issues +labels: type:docs + +--- +Thank you for submitting a MediaPipe documentation issue. +The MediaPipe docs are open source! To get involved, read the documentation Contributor Guide +## URL(s) with the issue: + +Please provide a link to the documentation entry, for example: https://github.com/google/mediapipe/blob/master/docs/solutions/face_mesh.md#models + +## Description of issue (what needs changing): + +Kinds of documentation problems: + +### Clear description + +For example, why should someone use this method? How is it useful? + +### Correct links + +Is the link to the source code correct? + +### Parameters defined +Are all parameters defined and formatted correctly? + +### Returns defined + +Are return values defined? + +### Raises listed and defined + +Are the errors defined? For example, + +### Usage example + +Is there a usage example? + +See the API guide: +on how to write testable usage examples. + +### Request visuals, if applicable + +Are there currently visuals? If not, will it clarify the content? + +### Submit a pull request? + +Are you planning to also submit a pull request to fix the issue? See the docs +https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md + diff --git a/.github/ISSUE_TEMPLATE/30-bug-issue.md b/.github/ISSUE_TEMPLATE/30-bug-issue.md new file mode 100644 index 000000000..996c06cf5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/30-bug-issue.md @@ -0,0 +1,32 @@ +--- +name: "Bug Issue" +about: Use this template for reporting a bug +labels: type:bug + +--- +Please make sure that this is a bug and also refer to the [troubleshooting](https://google.github.io/mediapipe/getting_started/troubleshooting.html), FAQ documentation before raising any issues. + +**System information** (Please provide as much relevant information as possible) + +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: +- Browser and version (e.g. Google Chrome, Safari) if the issue happens on browser: +- Programming Language and version ( e.g. C++, Python, Java): +- [MediaPipe version](https://github.com/google/mediapipe/releases): +- Bazel version (if compiling from source): +- Solution ( e.g. FaceMesh, Pose, Holistic ): +- Android Studio, NDK, SDK versions (if issue is related to building in Android environment): +- Xcode & Tulsi version (if issue is related to building for iOS): + +**Describe the current behavior:** + +**Describe the expected behavior:** + +**Standalone code to reproduce the issue:** +Provide a reproducible test case that is the bare minimum necessary to replicate the problem. If possible, please share a link to Colab/repo link /any notebook: + +**Other info / Complete Logs :** + Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached diff --git a/.github/ISSUE_TEMPLATE/40-feature-request.md b/.github/ISSUE_TEMPLATE/40-feature-request.md new file mode 100644 index 000000000..2e1aafc7a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/40-feature-request.md @@ -0,0 +1,24 @@ +--- +name: "Feature Request" +about: Use this template for raising a feature request +labels: type:feature + +--- +Please make sure that this is a feature request. + +**System information** (Please provide as much relevant information as possible) + +- MediaPipe Solution (you are using): +- Programming language : C++/typescript/Python/Objective C/Android Java +- Are you willing to contribute it (Yes/No): + + +**Describe the feature and the current behavior/state:** + +**Will this change the current api? How?** + +**Who will benefit with this feature?** + +**Please specify the use cases for this feature:** + +**Any Other info:** diff --git a/.github/ISSUE_TEMPLATE/50-other-issues.md b/.github/ISSUE_TEMPLATE/50-other-issues.md new file mode 100644 index 000000000..e51add916 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/50-other-issues.md @@ -0,0 +1,14 @@ +--- +name: "Other Issue" +about: Use this template for any other non-support related issues. +labels: type:others + +--- +This template is for miscellaneous issues not covered by the other issue categories + +For questions on how to work with MediaPipe, or support for problems that are not verified bugs in MediaPipe, please go to [StackOverflow](https://stackoverflow.com/questions/tagged/mediapipe) and [Slack](https://mediapipe.page.link/joinslack) communities. + +If you are reporting a vulnerability, please use the [dedicated reporting process](https://github.com/google/mediapipe/security). + +For high-level discussions about MediaPipe, please post to discuss@mediapipe.org, for questions about the development or internal workings of MediaPipe, or if you would like to know how to contribute to MediaPipe, please post to developers@mediapipe.org. + diff --git a/.github/bot_config.yml b/.github/bot_config.yml new file mode 100644 index 000000000..b1b2d98ea --- /dev/null +++ b/.github/bot_config.yml @@ -0,0 +1,18 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# A list of assignees +assignees: + - sgowroji diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 000000000..03c67d0f6 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,34 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +# Number of days of inactivity before an Issue or Pull Request becomes stale +daysUntilStale: 7 +# Number of days of inactivity before a stale Issue or Pull Request is closed +daysUntilClose: 7 +# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled) +onlyLabels: + - stat:awaiting response +# Comment to post when marking as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you. +# Comment to post when removing the stale label. Set to `false` to disable +unmarkComment: false +closeComment: > + Closing as stale. Please reopen if you'd like to work on this further. diff --git a/MANIFEST.in b/MANIFEST.in index 8d5c4ec50..14afffebe 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,6 +8,7 @@ include README.md include requirements.txt recursive-include mediapipe/modules *.tflite *.txt *.binarypb +exclude mediapipe/modules/face_detection/face_detection_full_range.tflite exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite exclude mediapipe/modules/objectron/object_detection_3d_sneakers.tflite diff --git a/README.md b/README.md index 23e0d9981..9ea72ab8a 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,12 @@ Hair Segmentation [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | [Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) | ✅ | ✅ | ✅ | ✅ | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | @@ -54,46 +55,22 @@ See also [MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) for ML models released in MediaPipe. -## MediaPipe in Python - -MediaPipe offers customizable Python solutions as a prebuilt Python package on -[PyPI](https://pypi.org/project/mediapipe/), which can be installed simply with -`pip install mediapipe`. It also provides tools for users to build their own -solutions. Please see -[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) -for more info. - -## MediaPipe on the Web - -MediaPipe on the Web is an effort to run the same ML solutions built for mobile -and desktop also in web browsers. The official API is under construction, but -the core technology has been proven effective. Please see -[MediaPipe on the Web](https://developers.googleblog.com/2020/01/mediapipe-on-web.html) -in Google Developers Blog for details. - -You can use the following links to load a demo in the MediaPipe Visualizer, and -over there click the "Runner" icon in the top bar like shown below. The demos -use your webcam video as input, which is processed all locally in real-time and -never leaves your device. - -![visualizer_runner](docs/images/visualizer_runner.png) - -* [MediaPipe Face Detection](https://viz.mediapipe.dev/demo/face_detection) -* [MediaPipe Iris](https://viz.mediapipe.dev/demo/iris_tracking) -* [MediaPipe Iris: Depth-from-Iris](https://viz.mediapipe.dev/demo/iris_depth) -* [MediaPipe Hands](https://viz.mediapipe.dev/demo/hand_tracking) -* [MediaPipe Hands (palm/hand detection only)](https://viz.mediapipe.dev/demo/hand_detection) -* [MediaPipe Pose](https://viz.mediapipe.dev/demo/pose_tracking) -* [MediaPipe Hair Segmentation](https://viz.mediapipe.dev/demo/hair_segmentation) - ## Getting started -Learn how to [install](https://google.github.io/mediapipe/getting_started/install) -MediaPipe and -[build example applications](https://google.github.io/mediapipe/getting_started/building_examples), -and start exploring our ready-to-use -[solutions](https://google.github.io/mediapipe/solutions/solutions) that you can -further extend and customize. +To start using MediaPipe +[solutions](https://google.github.io/mediapipe/solutions/solutions) with only a few +lines code, see example code and demos in +[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) and +[MediaPipe in JavaScript](https://google.github.io/mediapipe/getting_started/javascript). + +To use MediaPipe in C++, Android and iOS, which allow further customization of +the [solutions](https://google.github.io/mediapipe/solutions/solutions) as well as +building your own, learn how to +[install](https://google.github.io/mediapipe/getting_started/install) MediaPipe and +start building example applications in +[C++](https://google.github.io/mediapipe/getting_started/cpp), +[Android](https://google.github.io/mediapipe/getting_started/android) and +[iOS](https://google.github.io/mediapipe/getting_started/ios). The source code is hosted in the [MediaPipe Github repository](https://github.com/google/mediapipe), and you can @@ -167,6 +144,13 @@ bash build_macos_desktop_examples.sh --cpu i386 --app face_detection -r ## Publications +* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) + in Google Developers Blog +* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html) + in Google Developers Blog +* [SignAll SDK: Sign language interface using MediaPipe is now available for + developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) + in Google Developers Blog * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) diff --git a/WORKSPACE b/WORKSPACE index c7cb94346..9b0a7e86c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -65,26 +65,19 @@ rules_foreign_cc_dependencies() all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" # GoogleTest/GoogleMock framework. Used by most unit-tests. -# Last updated 2020-06-30. +# Last updated 2021-07-02. http_archive( name = "com_google_googletest", - urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"], - patches = [ - # fix for https://github.com/google/googletest/issues/2817 - "@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff" - ], - patch_args = [ - "-p1", - ], - strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e", - sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895", + urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"], + strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3", + sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049", ) # Google Benchmark library. http_archive( name = "com_google_benchmark", - urls = ["https://github.com/google/benchmark/archive/master.zip"], - strip_prefix = "benchmark-master", + urls = ["https://github.com/google/benchmark/archive/main.zip"], + strip_prefix = "benchmark-main", build_file = "@//third_party:benchmark.BUILD", ) @@ -176,11 +169,11 @@ http_archive( http_archive( name = "pybind11", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", - "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz", ], - sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", - strip_prefix = "pybind11-2.4.3", + sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020", + strip_prefix = "pybind11-2.7.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) @@ -254,6 +247,20 @@ http_archive( url = "https://github.com/opencv/opencv/releases/download/3.2.0/opencv-3.2.0-ios-framework.zip", ) +http_archive( + name = "stblib", + strip_prefix = "stb-b42009b3b9d4ca35bc703f5310eedc74f584be58", + sha256 = "13a99ad430e930907f5611325ec384168a958bf7610e63e60e2fd8e7b7379610", + urls = ["https://github.com/nothings/stb/archive/b42009b3b9d4ca35bc703f5310eedc74f584be58.tar.gz"], + build_file = "@//third_party:stblib.BUILD", + patches = [ + "@//third_party:stb_image_impl.diff" + ], + patch_args = [ + "-p1", + ], +) + # You may run setup_android.sh to install Android SDK and NDK. android_ndk_repository( name = "androidndk", @@ -336,7 +343,9 @@ load("@rules_jvm_external//:defs.bzl", "maven_install") maven_install( artifacts = [ "androidx.concurrent:concurrent-futures:1.0.0-alpha03", - "androidx.lifecycle:lifecycle-common:2.2.0", + "androidx.lifecycle:lifecycle-common:2.3.1", + "androidx.activity:activity:1.2.2", + "androidx.fragment:fragment:1.3.4", "androidx.annotation:annotation:aar:1.1.0", "androidx.appcompat:appcompat:aar:1.1.0-rc01", "androidx.camera:camera-core:1.0.0-beta10", @@ -349,11 +358,11 @@ maven_install( "androidx.test.espresso:espresso-core:3.1.1", "com.github.bumptech.glide:glide:4.11.0", "com.google.android.material:material:aar:1.0.0-rc01", - "com.google.auto.value:auto-value:1.6.4", - "com.google.auto.value:auto-value-annotations:1.6.4", - "com.google.code.findbugs:jsr305:3.0.2", - "com.google.flogger:flogger-system-backend:0.3.1", - "com.google.flogger:flogger:0.3.1", + "com.google.auto.value:auto-value:1.8.1", + "com.google.auto.value:auto-value-annotations:1.8.1", + "com.google.code.findbugs:jsr305:latest.release", + "com.google.flogger:flogger-system-backend:latest.release", + "com.google.flogger:flogger:latest.release", "com.google.guava:guava:27.0.1-android", "com.google.guava:listenablefuture:1.0", "junit:junit:4.12", @@ -381,9 +390,9 @@ http_archive( ) # Tensorflow repo should always go after the other external dependencies. -# 2021-04-30 -_TENSORFLOW_GIT_COMMIT = "5bd3c57ef184543d22e34e36cff9d9bea608e06d" -_TENSORFLOW_SHA256= "9a45862834221aafacf6fb275f92b3876bc89443cbecc51be93f13839a6609f0" +# 2021-07-29 +_TENSORFLOW_GIT_COMMIT = "52a2905cbc21034766c08041933053178c5d10e3" +_TENSORFLOW_SHA256 = "06d4691bcdb700f3275fa0971a1585221c2b9f3dffe867963be565a6643d7f56" http_archive( name = "org_tensorflow", urls = [ @@ -404,3 +413,18 @@ load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") tf_workspace3() load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") tf_workspace2() + +# Edge TPU +http_archive( + name = "libedgetpu", + sha256 = "14d5527a943a25bc648c28a9961f954f70ba4d79c0a9ca5ae226e1831d72fe80", + strip_prefix = "libedgetpu-3164995622300286ef2bb14d7fdc2792dae045b7", + urls = [ + "https://github.com/google-coral/libedgetpu/archive/3164995622300286ef2bb14d7fdc2792dae045b7.tar.gz" + ], +) +load("@libedgetpu//:workspace.bzl", "libedgetpu_dependencies") +libedgetpu_dependencies() + +load("@coral_crosstool//:configure.bzl", "cc_crosstool") +cc_crosstool(name = "crosstool") diff --git a/build_desktop_examples.sh b/build_desktop_examples.sh index a35556cf0..7ff8db29c 100644 --- a/build_desktop_examples.sh +++ b/build_desktop_examples.sh @@ -97,6 +97,7 @@ for app in ${apps}; do if [[ ${target_name} == "holistic_tracking" || ${target_name} == "iris_tracking" || ${target_name} == "pose_tracking" || + ${target_name} == "selfie_segmentation" || ${target_name} == "upper_body_pose_tracking" ]]; then graph_suffix="cpu" else diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 98bf1def4..9548fa461 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -248,12 +248,70 @@ absl::Status MyCalculator::Process() { } ``` +## Calculator options + +Calculators accept processing parameters through (1) input stream packets (2) +input side packets, and (3) calculator options. Calculator options, if +specified, appear as literal values in the `node_options` field of the +`CalculatorGraphConfiguration.Node` message. + +``` + node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:main_model_input" + output_stream: "TENSORS:main_model_output" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/detection_model.tflite" + } + } + } +``` + +The `node_options` field accepts the proto3 syntax. Alternatively, calculator +options can be specified in the `options` field using proto2 syntax. + +``` + node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:main_model_input" + output_stream: "TENSORS:main_model_output" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/detection_model.tflite" + } + } + } +``` + +Not all calculators accept calcuator options. In order to accept options, a +calculator will normally define a new protobuf message type to represent its +options, such as `PacketClonerCalculatorOptions`. The calculator will then +read that protobuf message in its `CalculatorBase::Open` method, and possibly +also in its `CalculatorBase::GetContract` function or its +`CalculatorBase::Process` method. Normally, the new protobuf message type will +be defined as a protobuf schema using a ".proto" file and a +`mediapipe_proto_library()` build rule. + +``` + mediapipe_proto_library( + name = "packet_cloner_calculator_proto", + srcs = ["packet_cloner_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], + ) +``` + + ## Example calculator This section discusses the implementation of `PacketClonerCalculator`, which does a relatively simple job, and is used in many calculator graphs. -`PacketClonerCalculator` simply produces a copy of its most recent input -packets on demand. +`PacketClonerCalculator` simply produces a copy of its most recent input packets +on demand. `PacketClonerCalculator` is useful when the timestamps of arriving data packets are not aligned perfectly. Suppose we have a room with a microphone, light @@ -279,8 +337,8 @@ input streams: imageframe of video data representing video collected from camera in the room with timestamp. -Below is the implementation of the `PacketClonerCalculator`. You can see -the `GetContract()`, `Open()`, and `Process()` methods as well as the instance +Below is the implementation of the `PacketClonerCalculator`. You can see the +`GetContract()`, `Open()`, and `Process()` methods as well as the instance variable `current_` which holds the most recent input packets. ```c++ @@ -401,6 +459,6 @@ node { The diagram below shows how the `PacketClonerCalculator` defines its output packets (bottom) based on its series of input packets (top). -| ![Graph using PacketClonerCalculator](../images/packet_cloner_calculator.png) | -| :---------------------------------------------------------------------------: | -| *Each time it receives a packet on its TICK input stream, the PacketClonerCalculator outputs the most recent packet from each of its input streams. The sequence of output packets (bottom) is determined by the sequence of input packets (top) and their timestamps. The timestamps are shown along the right side of the diagram.* | +![Graph using PacketClonerCalculator](../images/packet_cloner_calculator.png) | +:--------------------------------------------------------------------------: | +*Each time it receives a packet on its TICK input stream, the PacketClonerCalculator outputs the most recent packet from each of its input streams. The sequence of output packets (bottom) is determined by the sequence of input packets (top) and their timestamps. The timestamps are shown along the right side of the diagram.* | diff --git a/docs/framework_concepts/framework_concepts.md b/docs/framework_concepts/framework_concepts.md index dcf446a9d..dd43d830c 100644 --- a/docs/framework_concepts/framework_concepts.md +++ b/docs/framework_concepts/framework_concepts.md @@ -111,11 +111,11 @@ component known as an InputStreamHandler. See [Synchronization](synchronization.md) for more details. -### Realtime data streams +### Real-time streams MediaPipe calculator graphs are often used to process streams of video or audio frames for interactive applications. Normally, each Calculator runs as soon as all of its input packets for a given timestamp become available. Calculators -used in realtime graphs need to define output timestamp bounds based on input +used in real-time graphs need to define output timestamp bounds based on input timestamp bounds in order to allow downstream calculators to be scheduled -promptly. See [Realtime data streams](realtime.md) for details. +promptly. See [Real-time Streams](realtime_streams.md) for details. diff --git a/docs/framework_concepts/realtime.md b/docs/framework_concepts/realtime_streams.md similarity index 91% rename from docs/framework_concepts/realtime.md rename to docs/framework_concepts/realtime_streams.md index 36b606825..038081453 100644 --- a/docs/framework_concepts/realtime.md +++ b/docs/framework_concepts/realtime_streams.md @@ -1,29 +1,28 @@ --- layout: default -title: Processing real-time data streams +title: Real-time Streams +parent: Framework Concepts nav_order: 6 -has_children: true -has_toc: false --- -# Processing real-time data streams +# Real-time Streams {: .no_toc } 1. TOC {:toc} --- -## Realtime timestamps +## Real-time timestamps MediaPipe calculator graphs are often used to process streams of video or audio frames for interactive applications. The MediaPipe framework requires only that successive packets be assigned monotonically increasing timestamps. By -convention, realtime calculators and graphs use the recording time or the +convention, real-time calculators and graphs use the recording time or the presentation time of each frame as its timestamp, with each timestamp indicating the microseconds since `Jan/1/1970:00:00:00`. This allows packets from various sources to be processed in a globally consistent sequence. -## Realtime scheduling +## Real-time scheduling Normally, each Calculator runs as soon as all of its input packets for a given timestamp become available. Normally, this happens when the calculator has @@ -38,7 +37,7 @@ When a calculator does not produce any output packets for a given timestamp, it can instead output a "timestamp bound" indicating that no packet will be produced for that timestamp. This indication is necessary to allow downstream calculators to run at that timestamp, even though no packet has arrived for -certain streams for that timestamp. This is especially important for realtime +certain streams for that timestamp. This is especially important for real-time graphs in interactive applications, where it is crucial that each calculator begin processing as soon as possible. @@ -83,12 +82,12 @@ For example, `Timestamp(1).NextAllowedInStream() == Timestamp(2)`. ## Propagating timestamp bounds -Calculators that will be used in realtime graphs need to define output timestamp -bounds based on input timestamp bounds in order to allow downstream calculators -to be scheduled promptly. A common pattern is for calculators to output packets -with the same timestamps as their input packets. In this case, simply outputting -a packet on every call to `Calculator::Process` is sufficient to define output -timestamp bounds. +Calculators that will be used in real-time graphs need to define output +timestamp bounds based on input timestamp bounds in order to allow downstream +calculators to be scheduled promptly. A common pattern is for calculators to +output packets with the same timestamps as their input packets. In this case, +simply outputting a packet on every call to `Calculator::Process` is sufficient +to define output timestamp bounds. However, calculators are not required to follow this common pattern for output timestamps, they are only required to choose monotonically increasing output diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index 71224a258..c3c6506ee 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -16,12 +16,14 @@ nav_order: 1 Please follow instructions below to build Android example apps in the supported MediaPipe [solutions](../solutions/solutions.md). To learn more about these -example apps, start from [Hello World! on Android](./hello_world_android.md). To -incorporate MediaPipe into an existing Android Studio project, see these -[instructions](./android_archive_library.md) that use Android Archive (AAR) and -Gradle. +example apps, start from [Hello World! on Android](./hello_world_android.md). -## Building Android example apps +To incorporate MediaPipe into Android Studio projects, see these +[instructions](./android_solutions.md) to use the MediaPipe Android Solution +APIs (currently in alpha) that are now available in +[Google's Maven Repository](https://maven.google.com/web/index.html?#com.google.mediapipe). + +## Building Android example apps with Bazel ### Prerequisite @@ -51,16 +53,6 @@ $YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or android_sdk_repository() in the [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) file. -Please verify all the necessary packages are installed. - -* Android SDK Platform API Level 28 or 29 -* Android SDK Build-Tools 28 or 29 -* Android SDK Platform-Tools 28 or 29 -* Android SDK Tools 26.1.1 -* Android NDK 19c or above - -### Option 1: Build with Bazel in Command Line - Tip: You can run this [script](https://github.com/google/mediapipe/blob/master/build_android_examples.sh) to build (and install) all MediaPipe Android example apps. @@ -84,108 +76,3 @@ to build (and install) all MediaPipe Android example apps. ```bash adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/handtrackinggpu.apk ``` - -### Option 2: Build with Bazel in Android Studio - -The MediaPipe project can be imported into Android Studio using the Bazel -plugins. This allows the MediaPipe examples to be built and modified in Android -Studio. - -To incorporate MediaPipe into an existing Android Studio project, see these -[instructions](./android_archive_library.md) that use Android Archive (AAR) and -Gradle. - -The steps below use Android Studio 3.5 to build and install a MediaPipe example -app: - -1. Install and launch Android Studio 3.5. - -2. Select `Configure` -> `SDK Manager` -> `SDK Platforms`. - - * Verify that Android SDK Platform API Level 28 or 29 is installed. - * Take note of the Android SDK Location, e.g., - `/usr/local/home/Android/Sdk`. - -3. Select `Configure` -> `SDK Manager` -> `SDK Tools`. - - * Verify that Android SDK Build-Tools 28 or 29 is installed. - * Verify that Android SDK Platform-Tools 28 or 29 is installed. - * Verify that Android SDK Tools 26.1.1 is installed. - * Verify that Android NDK 19c or above is installed. - * Take note of the Android NDK Location, e.g., - `/usr/local/home/Android/Sdk/ndk-bundle` or - `/usr/local/home/Android/Sdk/ndk/20.0.5594570`. - -4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point - to the installed SDK and NDK. - - ```bash - export ANDROID_HOME=/usr/local/home/Android/Sdk - - # If the NDK libraries are installed by a previous version of Android Studio, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk-bundle - # If the NDK libraries are installed by Android Studio 3.5, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk/ - ``` - -5. Select `Configure` -> `Plugins` to install `Bazel`. - -6. On Linux, select `File` -> `Settings` -> `Bazel settings`. On macos, select - `Android Studio` -> `Preferences` -> `Bazel settings`. Then, modify `Bazel - binary location` to be the same as the output of `$ which bazel`. - -7. Select `Import Bazel Project`. - - * Select `Workspace`: `/path/to/mediapipe` and select `Next`. - * Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select - `Next`. - * Modify `Project View` to be the following and select `Finish`. - - ``` - directories: - # read project settings, e.g., .bazelrc - . - -mediapipe/objc - -mediapipe/examples/ios - - targets: - //mediapipe/examples/android/...:all - //mediapipe/java/...:all - - android_sdk_platform: android-29 - - sync_flags: - --host_crosstool_top=@bazel_tools//tools/cpp:toolchain - ``` - -8. Select `Bazel` -> `Sync` -> `Sync project with Build files`. - - Note: Even after doing step 4, if you still see the error: `"no such package - '@androidsdk//': Either the path attribute of android_sdk_repository or the - ANDROID_HOME environment variable must be set."`, please modify the - [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) - file to point to your SDK and NDK library locations, as below: - - ``` - android_sdk_repository( - name = "androidsdk", - path = "/path/to/android/sdk" - ) - - android_ndk_repository( - name = "androidndk", - path = "/path/to/android/ndk" - ) - ``` - -9. Connect an Android device to the workstation. - -10. Select `Run...` -> `Edit Configurations...`. - - * Select `Templates` -> `Bazel Command`. - * Enter Target Expression: - `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu` - * Enter Bazel command: `mobile-install`. - * Enter Bazel flags: `-c opt --config=android_arm64`. - * Press the `[+]` button to add the new configuration. - * Select `Run` to run the example app on the connected Android device. diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index 2c2ca99f3..d2f25213f 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -3,7 +3,7 @@ layout: default title: MediaPipe Android Archive parent: MediaPipe on Android grand_parent: Getting Started -nav_order: 2 +nav_order: 3 --- # MediaPipe Android Archive @@ -92,12 +92,12 @@ each project. and copy [the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41) and - [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite). + [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_short_range.tflite). ```bash bazel build -c opt mediapipe/graphs/face_detection:face_detection_mobile_gpu_binary_graph cp bazel-bin/mediapipe/graphs/face_detection/face_detection_mobile_gpu.binarypb /path/to/your/app/src/main/assets/ - cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/ + cp mediapipe/modules/face_detection/face_detection_short_range.tflite /path/to/your/app/src/main/assets/ ``` ![Screenshot](../images/mobile/assets_location.png) @@ -113,10 +113,9 @@ each project. androidTestImplementation 'androidx.test.ext:junit:1.1.0' androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' // MediaPipe deps - implementation 'com.google.flogger:flogger:0.3.1' - implementation 'com.google.flogger:flogger-system-backend:0.3.1' - implementation 'com.google.code.findbugs:jsr305:3.0.2' - implementation 'com.google.guava:guava:27.0.1-android' + implementation 'com.google.flogger:flogger:latest.release' + implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.code.findbugs:jsr305:latest.release' implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library @@ -125,7 +124,7 @@ each project. implementation "androidx.camera:camera-camera2:$camerax_version" implementation "androidx.camera:camera-lifecycle:$camerax_version" // AutoValue - def auto_value_version = "1.6.4" + def auto_value_version = "1.8.1" implementation "com.google.auto.value:auto-value-annotations:$auto_value_version" annotationProcessor "com.google.auto.value:auto-value:$auto_value_version" } diff --git a/docs/getting_started/android_solutions.md b/docs/getting_started/android_solutions.md new file mode 100644 index 000000000..de7135c18 --- /dev/null +++ b/docs/getting_started/android_solutions.md @@ -0,0 +1,79 @@ +--- +layout: default +title: Android Solutions +parent: MediaPipe on Android +grand_parent: Getting Started +nav_order: 2 +--- + +# Android Solution APIs +{: .no_toc } + +1. TOC +{:toc} +--- + +Please follow instructions below to use the MediaPipe Solution APIs in Android +Studio projects and build the Android example apps in the supported MediaPipe +[solutions](../solutions/solutions.md). + +## Integrate MediaPipe Android Solutions in Android Studio + +MediaPipe Android Solution APIs (currently in alpha) are now available in +[Google's Maven Repository](https://maven.google.com/web/index.html?#com.google.mediapipe). +To incorporate MediaPipe Android Solutions into an Android Studio project, add +the following into the project's Gradle dependencies: + +``` +dependencies { + // MediaPipe solution-core is the foundation of any MediaPipe solutions. + implementation 'com.google.mediapipe:solution-core:latest.release' + // Optional: MediaPipe Hands solution. + implementation 'com.google.mediapipe:hands:latest.release' + // Optional: MediaPipe FaceMesh solution. + implementation 'com.google.mediapipe:facemesh:latest.release' + // MediaPipe deps + implementation 'com.google.flogger:flogger:latest.release' + implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.guava:guava:27.0.1-android' + implementation 'com.google.protobuf:protobuf-java:3.11.4' + // CameraX core library + def camerax_version = "1.0.0-beta10" + implementation "androidx.camera:camera-core:$camerax_version" + implementation "androidx.camera:camera-camera2:$camerax_version" + implementation "androidx.camera:camera-lifecycle:$camerax_version" +} +``` + +See the detailed solutions API usage examples for different use cases in the +solution example apps' +[source code](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions). +If the prebuilt maven packages are not sufficient, building the MediaPipe +Android archive library locally by following these +[instructions](./android_archive_library.md). + +## Build solution example apps in Android Studio + +1. Open Android Studio Arctic Fox on Linux, macOS, or Windows. + +2. Import mediapipe/examples/android/solutions directory into Android Studio. + + ![Screenshot](../images/import_mp_android_studio_project.png) + +3. For Windows users, run `create_win_symlinks.bat` as administrator to create + res directory symlinks. + + ![Screenshot](../images/run_create_win_symlinks.png) + +4. Select "File" -> "Sync Project with Gradle Files" to sync project. + +5. Run solution example app in Android Studio. + + ![Screenshot](../images/run_android_solution_app.png) + +6. (Optional) Run solutions on CPU. + + MediaPipe solution example apps run the pipeline and the model inference on + GPU by default. If needed, for example to run the apps on Android Emulator, + set the `RUN_ON_GPU` boolean variable to `false` in the app's + MainActivity.java to run the pipeline and the model inference on CPU. diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index 9f277f799..6674d4023 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -31,8 +31,8 @@ stream on an Android device. ## Setup -1. Install MediaPipe on your system, see [MediaPipe installation guide] for - details. +1. Install MediaPipe on your system, see + [MediaPipe installation guide](./install.md) for details. 2. Install Android Development SDK and Android NDK. See how to do so also in [MediaPipe installation guide]. 3. Enable [developer options] on your Android device. @@ -770,7 +770,6 @@ If you ran into any issues, please see the full code of the tutorial [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java [`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout [`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java -[MediaPipe installation guide]:./install.md [`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java [`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html [`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 06d79c67d..4591b5f33 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -31,8 +31,8 @@ stream on an iOS device. ## Setup -1. Install MediaPipe on your system, see [MediaPipe installation guide] for - details. +1. Install MediaPipe on your system, see + [MediaPipe installation guide](./install.md) for details. 2. Setup your iOS device for development. 3. Setup [Bazel] on your system to build and deploy the iOS app. @@ -113,6 +113,10 @@ bazel to build the iOS application. The content of the 5. `Main.storyboard` and `Launch.storyboard` 6. `Assets.xcassets` directory. +Note: In newer versions of Xcode, you may see additional files `SceneDelegate.h` +and `SceneDelegate.m`. Make sure to copy them too and add them to the `BUILD` +file mentioned below. + Copy these files to a directory named `HelloWorld` to a location that can access the MediaPipe source code. For example, the source code of the application that we will build in this tutorial is located in @@ -247,6 +251,12 @@ We need to get frames from the `_cameraSource` into our application `MPPInputSourceDelegate`. So our application `ViewController` can be a delegate of `_cameraSource`. +Update the interface definition of `ViewController` accordingly: + +``` +@interface ViewController () +``` + To handle camera setup and process incoming frames, we should use a queue different from the main queue. Add the following to the implementation block of the `ViewController`: @@ -288,6 +298,12 @@ utility called `MPPLayerRenderer` to display images on the screen. This utility can be used to display `CVPixelBufferRef` objects, which is the type of the images provided by `MPPCameraInputSource` to its delegates. +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPLayerRenderer.h" +``` + To display images of the screen, we need to add a new `UIView` object called `_liveView` to the `ViewController`. @@ -411,6 +427,12 @@ Objective-C++. ### Use the graph in `ViewController` +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPGraph.h" +``` + Declare a static constant with the name of the graph, the input stream and the output stream: @@ -549,6 +571,12 @@ method to receive packets on this output stream and display them on the screen: } ``` +Update the interface definition of `ViewController` with `MPPGraphDelegate`: + +``` +@interface ViewController () +``` + And that is all! Build and run the app on your iOS device. You should see the results of running the edge detection graph on a live video feed. Congrats! @@ -560,6 +588,5 @@ appropriate `BUILD` file dependencies for the edge detection graph. [Bazel]:https://bazel.build/ [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt -[MediaPipe installation guide]:./install.md -[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common) -[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld) +[common]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common +[helloworld]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 95dce1d17..bb2539d33 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -43,104 +43,189 @@ install --user six`. 3. Install OpenCV and FFmpeg. - Option 1. Use package manager tool to install the pre-compiled OpenCV - libraries. FFmpeg will be installed via libopencv-video-dev. + **Option 1**. Use package manager tool to install the pre-compiled OpenCV + libraries. FFmpeg will be installed via `libopencv-video-dev`. - Note: Debian 9 and Ubuntu 16.04 provide OpenCV 2.4.9. You may want to take - option 2 or 3 to install OpenCV 3 or above. + OS | OpenCV + -------------------- | ------ + Debian 9 (stretch) | 2.4 + Debian 10 (buster) | 3.2 + Debian 11 (bullseye) | 4.5 + Ubuntu 16.04 LTS | 2.4 + Ubuntu 18.04 LTS | 3.2 + Ubuntu 20.04 LTS | 4.2 + Ubuntu 20.04 LTS | 4.2 + Ubuntu 21.04 | 4.5 ```bash - $ sudo apt-get install libopencv-core-dev libopencv-highgui-dev \ - libopencv-calib3d-dev libopencv-features2d-dev \ - libopencv-imgproc-dev libopencv-video-dev + $ sudo apt-get install -y \ + libopencv-core-dev \ + libopencv-highgui-dev \ + libopencv-calib3d-dev \ + libopencv-features2d-dev \ + libopencv-imgproc-dev \ + libopencv-video-dev ``` - Debian 9 and Ubuntu 18.04 install the packages in - `/usr/lib/x86_64-linux-gnu`. MediaPipe's [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] are configured for this library path. Ubuntu 20.04 - may install the OpenCV and FFmpeg packages in `/usr/local`, Please follow - the option 3 below to modify the [`WORKSPACE`], [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] files accordingly. - - Moreover, for Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, the - library path needs to be modified like the following: + MediaPipe's [`opencv_linux.BUILD`] and [`WORKSPACE`] are already configured + for OpenCV 2/3 and should work correctly on any architecture: ```bash - sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD + # WORKSPACE + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr", + ) + + # opencv_linux.BUILD for OpenCV 2/3 installed from Debian package + cc_library( + name = "opencv", + linkopts = [ + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) ``` - Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source - and modify MediaPipe's OpenCV config. + For OpenCV 4 you need to modify [`opencv_linux.BUILD`] taking into account + current architecture: - Option 3. Follow OpenCV's + ```bash + # WORKSPACE + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr", + ) + + # opencv_linux.BUILD for OpenCV 4 installed from Debian package + cc_library( + name = "opencv", + hdrs = glob([ + # Uncomment according to your multiarch value (gcc -print-multiarch): + # "include/aarch64-linux-gnu/opencv4/opencv2/cvconfig.h", + # "include/arm-linux-gnueabihf/opencv4/opencv2/cvconfig.h", + # "include/x86_64-linux-gnu/opencv4/opencv2/cvconfig.h", + "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + # Uncomment according to your multiarch value (gcc -print-multiarch): + # "include/aarch64-linux-gnu/opencv4/", + # "include/arm-linux-gnueabihf/opencv4/", + # "include/x86_64-linux-gnu/opencv4/", + "include/opencv4/", + ], + linkopts = [ + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + **Option 2**. Run [`setup_opencv.sh`] to automatically build OpenCV from + source and modify MediaPipe's OpenCV config. This option will do all steps + defined in Option 3 automatically. + + **Option 3**. Follow OpenCV's [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) to manually build OpenCV from source code. - Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg - libraries. For example if OpenCV and FFmpeg are both manually installed in - "/usr/local/", you will need to update: (1) the "linux_opencv" and - "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv" - cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg" - cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below: + You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to point + MediaPipe to your own OpenCV libraries. Assume OpenCV would be installed to + `/usr/local/` which is recommended by default. + + OpenCV 2/3 setup: ```bash + # WORKSPACE new_local_repository( - name = "linux_opencv", - build_file = "@//third_party:opencv_linux.BUILD", - path = "/usr/local", + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", ) + # opencv_linux.BUILD for OpenCV 2/3 installed to /usr/local + cc_library( + name = "opencv", + linkopts = [ + "-L/usr/local/lib", + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + OpenCV 4 setup: + + ```bash + # WORKSPACE new_local_repository( - name = "linux_ffmpeg", - build_file = "@//third_party:ffmpeg_linux.BUILD", - path = "/usr/local", + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", ) + # opencv_linux.BUILD for OpenCV 4 installed to /usr/local cc_library( - name = "opencv", - srcs = glob( - [ - "lib/libopencv_core.so", - "lib/libopencv_highgui.so", - "lib/libopencv_imgcodecs.so", - "lib/libopencv_imgproc.so", - "lib/libopencv_video.so", - "lib/libopencv_videoio.so", - ], - ), - hdrs = glob([ - # For OpenCV 3.x - "include/opencv2/**/*.h*", - # For OpenCV 4.x - # "include/opencv4/opencv2/**/*.h*", - ]), - includes = [ - # For OpenCV 3.x - "include/", - # For OpenCV 4.x - # "include/opencv4/", - ], - linkstatic = 1, - visibility = ["//visibility:public"], + name = "opencv", + hdrs = glob([ + "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + "include/opencv4/", + ], + linkopts = [ + "-L/usr/local/lib", + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + Current FFmpeg setup is defined in [`ffmpeg_linux.BUILD`] and should work + for any architecture: + + ```bash + # WORKSPACE + new_local_repository( + name = "linux_ffmpeg", + build_file = "@//third_party:ffmpeg_linux.BUILD", + path = "/usr" ) + # ffmpeg_linux.BUILD for FFmpeg installed from Debian package cc_library( - name = "libffmpeg", - srcs = glob( - [ - "lib/libav*.so", - ], - ), - hdrs = glob(["include/libav*/*.h"]), - includes = ["include"], - linkopts = [ - "-lavcodec", - "-lavformat", - "-lavutil", - ], - linkstatic = 1, - visibility = ["//visibility:public"], + name = "libffmpeg", + linkopts = [ + "-l:libavcodec.so", + "-l:libavformat.so", + "-l:libavutil.so", + ], ) ``` @@ -711,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the ```bash $ docker run -it --name mediapipe mediapipe:latest - root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world + root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world # Should print: # Hello World! diff --git a/docs/getting_started/javascript.md b/docs/getting_started/javascript.md index 0c49e1dd4..f56abcd6e 100644 --- a/docs/getting_started/javascript.md +++ b/docs/getting_started/javascript.md @@ -16,17 +16,29 @@ nav_order: 4 MediaPipe currently offers the following solutions: -Solution | NPM Package | Example ------------------ | ----------------------------- | ------- -[Face Mesh][F-pg] | [@mediapipe/face_mesh][F-npm] | [mediapipe.dev/demo/face_mesh][F-demo] -[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] -[Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] -[Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] -[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] +Solution | NPM Package | Example +--------------------------- | --------------------------------------- | ------- +[Face Mesh][F-pg] | [@mediapipe/face_mesh][F-npm] | [mediapipe.dev/demo/face_mesh][F-demo] +[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] +[Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] +[Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] +[Objectron][Ob-pg] | [@mediapipe/objectron][Ob-npm] | [mediapipe.dev/demo/objectron][Ob-demo] +[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] +[Selfie Segmentation][S-pg] | [@mediapipe/selfie_segmentation][S-npm] | [mediapipe.dev/demo/selfie_segmentation][S-demo] Click on a solution link above for more information, including API and code snippets. +### Supported plaforms: + +| Browser | Platform | Notes | +| ------- | ----------------------- | -------------------------------------- | +| Chrome | Android / Windows / Mac | Pixel 4 and older unsupported. Fuschia | +| | | unsupported. | +| Chrome | iOS | Camera unavailable in Chrome on iOS. | +| Safari | iPad/iPhone/Mac | iOS and Safari on iPad / iPhone / | +| | | MacBook | + The quickest way to get acclimated is to look at the examples above. Each demo has a link to a [CodePen][codepen] so that you can edit the code and try it yourself. We have included a number of utility packages to help you get started: @@ -66,29 +78,25 @@ affecting your work, restrict your request to a `` number. e.g., [F-pg]: ../solutions/face_mesh#javascript-solution-api [Fd-pg]: ../solutions/face_detection#javascript-solution-api [H-pg]: ../solutions/hands#javascript-solution-api +[Ob-pg]: ../solutions/objectron#javascript-solution-api [P-pg]: ../solutions/pose#javascript-solution-api +[S-pg]: ../solutions/selfie_segmentation#javascript-solution-api [Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic [F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh [Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection [H-npm]: https://www.npmjs.com/package/@mediapipe/hands +[Ob-npm]: https://www.npmjs.com/package/@mediapipe/objectron [P-npm]: https://www.npmjs.com/package/@mediapipe/pose +[S-npm]: https://www.npmjs.com/package/@mediapipe/selfie_segmentation [draw-npm]: https://www.npmjs.com/package/@mediapipe/drawing_utils [cam-npm]: https://www.npmjs.com/package/@mediapipe/camera_utils [ctrl-npm]: https://www.npmjs.com/package/@mediapipe/control_utils -[Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic -[F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh -[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection -[H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands -[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose -[Ho-pen]: https://code.mediapipe.dev/codepen/holistic -[F-pen]: https://code.mediapipe.dev/codepen/face_mesh -[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection -[H-pen]: https://code.mediapipe.dev/codepen/hands -[P-pen]: https://code.mediapipe.dev/codepen/pose [Ho-demo]: https://mediapipe.dev/demo/holistic [F-demo]: https://mediapipe.dev/demo/face_mesh [Fd-demo]: https://mediapipe.dev/demo/face_detection [H-demo]: https://mediapipe.dev/demo/hands +[Ob-demo]: https://mediapipe.dev/demo/objectron [P-demo]: https://mediapipe.dev/demo/pose +[S-demo]: https://mediapipe.dev/demo/selfie_segmentation [npm]: https://www.npmjs.com/package/@mediapipe [codepen]: https://code.mediapipe.dev/codepen diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index d59f35bbf..83550be84 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -51,6 +51,7 @@ details in each solution via the links below: * [MediaPipe Holistic](../solutions/holistic#python-solution-api) * [MediaPipe Objectron](../solutions/objectron#python-solution-api) * [MediaPipe Pose](../solutions/pose#python-solution-api) +* [MediaPipe Selfie Segmentation](../solutions/selfie_segmentation#python-solution-api) ## MediaPipe on Google Colab @@ -62,6 +63,7 @@ details in each solution via the links below: * [MediaPipe Pose Colab](https://mediapipe.page.link/pose_py_colab) * [MediaPipe Pose Classification Colab (Basic)](https://mediapipe.page.link/pose_classification_basic) * [MediaPipe Pose Classification Colab (Extended)](https://mediapipe.page.link/pose_classification_extended) +* [MediaPipe Selfie Segmentation Colab](https://mediapipe.page.link/selfie_segmentation_py_colab) ## MediaPipe Python Framework diff --git a/docs/getting_started/python_framework.md b/docs/getting_started/python_framework.md index ece14bc91..688285d87 100644 --- a/docs/getting_started/python_framework.md +++ b/docs/getting_started/python_framework.md @@ -74,7 +74,7 @@ Mapping\[str, Packet\] | std::map | create_st np.ndarray
(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(
        format=ImageFormat.SRGB,
        data=mat) | get_image_frame(packet) np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet) Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet) -List\[Proto\] | std::vector\ | create_proto_vector(proto_list) | get_proto_list(packet) +List\[Proto\] | std::vector\ | n/a | get_proto_list(packet) It's not uncommon that users create custom C++ classes and and send those into the graphs and calculators. To allow the custom classes to be used in Python diff --git a/docs/images/import_mp_android_studio_project.png b/docs/images/import_mp_android_studio_project.png new file mode 100644 index 000000000..aa02b95ce Binary files /dev/null and b/docs/images/import_mp_android_studio_project.png differ diff --git a/docs/images/mobile/pose_segmentation.mp4 b/docs/images/mobile/pose_segmentation.mp4 new file mode 100644 index 000000000..e0a68da70 Binary files /dev/null and b/docs/images/mobile/pose_segmentation.mp4 differ diff --git a/docs/images/mobile/pose_tracking_pck_chart.png b/docs/images/mobile/pose_tracking_pck_chart.png index 8b781e630..1fa4bf97d 100644 Binary files a/docs/images/mobile/pose_tracking_pck_chart.png and b/docs/images/mobile/pose_tracking_pck_chart.png differ diff --git a/docs/images/mobile/pose_world_landmarks.mp4 b/docs/images/mobile/pose_world_landmarks.mp4 new file mode 100644 index 000000000..4a5bf3016 Binary files /dev/null and b/docs/images/mobile/pose_world_landmarks.mp4 differ diff --git a/docs/images/run_android_solution_app.png b/docs/images/run_android_solution_app.png new file mode 100644 index 000000000..aa21f3c24 Binary files /dev/null and b/docs/images/run_android_solution_app.png differ diff --git a/docs/images/run_create_win_symlinks.png b/docs/images/run_create_win_symlinks.png new file mode 100644 index 000000000..69b94b75f Binary files /dev/null and b/docs/images/run_create_win_symlinks.png differ diff --git a/docs/images/selfie_segmentation_web.mp4 b/docs/images/selfie_segmentation_web.mp4 new file mode 100644 index 000000000..d9e62838e Binary files /dev/null and b/docs/images/selfie_segmentation_web.mp4 differ diff --git a/docs/index.md b/docs/index.md index 9035bf106..86d6ddc5e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -40,11 +40,12 @@ Hair Segmentation [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | [Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) | ✅ | ✅ | ✅ | ✅ | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | @@ -54,46 +55,22 @@ See also [MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) for ML models released in MediaPipe. -## MediaPipe in Python - -MediaPipe offers customizable Python solutions as a prebuilt Python package on -[PyPI](https://pypi.org/project/mediapipe/), which can be installed simply with -`pip install mediapipe`. It also provides tools for users to build their own -solutions. Please see -[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) -for more info. - -## MediaPipe on the Web - -MediaPipe on the Web is an effort to run the same ML solutions built for mobile -and desktop also in web browsers. The official API is under construction, but -the core technology has been proven effective. Please see -[MediaPipe on the Web](https://developers.googleblog.com/2020/01/mediapipe-on-web.html) -in Google Developers Blog for details. - -You can use the following links to load a demo in the MediaPipe Visualizer, and -over there click the "Runner" icon in the top bar like shown below. The demos -use your webcam video as input, which is processed all locally in real-time and -never leaves your device. - -![visualizer_runner](images/visualizer_runner.png) - -* [MediaPipe Face Detection](https://viz.mediapipe.dev/demo/face_detection) -* [MediaPipe Iris](https://viz.mediapipe.dev/demo/iris_tracking) -* [MediaPipe Iris: Depth-from-Iris](https://viz.mediapipe.dev/demo/iris_depth) -* [MediaPipe Hands](https://viz.mediapipe.dev/demo/hand_tracking) -* [MediaPipe Hands (palm/hand detection only)](https://viz.mediapipe.dev/demo/hand_detection) -* [MediaPipe Pose](https://viz.mediapipe.dev/demo/pose_tracking) -* [MediaPipe Hair Segmentation](https://viz.mediapipe.dev/demo/hair_segmentation) - ## Getting started -Learn how to [install](https://google.github.io/mediapipe/getting_started/install) -MediaPipe and -[build example applications](https://google.github.io/mediapipe/getting_started/building_examples), -and start exploring our ready-to-use -[solutions](https://google.github.io/mediapipe/solutions/solutions) that you can -further extend and customize. +To start using MediaPipe +[solutions](https://google.github.io/mediapipe/solutions/solutions) with only a few +lines code, see example code and demos in +[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) and +[MediaPipe in JavaScript](https://google.github.io/mediapipe/getting_started/javascript). + +To use MediaPipe in C++, Android and iOS, which allow further customization of +the [solutions](https://google.github.io/mediapipe/solutions/solutions) as well as +building your own, learn how to +[install](https://google.github.io/mediapipe/getting_started/install) MediaPipe and +start building example applications in +[C++](https://google.github.io/mediapipe/getting_started/cpp), +[Android](https://google.github.io/mediapipe/getting_started/android) and +[iOS](https://google.github.io/mediapipe/getting_started/ios). The source code is hosted in the [MediaPipe Github repository](https://github.com/google/mediapipe), and you can @@ -102,6 +79,13 @@ run code search using ## Publications +* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) + in Google Developers Blog +* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html) + in Google Developers Blog +* [SignAll SDK: Sign language interface using MediaPipe is now available for + developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) + in Google Developers Blog * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) diff --git a/docs/solutions/autoflip.md b/docs/solutions/autoflip.md index 0e118cc55..676abcae8 100644 --- a/docs/solutions/autoflip.md +++ b/docs/solutions/autoflip.md @@ -2,7 +2,7 @@ layout: default title: AutoFlip (Saliency-aware Video Cropping) parent: Solutions -nav_order: 13 +nav_order: 14 --- # AutoFlip: Saliency-aware Video Cropping diff --git a/docs/solutions/box_tracking.md b/docs/solutions/box_tracking.md index 0e7550e7f..b84a015d1 100644 --- a/docs/solutions/box_tracking.md +++ b/docs/solutions/box_tracking.md @@ -2,7 +2,7 @@ layout: default title: Box Tracking parent: Solutions -nav_order: 9 +nav_order: 10 --- # MediaPipe Box Tracking diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 8d5de36eb..9d08ee482 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -45,6 +45,15 @@ section. Naming style and availability may differ slightly across platforms/languages. +#### model_selection + +An integer index `0` or `1`. Use `0` to select a short-range model that works +best for faces within 2 meters from the camera, and `1` for a full-range model +best for faces within 5 meters. For the full-range option, a sparse model is +used for its improved inference speed. Please refer to the +[model cards](./models.md#face_detection) for details. Default to `0` if not +specified. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the @@ -68,10 +77,11 @@ normalized to `[0.0, 1.0]` by the image width and height respectively. Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: +* [model_selection](#model_selection) * [min_detection_confidence](#min_detection_confidence) ```python @@ -81,9 +91,10 @@ mp_face_detection = mp.solutions.face_detection mp_drawing = mp.solutions.drawing_utils # For static images: +IMAGE_FILES = [] with mp_face_detection.FaceDetection( - min_detection_confidence=0.5) as face_detection: - for idx, file in enumerate(file_list): + model_selection=1, min_detection_confidence=0.5) as face_detection: + for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) # Convert the BGR image to RGB and process it with MediaPipe Face Detection. results = face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) @@ -102,7 +113,7 @@ with mp_face_detection.FaceDetection( # For webcam input: cap = cv2.VideoCapture(0) with mp_face_detection.FaceDetection( - min_detection_confidence=0.5) as face_detection: + model_selection=0, min_detection_confidence=0.5) as face_detection: while cap.isOpened(): success, image = cap.read() if not success: @@ -138,6 +149,7 @@ and the following usage example. Supported configuration options: +* [modelSelection](#model_selection) * [minDetectionConfidence](#min_detection_confidence) ```html @@ -188,6 +200,7 @@ const faceDetection = new FaceDetection({locateFile: (file) => { return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`; }}); faceDetection.setOptions({ + modelSelection: 0 minDetectionConfidence: 0.5 }); faceDetection.onResults(onResults); @@ -254,10 +267,6 @@ same configuration as the GPU pipeline, runs entirely on CPU. * Target: [`mediapipe/examples/desktop/face_detection:face_detection_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/face_detection/BUILD) -### Web - -Please refer to [these instructions](../index.md#mediapipe-on-the-web). - ### Coral Please refer to diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 0c620120c..a94785324 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -69,7 +69,7 @@ and renders using a dedicated The [face landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt) internally uses a -[face_detection_subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt) +[face_detection_subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt) from the [face detection module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection). @@ -265,7 +265,7 @@ magnitude of `z` uses roughly the same scale as `x`. Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: @@ -278,15 +278,17 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_face_mesh = mp.solutions.face_mesh # For static images: +IMAGE_FILES = [] drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5) as face_mesh: - for idx, file in enumerate(file_list): + for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) # Convert the BGR image to RGB before processing. results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) @@ -300,9 +302,17 @@ with mp_face_mesh.FaceMesh( mp_drawing.draw_landmarks( image=annotated_image, landmark_list=face_landmarks, - connections=mp_face_mesh.FACE_CONNECTIONS, - landmark_drawing_spec=drawing_spec, - connection_drawing_spec=drawing_spec) + connections=mp_face_mesh.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) + mp_drawing.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # For webcam input: @@ -334,9 +344,17 @@ with mp_face_mesh.FaceMesh( mp_drawing.draw_landmarks( image=image, landmark_list=face_landmarks, - connections=mp_face_mesh.FACE_CONNECTIONS, - landmark_drawing_spec=drawing_spec, - connection_drawing_spec=drawing_spec) + connections=mp_face_mesh.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) cv2.imshow('MediaPipe FaceMesh', image) if cv2.waitKey(5) & 0xFF == 27: break @@ -422,6 +440,200 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api) +to add MediaPipe Gradle dependencies, then try the FaceMash solution API in the +companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facemesh) +following +[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio) +and learn more in the usage example below. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumFaces](#max_num_faces) +* runOnGpu: Run the pipeline and the model inference on GPU or CPU. + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setMode(FaceMeshOptions.STREAMING_MODE) // API soon to become + .setMaxNumFaces(1) // setStaticImageMode(false) + .setRunOnGpu(true).build(); +FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); +facemesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe FaceMesh. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> facemesh.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +facemesh.setResultListener( + faceMeshResult -> { + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + facemesh.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setMode(FaceMeshOptions.STATIC_IMAGE_MODE) // API soon to become + .setMaxNumFaces(1) // setStaticImageMode(true) + .setRunOnGpu(true).build(); +FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); + +// Connects MediaPipe FaceMesh to the user-defined ImageView instance that allows +// users to have the custom drawing of the output landmarks on it. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java +// as an example. +FaceMeshResultImageView imageView = new FaceMeshResultImageView(this); +facemesh.setResultListener( + faceMeshResult -> { + int width = faceMeshResult.inputBitmap().getWidth(); + int height = faceMeshResult.inputBitmap().getHeight(); + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe FaceMesh nose coordinates (pixel values): x=%f, y=%f", + noseLandmark.getX() * width, noseLandmark.getY() * height)); + // Request canvas drawing. + imageView.setFaceMeshResult(faceMeshResult); + runOnUiThread(() -> imageView.update()); + }); +facemesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + facemesh.send(bitmap); + } + } + }); +Intent gallery = new Intent( + Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); +imageGetter.launch(gallery); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setMode(FaceMeshOptions.STREAMING_MODE) // API soon to become + .setMaxNumFaces(1) // setStaticImageMode(false) + .setRunOnGpu(true).build(); +FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); +facemesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe FaceMesh. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> facemesh.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +facemesh.setResultListener( + faceMeshResult -> { + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + facemesh.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); +videoGetter.launch(gallery); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/hair_segmentation.md b/docs/solutions/hair_segmentation.md index 5e2e4a7c5..9dd997b95 100644 --- a/docs/solutions/hair_segmentation.md +++ b/docs/solutions/hair_segmentation.md @@ -2,7 +2,7 @@ layout: default title: Hair Segmentation parent: Solutions -nav_order: 7 +nav_order: 8 --- # MediaPipe Hair Segmentation @@ -51,7 +51,14 @@ to visualize its associated subgraphs, please see ### Web -Please refer to [these instructions](../index.md#mediapipe-on-the-web). +Use [this link](https://viz.mediapipe.dev/demo/hair_segmentation) to load a demo +in the MediaPipe Visualizer, and over there click the "Runner" icon in the top +bar like shown below. The demos use your webcam video as input, which is +processed all locally in real-time and never leaves your device. Please see +[MediaPipe on the Web](https://developers.googleblog.com/2020/01/mediapipe-on-web.html) +in Google Developers Blog for details. + +![visualizer_runner](../images/visualizer_runner.png) ## Resources diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index ac10124f2..c3088d64c 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -206,7 +206,7 @@ is not the case, please swap the handedness output in the application. Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: @@ -219,14 +219,16 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_hands = mp.solutions.hands # For static images: +IMAGE_FILES = [] with mp_hands.Hands( static_image_mode=True, max_num_hands=2, min_detection_confidence=0.5) as hands: - for idx, file in enumerate(file_list): + for idx, file in enumerate(IMAGE_FILES): # Read an image, flip it around y-axis for correct handedness output (see # above). image = cv2.flip(cv2.imread(file), 1) @@ -247,7 +249,11 @@ with mp_hands.Hands( f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})' ) mp_drawing.draw_landmarks( - annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + annotated_image, + hand_landmarks, + mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) cv2.imwrite( '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) @@ -277,7 +283,11 @@ with mp_hands.Hands( if results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks: mp_drawing.draw_landmarks( - image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + image, + hand_landmarks, + mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) cv2.imshow('MediaPipe Hands', image) if cv2.waitKey(5) & 0xFF == 27: break @@ -358,6 +368,200 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api) +to add MediaPipe Gradle dependencies, then try the Hands solution API in the +companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/hands) +following +[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio) +and learn more in usage example below. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumHands](#max_num_hands) +* runOnGpu: Run the pipeline and the model inference on GPU or CPU. + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +HandsOptions handsOptions = + HandsOptions.builder() + .setMode(HandsOptions.STREAMING_MODE) // API soon to become + .setMaxNumHands(1) // setStaticImageMode(false) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe Hands. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> hands.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, hands.getGlContext(), hands.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +hands.setResultListener( + handsResult -> { + NormalizedLandmark wristLandmark = Hands.getHandLandmark( + handsResult, 0, HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + hands.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +HandsOptions handsOptions = + HandsOptions.builder() + .setMode(HandsOptions.STATIC_IMAGE_MODE) // API soon to become + .setMaxNumHands(1) // setStaticImageMode(true) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); + +// Connects MediaPipe Hands to the user-defined ImageView instance that allows +// users to have the custom drawing of the output landmarks on it. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java +// as an example. +HandsResultImageView imageView = new HandsResultImageView(this); +hands.setResultListener( + handsResult -> { + int width = handsResult.inputBitmap().getWidth(); + int height = handsResult.inputBitmap().getHeight(); + NormalizedLandmark wristLandmark = Hands.getHandLandmark( + handsResult, 0, HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist coordinates (pixel values): x=%f, y=%f", + wristLandmark.getX() * width, wristLandmark.getY() * height)); + // Request canvas drawing. + imageView.setHandsResult(handsResult); + runOnUiThread(() -> imageView.update()); + }); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + hands.send(bitmap); + } + } + }); +Intent gallery = new Intent( + Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); +imageGetter.launch(gallery); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +HandsOptions handsOptions = + HandsOptions.builder() + .setMode(HandsOptions.STREAMING_MODE) // API soon to become + .setMaxNumHands(1) // setStaticImageMode(false) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe Hands. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> hands.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, hands.getGlContext(), hands.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +hands.setResultListener( + handsResult -> { + NormalizedLandmark wristLandmark = Hands.getHandLandmark( + handsResult, 0, HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + hands.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); +videoGetter.launch(gallery); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 7c02c8d75..0532a33dd 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -176,6 +176,16 @@ A list of pose landmarks. Each landmark consists of the following: * `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the landmark being visible (present and not occluded) in the image. +#### pose_world_landmarks + +Another list of pose landmarks in world coordinates. Each landmark consists of +the following: + +* `x`, `y` and `z`: Real-world 3D coordinates in meters with the origin at the + center between hips. +* `visibility`: Identical to that defined in the corresponding + [pose_landmarks](#pose_landmarks). + #### face_landmarks A list of 468 face landmarks. Each landmark consists of `x`, `y` and `z`. `x` @@ -201,7 +211,7 @@ A list of 21 hand landmarks on the right hand, in the same representation as Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: @@ -215,13 +225,15 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_holistic = mp.solutions.holistic # For static images: +IMAGE_FILES = [] with mp_holistic.Holistic( static_image_mode=True, model_complexity=2) as holistic: - for idx, file in enumerate(file_list): + for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) image_height, image_width, _ = image.shape # Convert the BGR image to RGB before processing. @@ -236,14 +248,22 @@ with mp_holistic.Holistic( # Draw pose, left and right hands, and face landmarks on the image. annotated_image = image.copy() mp_drawing.draw_landmarks( - annotated_image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + annotated_image, + results.face_landmarks, + mp_holistic.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) mp_drawing.draw_landmarks( - annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + annotated_image, + results.pose_landmarks, + mp_holistic.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles. + get_default_pose_landmarks_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + # Plot pose world landmarks. + mp_drawing.plot_landmarks( + results.pose_world_landmarks, mp_holistic.POSE_CONNECTIONS) # For webcam input: cap = cv2.VideoCapture(0) @@ -269,13 +289,18 @@ with mp_holistic.Holistic( image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) mp_drawing.draw_landmarks( - image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + image, + results.face_landmarks, + mp_holistic.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) mp_drawing.draw_landmarks( - image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + image, + results.pose_landmarks, + mp_holistic.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles + .get_default_pose_landmarks_style()) cv2.imshow('MediaPipe Holistic', image) if cv2.waitKey(5) & 0xFF == 27: break diff --git a/docs/solutions/instant_motion_tracking.md b/docs/solutions/instant_motion_tracking.md index 36e5e83e0..9fea7ec1c 100644 --- a/docs/solutions/instant_motion_tracking.md +++ b/docs/solutions/instant_motion_tracking.md @@ -2,7 +2,7 @@ layout: default title: Instant Motion Tracking parent: Solutions -nav_order: 10 +nav_order: 11 --- # MediaPipe Instant Motion Tracking diff --git a/docs/solutions/iris.md b/docs/solutions/iris.md index 61ca8049c..af71c895f 100644 --- a/docs/solutions/iris.md +++ b/docs/solutions/iris.md @@ -69,7 +69,7 @@ and renders using a dedicated The [face landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt) internally uses a -[face detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt) +[face detection subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt) from the [face detection module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection). @@ -193,7 +193,17 @@ on how to build MediaPipe examples. ### Web -Please refer to [these instructions](../index.md#mediapipe-on-the-web). +You can use the following links to load a demo in the MediaPipe Visualizer, and +over there click the "Runner" icon in the top bar like shown below. The demos +use your webcam video as input, which is processed all locally in real-time and +never leaves your device. Please see +[MediaPipe on the Web](https://developers.googleblog.com/2020/01/mediapipe-on-web.html) +in Google Developers Blog for details. + +![visualizer_runner](../images/visualizer_runner.png) + +* [MediaPipe Iris](https://viz.mediapipe.dev/demo/iris_tracking) +* [MediaPipe Iris: Depth-from-Iris](https://viz.mediapipe.dev/demo/iris_depth) ## Resources diff --git a/docs/solutions/knift.md b/docs/solutions/knift.md index 41691c418..b008f1496 100644 --- a/docs/solutions/knift.md +++ b/docs/solutions/knift.md @@ -2,7 +2,7 @@ layout: default title: KNIFT (Template-based Feature Matching) parent: Solutions -nav_order: 12 +nav_order: 13 --- # MediaPipe KNIFT diff --git a/docs/solutions/media_sequence.md b/docs/solutions/media_sequence.md index cd3b7ecef..e6bd5fd44 100644 --- a/docs/solutions/media_sequence.md +++ b/docs/solutions/media_sequence.md @@ -2,7 +2,7 @@ layout: default title: Dataset Preparation with MediaSequence parent: Solutions -nav_order: 14 +nav_order: 15 --- # Dataset Preparation with MediaSequence diff --git a/docs/solutions/models.md b/docs/solutions/models.md index e0ff4d14a..2f3001722 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -14,12 +14,27 @@ nav_order: 30 ### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) -* Face detection model for front-facing/selfie camera: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite), - [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite) -* Face detection model for back-facing camera: - [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_back.tflite) -* [Model card](https://mediapipe.page.link/blazeface-mc) +* Short-range model (best for faces within 2 meters from the camera): + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_short_range.tflite), + [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite), + [Model card](https://mediapipe.page.link/blazeface-mc) +* Full-range model (dense, best for faces within 5 meters from the camera): + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_full_range.tflite), + [Model card](https://mediapipe.page.link/blazeface-back-mc) +* Full-range model (sparse, best for faces within 5 meters from the camera): + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite), + [Model card](https://mediapipe.page.link/blazeface-back-sparse-mc) + +Full-range dense and sparse models have the same quality in terms of +[F-score](https://en.wikipedia.org/wiki/F-score) however differ in underlying +metrics. The dense model is slightly better in +[Recall](https://en.wikipedia.org/wiki/Precision_and_recall) whereas the sparse +model outperforms the dense one in +[Precision](https://en.wikipedia.org/wiki/Precision_and_recall). Speed-wise +sparse model is ~30% faster when executing on CPU via +[XNNPACK](https://github.com/google/XNNPACK) whereas on GPU the models +demonstrate comparable latencies. Depending on your application, you may prefer +one over the other. ### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) @@ -60,6 +75,12 @@ nav_order: 30 * Hand recrop model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/hand_recrop.tflite) +### [Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) + +* [TFLite model (general)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/selfie_segmentation/selfie_segmentation.tflite) +* [TFLite model (landscape)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/selfie_segmentation/selfie_segmentation_landscape.tflite) +* [Model card](https://mediapipe.page.link/selfiesegmentation-mc) + ### [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) * [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hair_segmentation.tflite) diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index 044748537..d7cc2cec1 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -2,7 +2,7 @@ layout: default title: Object Detection parent: Solutions -nav_order: 8 +nav_order: 9 --- # MediaPipe Object Detection diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 0164e23b3..d7dc8f045 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -2,7 +2,7 @@ layout: default title: Objectron (3D Object Detection) parent: Solutions -nav_order: 11 +nav_order: 12 --- # MediaPipe Objectron @@ -224,29 +224,33 @@ where object detection simply runs on every image. Default to `0.99`. #### model_name -Name of the model to use for predicting 3D bounding box landmarks. Currently supports -`{'Shoe', 'Chair', 'Cup', 'Camera'}`. +Name of the model to use for predicting 3D bounding box landmarks. Currently +supports `{'Shoe', 'Chair', 'Cup', 'Camera'}`. Default to `Shoe`. #### focal_length -Camera focal length `(fx, fy)`, by default is defined in -[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in -[pixel space](#pixel-space), users should provide `image_size` = `(image_width, -image_height)` to enable conversions inside the API. For further details about -NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). +By default, camera focal length defined in [NDC space](#ndc-space), i.e., `(fx, +fy)`. Default to `(1.0, 1.0)`. To specify focal length in +[pixel space](#pixel-space) instead, i.e., `(fx_pixel, fy_pixel)`, users should +provide [`image_size`](#image_size) = `(image_width, image_height)` to enable +conversions inside the API. For further details about NDC and pixel space, +please see [Coordinate Systems](#coordinate-systems). #### principal_point -Camera principal point `(px, py)`, by default is defined in -[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in -[pixel space](#pixel-space), users should provide `image_size` = `(image_width, -image_height)` to enable conversions inside the API. For further details about -NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). +By default, camera principal point defined in [NDC space](#ndc-space), i.e., +`(px, py)`. Default to `(0.0, 0.0)`. To specify principal point in +[pixel space](#pixel-space), i.e.,`(px_pixel, py_pixel)`, users should provide +[`image_size`](#image_size) = `(image_width, image_height)` to enable +conversions inside the API. For further details about NDC and pixel space, +please see [Coordinate Systems](#coordinate-systems). #### image_size -(**Optional**) size `(image_width, image_height)` of the input image, **ONLY** -needed when use `focal_length` and `principal_point` in pixel space. +**Specify only when [`focal_length`](#focal_length) and +[`principal_point`](#principal_point) are specified in pixel space.** + +Size of the input image, i.e., `(image_width, image_height)`. ### Output @@ -277,7 +281,7 @@ following: Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: @@ -297,11 +301,12 @@ mp_drawing = mp.solutions.drawing_utils mp_objectron = mp.solutions.objectron # For static images: +IMAGE_FILES = [] with mp_objectron.Objectron(static_image_mode=True, max_num_objects=5, min_detection_confidence=0.5, model_name='Shoe') as objectron: - for idx, file in enumerate(file_list): + for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) # Convert the BGR image to RGB and process it with MediaPipe Objectron. results = objectron.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) @@ -355,6 +360,89 @@ with mp_objectron.Objectron(static_image_mode=False, cap.release() ``` +## JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumObjects](#max_num_objects) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) +* [modelName](#model_name) +* [focalLength](#focal_length) +* [principalPoint](#principal_point) +* [imageSize](#image_size) + +```html + + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + ## Example Apps Please first see general instructions for @@ -441,7 +529,7 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http > ``` > and then run > -> ```build +> ```bash > bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] > ``` > INPUT_DIR should be the folder with initial asset .obj files to be processed, @@ -560,11 +648,15 @@ py = -py_pixel * 2.0 / image_height + 1.0 [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) * Google AI Blog: [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) -* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021 +* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the + Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in + CVPR 2021 * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak Shape Supervision](https://arxiv.org/abs/2003.03522) * Paper: [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) - ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020 + ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth + Workshop on Computer Vision for AR/VR, CVPR 2020 * [Models and model cards](./models.md#objectron) +* [Web demo](https://code.mediapipe.dev/codepen/objectron) * [Python Colab](https://mediapipe.page.link/objectron_py_colab) diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index feed2ad34..271199bb5 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -30,7 +30,8 @@ overlay of digital content and information on top of the physical world in augmented reality. MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring -33 3D landmarks on the whole body from RGB video frames utilizing our +33 3D landmarks and background segmentation mask on the whole body from RGB +video frames utilizing our [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) research that also powers the [ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection). @@ -49,11 +50,11 @@ The solution utilizes a two-step detector-tracker ML pipeline, proven to be effective in our [MediaPipe Hands](./hands.md) and [MediaPipe Face Mesh](./face_mesh.md) solutions. Using a detector, the pipeline first locates the person/pose region-of-interest (ROI) within the frame. The -tracker subsequently predicts the pose landmarks within the ROI using the -ROI-cropped frame as input. Note that for video use cases the detector is -invoked only as needed, i.e., for the very first frame and when the tracker -could no longer identify body pose presence in the previous frame. For other -frames the pipeline simply derives the ROI from the previous frame’s pose +tracker subsequently predicts the pose landmarks and segmentation mask within +the ROI using the ROI-cropped frame as input. Note that for video use cases the +detector is invoked only as needed, i.e., for the very first frame and when the +tracker could no longer identify body pose presence in the previous frame. For +other frames the pipeline simply derives the ROI from the previous frame’s pose landmarks. The pipeline is implemented as a MediaPipe @@ -87,11 +88,11 @@ from [COCO topology](https://cocodataset.org/#keypoints-2020). Method | Yoga
[`mAP`] | Yoga
[`PCK@0.2`] | Dance
[`mAP`] | Dance
[`PCK@0.2`] | HIIT
[`mAP`] | HIIT
[`PCK@0.2`] ----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------: -BlazePose.Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5** -BlazePose.Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7** -BlazePose.Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5** -[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0** -[Apple.Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6** +BlazePose GHUM Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5** +BlazePose GHUM Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7** +BlazePose GHUM Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5** +[AlphaPose ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0** +[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6** ![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) | :--------------------------------------------------------------------------: | @@ -100,11 +101,11 @@ BlazePose.Lite We designed our models specifically for live perception use cases, so all of them work in real-time on the majority of modern devices. -Method | Latency
Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency
MacBook Pro (15-inch 2017) ---------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------: -BlazePose.Heavy | 53 ms | 38 ms -BlazePose.Full | 25 ms | 27 ms -BlazePose.Lite | 20 ms | 25 ms +Method | Latency
Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency
MacBook Pro (15-inch 2017) +-------------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------: +BlazePose GHUM Heavy | 53 ms | 38 ms +BlazePose GHUM Full | 25 ms | 27 ms +BlazePose GHUM Lite | 20 ms | 25 ms ## Models @@ -129,16 +130,19 @@ hip midpoints. The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks (see figure below). -Please find more detail in the -[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), -this [paper](https://arxiv.org/abs/2006.10204) and -[the model card](./models.md#pose), and the attributes in each landmark -[below](#pose_landmarks). - ![pose_tracking_full_body_landmarks.png](../images/mobile/pose_tracking_full_body_landmarks.png) | :----------------------------------------------------------------------------------------------: | *Fig 4. 33 pose landmarks.* | +Optionally, MediaPipe Pose can predicts a full-body +[segmentation mask](#segmentation_mask) represented as a two-class segmentation +(human or background). + +Please find more detail in the +[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), +this [paper](https://arxiv.org/abs/2006.10204), +[the model card](./models.md#pose) and the [Output](#output) section below. + ## Solution APIs ### Cross-platform Configuration Options @@ -167,6 +171,18 @@ If set to `true`, the solution filters pose landmarks across different input images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) is also set to `true`. Default to `true`. +#### enable_segmentation + +If set to `true`, in addition to the pose landmarks the solution also generates +the segmentation mask. Default to `false`. + +#### smooth_segmentation + +If set to `true`, the solution filters segmentation masks across different input +images to reduce jitter. Ignored if [enable_segmentation](#enable_segmentation) +is `false` or [static_image_mode](#static_image_mode) is `true`. Default to +`true`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the @@ -187,28 +203,56 @@ Naming style may differ slightly across platforms/languages. #### pose_landmarks -A list of pose landmarks. Each lanmark consists of the following: +A list of pose landmarks. Each landmark consists of the following: * `x` and `y`: Landmark coordinates normalized to `[0.0, 1.0]` by the image width and height respectively. * `z`: Represents the landmark depth with the depth at the midpoint of hips being the origin, and the smaller the value the closer the landmark is to the camera. The magnitude of `z` uses roughly the same scale as `x`. - * `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the landmark being visible (present and not occluded) in the image. +#### pose_world_landmarks + +*Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* | +:-----------------------------------------------------------: | + | + +Another list of pose landmarks in world coordinates. Each landmark consists of +the following: + +* `x`, `y` and `z`: Real-world 3D coordinates in meters with the origin at the + center between hips. +* `visibility`: Identical to that defined in the corresponding + [pose_landmarks](#pose_landmarks). + +#### segmentation_mask + +The output segmentation mask, predicted only when +[enable_segmentation](#enable_segmentation) is set to `true`. The mask has the +same width and height as the input image, and contains values in `[0.0, 1.0]` +where `1.0` and `0.0` indicate high certainty of a "human" and "background" +pixel respectively. Please refer to the platform-specific usage examples below +for usage details. + +*Fig 6. Example of MediaPipe Pose segmentation mask.* | +:---------------------------------------------------: | + | + ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to install MediaPipe Python package, then learn more in the companion -[Python Colab](#resources) and the following usage example. +[Python Colab](#resources) and the usage example below. Supported configuration options: * [static_image_mode](#static_image_mode) * [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) +* [enable_segmentation](#enable_segmentation) +* [smooth_segmentation](#smooth_segmentation) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -216,14 +260,18 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_pose = mp.solutions.pose # For static images: +IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_pose.Pose( static_image_mode=True, model_complexity=2, + enable_segmentation=True, min_detection_confidence=0.5) as pose: - for idx, file in enumerate(file_list): + for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) image_height, image_width, _ = image.shape # Convert the BGR image to RGB before processing. @@ -233,14 +281,28 @@ with mp_pose.Pose( continue print( f'Nose coordinates: (' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].y * image_height})' ) - # Draw pose landmarks on the image. + annotated_image = image.copy() + # Draw segmentation on the image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + annotated_image = np.where(condition, annotated_image, bg_image) + # Draw pose landmarks on the image. mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + annotated_image, + results.pose_landmarks, + mp_pose.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + # Plot pose world landmarks. + mp_drawing.plot_landmarks( + results.pose_world_landmarks, mp_pose.POSE_CONNECTIONS) # For webcam input: cap = cv2.VideoCapture(0) @@ -266,7 +328,10 @@ with mp_pose.Pose( image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + image, + results.pose_landmarks, + mp_pose.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style()) cv2.imshow('MediaPipe Pose', image) if cv2.waitKey(5) & 0xFF == 27: break @@ -283,6 +348,8 @@ Supported configuration options: * [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) +* [enableSegmentation](#enable_segmentation) +* [smoothSegmentation](#smooth_segmentation) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -293,6 +360,7 @@ Supported configuration options: + @@ -301,6 +369,7 @@ Supported configuration options:
+
@@ -311,17 +380,38 @@ Supported configuration options: const videoElement = document.getElementsByClassName('input_video')[0]; const canvasElement = document.getElementsByClassName('output_canvas')[0]; const canvasCtx = canvasElement.getContext('2d'); +const landmarkContainer = document.getElementsByClassName('landmark-grid-container')[0]; +const grid = new LandmarkGrid(landmarkContainer); function onResults(results) { + if (!results.poseLandmarks) { + grid.updateLandmarks([]); + return; + } + canvasCtx.save(); canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); + canvasCtx.drawImage(results.segmentationMask, 0, 0, + canvasElement.width, canvasElement.height); + + // Only overwrite existing pixels. + canvasCtx.globalCompositeOperation = 'source-in'; + canvasCtx.fillStyle = '#00FF00'; + canvasCtx.fillRect(0, 0, canvasElement.width, canvasElement.height); + + // Only overwrite missing pixels. + canvasCtx.globalCompositeOperation = 'destination-atop'; canvasCtx.drawImage( results.image, 0, 0, canvasElement.width, canvasElement.height); + + canvasCtx.globalCompositeOperation = 'source-over'; drawConnectors(canvasCtx, results.poseLandmarks, POSE_CONNECTIONS, {color: '#00FF00', lineWidth: 4}); drawLandmarks(canvasCtx, results.poseLandmarks, {color: '#FF0000', lineWidth: 2}); canvasCtx.restore(); + + grid.updateLandmarks(results.poseWorldLandmarks); } const pose = new Pose({locateFile: (file) => { @@ -330,6 +420,8 @@ const pose = new Pose({locateFile: (file) => { pose.setOptions({ modelComplexity: 1, smoothLandmarks: true, + enableSegmentation: true, + smoothSegmentation: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md new file mode 100644 index 000000000..2cb155fb3 --- /dev/null +++ b/docs/solutions/selfie_segmentation.md @@ -0,0 +1,290 @@ +--- +layout: default +title: Selfie Segmentation +parent: Solutions +nav_order: 7 +--- + +# MediaPipe Selfie Segmentation +{: .no_toc } + +
+ + Table of contents + + {: .text-delta } +1. TOC +{:toc} +
+--- + +## Overview + +*Fig 1. Example of MediaPipe Selfie Segmentation.* | +:------------------------------------------------: | + | + +MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can +run in real-time on both smartphones and laptops. The intended use cases include +selfie effects and video conferencing, where the person is close (< 2m) to the +camera. + +## Models + +In this solution, we provide two models: general and landscape. Both models are +based on +[MobileNetV3](https://ai.googleblog.com/2019/11/introducing-next-generation-on-device.html), +with modifications to make them more efficient. The general model operates on a +256x256x3 (HWC) tensor, and outputs a 256x256x1 tensor representing the +segmentation mask. The landscape model is similar to the general model, but +operates on a 144x256x3 (HWC) tensor. It has fewer FLOPs than the general model, +and therefore, runs faster. Note that MediaPipe Selfie Segmentation +automatically resizes the input image to the desired tensor dimension before +feeding it into the ML models. + +The general model is also powering +[ML Kit](https://developers.google.com/ml-kit/vision/selfie-segmentation), and a +variant of the landscape model is powering +[Google Meet](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html). +Please find more detail about the models in the +[model card](./models.md#selfie-segmentation). + +## ML Pipeline + +The pipeline is implemented as a MediaPipe +[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/selfie_segmentation/selfie_segmentation_gpu.pbtxt) +that uses a +[selfie segmentation subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu.pbtxt) +from the +[selfie segmentation module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/selfie_segmentation). + +Note: To visualize a graph, copy the graph and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how +to visualize its associated subgraphs, please see +[visualizer documentation](../tools/visualizer.md). + +## Solution APIs + +### Cross-platform Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### model_selection + +An integer index `0` or `1`. Use `0` to select the general model, and `1` to +select the landscape model (see details in [Models](#models)). Default to `0` if +not specified. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### segmentation_mask + +The output segmentation mask, which has the same dimension as the input image. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the usage example below. + +Supported configuration options: + +* [model_selection](#model_selection) + +```python +import cv2 +import mediapipe as mp +import numpy as np +mp_drawing = mp.solutions.drawing_utils +mp_selfie_segmentation = mp.solutions.selfie_segmentation + +# For static images: +IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray +MASK_COLOR = (255, 255, 255) # white +with mp_selfie_segmentation.SelfieSegmentation( + model_selection=0) as selfie_segmentation: + for idx, file in enumerate(IMAGE_FILES): + image = cv2.imread(file) + image_height, image_width, _ = image.shape + # Convert the BGR image to RGB before processing. + results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Draw selfie segmentation on the background image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 + # Generate solid color images for showing the output selfie segmentation mask. + fg_image = np.zeros(image.shape, dtype=np.uint8) + fg_image[:] = MASK_COLOR + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + output_image = np.where(condition, fg_image, bg_image) + cv2.imwrite('/tmp/selfie_segmentation_output' + str(idx) + '.png', output_image) + +# For webcam input: +BG_COLOR = (192, 192, 192) # gray +cap = cv2.VideoCapture(0) +with mp_selfie_segmentation.SelfieSegmentation( + model_selection=1) as selfie_segmentation: + bg_image = None + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = selfie_segmentation.process(image) + + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + # Draw selfie segmentation on the background image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack( + (results.segmentation_mask,) * 3, axis=-1) > 0.1 + # The background can be customized. + # a) Load an image (with the same width and height of the input image) to + # be the background, e.g., bg_image = cv2.imread('/path/to/image/file') + # b) Blur the input image by applying image filtering, e.g., + # bg_image = cv2.GaussianBlur(image,(55,55),0) + if bg_image is None: + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + output_image = np.where(condition, image, bg_image) + + cv2.imshow('MediaPipe Selfie Segmentation', output_image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [modelSelection](#model_selection) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + +## Example Apps + +Please first see general instructions for +[Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and +[desktop](../getting_started/cpp.md) on how to build MediaPipe examples. + +Note: To visualize a graph, copy the graph and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how +to visualize its associated subgraphs, please see +[visualizer documentation](../tools/visualizer.md). + +### Mobile + +* Graph: + [`mediapipe/graphs/selfie_segmentation/selfie_segmentation_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/selfie_segmentation/selfie_segmentation_gpu.pbtxt) +* Android target: + [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing) + [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD) +* iOS target: + [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) + +### Desktop + +Please first see general instructions for [desktop](../getting_started/cpp.md) +on how to build MediaPipe examples. + +* Running on CPU + * Graph: + [`mediapipe/graphs/selfie_segmentation/selfie_segmentation_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/selfie_segmentation/selfie_segmentation_cpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/selfie_segmentation:selfie_segmentation_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/selfie_segmentation/BUILD) +* Running on GPU + * Graph: + [`mediapipe/graphs/selfie_segmentation/selfie_segmentation_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/selfie_segmentation/selfie_segmentation_gpu.pbtxt) + * Target: + [`mediapipe/examples/desktop/selfie_segmentation:selfie_segmentation_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/selfie_segmentation/BUILD) + +## Resources + +* Google AI Blog: + [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) +* [ML Kit Selfie Segmentation API](https://developers.google.com/ml-kit/vision/selfie-segmentation) +* [Models and model cards](./models.md#selfie-segmentation) +* [Web demo](https://code.mediapipe.dev/codepen/selfie_segmentation) +* [Python Colab](https://mediapipe.page.link/selfie_segmentation_py_colab) diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index a95f0c032..e9e4cdc38 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -13,6 +13,9 @@ has_toc: false {:toc} --- +MediaPipe offers open source cross-platform, customizable ML solutions for live +and streaming media. + @@ -24,11 +27,12 @@ has_toc: false [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | [Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | +[Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) | ✅ | ✅ | ✅ | ✅ | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/docs/solutions/youtube_8m.md b/docs/solutions/youtube_8m.md index abef6f1b6..5415c146a 100644 --- a/docs/solutions/youtube_8m.md +++ b/docs/solutions/youtube_8m.md @@ -2,7 +2,7 @@ layout: default title: YouTube-8M Feature Extraction and Model Inference parent: Solutions -nav_order: 15 +nav_order: 16 --- # YouTube-8M Feature Extraction and Model Inference diff --git a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen index 11daafdcb..f3b74900c 100644 --- a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen +++ b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen @@ -16,6 +16,7 @@ "mediapipe/examples/ios/objectdetectiongpu/BUILD", "mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD", "mediapipe/examples/ios/posetrackinggpu/BUILD", + "mediapipe/examples/ios/selfiesegmentationgpu/BUILD", "mediapipe/framework/BUILD", "mediapipe/gpu/BUILD", "mediapipe/objc/BUILD", @@ -35,6 +36,7 @@ "//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp", "//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp", "//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp", + "//mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp", "//mediapipe/objc:mediapipe_framework_ios" ], "optionSet" : { @@ -103,6 +105,7 @@ "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/posetrackinggpu", + "mediapipe/examples/ios/selfiesegmentationgpu", "mediapipe/framework", "mediapipe/framework/deps", "mediapipe/framework/formats", @@ -120,6 +123,7 @@ "mediapipe/graphs/hand_tracking", "mediapipe/graphs/object_detection", "mediapipe/graphs/pose_tracking", + "mediapipe/graphs/selfie_segmentation", "mediapipe/models", "mediapipe/modules", "mediapipe/objc", diff --git a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf index 33498e8c1..a2fe886cf 100644 --- a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf +++ b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf @@ -22,6 +22,7 @@ "mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/objectdetectiontrackinggpu", "mediapipe/examples/ios/posetrackinggpu", + "mediapipe/examples/ios/selfiesegmentationgpu", "mediapipe/objc" ], "projectName" : "Mediapipe", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 0c9dbcd99..be5a0aaf1 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -140,6 +140,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "graph_profile_calculator_proto", + srcs = ["graph_profile_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], @@ -419,6 +429,23 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "make_pair_calculator_test", + size = "small", + srcs = ["make_pair_calculator_test.cc"], + deps = [ + ":make_pair_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + "//mediapipe/util:packet_test_util", + "//mediapipe/util:time_series_test_util", + ], +) + cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], @@ -933,8 +960,8 @@ cc_test( ) cc_library( - name = "split_normalized_landmark_list_calculator", - srcs = ["split_normalized_landmark_list_calculator.cc"], + name = "split_landmarks_calculator", + srcs = ["split_landmarks_calculator.cc"], visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", @@ -948,10 +975,10 @@ cc_library( ) cc_test( - name = "split_normalized_landmark_list_calculator_test", - srcs = ["split_normalized_landmark_list_calculator_test.cc"], + name = "split_landmarks_calculator_test", + srcs = ["split_landmarks_calculator_test.cc"], deps = [ - ":split_normalized_landmark_list_calculator", + ":split_landmarks_calculator", ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -1183,3 +1210,45 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "graph_profile_calculator", + srcs = ["graph_profile_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_profile_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "graph_profile_calculator_test", + srcs = ["graph_profile_calculator_test.cc"], + deps = [ + ":graph_profile_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:test_calculators", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework/tool:simulation_clock_executor", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) diff --git a/mediapipe/calculators/core/add_header_calculator_test.cc b/mediapipe/calculators/core/add_header_calculator_test.cc index 4e197918d..bbe9bdd30 100644 --- a/mediapipe/calculators/core/add_header_calculator_test.cc +++ b/mediapipe/calculators/core/add_header_calculator_test.cc @@ -24,6 +24,9 @@ namespace mediapipe { +constexpr char kDataTag[] = "DATA"; +constexpr char kHeaderTag[] = "HEADER"; + class AddHeaderCalculatorTest : public ::testing::Test {}; TEST_F(AddHeaderCalculatorTest, HeaderStream) { @@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableInputs()->Tag("HEADER").header = + runner.MutableInputs()->Tag(kHeaderTag).header = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableInputs()->Tag("HEADER").header = + runner.MutableInputs()->Tag(kHeaderTag).header = Adopt(new std::string("my_header")); - runner.MutableInputs()->Tag("HEADER").packets.push_back( - Adopt(new std::string("not allowed"))); + runner.MutableInputs() + ->Tag(kHeaderTag) + .packets.push_back(Adopt(new std::string("not allowed"))); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) { CalculatorRunner runner(node); // Set both headers and add 5 packets. - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run should fail because header can only be provided one way. diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 1d0f7824d..7a5a9d5e9 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -42,4 +42,9 @@ REGISTER_CALCULATOR(BeginLoopDetectionCalculator); typedef BeginLoopCalculator> BeginLoopMatrixCalculator; REGISTER_CALCULATOR(BeginLoopMatrixCalculator); +// A calculator to process std::vector>. +typedef BeginLoopCalculator>> + BeginLoopMatrixVectorCalculator; +REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/counting_source_calculator.cc b/mediapipe/calculators/core/counting_source_calculator.cc index 0b731d9ce..fb75669e9 100644 --- a/mediapipe/calculators/core/counting_source_calculator.cc +++ b/mediapipe/calculators/core/counting_source_calculator.cc @@ -19,6 +19,13 @@ namespace mediapipe { +constexpr char kIncrementTag[] = "INCREMENT"; +constexpr char kInitialValueTag[] = "INITIAL_VALUE"; +constexpr char kBatchSizeTag[] = "BATCH_SIZE"; +constexpr char kErrorCountTag[] = "ERROR_COUNT"; +constexpr char kMaxCountTag[] = "MAX_COUNT"; +constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN"; + // Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of // sequential numbers from INITIAL_VALUE (default 0) with a common // difference of INCREMENT (default 1) between successive numbers (with @@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase { static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { - cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set(); + if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) { + cc->InputSidePackets().Tag(kErrorOnOpenTag).Set(); } - RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") || - cc->InputSidePackets().HasTag("ERROR_COUNT")); - if (cc->InputSidePackets().HasTag("MAX_COUNT")) { - cc->InputSidePackets().Tag("MAX_COUNT").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) || + cc->InputSidePackets().HasTag(kErrorCountTag)); + if (cc->InputSidePackets().HasTag(kMaxCountTag)) { + cc->InputSidePackets().Tag(kMaxCountTag).Set(); } - if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { - cc->InputSidePackets().Tag("ERROR_COUNT").Set(); + if (cc->InputSidePackets().HasTag(kErrorCountTag)) { + cc->InputSidePackets().Tag(kErrorCountTag).Set(); } - if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { - cc->InputSidePackets().Tag("BATCH_SIZE").Set(); + if (cc->InputSidePackets().HasTag(kBatchSizeTag)) { + cc->InputSidePackets().Tag(kBatchSizeTag).Set(); } - if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { - cc->InputSidePackets().Tag("INITIAL_VALUE").Set(); + if (cc->InputSidePackets().HasTag(kInitialValueTag)) { + cc->InputSidePackets().Tag(kInitialValueTag).Set(); } - if (cc->InputSidePackets().HasTag("INCREMENT")) { - cc->InputSidePackets().Tag("INCREMENT").Set(); + if (cc->InputSidePackets().HasTag(kIncrementTag)) { + cc->InputSidePackets().Tag(kIncrementTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { - if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && - cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { + if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) && + cc->InputSidePackets().Tag(kErrorOnOpenTag).Get()) { return absl::NotFoundError("expected error"); } - if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { - error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get(); + if (cc->InputSidePackets().HasTag(kErrorCountTag)) { + error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get(); RET_CHECK_LE(0, error_count_); } - if (cc->InputSidePackets().HasTag("MAX_COUNT")) { - max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get(); + if (cc->InputSidePackets().HasTag(kMaxCountTag)) { + max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get(); RET_CHECK_LE(0, max_count_); } - if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { - batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get(); + if (cc->InputSidePackets().HasTag(kBatchSizeTag)) { + batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get(); RET_CHECK_LT(0, batch_size_); } - if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { - counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get(); + if (cc->InputSidePackets().HasTag(kInitialValueTag)) { + counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get(); } - if (cc->InputSidePackets().HasTag("INCREMENT")) { - increment_ = cc->InputSidePackets().Tag("INCREMENT").Get(); + if (cc->InputSidePackets().HasTag(kIncrementTag)) { + increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get(); RET_CHECK_LT(0, increment_); } RET_CHECK(error_count_ >= 0 || max_count_ >= 0); diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc index 04a7e55a0..a8adefc63 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc @@ -35,11 +35,14 @@ // } namespace mediapipe { +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; +constexpr char kEncodedTag[] = "ENCODED"; + class DequantizeByteArrayCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("ENCODED").Set(); - cc->Outputs().Tag("FLOAT_VECTOR").Set>(); + cc->Inputs().Tag(kEncodedTag).Set(); + cc->Outputs().Tag(kFloatVectorTag).Set>(); return absl::OkStatus(); } @@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { const std::string& encoded = - cc->Inputs().Tag("ENCODED").Value().Get(); + cc->Inputs().Tag(kEncodedTag).Value().Get(); std::vector float_vector; float_vector.reserve(encoded.length()); for (int i = 0; i < encoded.length(); ++i) { @@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { static_cast(encoded.at(i)) * scalar_ + bias_); } cc->Outputs() - .Tag("FLOAT_VECTOR") + .Tag(kFloatVectorTag) .AddPacket(MakePacket>(float_vector) .At(cc->InputTimestamp())); return absl::OkStatus(); diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc index cf0a8dc15..81b9e5562 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; +constexpr char kEncodedTag[] = "ENCODED"; + TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -39,8 +42,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -64,8 +69,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -89,8 +96,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) { )pb"); CalculatorRunner runner(node_config); unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01}; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket( - std::string(reinterpret_cast(input), 4)) - .At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket( + std::string(reinterpret_cast(input), 4)) + .At(Timestamp(0))); auto status = runner.Run(); MP_ASSERT_OK(runner.Run()); const std::vector& outputs = - runner.Outputs().Tag("FLOAT_VECTOR").packets; + runner.Outputs().Tag(kFloatVectorTag).packets; EXPECT_EQ(1, outputs.size()); const std::vector& result = outputs[0].Get>(); ASSERT_FALSE(result.empty()); diff --git a/mediapipe/calculators/core/end_loop_calculator.cc b/mediapipe/calculators/core/end_loop_calculator.cc index fb02f7618..2a366f992 100644 --- a/mediapipe/calculators/core/end_loop_calculator.cc +++ b/mediapipe/calculators/core/end_loop_calculator.cc @@ -28,6 +28,10 @@ typedef EndLoopCalculator> EndLoopNormalizedRectCalculator; REGISTER_CALCULATOR(EndLoopNormalizedRectCalculator); +typedef EndLoopCalculator> + EndLoopLandmarkListVectorCalculator; +REGISTER_CALCULATOR(EndLoopLandmarkListVectorCalculator); + typedef EndLoopCalculator> EndLoopNormalizedLandmarkListVectorCalculator; REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator); diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index eba621ce3..b365121bc 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -24,6 +24,11 @@ namespace mediapipe { +constexpr char kFinishedTag[] = "FINISHED"; +constexpr char kAllowTag[] = "ALLOW"; +constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; +constexpr char kOptionsTag[] = "OPTIONS"; + // FlowLimiterCalculator is used to limit the number of frames in flight // by dropping input frames when necessary. // @@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { auto& side_inputs = cc->InputSidePackets(); - side_inputs.Tag("OPTIONS").Set().Optional(); - cc->Inputs().Tag("OPTIONS").Set().Optional(); + side_inputs.Tag(kOptionsTag).Set().Optional(); + cc->Inputs() + .Tag(kOptionsTag) + .Set() + .Optional(); RET_CHECK_GE(cc->Inputs().NumEntries(""), 1); for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); } cc->Inputs().Get("FINISHED", 0).SetAny(); - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set().Optional(); - cc->Outputs().Tag("ALLOW").Set().Optional(); + cc->InputSidePackets().Tag(kMaxInFlightTag).Set().Optional(); + cc->Outputs().Tag(kAllowTag).Set().Optional(); cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); return absl::OkStatus(); @@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { options_ = cc->Options(); options_ = tool::RetrieveOptions(options_, cc->InputSidePackets()); - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { options_.set_max_in_flight( - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get()); + cc->InputSidePackets().Tag(kMaxInFlightTag).Get()); } input_queues_.resize(cc->Inputs().NumEntries("")); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); @@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase { // Outputs a packet indicating whether a frame was sent or dropped. void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { - if (cc->Outputs().HasTag("ALLOW")) { - cc->Outputs().Tag("ALLOW").AddPacket(MakePacket(allow).At(ts)); + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); } } @@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase { options_ = tool::RetrieveOptions(options_, cc->Inputs()); // Process the FINISHED input stream. - Packet finished_packet = cc->Inputs().Tag("FINISHED").Value(); + Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value(); if (finished_packet.Timestamp() == cc->InputTimestamp()) { while (!frames_in_flight_.empty() && frames_in_flight_.front() <= finished_packet.Timestamp()) { @@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase { Timestamp bound = cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream(); SetNextTimestampBound(bound, &cc->Outputs().Get("", 0)); - if (cc->Outputs().HasTag("ALLOW")) { - SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW")); + if (cc->Outputs().HasTag(kAllowTag)) { + SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag)); } } diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index d2294dd48..962b1c81a 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -36,6 +36,13 @@ namespace mediapipe { namespace { + +constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS"; +constexpr char kClockTag[] = "CLOCK"; +constexpr char kWarmupTimeTag[] = "WARMUP_TIME"; +constexpr char kSleepTimeTag[] = "SLEEP_TIME"; +constexpr char kPacketTag[] = "PACKET"; + // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: @@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) { class SleepCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); - cc->InputSidePackets().Tag("SLEEP_TIME").Set(); - cc->InputSidePackets().Tag("WARMUP_TIME").Set(); - cc->InputSidePackets().Tag("CLOCK").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag)); + cc->InputSidePackets().Tag(kSleepTimeTag).Set(); + cc->InputSidePackets().Tag(kWarmupTimeTag).Set(); + cc->InputSidePackets().Tag(kClockTag).Set(); cc->SetTimestampOffset(0); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { - clock_ = cc->InputSidePackets().Tag("CLOCK").Get(); + clock_ = cc->InputSidePackets().Tag(kClockTag).Get(); return absl::OkStatus(); } @@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase { ++packet_count; absl::Duration sleep_time = absl::Microseconds( packet_count == 1 - ? cc->InputSidePackets().Tag("WARMUP_TIME").Get() - : cc->InputSidePackets().Tag("SLEEP_TIME").Get()); + ? cc->InputSidePackets().Tag(kWarmupTimeTag).Get() + : cc->InputSidePackets().Tag(kSleepTimeTag).Get()); clock_->Sleep(sleep_time); - cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + cc->Outputs() + .Tag(kPacketTag) + .AddPacket(cc->Inputs().Tag(kPacketTag).Value()); return absl::OkStatus(); } @@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator); class DropCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); - cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag)); + cc->InputSidePackets().Tag(kDropTimestampsTag).Set(); cc->SetProcessTimestampBounds(true); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { - if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) { + if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) { ++packet_count; } bool drop = (packet_count == 3); - if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) { - cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) { + cc->Outputs() + .Tag(kPacketTag) + .AddPacket(cc->Inputs().Tag(kPacketTag).Value()); } - if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get()) { - cc->Outputs().Tag("PACKET").SetNextTimestampBound( - cc->InputTimestamp().NextAllowedInStream()); + if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get()) { + cc->Outputs() + .Tag(kPacketTag) + .SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream()); } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index 189671860..8fdb9e0a3 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -21,6 +21,11 @@ namespace mediapipe { namespace { + +constexpr char kStateChangeTag[] = "STATE_CHANGE"; +constexpr char kDisallowTag[] = "DISALLOW"; +constexpr char kAllowTag[] = "ALLOW"; + enum GateState { GATE_UNINITIALIZED, GATE_ALLOW, @@ -59,8 +64,9 @@ std::string ToString(GateState state) { // ALLOW or DISALLOW can also be specified as an input side packet. The rules // for evaluation remain the same as above. // -// ALLOW/DISALLOW inputs must be specified either using input stream or -// via input side packet but not both. +// ALLOW/DISALLOW inputs must be specified either using input stream or via +// input side packet but not both. If neither is specified, the behavior is then +// determined by the "allow" field in the calculator options. // // Intended to be used with the default input stream handler, which synchronizes // all data input streams with the ALLOW/DISALLOW control input stream. @@ -83,30 +89,33 @@ class GateCalculator : public CalculatorBase { GateCalculator() {} static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) { - bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || - cc->InputSidePackets().HasTag("DISALLOW"); + bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) || + cc->InputSidePackets().HasTag(kDisallowTag); bool input_via_stream = - cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW"); - // Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW - // input. - RET_CHECK(input_via_side_packet ^ input_via_stream); + cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag); + // Only one of input_side_packet or input_stream may specify + // ALLOW/DISALLOW input. if (input_via_side_packet) { - RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^ - cc->InputSidePackets().HasTag("DISALLOW")); + RET_CHECK(!input_via_stream); + RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^ + cc->InputSidePackets().HasTag(kDisallowTag)); - if (cc->InputSidePackets().HasTag("ALLOW")) { - cc->InputSidePackets().Tag("ALLOW").Set(); + if (cc->InputSidePackets().HasTag(kAllowTag)) { + cc->InputSidePackets().Tag(kAllowTag).Set().Optional(); } else { - cc->InputSidePackets().Tag("DISALLOW").Set(); + cc->InputSidePackets().Tag(kDisallowTag).Set().Optional(); } - } else { - RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); + } + if (input_via_stream) { + RET_CHECK(!input_via_side_packet); + RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^ + cc->Inputs().HasTag(kDisallowTag)); - if (cc->Inputs().HasTag("ALLOW")) { - cc->Inputs().Tag("ALLOW").Set(); + if (cc->Inputs().HasTag(kAllowTag)) { + cc->Inputs().Tag(kAllowTag).Set(); } else { - cc->Inputs().Tag("DISALLOW").Set(); + cc->Inputs().Tag(kDisallowTag).Set(); } } return absl::OkStatus(); @@ -125,23 +134,22 @@ class GateCalculator : public CalculatorBase { cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); } - if (cc->Outputs().HasTag("STATE_CHANGE")) { - cc->Outputs().Tag("STATE_CHANGE").Set(); + if (cc->Outputs().HasTag(kStateChangeTag)) { + cc->Outputs().Tag(kStateChangeTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { - use_side_packet_for_allow_disallow_ = false; - if (cc->InputSidePackets().HasTag("ALLOW")) { + if (cc->InputSidePackets().HasTag(kAllowTag)) { use_side_packet_for_allow_disallow_ = true; allow_by_side_packet_decision_ = - cc->InputSidePackets().Tag("ALLOW").Get(); - } else if (cc->InputSidePackets().HasTag("DISALLOW")) { + cc->InputSidePackets().Tag(kAllowTag).Get(); + } else if (cc->InputSidePackets().HasTag(kDisallowTag)) { use_side_packet_for_allow_disallow_ = true; allow_by_side_packet_decision_ = - !cc->InputSidePackets().Tag("DISALLOW").Get(); + !cc->InputSidePackets().Tag(kDisallowTag).Get(); } cc->SetOffset(TimestampDiff(0)); @@ -152,26 +160,34 @@ class GateCalculator : public CalculatorBase { const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); empty_packets_as_allow_ = options.empty_packets_as_allow(); + if (!use_side_packet_for_allow_disallow_ && + !cc->Inputs().HasTag(kAllowTag) && !cc->Inputs().HasTag(kDisallowTag)) { + use_option_for_allow_disallow_ = true; + allow_by_option_decision_ = options.allow(); + } + return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { bool allow = empty_packets_as_allow_; - if (use_side_packet_for_allow_disallow_) { + if (use_option_for_allow_disallow_) { + allow = allow_by_option_decision_; + } else if (use_side_packet_for_allow_disallow_) { allow = allow_by_side_packet_decision_; } else { - if (cc->Inputs().HasTag("ALLOW") && - !cc->Inputs().Tag("ALLOW").IsEmpty()) { - allow = cc->Inputs().Tag("ALLOW").Get(); + if (cc->Inputs().HasTag(kAllowTag) && + !cc->Inputs().Tag(kAllowTag).IsEmpty()) { + allow = cc->Inputs().Tag(kAllowTag).Get(); } - if (cc->Inputs().HasTag("DISALLOW") && - !cc->Inputs().Tag("DISALLOW").IsEmpty()) { - allow = !cc->Inputs().Tag("DISALLOW").Get(); + if (cc->Inputs().HasTag(kDisallowTag) && + !cc->Inputs().Tag(kDisallowTag).IsEmpty()) { + allow = !cc->Inputs().Tag(kDisallowTag).Get(); } } const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; - if (cc->Outputs().HasTag("STATE_CHANGE")) { + if (cc->Outputs().HasTag(kStateChangeTag)) { if (last_gate_state_ != GATE_UNINITIALIZED && last_gate_state_ != new_gate_state) { VLOG(2) << "State transition in " << cc->NodeName() << " @ " @@ -179,7 +195,7 @@ class GateCalculator : public CalculatorBase { << ToString(last_gate_state_) << " to " << ToString(new_gate_state); cc->Outputs() - .Tag("STATE_CHANGE") + .Tag(kStateChangeTag) .AddPacket(MakePacket(allow).At(cc->InputTimestamp())); } } @@ -211,8 +227,10 @@ class GateCalculator : public CalculatorBase { GateState last_gate_state_ = GATE_UNINITIALIZED; int num_data_streams_; bool empty_packets_as_allow_; - bool use_side_packet_for_allow_disallow_; + bool use_side_packet_for_allow_disallow_ = false; bool allow_by_side_packet_decision_; + bool use_option_for_allow_disallow_ = false; + bool allow_by_option_decision_; }; REGISTER_CALCULATOR(GateCalculator); diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto index 76bacc74e..32402bf28 100644 --- a/mediapipe/calculators/core/gate_calculator.proto +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -29,4 +29,8 @@ message GateCalculatorOptions { // disallowing the corresponding packets in the data input streams. Setting // this option to true inverts that, allowing the data packets to go through. optional bool empty_packets_as_allow = 1; + + // Whether to allow or disallow the input streams to pass when no + // ALLOW/DISALLOW input or side input is specified. + optional bool allow = 2 [default = false]; } diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index 0b78b9b75..c523bce28 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -22,6 +22,9 @@ namespace mediapipe { namespace { +constexpr char kDisallowTag[] = "DISALLOW"; +constexpr char kAllowTag[] = "ALLOW"; + class GateCalculatorTest : public ::testing::Test { protected: // Helper to run a graph and return status. @@ -110,6 +113,68 @@ TEST_F(GateCalculatorTest, InvalidInputs) { )"))); } +TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(false, output[1].Get()); +} + +TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: false + } + } + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + +TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { SetRunner(R"( calculator: "GateCalculator" @@ -117,7 +182,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true)); + runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -139,7 +204,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false)); + runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -161,7 +226,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false)); + runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -179,7 +244,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true)); + runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); diff --git a/mediapipe/calculators/core/graph_profile_calculator.cc b/mediapipe/calculators/core/graph_profile_calculator.cc new file mode 100644 index 000000000..9b9aa3bb7 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator.cc @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/core/graph_profile_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_profile.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +// This calculator periodically copies the GraphProfile from +// mediapipe::GraphProfiler::CaptureProfile to the "PROFILE" output stream. +// +// Example config: +// node { +// calculator: "GraphProfileCalculator" +// output_stream: "FRAME:any_frame" +// output_stream: "PROFILE:graph_profile" +// } +// +class GraphProfileCalculator : public Node { + public: + static constexpr Input::Multiple kFrameIn{"FRAME"}; + static constexpr Output kProfileOut{"PROFILE"}; + + MEDIAPIPE_NODE_CONTRACT(kFrameIn, kProfileOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + auto options = cc->Options<::mediapipe::GraphProfileCalculatorOptions>(); + + if (prev_profile_ts_ == Timestamp::Unset() || + cc->InputTimestamp() - prev_profile_ts_ >= options.profile_interval()) { + prev_profile_ts_ = cc->InputTimestamp(); + GraphProfile result; + MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile(&result)); + kProfileOut(cc).Send(result); + } + return absl::OkStatus(); + } + + private: + Timestamp prev_profile_ts_; +}; + +MEDIAPIPE_REGISTER_NODE(GraphProfileCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/graph_profile_calculator.proto b/mediapipe/calculators/core/graph_profile_calculator.proto new file mode 100644 index 000000000..2bcc480c8 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator.proto @@ -0,0 +1,30 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option objc_class_prefix = "MediaPipe"; + +message GraphProfileCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional GraphProfileCalculatorOptions ext = 367481815; + } + + // The interval in microseconds between successive reported GraphProfiles. + optional int64 profile_interval = 1 [default = 1000000]; +} diff --git a/mediapipe/calculators/core/graph_profile_calculator_test.cc b/mediapipe/calculators/core/graph_profile_calculator_test.cc new file mode 100644 index 000000000..8a7845b19 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator_test.cc @@ -0,0 +1,211 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_profile.pb.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/port/threadpool.h" +#include "mediapipe/framework/tool/simulation_clock_executor.h" + +// Tests for GraphProfileCalculator. +using testing::ElementsAre; + +namespace mediapipe { +namespace { + +constexpr char kClockTag[] = "CLOCK"; + +using mediapipe::Clock; + +// A Calculator with a fixed Process call latency. +class SleepCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Tag(kClockTag).Set>(); + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->SetTimestampOffset(TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Open(CalculatorContext* cc) final { + clock_ = + cc->InputSidePackets().Tag(kClockTag).Get>(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + clock_->Sleep(absl::Milliseconds(5)); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return absl::OkStatus(); + } + std::shared_ptr<::mediapipe::Clock> clock_ = nullptr; +}; +REGISTER_CALCULATOR(SleepCalculator); + +// Tests showing GraphProfileCalculator reporting GraphProfile output packets. +class GraphProfileCalculatorTest : public ::testing::Test { + protected: + void SetUpProfileGraph() { + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: "input_packets_0" + node { + calculator: 'SleepCalculator' + input_side_packet: 'CLOCK:sync_clock' + input_stream: 'input_packets_0' + output_stream: 'output_packets_1' + } + node { + calculator: "GraphProfileCalculator" + options: { + [mediapipe.GraphProfileCalculatorOptions.ext]: { + profile_interval: 25000 + } + } + input_stream: "FRAME:output_packets_1" + output_stream: "PROFILE:output_packets_0" + } + )", + &graph_config_)); + } + + static Packet PacketAt(int64 ts) { + return Adopt(new int64(999)).At(Timestamp(ts)); + } + static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); } + static bool IsNone(const Packet& packet) { + return packet.Timestamp() == Timestamp::OneOverPostStream(); + } + // Return the values of the timestamps of a vector of Packets. + static std::vector TimestampValues( + const std::vector& packets) { + std::vector result; + for (const Packet& p : packets) { + result.push_back(p.Timestamp().Value()); + } + return result; + } + + // Runs a CalculatorGraph with a series of packet sets. + // Returns a vector of packets from each graph output stream. + void RunGraph(const std::vector>& input_sets, + std::vector* output_packets) { + // Register output packet observers. + tool::AddVectorSink("output_packets_0", &graph_config_, output_packets); + + // Start running the graph. + std::shared_ptr executor( + new SimulationClockExecutor(3 /*num_threads*/)); + CalculatorGraph graph; + MP_ASSERT_OK(graph.SetExecutor("", executor)); + graph.profiler()->SetClock(executor->GetClock()); + MP_ASSERT_OK(graph.Initialize(graph_config_)); + executor->GetClock()->ThreadStart(); + MP_ASSERT_OK(graph.StartRun({ + {"sync_clock", + Adopt(new std::shared_ptr<::mediapipe::Clock>(executor->GetClock()))}, + })); + + // Send each packet to the graph in the specified order. + for (int t = 0; t < input_sets.size(); t++) { + const std::vector& input_set = input_sets[t]; + for (int i = 0; i < input_set.size(); i++) { + const Packet& packet = input_set[i]; + if (!IsNone(packet)) { + MP_EXPECT_OK(graph.AddPacketToInputStream( + absl::StrCat("input_packets_", i), packet)); + } + executor->GetClock()->Sleep(absl::Milliseconds(10)); + } + } + MP_ASSERT_OK(graph.CloseAllInputStreams()); + executor->GetClock()->Sleep(absl::Milliseconds(100)); + executor->GetClock()->ThreadFinish(); + MP_ASSERT_OK(graph.WaitUntilDone()); + } + + CalculatorGraphConfig graph_config_; +}; + +TEST_F(GraphProfileCalculatorTest, GraphProfile) { + SetUpProfileGraph(); + auto profiler_config = graph_config_.mutable_profiler_config(); + profiler_config->set_enable_profiler(true); + profiler_config->set_trace_enabled(false); + profiler_config->set_trace_log_disabled(true); + profiler_config->set_enable_stream_latency(true); + profiler_config->set_calculator_filter(".*Calculator"); + + // Run the graph with a series of packet sets. + std::vector> input_sets = { + {PacketAt(10000)}, // + {PacketAt(20000)}, // + {PacketAt(30000)}, // + {PacketAt(40000)}, + }; + std::vector output_packets; + RunGraph(input_sets, &output_packets); + + // Validate the output packets. + EXPECT_THAT(TimestampValues(output_packets), // + ElementsAre(10000, 40000)); + + GraphProfile expected_profile = + mediapipe::ParseTextProtoOrDie(R"pb( + calculator_profiles { + name: "GraphProfileCalculator" + open_runtime: 0 + process_runtime { total: 0 count: 3 } + process_input_latency { total: 15000 count: 3 } + process_output_latency { total: 15000 count: 3 } + input_stream_profiles { + name: "output_packets_1" + back_edge: false + latency { total: 0 count: 3 } + } + } + calculator_profiles { + name: "SleepCalculator" + open_runtime: 0 + process_runtime { total: 15000 count: 3 } + process_input_latency { total: 0 count: 3 } + process_output_latency { total: 15000 count: 3 } + input_stream_profiles { + name: "input_packets_0" + back_edge: false + latency { total: 0 count: 3 } + } + })pb"); + + EXPECT_THAT(output_packets[1].Get(), + mediapipe::EqualsProto(expected_profile)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/make_pair_calculator_test.cc b/mediapipe/calculators/core/make_pair_calculator_test.cc new file mode 100644 index 000000000..ee3396697 --- /dev/null +++ b/mediapipe/calculators/core/make_pair_calculator_test.cc @@ -0,0 +1,70 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/validate_type.h" +#include "mediapipe/util/packet_test_util.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { + +class MakePairCalculatorTest + : public mediapipe::TimeSeriesCalculatorTest { + protected: + void SetUp() override { + calculator_name_ = "MakePairCalculator"; + num_input_streams_ = 2; + } +}; + +TEST_F(MakePairCalculatorTest, ProducesExpectedPairs) { + InitializeGraph(); + AppendInputPacket(new std::string("first packet"), Timestamp(1), + /* input_index= */ 0); + AppendInputPacket(new std::string("second packet"), Timestamp(5), + /* input_index= */ 0); + AppendInputPacket(new int(10), Timestamp(1), /* input_index= */ 1); + AppendInputPacket(new int(20), Timestamp(5), /* input_index= */ 1); + + MP_ASSERT_OK(RunGraph()); + + EXPECT_THAT( + output().packets, + ::testing::ElementsAre( + mediapipe::PacketContainsTimestampAndPayload< + std::pair>( + Timestamp(1), + ::testing::Pair( + mediapipe::PacketContainsTimestampAndPayload( + Timestamp(1), std::string("first packet")), + mediapipe::PacketContainsTimestampAndPayload( + Timestamp(1), 10))), + mediapipe::PacketContainsTimestampAndPayload< + std::pair>( + Timestamp(5), + ::testing::Pair( + mediapipe::PacketContainsTimestampAndPayload( + Timestamp(5), std::string("second packet")), + mediapipe::PacketContainsTimestampAndPayload( + Timestamp(5), 20))))); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc index 0bbf94dc8..45a0e1cd3 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc @@ -29,6 +29,9 @@ namespace mediapipe { namespace { +constexpr char kMinuendTag[] = "MINUEND"; +constexpr char kSubtrahendTag[] = "SUBTRAHEND"; + // A 3x4 Matrix of random integers in [0,1000). const char kMatrixText[] = "rows: 3\n" @@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) { CalculatorRunner runner(node_config); Matrix* side_matrix = new Matrix(); MatrixFromTextProto(kMatrixText, side_matrix); - runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix); + runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix); Matrix* input_matrix = new Matrix(); MatrixFromTextProto(kMatrixText2, input_matrix); - runner.MutableInputs()->Tag("MINUEND").packets.push_back( - Adopt(input_matrix).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kMinuendTag) + .packets.push_back(Adopt(input_matrix).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); EXPECT_EQ(1, runner.Outputs().Index(0).packets.size()); @@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) { CalculatorRunner runner(node_config); Matrix* side_matrix = new Matrix(); MatrixFromTextProto(kMatrixText, side_matrix); - runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix); + runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix); Matrix* input_matrix = new Matrix(); MatrixFromTextProto(kMatrixText2, input_matrix); runner.MutableInputs() - ->Tag("SUBTRAHEND") + ->Tag(kSubtrahendTag) .packets.push_back(Adopt(input_matrix).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); diff --git a/mediapipe/calculators/core/packet_presence_calculator.cc b/mediapipe/calculators/core/packet_presence_calculator.cc index cb119a76d..ac723ad8a 100644 --- a/mediapipe/calculators/core/packet_presence_calculator.cc +++ b/mediapipe/calculators/core/packet_presence_calculator.cc @@ -17,6 +17,9 @@ namespace mediapipe { +constexpr char kPresenceTag[] = "PRESENCE"; +constexpr char kPacketTag[] = "PACKET"; + // For each non empty input packet, emits a single output packet containing a // boolean value "true", "false" in response to empty packets (a.k.a. timestamp // bound updates) This can be used to "flag" the presence of an arbitrary packet @@ -58,8 +61,8 @@ namespace mediapipe { class PacketPresenceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PRESENCE").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPresenceTag).Set(); // Process() function is invoked in response to input stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); @@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { cc->Outputs() - .Tag("PRESENCE") - .AddPacket(MakePacket(!cc->Inputs().Tag("PACKET").IsEmpty()) + .Tag(kPresenceTag) + .AddPacket(MakePacket(!cc->Inputs().Tag(kPacketTag).IsEmpty()) .At(cc->InputTimestamp())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 43253520a..81ccdbe65 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -39,6 +39,11 @@ namespace mediapipe { REGISTER_CALCULATOR(PacketResamplerCalculator); namespace { + +constexpr char kSeedTag[] = "SEED"; +constexpr char kVideoHeaderTag[] = "VIDEO_HEADER"; +constexpr char kOptionsTag[] = "OPTIONS"; + // Returns a TimestampDiff (assuming microseconds) corresponding to the // given time in seconds. TimestampDiff TimestampDiffFromSeconds(double seconds) { @@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { const auto& resampler_options = cc->Options(); - if (cc->InputSidePackets().HasTag("OPTIONS")) { - cc->InputSidePackets().Tag("OPTIONS").Set(); + if (cc->InputSidePackets().HasTag(kOptionsTag)) { + cc->InputSidePackets().Tag(kOptionsTag).Set(); } CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0); if (!input_data_id.IsValid()) { input_data_id = cc->Inputs().GetId("", 0); } cc->Inputs().Get(input_data_id).SetAny(); - if (cc->Inputs().HasTag("VIDEO_HEADER")) { - cc->Inputs().Tag("VIDEO_HEADER").Set(); + if (cc->Inputs().HasTag(kVideoHeaderTag)) { + cc->Inputs().Tag(kVideoHeaderTag).Set(); } CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0); @@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { output_data_id = cc->Outputs().GetId("", 0); } cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id)); - if (cc->Outputs().HasTag("VIDEO_HEADER")) { - cc->Outputs().Tag("VIDEO_HEADER").Set(); + if (cc->Outputs().HasTag(kVideoHeaderTag)) { + cc->Outputs().Tag(kVideoHeaderTag).Set(); } if (resampler_options.jitter() != 0.0) { RET_CHECK_GT(resampler_options.jitter(), 0.0); RET_CHECK_LE(resampler_options.jitter(), 1.0); - RET_CHECK(cc->InputSidePackets().HasTag("SEED")); - cc->InputSidePackets().Tag("SEED").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag)); + cc->InputSidePackets().Tag(kSeedTag).Set(); } return absl::OkStatus(); } @@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream() && - cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && - !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { - video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); + cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) && + !cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) { + video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get(); video_header_.frame_rate = frame_rate_; if (cc->Inputs().Get(input_data_id_).IsEmpty()) { return absl::OkStatus(); @@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) { "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open( "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) { "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) { base_timestamp_ + TimestampDiffFromSeconds(first_index / calculator_->frame_rate_); } - if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { + if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) { cc->Outputs() - .Tag("VIDEO_HEADER") + .Tag(kVideoHeaderTag) .Add(new VideoHeader(calculator_->video_header_), Timestamp::PreStream()); } diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index 191e1d842..f02da0d18 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -32,6 +32,12 @@ namespace mediapipe { using ::testing::ElementsAre; namespace { + +constexpr char kOptionsTag[] = "OPTIONS"; +constexpr char kSeedTag[] = "SEED"; +constexpr char kVideoHeaderTag[] = "VIDEO_HEADER"; +constexpr char kDataTag[] = "DATA"; + // A simple version of CalculatorRunner with built-in convenience // methods for setting inputs from a vector and checking outputs // against expected outputs (both timestamps and contents). @@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { )pb")); for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { - runner.MutableInputs()->Tag("DATA").packets.push_back( + runner.MutableInputs()->Tag(kDataTag).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); } VideoHeader video_header_in; @@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { video_header_in.duration = 1.0; video_header_in.format = ImageFormat::SRGB; runner.MutableInputs() - ->Tag("VIDEO_HEADER") + ->Tag(kVideoHeaderTag) .packets.push_back( Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); MP_ASSERT_OK(runner.Run()); - ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size()); + ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size()); EXPECT_EQ(Timestamp::PreStream(), - runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp()); + runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp()); const VideoHeader& video_header_out = - runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get(); + runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get(); EXPECT_EQ(video_header_in.width, video_header_out.width); EXPECT_EQ(video_header_in.height, video_header_out.height); EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate); @@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) { [mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 30 })pb")); - runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options); runner.SetInput({-222, 15000, 32000, 49999, 150000}); MP_ASSERT_OK(runner.Run()); EXPECT_EQ(6, runner.Outputs().Index(0).packets.size()); @@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) { frame_rate: 30 base_timestamp: 0 })pb")); - runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options); runner.SetInput({-222, 15000, 32000, 49999, 150000}); MP_ASSERT_OK(runner.Run()); diff --git a/mediapipe/calculators/core/packet_thinner_calculator.cc b/mediapipe/calculators/core/packet_thinner_calculator.cc index d3d391b61..1d94d886b 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator.cc @@ -217,6 +217,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { header->format = video_header.format; header->width = video_header.width; header->height = video_header.height; + header->duration = video_header.duration; header->frame_rate = new_frame_rate; cc->Outputs().Index(0).SetHeader(Adopt(header.release())); } else { diff --git a/mediapipe/calculators/core/packet_thinner_calculator_test.cc b/mediapipe/calculators/core/packet_thinner_calculator_test.cc index 86fcc00f9..3522488e7 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator_test.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator_test.cc @@ -29,6 +29,8 @@ namespace mediapipe { namespace { +constexpr char kPeriodTag[] = "PERIOD"; + // A simple version of CalculatorRunner with built-in convenience methods for // setting inputs from a vector and checking outputs against a vector of // expected outputs. @@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) { SimpleRunner runner(node); runner.SetInput({2, 4, 6, 8, 10, 12, 14}); - runner.MutableSidePackets()->Tag("PERIOD") = MakePacket(5); + runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket(5); MP_ASSERT_OK(runner.Run()); const std::vector expected_timestamps = {2, 8, 14}; @@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) { SimpleRunner runner(node); runner.SetInput({2, 4, 6, 8, 10, 12, 14}); - runner.MutableSidePackets()->Tag("PERIOD") = MakePacket(5); + runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket(5); MP_ASSERT_OK(runner.Run()); const std::vector expected_timestamps = {2, 6, 10, 14}; diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index c9d431d1c..563417669 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -39,6 +39,8 @@ using ::testing::Pair; using ::testing::Value; namespace { +constexpr char kDisallowTag[] = "DISALLOW"; + // Returns the timestamp values for a vector of Packets. // TODO: puth this kind of test util in a common place. std::vector TimestampValues(const std::vector& packets) { @@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); - cc->Inputs().Tag("DISALLOW").Set(); + cc->Inputs().Tag(kDisallowTag).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { if (!cc->Inputs().Index(0).IsEmpty() && - !cc->Inputs().Tag("DISALLOW").Get()) { + !cc->Inputs().Tag(kDisallowTag).Get()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc index e95509298..e1df66c1a 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -41,11 +41,14 @@ // } namespace mediapipe { +constexpr char kEncodedTag[] = "ENCODED"; +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; + class QuantizeFloatVectorCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("FLOAT_VECTOR").Set>(); - cc->Outputs().Tag("ENCODED").Set(); + cc->Inputs().Tag(kFloatVectorTag).Set>(); + cc->Outputs().Tag(kEncodedTag).Set(); return absl::OkStatus(); } @@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { const std::vector& float_vector = - cc->Inputs().Tag("FLOAT_VECTOR").Value().Get>(); + cc->Inputs().Tag(kFloatVectorTag).Value().Get>(); int feature_size = float_vector.size(); std::string encoded_features; encoded_features.reserve(feature_size); @@ -86,8 +89,10 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { (old_value - min_quantized_value_) * (255.0 / range_)); encoded_features += encoded; } - cc->Outputs().Tag("ENCODED").AddPacket( - MakePacket(encoded_features).At(cc->InputTimestamp())); + cc->Outputs() + .Tag(kEncodedTag) + .AddPacket( + MakePacket(encoded_features).At(cc->InputTimestamp())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc index 8f23437b6..a3a410565 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kEncodedTag[] = "ENCODED"; +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; + TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].Get().empty()); EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); @@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) { CalculatorRunner runner(node_config); std::vector vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f}; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); const std::string& result = outputs[0].Get(); ASSERT_FALSE(result.empty()); @@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) { CalculatorRunner runner(node_config); std::vector vector = {-65.0f, 65.0f}; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); const std::string& result = outputs[0].Get(); ASSERT_FALSE(result.empty()); diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index 277f83fe2..ef3cb9896 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -23,6 +23,9 @@ namespace mediapipe { +constexpr char kAllowTag[] = "ALLOW"; +constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; + // RealTimeFlowLimiterCalculator is used to limit the number of pipelined // processing operations in a section of the graph. // @@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); } cc->Inputs().Get("FINISHED", 0).SetAny(); - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set(); + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { + cc->InputSidePackets().Tag(kMaxInFlightTag).Set(); } - if (cc->Outputs().HasTag("ALLOW")) { - cc->Outputs().Tag("ALLOW").Set(); + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).Set(); } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); @@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { finished_id_ = cc->Inputs().GetId("FINISHED", 0); max_in_flight_ = 1; - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get(); + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { + max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get(); } RET_CHECK_GE(max_in_flight_, 1); num_in_flight_ = 0; diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc index fe4785860..7fddd7fdf 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc @@ -33,6 +33,9 @@ namespace mediapipe { namespace { + +constexpr char kFinishedTag[] = "FINISHED"; + // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: @@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) { Timestamp timestamp = Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); runner.MutableInputs() - ->Tag("FINISHED") + ->Tag(kFinishedTag) .packets.push_back(MakePacket(true).At(timestamp)); } diff --git a/mediapipe/calculators/core/sequence_shift_calculator_test.cc b/mediapipe/calculators/core/sequence_shift_calculator_test.cc index 23ad57225..8c749904c 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator_test.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator_test.cc @@ -22,6 +22,8 @@ namespace mediapipe { namespace { +constexpr char kPacketOffsetTag[] = "PACKET_OFFSET"; + // Adds packets containing integers equal to their original timestamp. void AddPackets(CalculatorRunner* runner) { for (int i = 0; i < 10; ++i) { @@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) { CalculatorRunner runner(node); AddPackets(&runner); - runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2)); + runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2)); MP_ASSERT_OK(runner.Run()); const std::vector& input_packets = runner.MutableInputs()->Index(0).packets; diff --git a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/split_landmarks_calculator.cc similarity index 75% rename from mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc rename to mediapipe/calculators/core/split_landmarks_calculator.cc index d57cebe9c..5bc876bf6 100644 --- a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/split_landmarks_calculator.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_NORMALIZED_LANDMARK_LIST_CALCULATOR_H_ // NOLINT -#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_NORMALIZED_LANDMARK_LIST_CALCULATOR_H_ // NOLINT +#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ // NOLINT +#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ // NOLINT #include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,29 +24,30 @@ namespace mediapipe { -// Splits an input packet with NormalizedLandmarkList into -// multiple NormalizedLandmarkList output packets using the [begin, end) ranges +// Splits an input packet with LandmarkListType into +// multiple LandmarkListType output packets using the [begin, end) ranges // specified in SplitVectorCalculatorOptions. If the option "element_only" is // set to true, all ranges should be of size 1 and all outputs will be elements -// of type NormalizedLandmark. If "element_only" is false, ranges can be -// non-zero in size and all outputs will be of type NormalizedLandmarkList. +// of type LandmarkType. If "element_only" is false, ranges can be +// non-zero in size and all outputs will be of type LandmarkListType. // If the option "combine_outputs" is set to true, only one output stream can be // specified and all ranges of elements will be combined into one -// NormalizedLandmarkList. -class SplitNormalizedLandmarkListCalculator : public CalculatorBase { +// LandmarkListType. +template +class SplitLandmarksCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); - cc->Inputs().Index(0).Set(); + cc->Inputs().Index(0).Set(); const auto& options = cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); if (options.combine_outputs()) { RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - cc->Outputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); for (int i = 0; i < options.ranges_size() - 1; ++i) { for (int j = i + 1; j < options.ranges_size(); ++j) { const auto& range_0 = options.ranges(i); @@ -81,9 +82,9 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } - cc->Outputs().Index(i).Set(); + cc->Outputs().Index(i).Set(); } else { - cc->Outputs().Index(i).Set(); + cc->Outputs().Index(i).Set(); } } } @@ -110,40 +111,39 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) override { - const NormalizedLandmarkList& input = - cc->Inputs().Index(0).Get(); + const LandmarkListType& input = + cc->Inputs().Index(0).Get(); RET_CHECK_GE(input.landmark_size(), max_range_end_) << "Max range end " << max_range_end_ << " exceeds landmarks size " << input.landmark_size(); if (combine_outputs_) { - NormalizedLandmarkList output; + LandmarkListType output; for (int i = 0; i < ranges_.size(); ++i) { for (int j = ranges_[i].first; j < ranges_[i].second; ++j) { - const NormalizedLandmark& input_landmark = input.landmark(j); + const LandmarkType& input_landmark = input.landmark(j); *output.add_landmark() = input_landmark; } } RET_CHECK_EQ(output.landmark_size(), total_elements_); cc->Outputs().Index(0).AddPacket( - MakePacket(output).At(cc->InputTimestamp())); + MakePacket(output).At(cc->InputTimestamp())); } else { if (element_only_) { for (int i = 0; i < ranges_.size(); ++i) { cc->Outputs().Index(i).AddPacket( - MakePacket(input.landmark(ranges_[i].first)) + MakePacket(input.landmark(ranges_[i].first)) .At(cc->InputTimestamp())); } } else { for (int i = 0; i < ranges_.size(); ++i) { - NormalizedLandmarkList output; + LandmarkListType output; for (int j = ranges_[i].first; j < ranges_[i].second; ++j) { - const NormalizedLandmark& input_landmark = input.landmark(j); + const LandmarkType& input_landmark = input.landmark(j); *output.add_landmark() = input_landmark; } cc->Outputs().Index(i).AddPacket( - MakePacket(output).At( - cc->InputTimestamp())); + MakePacket(output).At(cc->InputTimestamp())); } } } @@ -159,9 +159,15 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { bool combine_outputs_ = false; }; +typedef SplitLandmarksCalculator + SplitNormalizedLandmarkListCalculator; REGISTER_CALCULATOR(SplitNormalizedLandmarkListCalculator); +typedef SplitLandmarksCalculator + SplitLandmarkListCalculator; +REGISTER_CALCULATOR(SplitLandmarkListCalculator); + } // namespace mediapipe // NOLINTNEXTLINE -#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_NORMALIZED_LANDMARK_LIST_CALCULATOR_H_ +#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc b/mediapipe/calculators/core/split_landmarks_calculator_test.cc similarity index 100% rename from mediapipe/calculators/core/split_normalized_landmark_list_calculator_test.cc rename to mediapipe/calculators/core/split_landmarks_calculator_test.cc diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 39f81c046..0bbfadd05 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -80,6 +80,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "segmentation_smoothing_calculator_proto", + srcs = ["segmentation_smoothing_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], @@ -602,3 +612,187 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", ], ) + +cc_library( + name = "segmentation_smoothing_calculator", + srcs = ["segmentation_smoothing_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":segmentation_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gl_quad_renderer", + "//mediapipe/gpu:shader_util", + ], + }), + alwayslink = 1, +) + +cc_test( + name = "segmentation_smoothing_calculator_test", + srcs = ["segmentation_smoothing_calculator_test.cc"], + deps = [ + ":image_clone_calculator", + ":image_clone_calculator_cc_proto", + ":segmentation_smoothing_calculator", + ":segmentation_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +cc_library( + name = "affine_transformation", + hdrs = ["affine_transformation.h"], + deps = ["@com_google_absl//absl/status:statusor"], +) + +cc_library( + name = "affine_transformation_runner_gl", + srcs = ["affine_transformation_runner_gl.cc"], + hdrs = ["affine_transformation_runner_gl.h"], + deps = [ + ":affine_transformation", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/gpu:shader_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + ], +) + +cc_library( + name = "affine_transformation_runner_opencv", + srcs = ["affine_transformation_runner_opencv.cc"], + hdrs = ["affine_transformation_runner_opencv.h"], + deps = [ + ":affine_transformation", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + ], +) + +mediapipe_proto_library( + name = "warp_affine_calculator_proto", + srcs = ["warp_affine_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:gpu_origin_proto", + ], +) + +cc_library( + name = "warp_affine_calculator", + srcs = ["warp_affine_calculator.cc"], + hdrs = ["warp_affine_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":affine_transformation", + ":affine_transformation_runner_opencv", + ":warp_affine_calculator_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + ":affine_transformation_runner_gl", + ], + }), + alwayslink = 1, +) + +cc_test( + name = "warp_affine_calculator_test", + srcs = ["warp_affine_calculator_test.cc"], + data = [ + "//mediapipe/calculators/tensor:testdata/image_to_tensor/input.jpg", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", + ], + tags = ["desktop_only_test"], + deps = [ + ":affine_transformation", + ":warp_affine_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_converter", + "//mediapipe/calculators/tensor:image_to_tensor_utils", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/calculators/image/affine_transformation.h b/mediapipe/calculators/image/affine_transformation.h new file mode 100644 index 000000000..40793e7a1 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation.h @@ -0,0 +1,55 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ + +#include + +#include "absl/status/statusor.h" + +namespace mediapipe { + +class AffineTransformation { + public: + // Pixel extrapolation method. + // When converting image to tensor it may happen that tensor needs to read + // pixels outside image boundaries. Border mode helps to specify how such + // pixels will be calculated. + enum class BorderMode { kZero, kReplicate }; + + struct Size { + int width; + int height; + }; + + template + class Runner { + public: + virtual ~Runner() = default; + + // Transforms input into output using @matrix as following: + // output(x, y) = input(matrix[0] * x + matrix[1] * y + matrix[3], + // matrix[4] * x + matrix[5] * y + matrix[7]) + // where x and y ranges are defined by @output_size. + virtual absl::StatusOr Run(const InputT& input, + const std::array& matrix, + const Size& output_size, + BorderMode border_mode) = 0; + }; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc new file mode 100644 index 000000000..c38fc8e07 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -0,0 +1,354 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/affine_transformation_runner_gl.h" + +#include +#include + +#include "Eigen/Core" +#include "Eigen/Geometry" +#include "Eigen/LU" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_origin.pb.h" +#include "mediapipe/gpu/shader_util.h" + +namespace mediapipe { + +namespace { + +using mediapipe::GlCalculatorHelper; +using mediapipe::GlhCreateProgram; +using mediapipe::GlTexture; +using mediapipe::GpuBuffer; +using mediapipe::GpuOrigin; + +bool IsMatrixVerticalFlipNeeded(GpuOrigin::Mode gpu_origin) { + switch (gpu_origin) { + case GpuOrigin::DEFAULT: + case GpuOrigin::CONVENTIONAL: +#ifdef __APPLE__ + return false; +#else + return true; +#endif // __APPLE__ + case GpuOrigin::TOP_LEFT: + return false; + } +} + +#ifdef __APPLE__ +#define GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED 0 +#else +#define GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED 1 +#endif // __APPLE__ + +bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context) { + return gl_context.gl_major_version() > 3 || + (gl_context.gl_major_version() == 3 && + gl_context.gl_minor_version() >= 2); +} + +constexpr int kAttribVertex = 0; +constexpr int kAttribTexturePosition = 1; +constexpr int kNumAttributes = 2; + +class GlTextureWarpAffineRunner + : public AffineTransformation::Runner> { + public: + GlTextureWarpAffineRunner(std::shared_ptr gl_helper, + GpuOrigin::Mode gpu_origin) + : gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} + absl::Status Init() { + return gl_helper_->RunInGlContext([this]() -> absl::Status { + const GLint attr_location[kNumAttributes] = { + kAttribVertex, + kAttribTexturePosition, + }; + const GLchar* attr_name[kNumAttributes] = { + "position", + "texture_coordinate", + }; + + constexpr GLchar kVertShader[] = R"( + in vec4 position; + in mediump vec4 texture_coordinate; + out mediump vec2 sample_coordinate; + uniform mat4 transform_matrix; + + void main() { + gl_Position = position; + vec4 tc = transform_matrix * texture_coordinate; + sample_coordinate = tc.xy; + } + )"; + + constexpr GLchar kFragShader[] = R"( + DEFAULT_PRECISION(mediump, float) + in vec2 sample_coordinate; + uniform sampler2D input_texture; + + #ifdef GL_ES + #define fragColor gl_FragColor + #else + out vec4 fragColor; + #endif // defined(GL_ES); + + void main() { + vec4 color = texture2D(input_texture, sample_coordinate); + #ifdef CUSTOM_ZERO_BORDER_MODE + float out_of_bounds = + float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || + sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); + color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); + #endif // defined(CUSTOM_ZERO_BORDER_MODE) + fragColor = color; + } + )"; + + // Create program and set parameters. + auto create_fn = [&](const std::string& vs, + const std::string& fs) -> absl::StatusOr { + GLuint program = 0; + GlhCreateProgram(vs.c_str(), fs.c_str(), kNumAttributes, &attr_name[0], + attr_location, &program); + + RET_CHECK(program) << "Problem initializing warp affine program."; + glUseProgram(program); + glUniform1i(glGetUniformLocation(program, "input_texture"), 1); + GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); + return Program{.id = program, .matrix_id = matrix_id}; + }; + + const std::string vert_src = + absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); + + const std::string frag_src = absl::StrCat( + mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); + + ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); + + auto create_custom_zero_fn = [&]() -> absl::StatusOr { + std::string custom_zero_border_mode_def = R"( + #define CUSTOM_ZERO_BORDER_MODE + )"; + const std::string frag_custom_zero_src = + absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, + custom_zero_border_mode_def, kFragShader); + return create_fn(vert_src, frag_custom_zero_src); + }; +#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + if (!IsGlClampToBorderSupported(gl_helper_->GetGlContext())) { + ASSIGN_OR_RETURN(program_custom_zero_, create_custom_zero_fn()); + } +#else + ASSIGN_OR_RETURN(program_custom_zero_, create_custom_zero_fn()); +#endif // GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + + glGenFramebuffers(1, &framebuffer_); + + // vertex storage + glGenBuffers(2, vbo_); + glGenVertexArrays(1, &vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicSquareVertices), + mediapipe::kBasicSquareVertices, GL_STATIC_DRAW); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicTextureVertices), + mediapipe::kBasicTextureVertices, GL_STATIC_DRAW); + + glBindBuffer(GL_ARRAY_BUFFER, 0); + + return absl::OkStatus(); + }); + } + + absl::StatusOr> Run( + const GpuBuffer& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + std::unique_ptr gpu_buffer; + MP_RETURN_IF_ERROR( + gl_helper_->RunInGlContext([this, &input, &matrix, &size, &border_mode, + &gpu_buffer]() -> absl::Status { + auto input_texture = gl_helper_->CreateSourceTexture(input); + auto output_texture = gl_helper_->CreateDestinationTexture( + size.width, size.height, input.format()); + + MP_RETURN_IF_ERROR( + RunInternal(input_texture, matrix, border_mode, &output_texture)); + gpu_buffer = output_texture.GetFrame(); + return absl::OkStatus(); + })); + + return gpu_buffer; + } + + absl::Status RunInternal(const GlTexture& texture, + const std::array& matrix, + AffineTransformation::BorderMode border_mode, + GlTexture* output) { + glDisable(GL_DEPTH_TEST); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, output->width(), output->height()); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, output->name()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + output->name(), 0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(texture.target(), texture.name()); + + // a) Filtering. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + + // b) Clamping. + std::optional program = program_; + switch (border_mode) { + case AffineTransformation::BorderMode::kReplicate: { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + break; + } + case AffineTransformation::BorderMode::kZero: { +#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + if (program_custom_zero_) { + program = program_custom_zero_; + } else { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_BORDER); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_BORDER); + glTexParameterfv(GL_TEXTURE_2D, GL_TEXTURE_BORDER_COLOR, + std::array{0.0f, 0.0f, 0.0f, 0.0f}.data()); + } +#else + RET_CHECK(program_custom_zero_) + << "Program must have been initialized."; + program = program_custom_zero_; +#endif // GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + break; + } + } + glUseProgram(program->id); + + Eigen::Matrix eigen_mat(matrix.data()); + if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { + // @matrix describes affine transformation in terms of TOP LEFT origin, so + // in some cases/on some platforms an extra flipping should be done before + // and after. + const Eigen::Matrix flip_y( + {{1.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, -1.0f, 0.0f, 1.0f}, + {0.0f, 0.0f, 1.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f}}); + eigen_mat = flip_y * eigen_mat * flip_y; + } + + // If GL context is ES2, then GL_FALSE must be used for 'transpose' + // GLboolean in glUniformMatrix4fv, or else INVALID_VALUE error is reported. + // Hence, transposing the matrix and always passing transposed. + eigen_mat.transposeInPlace(); + glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); + + // vao + glBindVertexArray(vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glEnableVertexAttribArray(kAttribVertex); + glVertexAttribPointer(kAttribVertex, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glEnableVertexAttribArray(kAttribTexturePosition); + glVertexAttribPointer(kAttribTexturePosition, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // Resetting to MediaPipe texture param defaults. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + + glDisableVertexAttribArray(kAttribVertex); + glDisableVertexAttribArray(kAttribTexturePosition); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, 0); + + return absl::OkStatus(); + } + + ~GlTextureWarpAffineRunner() override { + gl_helper_->RunInGlContext([this]() { + // Release OpenGL resources. + if (framebuffer_ != 0) glDeleteFramebuffers(1, &framebuffer_); + if (program_.id != 0) glDeleteProgram(program_.id); + if (program_custom_zero_ && program_custom_zero_->id != 0) { + glDeleteProgram(program_custom_zero_->id); + } + if (vao_ != 0) glDeleteVertexArrays(1, &vao_); + glDeleteBuffers(2, vbo_); + }); + } + + private: + struct Program { + GLuint id; + GLint matrix_id; + }; + std::shared_ptr gl_helper_; + GpuOrigin::Mode gpu_origin_; + GLuint vao_ = 0; + GLuint vbo_[2] = {0, 0}; + Program program_; + std::optional program_custom_zero_; + GLuint framebuffer_ = 0; +}; + +#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + +} // namespace + +absl::StatusOr>>> +CreateAffineTransformationGlRunner( + std::shared_ptr gl_helper, GpuOrigin::Mode gpu_origin) { + auto runner = + absl::make_unique(gl_helper, gpu_origin); + MP_RETURN_IF_ERROR(runner->Init()); + return runner; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.h b/mediapipe/calculators/image/affine_transformation_runner_gl.h new file mode 100644 index 000000000..677e0720d --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.h @@ -0,0 +1,36 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_origin.pb.h" + +namespace mediapipe { + +absl::StatusOr>>> +CreateAffineTransformationGlRunner( + std::shared_ptr gl_helper, + mediapipe::GpuOrigin::Mode gpu_origin); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.cc b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc new file mode 100644 index 000000000..46026a987 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc @@ -0,0 +1,160 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +cv::BorderTypes GetBorderModeForOpenCv( + AffineTransformation::BorderMode border_mode) { + switch (border_mode) { + case AffineTransformation::BorderMode::kZero: + return cv::BORDER_CONSTANT; + case AffineTransformation::BorderMode::kReplicate: + return cv::BORDER_REPLICATE; + } +} + +class OpenCvRunner + : public AffineTransformation::Runner { + public: + absl::StatusOr Run( + const ImageFrame& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + // OpenCV warpAffine works in absolute coordinates, so the transfom (which + // accepts and produces relative coordinates) should be adjusted to first + // normalize coordinates and then scale them. + // clang-format off + cv::Matx44f normalize_dst_coordinate({ + 1.0f / size.width, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f / size.height, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + cv::Matx44f scale_src_coordinate({ + 1.0f * input.Width(), 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f * input.Height(), 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + // clang-format on + cv::Matx44f adjust_dst_coordinate; + cv::Matx44f adjust_src_coordinate; + // TODO: update to always use accurate implementation. + constexpr bool kOpenCvCompatibility = true; + if (kOpenCvCompatibility) { + adjust_dst_coordinate = normalize_dst_coordinate; + adjust_src_coordinate = scale_src_coordinate; + } else { + // To do an accurate affine image transformation and make "on-cpu" and + // "on-gpu" calculations aligned - extra offset is required to select + // correct pixels. + // + // Each destination pixel corresponds to some pixels region from source + // image.(In case of downscaling there can be more than one pixel.) The + // offset for x and y is calculated in the way, so pixel in the middle of + // the region is selected. + // + // For simplicity sake, let's consider downscaling from 100x50 to 10x10 + // without a rotation: + // 1. Each destination pixel corresponds to 10x5 region + // X range: [0, .. , 9] + // Y range: [0, .. , 4] + // 2. Considering we have __discrete__ pixels, the center of the region is + // between (4, 2) and (5, 2) pixels, let's assume it's a "pixel" + // (4.5, 2). + // 3. When using the above as an offset for every pixel select while + // downscaling, resulting pixels are: + // (4.5, 2), (14.5, 2), .. , (94.5, 2) + // (4.5, 7), (14.5, 7), .. , (94.5, 7) + // .. + // (4.5, 47), (14.5, 47), .., (94.5, 47) + // instead of: + // (0, 0), (10, 0), .. , (90, 0) + // (0, 5), (10, 7), .. , (90, 5) + // .. + // (0, 45), (10, 45), .., (90, 45) + // The latter looks shifted. + // + // Offsets are needed, so that __discrete__ pixel at (0, 0) corresponds to + // the same pixel as would __non discrete__ pixel at (0.5, 0.5). Hence, + // transformation matrix should shift coordinates by (0.5, 0.5) as the + // very first step. + // + // Due to the above shift, transformed coordinates would be valid for + // float coordinates where pixel (0, 0) spans [0.0, 1.0) x [0.0, 1.0). + // T0 make it valid for __discrete__ pixels, transformation matrix should + // shift coordinate by (-0.5f, -0.5f) as the very last step. (E.g. if we + // get (0.5f, 0.5f), then it's (0, 0) __discrete__ pixel.) + // clang-format off + cv::Matx44f shift_dst({1.0f, 0.0f, 0.0f, 0.5f, + 0.0f, 1.0f, 0.0f, 0.5f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + cv::Matx44f shift_src({1.0f, 0.0f, 0.0f, -0.5f, + 0.0f, 1.0f, 0.0f, -0.5f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + // clang-format on + adjust_dst_coordinate = normalize_dst_coordinate * shift_dst; + adjust_src_coordinate = shift_src * scale_src_coordinate; + } + + cv::Matx44f transform(matrix.data()); + cv::Matx44f transform_absolute = + adjust_src_coordinate * transform * adjust_dst_coordinate; + + cv::Mat in_mat = formats::MatView(&input); + + cv::Mat cv_affine_transform(2, 3, CV_32F); + cv_affine_transform.at(0, 0) = transform_absolute.val[0]; + cv_affine_transform.at(0, 1) = transform_absolute.val[1]; + cv_affine_transform.at(0, 2) = transform_absolute.val[3]; + cv_affine_transform.at(1, 0) = transform_absolute.val[4]; + cv_affine_transform.at(1, 1) = transform_absolute.val[5]; + cv_affine_transform.at(1, 2) = transform_absolute.val[7]; + + ImageFrame out_image(input.Format(), size.width, size.height); + cv::Mat out_mat = formats::MatView(&out_image); + + cv::warpAffine(in_mat, out_mat, cv_affine_transform, + cv::Size(out_mat.cols, out_mat.rows), + /*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, + GetBorderModeForOpenCv(border_mode)); + + return out_image; + } +}; + +} // namespace + +absl::StatusOr< + std::unique_ptr>> +CreateAffineTransformationOpenCvRunner() { + return absl::make_unique(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.h b/mediapipe/calculators/image/affine_transformation_runner_opencv.h new file mode 100644 index 000000000..200281c95 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.h @@ -0,0 +1,32 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/formats/image_frame.h" + +namespace mediapipe { + +absl::StatusOr< + std::unique_ptr>> +CreateAffineTransformationOpenCvRunner(); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 3d878bffc..6bb43dc00 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -240,7 +240,7 @@ absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { auto input_mat = mediapipe::formats::MatView(&input_frame); // Only 1 or 3 channel images supported by OpenCV. - if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { + if (!(input_mat.channels() == 1 || input_mat.channels() == 3)) { return absl::InternalError( "CPU filtering supports only 1 or 3 channel input images."); } diff --git a/mediapipe/calculators/image/image_clone_calculator.cc b/mediapipe/calculators/image/image_clone_calculator.cc index 107c42b92..1e76848b1 100644 --- a/mediapipe/calculators/image/image_clone_calculator.cc +++ b/mediapipe/calculators/image/image_clone_calculator.cc @@ -36,7 +36,7 @@ using GpuBuffer = mediapipe::GpuBuffer; // stored on the target storage (CPU vs GPU) specified in the calculator option. // // The clone shares ownership of the input pixel data on the existing storage. -// If the target storage is diffrent from the existing one, then the data is +// If the target storage is different from the existing one, then the data is // further copied there. // // Example usage: diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 60873ae9f..ee1bcdf96 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode( // IMAGE: ImageFrame representing the input image. // IMAGE_GPU: GpuBuffer representing the input image. // +// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as +// pair. If set, it will override corresponding field in calculator +// options and input side packet. +// // ROTATION_DEGREES (optional): The counterclockwise rotation angle in // degrees. This allows different rotation angles for different frames. It has // to be a multiple of 90 degrees. If provided, it overrides the @@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract( } #endif // !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set>(); + } + if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); } @@ -329,6 +337,13 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { !cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) { flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get(); } + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") && + !cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { + const auto& image_size = + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get>(); + output_width_ = image_size.first; + output_height_ = image_size.second; + } if (use_gpu_) { #if !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index 03d0c3c7a..062fb2cb3 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -37,6 +37,22 @@ constexpr char kImageFrameTag[] = "IMAGE"; constexpr char kMaskCpuTag[] = "MASK"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kMaskGpuTag[] = "MASK_GPU"; + +inline cv::Vec3b Blend(const cv::Vec3b& color1, const cv::Vec3b& color2, + float weight, int invert_mask, + int adjust_with_luminance) { + weight = (1 - invert_mask) * weight + invert_mask * (1.0f - weight); + + float luminance = + (1 - adjust_with_luminance) * 1.0f + + adjust_with_luminance * + (color1[0] * 0.299 + color1[1] * 0.587 + color1[2] * 0.114) / 255; + + float mix_value = weight * luminance; + + return color1 * (1.0 - mix_value) + color2 * mix_value; +} + } // namespace namespace mediapipe { @@ -44,15 +60,14 @@ namespace mediapipe { // A calculator to recolor a masked area of an image to a specified color. // // A mask image is used to specify where to overlay a user defined color. -// The luminance of the input image is used to adjust the blending weight, -// to help preserve image textures. // // Inputs: // One of the following IMAGE tags: -// IMAGE: An ImageFrame input image, RGB or RGBA. +// IMAGE: An ImageFrame input image in ImageFormat::SRGB. // IMAGE_GPU: A GpuBuffer input image, RGBA. // One of the following MASK tags: -// MASK: An ImageFrame input mask, Gray, RGB or RGBA. +// MASK: An ImageFrame input mask in ImageFormat::GRAY8, SRGB, SRGBA, or +// VEC32F1 // MASK_GPU: A GpuBuffer input mask, RGBA. // Output: // One of the following IMAGE tags: @@ -98,10 +113,12 @@ class RecolorCalculator : public CalculatorBase { void GlRender(); bool initialized_ = false; - std::vector color_; + std::vector color_; mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; bool use_gpu_ = false; + bool invert_mask_ = false; + bool adjust_with_luminance_ = false; #if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; @@ -233,11 +250,15 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { } cv::Mat mask_full; cv::resize(mask_mat, mask_full, input_mat.size()); + const cv::Vec3b recolor = {color_[0], color_[1], color_[2]}; auto output_img = absl::make_unique( input_img.Format(), input_mat.cols, input_mat.rows); cv::Mat output_mat = mediapipe::formats::MatView(output_img.get()); + const int invert_mask = invert_mask_ ? 1 : 0; + const int adjust_with_luminance = adjust_with_luminance_ ? 1 : 0; + // From GPU shader: /* vec4 weight = texture2D(mask, sample_coordinate); @@ -249,18 +270,23 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { fragColor = mix(color1, color2, mix_value); */ - for (int i = 0; i < output_mat.rows; ++i) { - for (int j = 0; j < output_mat.cols; ++j) { - float weight = mask_full.at(i, j) * (1.0 / 255.0); - cv::Vec3f color1 = input_mat.at(i, j); - cv::Vec3f color2 = {color_[0], color_[1], color_[2]}; - - float luminance = - (color1[0] * 0.299 + color1[1] * 0.587 + color1[2] * 0.114) / 255; - float mix_value = weight * luminance; - - cv::Vec3b mix_color = color1 * (1.0 - mix_value) + color2 * mix_value; - output_mat.at(i, j) = mix_color; + if (mask_img.Format() == ImageFormat::VEC32F1) { + for (int i = 0; i < output_mat.rows; ++i) { + for (int j = 0; j < output_mat.cols; ++j) { + const float weight = mask_full.at(i, j); + output_mat.at(i, j) = + Blend(input_mat.at(i, j), recolor, weight, invert_mask, + adjust_with_luminance); + } + } + } else { + for (int i = 0; i < output_mat.rows; ++i) { + for (int j = 0; j < output_mat.cols; ++j) { + const float weight = mask_full.at(i, j) * (1.0 / 255.0); + output_mat.at(i, j) = + Blend(input_mat.at(i, j), recolor, weight, invert_mask, + adjust_with_luminance); + } } } @@ -385,6 +411,9 @@ absl::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { color_.push_back(options.color().g()); color_.push_back(options.color().b()); + invert_mask_ = options.invert_mask(); + adjust_with_luminance_ = options.adjust_with_luminance(); + return absl::OkStatus(); } @@ -435,13 +464,20 @@ absl::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { uniform sampler2D frame; uniform sampler2D mask; uniform vec3 recolor; + uniform float invert_mask; + uniform float adjust_with_luminance; void main() { vec4 weight = texture2D(mask, sample_coordinate); vec4 color1 = texture2D(frame, sample_coordinate); vec4 color2 = vec4(recolor, 1.0); - float luminance = dot(color1.rgb, vec3(0.299, 0.587, 0.114)); + weight = mix(weight, 1.0 - weight, invert_mask); + + float luminance = mix(1.0, + dot(color1.rgb, vec3(0.299, 0.587, 0.114)), + adjust_with_luminance); + float mix_value = weight.MASK_COMPONENT * luminance; fragColor = mix(color1, color2, mix_value); @@ -458,6 +494,10 @@ absl::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0] / 255.0, color_[1] / 255.0, color_[2] / 255.0); + glUniform1f(glGetUniformLocation(program_, "invert_mask"), + invert_mask_ ? 1.0f : 0.0f); + glUniform1f(glGetUniformLocation(program_, "adjust_with_luminance"), + adjust_with_luminance_ ? 1.0f : 0.0f); #endif // !MEDIAPIPE_DISABLE_GPU return absl::OkStatus(); diff --git a/mediapipe/calculators/image/recolor_calculator.proto b/mediapipe/calculators/image/recolor_calculator.proto index 76326c079..abbf0849d 100644 --- a/mediapipe/calculators/image/recolor_calculator.proto +++ b/mediapipe/calculators/image/recolor_calculator.proto @@ -36,4 +36,11 @@ message RecolorCalculatorOptions { // Color to blend into input image where mask is > 0. // The blending is based on the input image luminosity. optional Color color = 2; + + // Swap the meaning of mask values for foreground/background. + optional bool invert_mask = 3 [default = false]; + + // Whether to use the luminance of the input image to further adjust the + // blending weight, to help preserve image textures. + optional bool adjust_with_luminance = 4 [default = true]; } diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index 575268da5..0669f5322 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -262,6 +262,7 @@ absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) { scale_image::FindOutputDimensions(crop_width_, crop_height_, // options_.target_width(), // options_.target_height(), // + options_.target_max_area(), // options_.preserve_aspect_ratio(), // options_.scale_to_multiple_of(), // &output_width_, &output_height_)); diff --git a/mediapipe/calculators/image/scale_image_calculator.proto b/mediapipe/calculators/image/scale_image_calculator.proto index e51ccafaa..2b7572d56 100644 --- a/mediapipe/calculators/image/scale_image_calculator.proto +++ b/mediapipe/calculators/image/scale_image_calculator.proto @@ -28,6 +28,11 @@ message ScaleImageCalculatorOptions { optional int32 target_width = 1; optional int32 target_height = 2; + // If set, then automatically calculates a target_width and target_height that + // has an area below the target max area. Aspect ratio preservation cannot be + // disabled. + optional int32 target_max_area = 15; + // If true, the image is scaled up or down proportionally so that it // fits inside the box represented by target_width and target_height. // Otherwise it is scaled to fit target_width and target_height diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 738e83da0..490d0336a 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -92,12 +92,21 @@ absl::Status FindOutputDimensions(int input_width, // int input_height, // int target_width, // int target_height, // + int target_max_area, // bool preserve_aspect_ratio, // int scale_to_multiple_of, // int* output_width, int* output_height) { CHECK(output_width); CHECK(output_height); + if (target_max_area > 0 && input_width * input_height > target_max_area) { + preserve_aspect_ratio = true; + target_height = static_cast(sqrt(static_cast(target_max_area) / + (static_cast(input_width) / + static_cast(input_height)))); + target_width = -1; // Resize width to preserve aspect ratio. + } + if (preserve_aspect_ratio) { RET_CHECK(scale_to_multiple_of == 2) << "FindOutputDimensions always outputs width and height that are " @@ -164,5 +173,17 @@ absl::Status FindOutputDimensions(int input_width, // << "Unable to set output dimensions based on target dimensions."; } +absl::Status FindOutputDimensions(int input_width, // + int input_height, // + int target_width, // + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height) { + return FindOutputDimensions( + input_width, input_height, target_width, target_height, -1, + preserve_aspect_ratio, scale_to_multiple_of, output_width, output_height); +} + } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_utils.h b/mediapipe/calculators/image/scale_image_utils.h index c2c0b8f7c..e7fccd8dc 100644 --- a/mediapipe/calculators/image/scale_image_utils.h +++ b/mediapipe/calculators/image/scale_image_utils.h @@ -34,15 +34,25 @@ absl::Status FindCropDimensions(int input_width, int input_height, // int* crop_width, int* crop_height, // int* col_start, int* row_start); -// Given an input width and height, a target width and height, whether to -// preserve the aspect ratio, and whether to round-down to the multiple of a -// given number nearest to the targets, determine the output width and height. -// If target_width or target_height is non-positive, then they will be set to -// the input_width and input_height respectively. If scale_to_multiple_of is -// less than 1, it will be treated like 1. The output_width and -// output_height will be reduced as necessary to preserve_aspect_ratio if the -// option is specified. If preserving the aspect ratio is desired, you must set -// scale_to_multiple_of to 2. +// Given an input width and height, a target width and height or max area, +// whether to preserve the aspect ratio, and whether to round-down to the +// multiple of a given number nearest to the targets, determine the output width +// and height. If target_width or target_height is non-positive, then they will +// be set to the input_width and input_height respectively. If target_area is +// non-positive, then it will be ignored. If scale_to_multiple_of is less than +// 1, it will be treated like 1. The output_width and output_height will be +// reduced as necessary to preserve_aspect_ratio if the option is specified. If +// preserving the aspect ratio is desired, you must set scale_to_multiple_of +// to 2. +absl::Status FindOutputDimensions(int input_width, int input_height, // + int target_width, + int target_height, // + int target_max_area, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height); + +// Backwards compatible helper. absl::Status FindOutputDimensions(int input_width, int input_height, // int target_width, int target_height, // diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc index 14a58e762..bda1fa4d6 100644 --- a/mediapipe/calculators/image/scale_image_utils_test.cc +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -79,49 +79,49 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { int output_width; int output_height; // Not scale. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(100, output_height); // Not scale with odd input size. - MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(201, output_width); EXPECT_EQ(101, output_height); // Scale down by 1/2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale up, doubling dimensions. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(400, output_width); EXPECT_EQ(200, output_height); // Fits a 2:1 image into a 150 x 150 box. Output dimensions are always // visible by 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(150, output_width); EXPECT_EQ(74, output_height); // Fits a 2:1 image into a 400 x 50 box. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to multiple number with odd targe size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to multiple number with odd targe size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to odd size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(151, output_width); EXPECT_EQ(101, output_height); } @@ -131,18 +131,18 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsNoAspectRatio) { int output_width; int output_height; // Scale width only. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(100, output_height); // Scale height only. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(200, output_height); // Scale both dimensions. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(150, output_width); EXPECT_EQ(200, output_height); } @@ -152,41 +152,78 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsDownScaleToMultipleOf) { int output_width; int output_height; // Set no targets, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(100, 100, -1, -1, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(100, 100, -1, -1, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(96, output_width); EXPECT_EQ(96, output_height); // Set width target, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(96, output_width); EXPECT_EQ(96, output_height); // Set height target, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, 201, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, 201, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(200, output_height); // Set both targets, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(144, output_width); EXPECT_EQ(200, output_height); // Doesn't throw error if keep aspect is true and downscale multiple is 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 200, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 200, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(400, output_width); EXPECT_EQ(200, output_height); // Throws error if keep aspect is true, but downscale multiple is not 2. - ASSERT_THAT(FindOutputDimensions(200, 100, 400, 200, true, 4, &output_width, - &output_height), + ASSERT_THAT(FindOutputDimensions(200, 100, 400, 200, -1, true, 4, + &output_width, &output_height), testing::Not(testing::status::IsOk())); // Downscaling to multiple ignored if multiple is less than 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 401, 201, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 401, 201, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(401, output_width); EXPECT_EQ(201, output_height); } +// Tests scaling without keeping the aspect ratio fixed. +TEST(ScaleImageUtilsTest, FindOutputDimensionsMaxArea) { + int output_width; + int output_height; + // Smaller area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 9000, false, 2, + &output_width, &output_height)); + EXPECT_NEAR( + 200 / 100, + static_cast(output_width) / static_cast(output_height), + 0.1f); + EXPECT_LE(output_width * output_height, 9000); + // Close to original area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 19999, false, 2, + &output_width, &output_height)); + EXPECT_NEAR( + 200.0 / 100.0, + static_cast(output_width) / static_cast(output_height), + 0.1f); + EXPECT_LE(output_width * output_height, 19999); + // Don't scale with larger area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 20001, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); + // Don't scale with equal area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 20000, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); + // Don't scale at all. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, -1, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); +} + } // namespace } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc new file mode 100644 index 000000000..db339b754 --- /dev/null +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -0,0 +1,429 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { +constexpr char kCurrentMaskTag[] = "MASK"; +constexpr char kPreviousMaskTag[] = "MASK_PREVIOUS"; +constexpr char kOutputMaskTag[] = "MASK_SMOOTHED"; + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + +// A calculator for mixing two segmentation masks together, +// based on an uncertantity probability estimate. +// +// Inputs: +// MASK - Image containing the new/current mask. +// [ImageFormat::VEC32F1, or +// GpuBufferFormat::kBGRA32/kRGB24/kGrayHalf16/kGrayFloat32] +// MASK_PREVIOUS - Image containing previous mask. +// [Same format as MASK_CURRENT] +// * If input channels is >1, only the first channel (R) is used as the mask. +// +// Output: +// MASK_SMOOTHED - Blended mask. +// [Same format as MASK_CURRENT] +// * The resulting filtered mask will be stored in R channel, +// and duplicated in A if 4 channels. +// +// Options: +// combine_with_previous_ratio - Amount of previous to blend with current. +// +// Example: +// node { +// calculator: "SegmentationSmoothingCalculator" +// input_stream: "MASK:mask" +// input_stream: "MASK_PREVIOUS:mask_previous" +// output_stream: "MASK_SMOOTHED:mask_smoothed" +// options: { +// [mediapipe.SegmentationSmoothingCalculatorOptions.ext] { +// combine_with_previous_ratio: 0.9 +// } +// } +// } +// +class SegmentationSmoothingCalculator : public CalculatorBase { + public: + SegmentationSmoothingCalculator() = default; + + static absl::Status GetContract(CalculatorContract* cc); + + // From Calculator. + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + + absl::Status GlSetup(CalculatorContext* cc); + void GlRender(CalculatorContext* cc); + + float combine_with_previous_ratio_; + + bool gpu_initialized_ = false; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(SegmentationSmoothingCalculator); + +absl::Status SegmentationSmoothingCalculator::GetContract( + CalculatorContract* cc) { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + cc->Inputs().Tag(kCurrentMaskTag).Set(); + cc->Inputs().Tag(kPreviousMaskTag).Set(); + cc->Outputs().Tag(kOutputMaskTag).Set(); + +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status SegmentationSmoothingCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + auto options = + cc->Options(); + combine_with_previous_ratio_ = options.combine_with_previous_ratio(); + +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status SegmentationSmoothingCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().Tag(kCurrentMaskTag).IsEmpty()) { + return absl::OkStatus(); + } + if (cc->Inputs().Tag(kPreviousMaskTag).IsEmpty()) { + // Pass through current image if previous is not available. + cc->Outputs() + .Tag(kOutputMaskTag) + .AddPacket(cc->Inputs().Tag(kCurrentMaskTag).Value()); + return absl::OkStatus(); + } + + // Run on GPU if incoming data is on GPU. + const bool use_gpu = cc->Inputs().Tag(kCurrentMaskTag).Get().UsesGpu(); + + if (use_gpu) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#else + return absl::InternalError("GPU processing is disabled."); +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + MP_RETURN_IF_ERROR(RenderCpu(cc)); + } + + return absl::OkStatus(); +} + +absl::Status SegmentationSmoothingCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + }); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { + // Setup source images. + const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get(); + const cv::Mat current_mat = mediapipe::formats::MatView(¤t_frame); + RET_CHECK_EQ(current_mat.type(), CV_32FC1) + << "Only 1-channel float input image is supported."; + + const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get(); + const cv::Mat previous_mat = mediapipe::formats::MatView(&previous_frame); + RET_CHECK_EQ(previous_mat.type(), current_mat.type()) + << "Warning: mixing input format types: " << previous_mat.type() + << " != " << previous_mat.type(); + + RET_CHECK_EQ(current_mat.rows, previous_mat.rows); + RET_CHECK_EQ(current_mat.cols, previous_mat.cols); + + // Setup destination image. + auto output_frame = std::make_shared( + current_frame.image_format(), current_mat.cols, current_mat.rows); + cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); + output_mat.setTo(cv::Scalar(0)); + + // Blending function. + const auto blending_fn = [&](const float prev_mask_value, + const float new_mask_value) { + /* + * Assume p := new_mask_value + * H(p) := 1 + (p * log(p) + (1-p) * log(1-p)) / log(2) + * uncertainty alpha(p) = + * Clamp(1 - (1 - H(p)) * (1 - H(p)), 0, 1) [squaring the uncertainty] + * + * The following polynomial approximates uncertainty alpha as a function + * of (p + 0.5): + */ + const float c1 = 5.68842; + const float c2 = -0.748699; + const float c3 = -57.8051; + const float c4 = 291.309; + const float c5 = -624.717; + const float t = new_mask_value - 0.5f; + const float x = t * t; + + const float uncertainty = + 1.0f - + std::min(1.0f, x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5))))); + + return new_mask_value + (prev_mask_value - new_mask_value) * + (uncertainty * combine_with_previous_ratio_); + }; + + // Write directly to the first channel of output. + for (int i = 0; i < output_mat.rows; ++i) { + float* out_ptr = output_mat.ptr(i); + const float* curr_ptr = current_mat.ptr(i); + const float* prev_ptr = previous_mat.ptr(i); + for (int j = 0; j < output_mat.cols; ++j) { + const float new_mask_value = curr_ptr[j]; + const float prev_mask_value = prev_ptr[j]; + out_ptr[j] = blending_fn(prev_mask_value, new_mask_value); + } + } + + cc->Outputs() + .Tag(kOutputMaskTag) + .AddPacket(MakePacket(output_frame).At(cc->InputTimestamp())); + + return absl::OkStatus(); +} + +absl::Status SegmentationSmoothingCalculator::RenderGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + // Setup source textures. + const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get(); + RET_CHECK( + (current_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 || + current_frame.format() == mediapipe::GpuBufferFormat::kGrayHalf16 || + current_frame.format() == mediapipe::GpuBufferFormat::kGrayFloat32 || + current_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) + << "Only RGBA, RGB, or 1-channel Float input image supported."; + + auto current_texture = gpu_helper_.CreateSourceTexture(current_frame); + + const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get(); + if (previous_frame.format() != current_frame.format()) { + LOG(ERROR) << "Warning: mixing input format types. "; + } + auto previous_texture = gpu_helper_.CreateSourceTexture(previous_frame); + + // Setup destination texture. + const int width = current_frame.width(), height = current_frame.height(); + auto output_texture = gpu_helper_.CreateDestinationTexture( + width, height, current_frame.format()); + + // Process shader. + { + gpu_helper_.BindFramebuffer(output_texture); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, current_texture.name()); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, previous_texture.name()); + GlRender(cc); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + } + glFlush(); + + // Send out image as GPU packet. + auto output_frame = output_texture.GetFrame(); + cc->Outputs() + .Tag(kOutputMaskTag) + .Add(output_frame.release(), cc->InputTimestamp()); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +void SegmentationSmoothingCalculator::GlRender(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + +#endif // !MEDIAPIPE_DISABLE_GPU +} + +absl::Status SegmentationSmoothingCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Shader to blend in previous mask based on computed uncertainty probability. + const std::string frag_src = + absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble), + R"( + DEFAULT_PRECISION(mediump, float) + + #ifdef GL_ES + #define fragColor gl_FragColor + #else + out vec4 fragColor; + #endif // defined(GL_ES); + + in vec2 sample_coordinate; + uniform sampler2D current_mask; + uniform sampler2D previous_mask; + uniform float combine_with_previous_ratio; + + void main() { + vec4 current_pix = texture2D(current_mask, sample_coordinate); + vec4 previous_pix = texture2D(previous_mask, sample_coordinate); + float new_mask_value = current_pix.r; + float prev_mask_value = previous_pix.r; + + // Assume p := new_mask_value + // H(p) := 1 + (p * log(p) + (1-p) * log(1-p)) / log(2) + // uncertainty alpha(p) = + // Clamp(1 - (1 - H(p)) * (1 - H(p)), 0, 1) [squaring the uncertainty] + // + // The following polynomial approximates uncertainty alpha as a function + // of (p + 0.5): + const float c1 = 5.68842; + const float c2 = -0.748699; + const float c3 = -57.8051; + const float c4 = 291.309; + const float c5 = -624.717; + float t = new_mask_value - 0.5; + float x = t * t; + + float uncertainty = + 1.0 - min(1.0, x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5))))); + + new_mask_value += + (prev_mask_value - new_mask_value) * (uncertainty * combine_with_previous_ratio); + + fragColor = vec4(new_mask_value, 0.0, 0.0, new_mask_value); + } + )"); + + // Create shader program and set parameters. + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(), + NUM_ATTRIBUTES, (const GLchar**)&attr_name[0], + attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "current_mask"), 1); + glUniform1i(glGetUniformLocation(program_, "previous_mask"), 2); + glUniform1f(glGetUniformLocation(program_, "combine_with_previous_ratio"), + combine_with_previous_ratio_); + +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.proto b/mediapipe/calculators/image/segmentation_smoothing_calculator.proto new file mode 100644 index 000000000..12b10ccd1 --- /dev/null +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.proto @@ -0,0 +1,35 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message SegmentationSmoothingCalculatorOptions { + extend CalculatorOptions { + optional SegmentationSmoothingCalculatorOptions ext = 377425128; + } + + // How much to blend in previous mask, based on a probability estimate. + // Range: [0-1] + // 0 = Use only current frame (no blending). + // 1 = Blend in the previous mask based on uncertainty estimate. + // With ratio at 1, the uncertainty estimate is trusted completely. + // When uncertainty is high, the previous mask is given higher weight. + // Therefore, if both ratio and uncertainty are 1, only old mask is used. + // A pixel is 'uncertain' if its value is close to the middle (0.5 or 127). + optional float combine_with_previous_ratio = 1 [default = 0.0]; +} diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc new file mode 100644 index 000000000..100d7de8a --- /dev/null +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc @@ -0,0 +1,206 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +// 4x4 VEC32F1, center 2x2 block set at ~250 +const float mask_data[] = { + 0.00, 0.00, 0.00, 0.00, // + 0.00, 0.98, 0.98, 0.00, // + 0.00, 0.98, 0.98, 0.00, // + 0.00, 0.00, 0.00, 0.00, // +}; + +void RunGraph(Packet curr_packet, Packet prev_packet, bool use_gpu, float ratio, + cv::Mat* result) { + CalculatorGraphConfig graph_config; + if (use_gpu) { + graph_config = ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "curr_mask" + input_stream: "prev_mask" + output_stream: "new_mask" + node { + calculator: "ImageCloneCalculator" + input_stream: "curr_mask" + output_stream: "curr_mask_gpu" + options: { + [mediapipe.ImageCloneCalculatorOptions.ext] { + output_on_gpu: true + } + } + } + node { + calculator: "ImageCloneCalculator" + input_stream: "prev_mask" + output_stream: "prev_mask_gpu" + options: { + [mediapipe.ImageCloneCalculatorOptions.ext] { + output_on_gpu: true + } + } + } + node { + calculator: "SegmentationSmoothingCalculator" + input_stream: "MASK:curr_mask_gpu" + input_stream: "MASK_PREVIOUS:prev_mask_gpu" + output_stream: "MASK_SMOOTHED:new_mask" + node_options { + [type.googleapis.com/ + mediapipe.SegmentationSmoothingCalculatorOptions]: { + combine_with_previous_ratio: $0 + } + } + } + )pb", + ratio)); + } else { + graph_config = ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "curr_mask" + input_stream: "prev_mask" + output_stream: "new_mask" + node { + calculator: "SegmentationSmoothingCalculator" + input_stream: "MASK:curr_mask" + input_stream: "MASK_PREVIOUS:prev_mask" + output_stream: "MASK_SMOOTHED:new_mask" + node_options { + [type.googleapis.com/ + mediapipe.SegmentationSmoothingCalculatorOptions]: { + combine_with_previous_ratio: $0 + } + } + } + )pb", + ratio)); + } + std::vector output_packets; + tool::AddVectorSink("new_mask", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK( + graph.AddPacketToInputStream("curr_mask", curr_packet.At(Timestamp(0)))); + MP_ASSERT_OK( + graph.AddPacketToInputStream("prev_mask", prev_packet.At(Timestamp(0)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, output_packets.size()); + + Image result_image = output_packets[0].Get(); + cv::Mat result_mat = formats::MatView(&result_image); + result_mat.copyTo(*result); + + // Fully close graph at end, otherwise calculator+Images are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("curr_mask")); + MP_ASSERT_OK(graph.CloseInputStream("prev_mask")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +void RunTest(bool use_gpu, float mix_ratio, cv::Mat& test_result) { + cv::Mat mask_mat(cv::Size(4, 4), CV_32FC1, const_cast(mask_data)); + cv::Mat curr_mat = mask_mat; + // 3x3 blur of 250 block produces all pixels '111'. + cv::Mat prev_mat; + cv::blur(mask_mat, prev_mat, cv::Size(3, 3)); + + Packet curr_packet = MakePacket(std::make_unique( + ImageFormat::VEC32F1, curr_mat.size().width, curr_mat.size().height)); + curr_mat.copyTo(formats::MatView(&(curr_packet.Get()))); + Packet prev_packet = MakePacket(std::make_unique( + ImageFormat::VEC32F1, prev_mat.size().width, prev_mat.size().height)); + prev_mat.copyTo(formats::MatView(&(prev_packet.Get()))); + + cv::Mat result; + RunGraph(curr_packet, prev_packet, use_gpu, mix_ratio, &result); + + ASSERT_EQ(curr_mat.rows, result.rows); + ASSERT_EQ(curr_mat.cols, result.cols); + ASSERT_EQ(curr_mat.type(), result.type()); + result.copyTo(test_result); + + if (mix_ratio == 1.0) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float in = curr_mat.at(i, j); + float out = result.at(i, j); + // Since the input has high value (250), it has low uncertainty. + // So the output should have changed lower (towards prev), + // but not too much. + if (in > 0) EXPECT_NE(in, out); + EXPECT_NEAR(in, out, 3.0 / 255.0); + } + } + } else if (mix_ratio == 0.0) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float in = curr_mat.at(i, j); + float out = result.at(i, j); + EXPECT_EQ(in, out); // Output should match current. + } + } + } else { + LOG(ERROR) << "invalid ratio"; + } +} + +TEST(SegmentationSmoothingCalculatorTest, TestSmoothing) { + bool use_gpu; + float mix_ratio; + + use_gpu = false; + mix_ratio = 0.0; + cv::Mat cpu_0; + RunTest(use_gpu, mix_ratio, cpu_0); + + use_gpu = false; + mix_ratio = 1.0; + cv::Mat cpu_1; + RunTest(use_gpu, mix_ratio, cpu_1); + + use_gpu = true; + mix_ratio = 1.0; + cv::Mat gpu_1; + RunTest(use_gpu, mix_ratio, gpu_1); + + // CPU & GPU should match. + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float gpu = gpu_1.at(i, j); + float cpu = cpu_1.at(i, j); + EXPECT_EQ(cpu, gpu); + } + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc new file mode 100644 index 000000000..e3d017a35 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -0,0 +1,211 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/warp_affine_calculator.h" + +#include +#include +#include + +#include "mediapipe/calculators/image/affine_transformation.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/calculators/image/affine_transformation_runner_gl.h" +#endif // !MEDIAPIPE_DISABLE_GPU +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" +#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { + +AffineTransformation::BorderMode GetBorderMode( + mediapipe::WarpAffineCalculatorOptions::BorderMode border_mode) { + switch (border_mode) { + case mediapipe::WarpAffineCalculatorOptions::BORDER_ZERO: + return AffineTransformation::BorderMode::kZero; + case mediapipe::WarpAffineCalculatorOptions::BORDER_UNSPECIFIED: + case mediapipe::WarpAffineCalculatorOptions::BORDER_REPLICATE: + return AffineTransformation::BorderMode::kReplicate; + } +} + +template +class WarpAffineRunnerHolder {}; + +template <> +class WarpAffineRunnerHolder { + public: + using RunnerType = AffineTransformation::Runner; + absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } + absl::StatusOr GetRunner() { + if (!runner_) { + ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); + } + return runner_.get(); + } + + private: + std::unique_ptr runner_; +}; + +#if !MEDIAPIPE_DISABLE_GPU +template <> +class WarpAffineRunnerHolder { + public: + using RunnerType = + AffineTransformation::Runner>; + absl::Status Open(CalculatorContext* cc) { + gpu_origin_ = + cc->Options().gpu_origin(); + gl_helper_ = std::make_shared(); + return gl_helper_->Open(cc); + } + absl::StatusOr GetRunner() { + if (!runner_) { + ASSIGN_OR_RETURN( + runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); + } + return runner_.get(); + } + + private: + mediapipe::GpuOrigin::Mode gpu_origin_; + std::shared_ptr gl_helper_; + std::unique_ptr runner_; +}; +#endif // !MEDIAPIPE_DISABLE_GPU + +template <> +class WarpAffineRunnerHolder { + public: + absl::Status Open(CalculatorContext* cc) { return runner_.Open(cc); } + absl::StatusOr< + AffineTransformation::Runner*> + GetRunner() { + return &runner_; + } + + private: + class Runner : public AffineTransformation::Runner { + public: + absl::Status Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(cpu_holder_.Open(cc)); +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_holder_.Open(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); + } + absl::StatusOr Run( + const mediapipe::Image& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + if (input.UsesGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + ASSIGN_OR_RETURN(auto* runner, gpu_holder_.GetRunner()); + ASSIGN_OR_RETURN(auto result, runner->Run(input.GetGpuBuffer(), matrix, + size, border_mode)); + return mediapipe::Image(*result); +#else + return absl::UnavailableError("GPU support is disabled"); +#endif // !MEDIAPIPE_DISABLE_GPU + } + ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner()); + const auto& frame_ptr = input.GetImageFrameSharedPtr(); + // Wrap image into image frame. + const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), + frame_ptr->Height(), frame_ptr->WidthStep(), + const_cast(frame_ptr->PixelData()), + [](uint8* data) {}); + ASSIGN_OR_RETURN(auto result, + runner->Run(image_frame, matrix, size, border_mode)); + return mediapipe::Image(std::make_shared(std::move(result))); + } + + private: + WarpAffineRunnerHolder cpu_holder_; +#if !MEDIAPIPE_DISABLE_GPU + WarpAffineRunnerHolder gpu_holder_; +#endif // !MEDIAPIPE_DISABLE_GPU + }; + + Runner runner_; +}; + +template +class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl { + public: +#if !MEDIAPIPE_DISABLE_GPU + static absl::Status UpdateContract(CalculatorContract* cc) { + if constexpr (std::is_same_v || + std::is_same_v) { + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + } + return absl::OkStatus(); + } +#endif // !MEDIAPIPE_DISABLE_GPU + + absl::Status Open(CalculatorContext* cc) override { return holder_.Open(cc); } + + absl::Status Process(CalculatorContext* cc) override { + if (InterfaceT::kInImage(cc).IsEmpty() || + InterfaceT::kMatrix(cc).IsEmpty() || + InterfaceT::kOutputSize(cc).IsEmpty()) { + return absl::OkStatus(); + } + const std::array& transform = *InterfaceT::kMatrix(cc); + auto [out_width, out_height] = *InterfaceT::kOutputSize(cc); + AffineTransformation::Size output_size; + output_size.width = out_width; + output_size.height = out_height; + ASSIGN_OR_RETURN(auto* runner, holder_.GetRunner()); + ASSIGN_OR_RETURN( + auto result, + runner->Run( + *InterfaceT::kInImage(cc), transform, output_size, + GetBorderMode(cc->Options() + .border_mode()))); + InterfaceT::kOutImage(cc).Send(std::move(result)); + + return absl::OkStatus(); + } + + private: + WarpAffineRunnerHolder + holder_; +}; + +} // namespace + +MEDIAPIPE_NODE_IMPLEMENTATION( + WarpAffineCalculatorImpl); +#if !MEDIAPIPE_DISABLE_GPU +MEDIAPIPE_NODE_IMPLEMENTATION( + WarpAffineCalculatorImpl); +#endif // !MEDIAPIPE_DISABLE_GPU +MEDIAPIPE_NODE_IMPLEMENTATION(WarpAffineCalculatorImpl); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.h b/mediapipe/calculators/image/warp_affine_calculator.h new file mode 100644 index 000000000..4a1b07030 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.h @@ -0,0 +1,94 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_buffer.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +// Runs affine transformation. +// +// Input: +// IMAGE - Image/ImageFrame/GpuBuffer +// +// MATRIX - std::array +// Used as following: +// output(x, y) = input(matrix[0] * x + matrix[1] * y + matrix[3], +// matrix[4] * x + matrix[5] * y + matrix[7]) +// where x and y ranges are defined by @OUTPUT_SIZE. +// +// OUTPUT_SIZE - std::pair +// Size of the output image. +// +// Output: +// IMAGE - Image/ImageFrame/GpuBuffer +// +// Note: +// - Output image type and format are the same as the input one. +// +// Usage example: +// node { +// calculator: "WarpAffineCalculator(Cpu|Gpu)" +// input_stream: "IMAGE:image" +// input_stream: "MATRIX:matrix" +// input_stream: "OUTPUT_SIZE:size" +// output_stream: "IMAGE:transformed_image" +// options: { +// [mediapipe.WarpAffineCalculatorOptions.ext] { +// border_mode: BORDER_ZERO +// } +// } +// } +template +class WarpAffineCalculatorIntf : public mediapipe::api2::NodeIntf { + public: + static constexpr mediapipe::api2::Input kInImage{"IMAGE"}; + static constexpr mediapipe::api2::Input> kMatrix{ + "MATRIX"}; + static constexpr mediapipe::api2::Input> kOutputSize{ + "OUTPUT_SIZE"}; + static constexpr mediapipe::api2::Output kOutImage{"IMAGE"}; +}; + +class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix, + kOutputSize, kOutImage); +}; +#if !MEDIAPIPE_DISABLE_GPU +class WarpAffineCalculatorGpu + : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorGpu, kInImage, kMatrix, + kOutputSize, kOutImage); +}; +#endif // !MEDIAPIPE_DISABLE_GPU +class WarpAffineCalculator : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculator, kInImage, kMatrix, kOutputSize, + kOutImage); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ diff --git a/mediapipe/calculators/image/warp_affine_calculator.proto b/mediapipe/calculators/image/warp_affine_calculator.proto new file mode 100644 index 000000000..20e6c1b07 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/gpu/gpu_origin.proto"; + +message WarpAffineCalculatorOptions { + extend CalculatorOptions { + optional WarpAffineCalculatorOptions ext = 373693895; + } + + // Pixel extrapolation methods. See @border_mode. + enum BorderMode { + BORDER_UNSPECIFIED = 0; + BORDER_ZERO = 1; + BORDER_REPLICATE = 2; + } + + // Pixel extrapolation method. + // When converting image to tensor it may happen that tensor needs to read + // pixels outside image boundaries. Border mode helps to specify how such + // pixels will be calculated. + // + // BORDER_REPLICATE is used by default. + optional BorderMode border_mode = 1; + + // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs + // to be flipped vertically as tensors are expected to start at top. + // (DEFAULT or unset interpreted as CONVENTIONAL.) + optional GpuOrigin.Mode gpu_origin = 2; +} diff --git a/mediapipe/calculators/image/warp_affine_calculator_test.cc b/mediapipe/calculators/image/warp_affine_calculator_test.cc new file mode 100644 index 000000000..959912cc9 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator_test.cc @@ -0,0 +1,615 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +cv::Mat GetRgb(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgb(bgr.rows, bgr.cols, CV_8UC3); + int from_to[] = {0, 2, 1, 1, 2, 0}; + cv::mixChannels(&bgr, 1, &rgb, 1, from_to, 3); + return rgb; +} + +cv::Mat GetRgba(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgba(bgr.rows, bgr.cols, CV_8UC4, cv::Scalar(0, 0, 0, 0)); + int from_to[] = {0, 2, 1, 1, 2, 0}; + cv::mixChannels(&bgr, 1, &bgr, 1, from_to, 3); + return bgr; +} + +// Test template. +// No processing/assertions should be done after the function is invoked. +void RunTest(const std::string& graph_text, const std::string& tag, + const cv::Mat& input, cv::Mat expected_result, + float similarity_threshold, std::array matrix, + int out_width, int out_height, + absl::optional border_mode) { + std::string border_mode_str; + if (border_mode) { + switch (*border_mode) { + case AffineTransformation::BorderMode::kReplicate: + border_mode_str = "border_mode: BORDER_REPLICATE"; + break; + case AffineTransformation::BorderMode::kZero: + border_mode_str = "border_mode: BORDER_ZERO"; + break; + } + } + auto graph_config = mediapipe::ParseTextProtoOrDie( + absl::Substitute(graph_text, /*$0=*/border_mode_str)); + + std::vector output_packets; + tool::AddVectorSink("output_image", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + ImageFrame input_image( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {}); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", + MakePacket(std::move(input_image)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "matrix", + MakePacket>(std::move(matrix)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "output_size", MakePacket>( + std::pair(out_width, out_height)) + .At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_packets, testing::SizeIs(1)); + + // Get and process results. + const ImageFrame& out_frame = output_packets[0].Get(); + cv::Mat result = formats::MatView(&out_frame); + double similarity = + 1.0 - cv::norm(result, expected_result, cv::NORM_RELATIVE | cv::NORM_L2); + EXPECT_GE(similarity, similarity_threshold); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.CloseInputStream("matrix")); + MP_ASSERT_OK(graph.CloseInputStream("output_size")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +enum class InputType { kImageFrame, kImage }; + +// Similarity is checked against OpenCV results always, and due to differences +// on how OpenCV and GL treats pixels there are two thresholds. +// TODO: update to have just one threshold when OpenCV +// implementation is updated. +struct SimilarityConfig { + double threshold_on_cpu; + double threshold_on_gpu; +}; + +void RunTest(cv::Mat input, cv::Mat expected_result, + const SimilarityConfig& similarity, std::array matrix, + int out_width, int out_height, + absl::optional border_mode) { + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "WarpAffineCalculatorCpu" + input_stream: "IMAGE:input_image" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + } + } + } + )", + "cpu", input, expected_result, similarity.threshold_on_cpu, matrix, + out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:input_image" + output_stream: "IMAGE:input_image_unified" + } + node { + calculator: "WarpAffineCalculator" + input_stream: "IMAGE:input_image_unified" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_unified" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + } + } + } + node { + calculator: "FromImageCalculator" + input_stream: "IMAGE:output_image_unified" + output_stream: "IMAGE_CPU:output_image" + } + )", + "cpu_image", input, expected_result, similarity.threshold_on_cpu, + matrix, out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + node { + calculator: "WarpAffineCalculatorGpu" + input_stream: "IMAGE:input_image_gpu" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_gpu" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + gpu_origin: TOP_LEFT + } + } + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + } + )", + "gpu", input, expected_result, similarity.threshold_on_gpu, matrix, + out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + node { + calculator: "ToImageCalculator" + input_stream: "IMAGE_GPU:input_image_gpu" + output_stream: "IMAGE:input_image_unified" + } + node { + calculator: "WarpAffineCalculator" + input_stream: "IMAGE:input_image_unified" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_unified" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + gpu_origin: TOP_LEFT + } + } + } + node { + calculator: "FromImageCalculator" + input_stream: "IMAGE:output_image_unified" + output_stream: "IMAGE_GPU:output_image_gpu" + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + } + )", + "gpu_image", input, expected_result, similarity.threshold_on_gpu, + matrix, out_width, out_height, border_mode); +} + +std::array GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, + bool keep_aspect_ratio, int out_width, + int out_height) { + std::array transform_mat; + mediapipe::RotatedRect roi_absolute = + mediapipe::GetRoi(input.cols, input.rows, roi); + mediapipe::PadRoi(out_width, out_height, keep_aspect_ratio, &roi_absolute) + .IgnoreError(); + mediapipe::GetRotatedSubRectToRectTransformMatrix( + roi_absolute, input.cols, input.rows, + /*flip_horizontaly=*/false, &transform_mat); + return transform_mat; +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = {}; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_with_rotation_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = {}; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, NoOp) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"); + int out_width = 64; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, NoOpBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"); + int out_width = 64; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 2234787c9..72c2f5181 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -26,6 +26,11 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +exports_files( + glob(["testdata/image_to_tensor/*"]), + visibility = ["//mediapipe/calculators/image:__subpackages__"], +) + selects.config_setting_group( name = "compute_shader_unavailable", match_any = [ @@ -109,6 +114,8 @@ cc_library( "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/util/tflite:config", + "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", @@ -349,6 +356,57 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "landmarks_to_tensor_calculator_proto", + srcs = ["landmarks_to_tensor_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "landmarks_to_tensor_calculator", + srcs = ["landmarks_to_tensor_calculator.cc"], + hdrs = ["landmarks_to_tensor_calculator.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":landmarks_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_test( + name = "landmarks_to_tensor_calculator_test", + srcs = ["landmarks_to_tensor_calculator_test.cc"], + deps = [ + ":landmarks_to_tensor_calculator", + ":landmarks_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], @@ -478,7 +536,6 @@ cc_library( deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", - ":image_to_tensor_converter_opencv", ":image_to_tensor_utils", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image", @@ -494,6 +551,9 @@ cc_library( ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [":image_to_tensor_calculator_gpu_deps"], + }) + select({ + "//mediapipe/framework/port:disable_opencv": [], + "//conditions:default": [":image_to_tensor_converter_opencv"], }), alwayslink = 1, ) @@ -753,3 +813,76 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +# Copied from /mediapipe/calculators/tflite/BUILD +selects.config_setting_group( + name = "gpu_inference_disabled", + match_any = [ + "//mediapipe/gpu:disable_gpu", + ], +) + +mediapipe_proto_library( + name = "tensors_to_segmentation_calculator_proto", + srcs = ["tensors_to_segmentation_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:gpu_origin_proto", + ], +) + +cc_library( + name = "tensors_to_segmentation_calculator", + srcs = ["tensors_to_segmentation_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensors_to_segmentation_calculator_cc_proto", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:port", + "//mediapipe/util:resource_util", + "@org_tensorflow//tensorflow/lite:framework", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/framework/port:statusor", + ] + selects.with_or({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + }) + selects.with_or({ + ":gpu_inference_disabled": [], + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:MPPMetalHelper", + ], + "//conditions:default": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index f681ab661..b579f0474 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -18,7 +18,6 @@ #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" -#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" @@ -33,6 +32,10 @@ #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/gpu_origin.pb.h" +#if !MEDIAPIPE_DISABLE_OPENCV +#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h" +#endif + #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" @@ -84,9 +87,9 @@ using GpuBuffer = mediapipe::GpuBuffer; // TENSORS - std::vector // Vector containing a single Tensor populated with an extrated RGB image. // MATRIX - std::array @Optional -// An std::array representing a 4x4 row-major-order matrix which -// can be used to map a point on the output tensor to a point on the input -// image. +// An std::array representing a 4x4 row-major-order matrix that +// maps a point on the input image to a point on the output tensor, and +// can be used to reverse the mapping by inverting the matrix. // LETTERBOX_PADDING - std::array @Optional // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to @@ -301,8 +304,13 @@ class ImageToTensorCalculator : public Node { } } else { if (!cpu_converter_) { +#if !MEDIAPIPE_DISABLE_OPENCV ASSIGN_OR_RETURN(cpu_converter_, CreateOpenCvConverter(cc, GetBorderMode())); +#else + LOG(FATAL) << "Cannot create image to tensor opencv converter since " + "MEDIAPIPE_DISABLE_OPENCV is defined."; +#endif // !MEDIAPIPE_DISABLE_OPENCV } } return absl::OkStatus(); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index 1c27f282a..d01916f3e 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -312,7 +312,7 @@ class GlProcessor : public ImageToTensorConverter { return absl::OkStatus(); })); - return std::move(tensor); + return tensor; } ~GlProcessor() override { @@ -338,8 +338,7 @@ CreateImageToGlBufferTensorConverter(CalculatorContext* cc, auto result = absl::make_unique(); MP_RETURN_IF_ERROR(result->Init(cc, input_starts_at_bottom, border_mode)); - // Simply "return std::move(result)" failed to build on macOS with bazel. - return std::unique_ptr(std::move(result)); + return result; } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 26c31eaf5..eb9681521 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -334,9 +334,7 @@ CreateImageToGlTextureTensorConverter(CalculatorContext* cc, BorderMode border_mode) { auto result = absl::make_unique(); MP_RETURN_IF_ERROR(result->Init(cc, input_starts_at_bottom, border_mode)); - - // Simply "return std::move(result)" failed to build on macOS with bazel. - return std::unique_ptr(std::move(result)); + return result; } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index 1f86e1ced..9714faa51 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -383,7 +383,7 @@ class MetalProcessor : public ImageToTensorConverter { tflite::gpu::HW(output_dims.height, output_dims.width), command_buffer, buffer_view.buffer())); [command_buffer commit]; - return std::move(tensor); + return tensor; } } @@ -399,8 +399,7 @@ absl::StatusOr> CreateMetalConverter( auto result = absl::make_unique(); MP_RETURN_IF_ERROR(result->Init(cc, border_mode)); - // Simply "return std::move(result)" failed to build on macOS with bazel. - return std::unique_ptr(std::move(result)); + return result; } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 04a4bbd97..22131a7e7 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -103,7 +103,7 @@ class OpenCvProcessor : public ImageToTensorConverter { GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); transformed.convertTo(dst, CV_32FC3, transform.scale, transform.offset); - return std::move(tensor); + return tensor; } private: @@ -114,10 +114,7 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::StatusOr> CreateOpenCvConverter( CalculatorContext* cc, BorderMode border_mode) { - // Simply "return absl::make_unique()" failed to build on - // macOS with bazel. - return std::unique_ptr( - absl::make_unique(border_mode)); + return absl::make_unique(border_mode); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 11256a338..46e0f928c 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -33,7 +33,7 @@ class InferenceCalculatorSelectorImpl absl::StatusOr GetConfig( const CalculatorGraphConfig::Node& subgraph_node) { const auto& options = - Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>( + Subgraph::GetOptions( subgraph_node); std::vector impls; const bool should_use_gpu = diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index 9fe06181c..1c54bc46e 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -99,8 +99,11 @@ class InferenceCalculator : public NodeIntf { kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; static constexpr SideInput::Optional kSideInModel{"MODEL"}; static constexpr Output> kOutTensors{"TENSORS"}; + static constexpr SideInput< + mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{ + "DELEGATE"}; MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, - kOutTensors); + kOutTensors, kDelegate); protected: using TfLiteDelegatePtr = diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index e0b538a91..59fd6a984 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -18,6 +18,9 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "InferenceCalculatorProto"; + // Full Example: // // node { @@ -67,9 +70,32 @@ message InferenceCalculatorOptions { // Only available for OpenCL delegate on Android. // Kernel caching will only be enabled if this path is set. optional string cached_kernel_path = 2; + + // Encapsulated compilation/runtime tradeoffs. + enum InferenceUsage { + UNSPECIFIED = 0; + + // InferenceRunner will be used only once. Therefore, it is important to + // minimize bootstrap time as well. + FAST_SINGLE_ANSWER = 1; + + // Prefer maximizing the throughput. Same inference runner will be used + // repeatedly on different inputs. + SUSTAINED_SPEED = 2; + } + optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED]; } + // Android only. - message Nnapi {} + message Nnapi { + // Directory to store compilation cache. If unspecified, NNAPI will not + // try caching the compilation. + optional string cache_dir = 1; + // Unique token identifying the model. It is the caller's responsibility + // to ensure there is no clash of the tokens. If unspecified, NNAPI will + // not try caching the compilation. + optional string model_token = 2; + } message Xnnpack { // Number of threads for XNNPACK delegate. (By default, calculator tries // to choose optimal number of threads depending on the device.) diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index e93ad4a3a..7d695ad9b 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -35,20 +35,30 @@ namespace api2 { namespace { -// Returns number of threads to configure XNNPACK delegate with. -// (Equal to user provided value if specified. Otherwise, it returns number of -// high cores (hard-coded to 1 for Emscripten without Threads extension)) -int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { - static constexpr int kDefaultNumThreads = -1; - if (opts.has_delegate() && opts.delegate().has_xnnpack() && - opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { - return opts.delegate().xnnpack().num_threads(); - } -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) - return InferHigherCoreIds().size(); +int GetXnnpackDefaultNumThreads() { +#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \ + defined(__EMSCRIPTEN_PTHREADS__) + constexpr int kMinNumThreadsByDefault = 1; + constexpr int kMaxNumThreadsByDefault = 4; + return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault, + kMaxNumThreadsByDefault); #else return 1; -#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ +#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__ +} + +// Returns number of threads to configure XNNPACK delegate with. +// Returns user provided value if specified. Otherwise, tries to choose optimal +// number of threads depending on the device. +int GetXnnpackNumThreads( + const bool opts_has_delegate, + const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) { + static constexpr int kDefaultNumThreads = -1; + if (opts_has_delegate && opts_delegate.has_xnnpack() && + opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) { + return opts_delegate.xnnpack().num_threads(); + } + return GetXnnpackDefaultNumThreads(); } } // namespace @@ -65,6 +75,7 @@ class InferenceCalculatorCpuImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -83,8 +94,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract( absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadModel(cc)); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { @@ -148,34 +158,61 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { const auto& calculator_opts = cc->Options(); - if (calculator_opts.has_delegate() && - calculator_opts.delegate().has_tflite()) { + auto opts_delegate = calculator_opts.delegate(); + if (!kDelegate(cc).IsEmpty()) { + mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate = + kDelegate(cc).Get(); + CHECK(input_side_packet_delegate.has_tflite() || + input_side_packet_delegate.has_xnnpack() || + input_side_packet_delegate.has_nnapi() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_cpu only supports delegate input side packet " + << "for TFLite, XNNPack and Nnapi"; + opts_delegate.MergeFrom(input_side_packet_delegate); + } + const bool opts_has_delegate = + calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty(); + if (opts_has_delegate && opts_delegate.has_tflite()) { // Default tflite inference requeqsted - no need to modify graph. return absl::OkStatus(); } #if defined(MEDIAPIPE_ANDROID) - const bool nnapi_requested = calculator_opts.has_delegate() - ? calculator_opts.delegate().has_nnapi() - : calculator_opts.use_nnapi(); + const bool nnapi_requested = opts_has_delegate ? opts_delegate.has_nnapi() + : calculator_opts.use_nnapi(); if (nnapi_requested) { // Attempt to use NNAPI. // If not supported, the default CPU delegate will be created and used. interpreter_->SetAllowFp16PrecisionForFp32(1); - delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { - // No need to free according to tflite::NnApiDelegate() documentation. - }); + tflite::StatefulNnApiDelegate::Options options; + const auto& nnapi = opts_delegate.nnapi(); + // Set up cache_dir and model_token for NNAPI compilation cache. + options.cache_dir = + nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr; + options.model_token = + nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr; + delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), + [](TfLiteDelegate*) {}); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); return absl::OkStatus(); @@ -185,13 +222,13 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { #if defined(__EMSCRIPTEN__) const bool use_xnnpack = true; #else - const bool use_xnnpack = calculator_opts.has_delegate() && - calculator_opts.delegate().has_xnnpack(); + const bool use_xnnpack = opts_has_delegate && opts_delegate.has_xnnpack(); #endif // defined(__EMSCRIPTEN__) if (use_xnnpack) { TfLiteXNNPackDelegateOptions xnnpack_opts{}; - xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); + xnnpack_opts.num_threads = + GetXnnpackNumThreads(opts_has_delegate, opts_delegate); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index d7c0e6138..444888403 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -18,6 +18,7 @@ #include #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/util/tflite/config.h" @@ -52,6 +53,7 @@ class InferenceCalculatorGlImpl absl::Status WriteKernelsToFile(); absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. @@ -65,6 +67,8 @@ class InferenceCalculatorGlImpl bool allow_precision_loss_ = false; mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage + tflite_gpu_runner_usage_; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GPU_SUPPORTED @@ -91,18 +95,30 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); - use_advanced_gpu_api_ = options.has_delegate() && - options.delegate().has_gpu() && - options.delegate().gpu().use_advanced_gpu_api(); - allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); - tflite_gpu_runner_api_ = options.delegate().gpu().api(); - use_kernel_caching_ = use_advanced_gpu_api_ && - options.delegate().gpu().has_cached_kernel_path(); + mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate(); + if (!kDelegate(cc).IsEmpty()) { + mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate = + kDelegate(cc).Get(); + CHECK(input_side_packet_delegate.has_gpu() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_gl only supports delegate input side packet " + << "for Gpu"; + delegate.MergeFrom(input_side_packet_delegate); + } + const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty(); + use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() && + delegate.gpu().use_advanced_gpu_api(); + allow_precision_loss_ = delegate.gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = delegate.gpu().api(); + tflite_gpu_runner_usage_ = delegate.gpu().usage(); + use_kernel_caching_ = + use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path(); use_gpu_delegate_ = !use_advanced_gpu_api_; if (use_kernel_caching_) { #ifdef MEDIAPIPE_ANDROID - cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + cached_kernel_filename_ = delegate.gpu().cached_kernel_path() + mediapipe::File::Basename(options.model_path()) + ".ker"; #endif // MEDIAPIPE_ANDROID @@ -115,10 +131,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { } MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, - &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); - })); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) + : LoadDelegateAndAllocateTensors(cc); + })); return absl::OkStatus(); } @@ -253,9 +270,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( : tflite::gpu::InferencePriority::MAX_PRECISION; options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + switch (tflite_gpu_runner_usage_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + FAST_SINGLE_ANSWER: { + options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + SUSTAINED_SPEED: { + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: { + return absl::InternalError("inference usage need to be specified."); + } + } tflite_gpu_runner_ = std::make_unique(options); switch (tflite_gpu_runner_api_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { tflite_gpu_runner_->ForceOpenGL(); break; @@ -264,13 +299,9 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( tflite_gpu_runner_->ForceOpenCL(); break; } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { - // Do not need to force any specific API. - break; - } } - MP_RETURN_IF_ERROR( - tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, op_resolver, /*allow_quant_ops=*/true)); // Create and bind OpenGL buffers for outputs. // The buffers are created once and their ids are passed to calculator outputs @@ -306,11 +337,19 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index d86a45c07..49e042290 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -92,6 +92,7 @@ class InferenceCalculatorMetalImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -130,8 +131,7 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { @@ -212,11 +212,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { interpreter_->SetNumThreads( cc->Options().cpu_num_thread()); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } @@ -226,12 +234,17 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { // Configure and create the delegate. TFLGpuDelegateOptions options; + // `enable_quantization` enables the run of sparse models i.e. the models with + // DENSIFY op preceding DEQUINTIZE op. Both ops get removed from the execution + // graph after the tensor of the weights is read. + options.enable_quantization = true; options.allow_precision_loss = allow_precision_loss_; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); + id device = gpu_helper_.mtlDevice; // Get input image sizes. diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc new file mode 100644 index 000000000..8f9323818 --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc @@ -0,0 +1,101 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h" + +#include + +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +float GetAttribute( + const Landmark& landmark, + const LandmarksToTensorCalculatorOptions::Attribute& attribute) { + switch (attribute) { + case LandmarksToTensorCalculatorOptions::X: + return landmark.x(); + case LandmarksToTensorCalculatorOptions::Y: + return landmark.y(); + case LandmarksToTensorCalculatorOptions::Z: + return landmark.z(); + case LandmarksToTensorCalculatorOptions::VISIBILITY: + return landmark.visibility(); + case LandmarksToTensorCalculatorOptions::PRESENCE: + return landmark.presence(); + } +} + +} // namespace + +class LandmarksToTensorCalculatorImpl + : public NodeImpl { + public: + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + RET_CHECK(options_.attributes_size() > 0) + << "At least one attribute must be specified"; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (kInLandmarkList(cc).IsEmpty()) { + return absl::OkStatus(); + } + + // Get input landmarks. + const auto& in_landmarks = *kInLandmarkList(cc); + + // Determine tensor shape. + const int n_landmarks = in_landmarks.landmark_size(); + const int n_attributes = options_.attributes_size(); + auto tensor_shape = options_.flatten() + ? Tensor::Shape{1, n_landmarks * n_attributes} + : Tensor::Shape{1, n_landmarks, n_attributes}; + + // Create empty tesnor. + Tensor tensor(Tensor::ElementType::kFloat32, tensor_shape); + auto* buffer = tensor.GetCpuWriteView().buffer(); + + // Fill tensor with landmark attributes. + for (int i = 0; i < n_landmarks; ++i) { + for (int j = 0; j < n_attributes; ++j) { + buffer[i * n_attributes + j] = + GetAttribute(in_landmarks.landmark(i), options_.attributes(j)); + } + } + + // Return vector with a single tensor. + auto result = std::vector(); + result.push_back(std::move(tensor)); + kOutTensors(cc).Send(std::move(result)); + + return absl::OkStatus(); + } + + private: + LandmarksToTensorCalculatorOptions options_; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksToTensorCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h new file mode 100644 index 000000000..662f1b05f --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h @@ -0,0 +1,61 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +// A calculator for converting landmars into a Tensor. +// +// Input: +// LANDMARKS - LandmarkList +// Landmarks to be converted into a Tensor. +// +// Output: +// TENSORS - std::vector +// Vector containing a single Tensor populated with landmark values. +// +// Example: +// node { +// calculator: "LandmarksToTensorCalculator" +// input_stream: "LANDMARKS:landmarks" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.LandmarksToTensorCalculatorOptions.ext] { +// attributes: [X, Y, Z, VISIBILITY, PRESENCE] +// # flatten: true +// } +// } +// } +class LandmarksToTensorCalculator : public NodeIntf { + public: + static constexpr Input::Optional kInLandmarkList{"LANDMARKS"}; + static constexpr Output> kOutTensors{"TENSORS"}; + MEDIAPIPE_NODE_INTERFACE(LandmarksToTensorCalculator, kInLandmarkList, + kOutTensors); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto new file mode 100644 index 000000000..6ef1c8d4e --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto @@ -0,0 +1,44 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the LandmarksToTensorCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarksToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional LandmarksToTensorCalculatorOptions ext = 394810235; + } + + enum Attribute { + X = 0; + Y = 1; + Z = 2; + VISIBILITY = 3; + PRESENCE = 4; + } + + // Subset and order of attributes as they should appear in the output Tensor. + // Should contain at least one attribute. + repeated Attribute attributes = 1; + + // Collapses all landmark attributes into a one dimensional tensor (i.e. + // switches from (n_landmarks, n_attributes) to (n_landmarks * n_attributes) + // representation). + optional bool flatten = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc new file mode 100644 index 000000000..dfda71b55 --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc @@ -0,0 +1,155 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +void RunLandmarks(mediapipe::CalculatorRunner* runner, + const LandmarkList& landmarks) { + runner->MutableInputs() + ->Tag("LANDMARKS") + .packets.push_back(MakePacket(landmarks).At(Timestamp(0))); + MP_ASSERT_OK(runner->Run()); +} + +const Tensor& GetOutputTensor(mediapipe::CalculatorRunner* runner) { + const auto& output_packets = runner->Outputs().Tag("TENSORS").packets; + EXPECT_EQ(output_packets.size(), 1); + + const auto& tensors = output_packets[0].Get>(); + EXPECT_EQ(tensors.size(), 1); + + return tensors[0]; +} + +void ValidateTensor(const Tensor& tensor, + const std::vector& expected_shape, + const std::vector& expected_values) { + EXPECT_EQ(tensor.shape().dims, expected_shape); + EXPECT_EQ(tensor.shape().num_elements(), expected_values.size()); + + auto* tensor_buffer = tensor.GetCpuReadView().buffer(); + const std::vector tensor_values( + tensor_buffer, tensor_buffer + tensor.shape().num_elements()); + EXPECT_THAT(tensor_values, testing::ElementsAreArray(expected_values)); +} + +TEST(LandmarksToTensorCalculatorTest, AllAttributes) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z, VISIBILITY, PRESENCE ] + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + landmark1->set_visibility(4.0f); + landmark1->set_presence(5.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + landmark2->set_visibility(9.0f); + landmark2->set_presence(10.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 2, 5}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}); +} + +TEST(LandmarksToTensorCalculatorTest, XYZAttributes) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z ] + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 2, 3}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f}); +} + +TEST(LandmarksToTensorCalculatorTest, XYZAttributes_Flatten) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z ] + flatten: true + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 6}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f}); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 82180fe52..f3c7c7b09 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -517,8 +517,8 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { uniform sampler2D frame; void main() { - $1 // flip - vec4 pixel = texture2D(frame, sample_coordinate); + vec2 coord = $1 + vec4 pixel = texture2D(frame, coord); $2 // normalize [-1,1] fragColor.r = pixel.r; // r channel $3 // g & b channels @@ -526,8 +526,9 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { })", /*$0=*/single_channel ? "vec1" : "vec4", /*$1=*/ - flip_vertically_ ? "sample_coordinate.y = 1.0 - sample_coordinate.y;" - : "", + flip_vertically_ + ? "vec2(sample_coordinate.x, 1.0 - sample_coordinate.y);" + : "sample_coordinate;", /*$2=*/output_range_.has_value() ? absl::Substitute("pixel = pixel * float($0) + float($1);", (output_range_->second - output_range_->first), diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 1a27cafce..498036c12 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -105,6 +105,15 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, // for anchors (e.g. for SSD models) depend on the outputs of the // detection model. The size of anchor tensor must be (num_boxes * // 4). +// +// Input side packet: +// ANCHORS (optional) - The anchors used for decoding the bounding boxes, as a +// vector of `Anchor` protos. Not required if post-processing is built-in +// the model. +// IGNORE_CLASSES (optional) - The list of class ids that should be ignored, as +// a vector of integers. It overrides the corresponding field in the +// calculator options. +// // Output: // DETECTIONS - Result MediaPipe detections. // @@ -132,8 +141,11 @@ class TensorsToDetectionsCalculator : public Node { static constexpr Input> kInTensors{"TENSORS"}; static constexpr SideInput>::Optional kInAnchors{ "ANCHORS"}; + static constexpr SideInput>::Optional kSideInIgnoreClasses{ + "IGNORE_CLASSES"}; static constexpr Output> kOutDetections{"DETECTIONS"}; - MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections); + MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kSideInIgnoreClasses, + kOutDetections); static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; @@ -566,8 +578,15 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { kNumCoordsPerBox, num_coords_); - for (int i = 0; i < options_.ignore_classes_size(); ++i) { - ignore_classes_.insert(options_.ignore_classes(i)); + if (kSideInIgnoreClasses(cc).IsConnected()) { + RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty()); + for (int ignore_class : *kSideInIgnoreClasses(cc)) { + ignore_classes_.insert(ignore_class); + } + } else { + for (int i = 0; i < options_.ignore_classes_size(); ++i) { + ignore_classes_.insert(options_.ignore_classes(i)); + } } return absl::OkStatus(); @@ -651,7 +670,8 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_scores[i], detection_classes[i], options_.flip_vertically()); const auto& bbox = detection.location_data().relative_bounding_box(); - if (bbox.width() < 0 || bbox.height() < 0) { + if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || + std::isnan(bbox.height())) { // Decoded detection boxes could have negative values for width/height due // to model prediction. Filter out those boxes since some downstream // calculators may assume non-negative values. (b/171391719) diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto index 24c0a5053..364eb5cce 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto @@ -56,7 +56,7 @@ message TensorsToDetectionsCalculatorOptions { // [x_center, y_center, w, h]. optional bool reverse_output_order = 14 [default = false]; // The ids of classes that should be ignored during decoding the score for - // each predicted box. + // each predicted box. Can be overridden with IGNORE_CLASSES side packet. repeated int32 ignore_classes = 8; optional bool sigmoid_score = 15 [default = false]; diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc new file mode 100644 index 000000000..ffc96b2e4 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -0,0 +1,881 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/gpu/gpu_origin.pb.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/interpreter.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/shader_util.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +#if MEDIAPIPE_METAL_ENABLED +#import +#import +#import + +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#endif // MEDIAPIPE_METAL_ENABLED + +namespace { +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; +} + +bool CanUseGpu() { +#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; +#else + return false; +#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED +} + +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kMaskTag[] = "MASK"; + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims) { + if (dims.size() == 3) { + return std::make_tuple(dims[0], dims[1], dims[2]); + } else if (dims.size() == 4) { + // BHWC format check B == 1 + RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } else { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + } +} +} // namespace + +namespace mediapipe { + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Converts Tensors from a tflite segmentation model to an image mask. +// +// Performs optional upscale to OUTPUT_SIZE dimensions if provided, +// otherwise the mask is the same size as input tensor. +// +// If at least one input tensor is already on GPU, processing happens on GPU and +// the output mask is also stored on GPU. Otherwise, processing and the output +// mask are both on CPU. +// +// On GPU, the mask is an RGBA image, in both the R & A channels, scaled 0-1. +// On CPU, the mask is a ImageFormat::VEC32F1 image, with values scaled 0-1. +// +// +// Inputs: +// One of the following TENSORS tags: +// TENSORS: Vector of Tensor, +// The tensor dimensions are specified in this calculator's options. +// OUTPUT_SIZE(optional): std::pair, +// If provided, the size to upscale mask to. +// +// Output: +// MASK: An Image output mask, RGBA(GPU) / VEC32F1(CPU). +// +// Options: +// See tensors_to_segmentation_calculator.proto +// +// Usage example: +// node { +// calculator: "TensorsToSegmentationCalculator" +// input_stream: "TENSORS:tensors" +// input_stream: "OUTPUT_SIZE:size" +// output_stream: "MASK:hair_mask" +// node_options: { +// [mediapipe.TensorsToSegmentationCalculatorOptions] { +// output_layer_index: 1 +// # gpu_origin: CONVENTIONAL # or TOP_LEFT +// } +// } +// } +// +// TODO Refactor and add support for other backends/platforms. +// +class TensorsToSegmentationCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc); + absl::Status ProcessCpu(CalculatorContext* cc); + void GlRender(); + + bool DoesGpuTextureStartAtBottom() { + return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; + } + + template + absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); + + ::mediapipe::TensorsToSegmentationCalculatorOptions options_; + +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint upsample_program_; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + std::unique_ptr mask_program_31_; +#else + GLuint mask_program_20_; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#if MEDIAPIPE_METAL_ENABLED + MPPMetalHelper* metal_helper_ = nullptr; + id mask_program_; +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(TensorsToSegmentationCalculator); + +// static +absl::Status TensorsToSegmentationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + // Inputs. + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Inputs().HasTag(kOutputSizeTag)) { + cc->Inputs().Tag(kOutputSizeTag).Set>(); + } + + // Outputs. + cc->Outputs().Tag(kMaskTag).Set(); + + if (CanUseGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#if MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + bool use_gpu = false; + + if (CanUseGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + use_gpu = true; + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#if MEDIAPIPE_METAL_ENABLED + metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(metal_helper_); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + + MP_RETURN_IF_ERROR(LoadOptions(cc)); + + if (use_gpu) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(InitGpu(cc)); +#else + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + + bool use_gpu = false; + if (CanUseGpu()) { + // Use GPU processing only if at least one input tensor is already on GPU. + for (const auto& tensor : input_tensors) { + if (tensor.ready_on_gpu()) { + use_gpu = true; + break; + } + } + } + + // Validate tensor channels and activation type. + { + RET_CHECK(!input_tensors.empty()); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + int tensor_channels = std::get<2>(hwc); + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + switch (options_.activation()) { + case Options::NONE: + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SIGMOID: + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SOFTMAX: + RET_CHECK_EQ(tensor_channels, 2); + break; + } + } + + if (use_gpu) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(ProcessGpu(cc)); + return absl::OkStatus(); + })); +#else + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + MP_RETURN_IF_ERROR(ProcessCpu(cc)); + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + gpu_helper_.RunInGlContext([this] { + if (upsample_program_) glDeleteProgram(upsample_program_); + upsample_program_ = 0; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + mask_program_31_.reset(); +#else + if (mask_program_20_) glDeleteProgram(mask_program_20_); + mask_program_20_ = 0; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#if MEDIAPIPE_METAL_ENABLED + mask_program_ = nil; +#endif // MEDIAPIPE_METAL_ENABLED + }); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::ProcessCpu( + CalculatorContext* cc) { + // Get input streams, and dimensions. + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float* raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // Process mask tensor and apply activation function. + if (tensor_channels == 2) { + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else if (tensor_channels == 1) { + RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != + options_.activation()); // Requires 2 channels. + if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == + options_.activation()) // Pass-through optimization. + tensor_mat.copyTo(small_mask_mat); + else + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else { + RET_CHECK_FAIL() << "Unsupported number of tensor channels " + << tensor_channels; + } + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::VEC32F1, output_width, output_height); + std::unique_ptr output_mask = absl::make_unique(mask_frame); + cv::Mat output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(small_mask_mat, output_mat, cv::Size(output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); + + return absl::OkStatus(); +} + +template +absl::Status TensorsToSegmentationCalculator::ApplyActivation( + cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { + // Configure activation function. + const int output_layer_index = options_.output_layer_index(); + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + const auto activation_fn = [&](const cv::Vec2f& mask_value) { + float new_mask_value = 0; + // TODO consider moving switch out of the loop, + // and also avoid float/Vec2f casting. + switch (options_.activation()) { + case Options::NONE: { + new_mask_value = mask_value[0]; + break; + } + case Options::SIGMOID: { + const float pixel0 = mask_value[0]; + new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); + break; + } + case Options::SOFTMAX: { + const float pixel0 = mask_value[0]; + const float pixel1 = mask_value[1]; + const float max_pixel = std::max(pixel0, pixel1); + const float min_pixel = std::min(pixel0, pixel1); + const float softmax_denom = + /*exp(max_pixel - max_pixel)=*/1.0f + + std::exp(min_pixel - max_pixel); + new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) / + softmax_denom; + break; + } + } + return new_mask_value; + }; + + // Process mask tensor. + for (int i = 0; i < tensor_mat.rows; ++i) { + for (int j = 0; j < tensor_mat.cols; ++j) { + const T& input_pix = tensor_mat.at(i, j); + const float mask_value = activation_fn(input_pix); + small_mask_mat->at(i, j) = mask_value; + } + } + + return absl::OkStatus(); +} + +// Steps: +// 1. receive tensor +// 2. process segmentation tensor into small mask +// 3. upsample small mask into output mask to be same size as input image +absl::Status TensorsToSegmentationCalculator::ProcessGpu( + CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + // Get input streams, and dimensions. + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Create initial working mask texture. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + tflite::gpu::gl::GlTexture small_mask_texture; +#else + mediapipe::GlTexture small_mask_texture; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Run shader, process mask tensor. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + { + MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture( + tflite::gpu::DataType::UINT8, // GL_RGBA8 + {tensor_width, tensor_height}, &small_mask_texture)); + + const int output_index = 0; + glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, + GL_WRITE_ONLY, GL_RGBA8); + + auto read_view = input_tensors[0].GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); + + const tflite::gpu::uint3 workgroups = { + NumGroups(tensor_width, kWorkgroupSize), + NumGroups(tensor_height, kWorkgroupSize), 1}; + + glUseProgram(mask_program_31_->id()); + glUniform2i(glGetUniformLocation(mask_program_31_->id(), "out_size"), + tensor_width, tensor_height); + + MP_RETURN_IF_ERROR(mask_program_31_->Dispatch(workgroups)); + } +#elif MEDIAPIPE_METAL_ENABLED + { + id command_buffer = [metal_helper_ commandBuffer]; + command_buffer.label = @"SegmentationKernel"; + id command_encoder = + [command_buffer computeCommandEncoder]; + [command_encoder setComputePipelineState:mask_program_]; + + auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; + + mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ + mediapipeGpuBufferWithWidth:tensor_width + height:tensor_height + format:mediapipe::GpuBufferFormat::kBGRA32]; + id small_mask_texture_metal = + [metal_helper_ metalTextureWithGpuBuffer:small_mask_buffer]; + [command_encoder setTexture:small_mask_texture_metal atIndex:1]; + + unsigned int out_size[] = {static_cast(tensor_width), + static_cast(tensor_height)}; + [command_encoder setBytes:&out_size length:sizeof(out_size) atIndex:2]; + + MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); + MTLSize threadgroups = + MTLSizeMake(NumGroups(tensor_width, kWorkgroupSize), + NumGroups(tensor_height, kWorkgroupSize), 1); + [command_encoder dispatchThreadgroups:threadgroups + threadsPerThreadgroup:threads_per_group]; + [command_encoder endEncoding]; + [command_buffer commit]; + + small_mask_texture = gpu_helper_.CreateSourceTexture(small_mask_buffer); + } +#else + { + small_mask_texture = gpu_helper_.CreateDestinationTexture( + tensor_width, tensor_height, + mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 + + // Go through CPU if not already texture 2D (no direct conversion yet). + // Tensor::GetOpenGlTexture2dReadView() doesn't automatically convert types. + if (!input_tensors[0].ready_as_opengl_texture_2d()) { + (void)input_tensors[0].GetCpuReadView(); + } + + auto read_view = input_tensors[0].GetOpenGlTexture2dReadView(); + + gpu_helper_.BindFramebuffer(small_mask_texture); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, read_view.name()); + glUseProgram(mask_program_20_); + GlRender(); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Upsample small mask into output. + mediapipe::GlTexture output_texture = gpu_helper_.CreateDestinationTexture( + output_width, output_height, + mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 + + // Run shader, upsample result. + { + gpu_helper_.BindFramebuffer(output_texture); + glActiveTexture(GL_TEXTURE1); +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); +#else + glBindTexture(GL_TEXTURE_2D, small_mask_texture.name()); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glUseProgram(upsample_program_); + GlRender(); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } + + // Send out image as GPU packet. + auto output_image = output_texture.GetFrame(); + cc->Outputs().Tag(kMaskTag).Add(output_image.release(), cc->InputTimestamp()); + + // Cleanup + output_texture.Release(); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +void TensorsToSegmentationCalculator::GlRender() { +#if !MEDIAPIPE_DISABLE_GPU + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // !MEDIAPIPE_DISABLE_GPU +} + +absl::Status TensorsToSegmentationCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { + // A shader to process a segmentation tensor into an output mask. + // Currently uses 4 channels for output, and sets R+A channels as mask value. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + // GLES 3.1 + const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, + 1}; + const std::string shader_header = + absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size), R"( +precision highp float; + +layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; + +uniform ivec2 out_size; +)"); + /* Shader defines will be inserted here. */ + + const std::string shader_src_main = R"( +layout(std430, binding = 2) readonly buffer B0 { +#ifdef TWO_CHANNEL_INPUT + vec2 elements[]; +#else + float elements[]; +#endif // TWO_CHANNEL_INPUT +} input_data; // data tensor + +void main() { + int out_width = out_size.x; + int out_height = out_size.y; + + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= out_width || gid.y >= out_height) { return; } + int linear_index = gid.y * out_width + gid.x; + +#ifdef TWO_CHANNEL_INPUT + vec2 input_value = input_data.elements[linear_index]; +#else + vec2 input_value = vec2(input_data.elements[linear_index], 0.0); +#endif // TWO_CHANNEL_INPUT + +// Run activation function. +// One and only one of FN_SOFTMAX,FN_SIGMOID,FN_NONE will be defined. +#ifdef FN_SOFTMAX + // Only two channel input tensor is supported. + vec2 input_px = input_value.rg; + float shift = max(input_px.r, input_px.g); + float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift); + float new_mask_value = + exp(input_px[OUTPUT_LAYER_INDEX] - shift) / softmax_denom; +#endif // FN_SOFTMAX + +#ifdef FN_SIGMOID + float new_mask_value = 1.0 / (exp(-input_value.r) + 1.0); +#endif // FN_SIGMOID + +#ifdef FN_NONE + float new_mask_value = input_value.r; +#endif // FN_NONE + +#ifdef FLIP_Y_COORD + int y_coord = out_height - gid.y - 1; +#else + int y_coord = gid.y; +#endif // defined(FLIP_Y_COORD) + ivec2 output_coordinate = ivec2(gid.x, y_coord); + + vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); + imageStore(output_texture, output_coordinate, out_value); +})"; + +#elif MEDIAPIPE_METAL_ENABLED + // METAL + const std::string shader_header = R"( +#include +using namespace metal; +)"; + /* Shader defines will be inserted here. */ + + const std::string shader_src_main = R"( +kernel void segmentationKernel( +#ifdef TWO_CHANNEL_INPUT + device float2* elements [[ buffer(0) ]], +#else + device float* elements [[ buffer(0) ]], +#endif // TWO_CHANNEL_INPUT + texture2d output_texture [[ texture(1) ]], + constant uint* out_size [[ buffer(2) ]], + uint2 gid [[ thread_position_in_grid ]]) +{ + uint out_width = out_size[0]; + uint out_height = out_size[1]; + + if (gid.x >= out_width || gid.y >= out_height) { return; } + uint linear_index = gid.y * out_width + gid.x; + +#ifdef TWO_CHANNEL_INPUT + float2 input_value = elements[linear_index]; +#else + float2 input_value = float2(elements[linear_index], 0.0); +#endif // TWO_CHANNEL_INPUT + +// Run activation function. +// One and only one of FN_SOFTMAX,FN_SIGMOID,FN_NONE will be defined. +#ifdef FN_SOFTMAX + // Only two channel input tensor is supported. + float2 input_px = input_value.xy; + float shift = max(input_px.x, input_px.y); + float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift); + float new_mask_value = + exp(input_px[OUTPUT_LAYER_INDEX] - shift) / softmax_denom; +#endif // FN_SOFTMAX + +#ifdef FN_SIGMOID + float new_mask_value = 1.0 / (exp(-input_value.x) + 1.0); +#endif // FN_SIGMOID + +#ifdef FN_NONE + float new_mask_value = input_value.x; +#endif // FN_NONE + +#ifdef FLIP_Y_COORD + int y_coord = out_height - gid.y - 1; +#else + int y_coord = gid.y; +#endif // defined(FLIP_Y_COORD) + uint2 output_coordinate = uint2(gid.x, y_coord); + + float4 out_value = float4(new_mask_value, 0.0, 0.0, new_mask_value); + output_texture.write(out_value, output_coordinate); +} +)"; + +#else + // GLES 2.0 + const std::string shader_header = absl::StrCat( + std::string(mediapipe::kMediaPipeFragmentShaderPreamble), R"( +DEFAULT_PRECISION(mediump, float) +)"); + /* Shader defines will be inserted here. */ + + const std::string shader_src_main = R"( +in vec2 sample_coordinate; + +uniform sampler2D input_texture; + +#ifdef GL_ES +#define fragColor gl_FragColor +#else +out vec4 fragColor; +#endif // defined(GL_ES); + +void main() { +#ifdef FLIP_Y_COORD + float y_coord = 1.0 - sample_coordinate.y; +#else + float y_coord = sample_coordinate.y; +#endif // defined(FLIP_Y_COORD) + vec2 adjusted_coordinate = vec2(sample_coordinate.x, y_coord); + vec4 input_value = texture2D(input_texture, adjusted_coordinate); + + // Run activation function. + // One and only one of FN_SOFTMAX,FN_SIGMOID,FN_NONE will be defined. + +#ifdef FN_SOFTMAX + // Only two channel input tensor is supported. + vec2 input_px = input_value.rg; + float shift = max(input_px.r, input_px.g); + float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift); + float new_mask_value = + exp(mix(input_px.r, input_px.g, float(OUTPUT_LAYER_INDEX)) - shift) / softmax_denom; +#endif // FN_SOFTMAX + +#ifdef FN_SIGMOID + float new_mask_value = 1.0 / (exp(-input_value.r) + 1.0); +#endif // FN_SIGMOID + +#ifdef FN_NONE + float new_mask_value = input_value.r; +#endif // FN_NONE + + vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); + fragColor = out_value; +})"; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Shader defines. + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + const std::string output_layer_index = + "\n#define OUTPUT_LAYER_INDEX int(" + + std::to_string(options_.output_layer_index()) + ")"; + const std::string flip_y_coord = + DoesGpuTextureStartAtBottom() ? "\n#define FLIP_Y_COORD" : ""; + const std::string fn_none = + options_.activation() == Options::NONE ? "\n#define FN_NONE" : ""; + const std::string fn_sigmoid = + options_.activation() == Options::SIGMOID ? "\n#define FN_SIGMOID" : ""; + const std::string fn_softmax = + options_.activation() == Options::SOFTMAX ? "\n#define FN_SOFTMAX" : ""; + const std::string two_channel = options_.activation() == Options::SOFTMAX + ? "\n#define TWO_CHANNEL_INPUT" + : ""; + const std::string shader_defines = + absl::StrCat(output_layer_index, flip_y_coord, fn_softmax, fn_sigmoid, + fn_none, two_channel); + + // Build full shader. + const std::string shader_src_no_previous = + absl::StrCat(shader_header, shader_defines, shader_src_main); + + // Vertex shader attributes. + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Main shader program & parameters +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + GlShader shader_without_previous; + MP_RETURN_IF_ERROR(GlShader::CompileShader( + GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous)); + mask_program_31_ = absl::make_unique(); + MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(shader_without_previous, + mask_program_31_.get())); +#elif MEDIAPIPE_METAL_ENABLED + id device = metal_helper_.mtlDevice; + NSString* library_source = + [NSString stringWithUTF8String:shader_src_no_previous.c_str()]; + NSError* error = nil; + id library = [device newLibraryWithSource:library_source + options:nullptr + error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"segmentationKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + mask_program_ = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(mask_program_ != nil) << "Couldn't create pipeline state " << + [[error localizedDescription] UTF8String]; +#else + mediapipe::GlhCreateProgram( + mediapipe::kBasicVertexShader, shader_src_no_previous.c_str(), + NUM_ATTRIBUTES, &attr_name[0], attr_location, &mask_program_20_); + RET_CHECK(mask_program_20_) << "Problem initializing the program."; + glUseProgram(mask_program_20_); + glUniform1i(glGetUniformLocation(mask_program_20_, "input_texture"), 1); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Simple pass-through program, used for hardware upsampling. + mediapipe::GlhCreateProgram( + mediapipe::kBasicVertexShader, mediapipe::kBasicTexturedFragmentShader, + NUM_ATTRIBUTES, &attr_name[0], attr_location, &upsample_program_); + RET_CHECK(upsample_program_) << "Problem initializing the program."; + glUseProgram(upsample_program_); + glUniform1i(glGetUniformLocation(upsample_program_, "video_frame"), 1); + + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.proto new file mode 100644 index 000000000..1662576b6 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/gpu/gpu_origin.proto"; + +message TensorsToSegmentationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorsToSegmentationCalculatorOptions ext = 374311106; + } + + // For CONVENTIONAL mode in OpenGL, textures start at bottom and needs + // to be flipped vertically as tensors are expected to start at top. + // (DEFAULT or unset is interpreted as CONVENTIONAL.) + optional GpuOrigin.Mode gpu_origin = 1; + + // Supported activation functions for filtering. + enum Activation { + NONE = 0; // Assumes 1-channel input tensor. + SIGMOID = 1; // Assumes 1-channel input tensor. + SOFTMAX = 2; // Assumes 2-channel input tensor. + } + // Activation function to apply to input tensor. + // Softmax requires a 2-channel tensor, see output_layer_index below. + optional Activation activation = 2 [default = NONE]; + + // Channel to use for processing tensor. + // Only applies when using activation=SOFTMAX. + // Works on two channel input tensor only. + optional int32 output_layer_index = 3 [default = 1]; +} diff --git a/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt index b0e00346c..64d970e11 100644 --- a/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt +++ b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt @@ -4,7 +4,7 @@ output_stream: "detections" # Subgraph that detects faces. node { - calculator: "FaceDetectionFrontCpu" + calculator: "FaceDetectionShortRangeCpu" input_stream: "IMAGE:image" output_stream: "DETECTIONS:detections" } diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 0dbbd57da..ac058610a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -88,6 +88,13 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "tensor_to_vector_string_calculator_options_proto", + srcs = ["tensor_to_vector_string_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], @@ -257,6 +264,14 @@ mediapipe_cc_proto_library( deps = [":tensor_to_vector_float_calculator_options_proto"], ) +mediapipe_cc_proto_library( + name = "tensor_to_vector_string_calculator_options_cc_proto", + srcs = ["tensor_to_vector_string_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":tensor_to_vector_string_calculator_options_proto"], +) + mediapipe_cc_proto_library( name = "unpack_media_sequence_calculator_cc_proto", srcs = ["unpack_media_sequence_calculator.proto"], @@ -572,9 +587,21 @@ cc_library( "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [ - "//mediapipe/framework/port:file_helpers", ], - }), + "//mediapipe:android": [], + }) + select( + { + "//conditions:default": [ + ], + }, + ) + select( + { + "//conditions:default": [ + ], + "//mediapipe:android": [ + ], + }, + ), alwayslink = 1, ) @@ -694,6 +721,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensor_to_vector_string_calculator", + srcs = ["tensor_to_vector_string_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ":tensor_to_vector_string_calculator_options_cc_proto", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", + ], + }), + alwayslink = 1, +) + cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], @@ -864,6 +911,7 @@ cc_test( "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -1058,6 +1106,20 @@ cc_test( ], ) +cc_test( + name = "tensor_to_vector_string_calculator_test", + srcs = ["tensor_to_vector_string_calculator_test.cc"], + deps = [ + ":tensor_to_vector_string_calculator", + ":tensor_to_vector_string_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + cc_test( name = "unpack_media_sequence_calculator_test", srcs = ["unpack_media_sequence_calculator_test.cc"], diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index ddb042e6a..3991f645d 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -37,6 +37,7 @@ const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kImageTag[] = "IMAGE"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; +const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_"; const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; const char kBBoxTag[] = "BBOX"; const char kKeypointsTag[] = "KEYPOINTS"; @@ -153,6 +154,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } } CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || @@ -231,6 +235,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearFeatureFloats(key, sequence_.get()); mpms::ClearFeatureTimestamp(key, sequence_.get()); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) { + std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) / + sizeof(*kBytesFeaturePrefixTag) - + 1); + mpms::ClearFeatureBytes(key, sequence_.get()); + mpms::ClearFeatureTimestamp(key, sequence_.get()); + } if (absl::StartsWith(tag, kKeypointsTag)) { std::string key = tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1); @@ -243,11 +254,6 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } - if (cc->Outputs().HasTag(kSequenceExampleTag)) { - cc->Outputs() - .Tag(kSequenceExampleTag) - .SetNextTimestampBound(Timestamp::Max()); - } return absl::OkStatus(); } @@ -305,7 +311,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (cc->Outputs().HasTag(kSequenceExampleTag)) { cc->Outputs() .Tag(kSequenceExampleTag) - .Add(sequence_.release(), Timestamp::PostStream()); + .Add(sequence_.release(), options.output_as_zero_timestamp() + ? Timestamp(0ll) + : Timestamp::PostStream()); } sequence_.reset(); @@ -408,6 +416,17 @@ class PackMediaSequenceCalculator : public CalculatorBase { cc->Inputs().Tag(tag).Get>(), sequence_.get()); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) / + sizeof(*kBytesFeaturePrefixTag) - + 1); + mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddFeatureBytes( + key, cc->Inputs().Tag(tag).Get>(), + sequence_.get()); + } if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; if (tag != kBBoxTag) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto index 695eb6b5e..6ba09fb16 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto @@ -65,4 +65,7 @@ message PackMediaSequenceCalculatorOptions { // If true, will return an error status if an output sequence would be too // many bytes to serialize. optional bool skip_large_sequences = 7 [default = true]; + + // If true/false, outputs the SequenceExample at timestamp 0/PostStream. + optional bool output_as_zero_timestamp = 8 [default = false]; } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index c163cebcd..b39a0bac0 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -29,6 +29,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/util/sequence/media_sequence.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" @@ -39,12 +40,33 @@ namespace { namespace tf = ::tensorflow; namespace mpms = mediapipe::mediasequence; +constexpr char kBboxTag[] = "BBOX"; +constexpr char kEncodedMediaStartTimestampTag[] = + "ENCODED_MEDIA_START_TIMESTAMP"; +constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA"; +constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION"; +constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST"; +constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED"; +constexpr char kAudioOtherTag[] = "AUDIO_OTHER"; +constexpr char kAudioTestTag[] = "AUDIO_TEST"; +constexpr char kBytesFeatureOtherTag[] = "BYTES_FEATURE_OTHER"; +constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST"; +constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; +constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; +constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; +constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; +constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; +constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; +constexpr char kImageTag[] = "IMAGE"; + class PackMediaSequenceCalculatorTest : public ::testing::Test { protected: void SetUpCalculator(const std::vector& input_streams, const tf::Features& features, - bool output_only_if_all_present, - bool replace_instead_of_append) { + const bool output_only_if_all_present, + const bool replace_instead_of_append, + const bool output_as_zero_timestamp = false) { CalculatorGraphConfig::Node config; config.set_calculator("PackMediaSequenceCalculator"); config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); @@ -57,6 +79,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test { *options->mutable_context_feature_map() = features; options->set_output_only_if_all_present(output_only_if_all_present); options->set_replace_data_instead_of_append(replace_instead_of_append); + options->set_output_as_zero_timestamp(output_as_zero_timestamp); runner_ = ::absl::make_unique(config); } @@ -80,17 +103,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -124,17 +147,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { auto image_ptr = ::absl::make_unique(encoded_image); runner_->MutableInputs() - ->Tag("IMAGE_PREFIX") + ->Tag(kImagePrefixTag) .packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -158,21 +181,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) { for (int i = 0; i < num_timesteps; ++i) { auto vf_ptr = ::absl::make_unique>(2, 2 << i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_TEST") + ->Tag(kFloatFeatureTestTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); vf_ptr = ::absl::make_unique>(2, 2 << i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_OTHER") + ->Tag(kFloatFeatureOtherTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -194,20 +217,65 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) { } } -TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { - SetUpCalculator( - {"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, - {}, false, true); - auto input_sequence = absl::make_unique(); +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) { + SetUpCalculator({"BYTES_FEATURE_TEST:test", "BYTES_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); - auto vf_ptr = absl::make_unique>(2, 3); - runner_->MutableInputs() - ->Tag("FLOAT_CONTEXT_FEATURE_TEST") - .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); - vf_ptr = absl::make_unique>(2, 4); - runner_->MutableInputs() - ->Tag("FLOAT_CONTEXT_FEATURE_OTHER") - .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("foo", 2 << i)); + runner_->MutableInputs() + ->Tag(kBytesFeatureTestTag) + .packets.push_back(Adopt(vs_ptr.release()).At(Timestamp(i))); + vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("bar", 2 << i)); + runner_->MutableInputs() + ->Tag(kBytesFeatureOtherTag) + .packets.push_back(Adopt(vs_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("OTHER", output_sequence)); + for (int i = 0; i < num_timesteps; ++i) { + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("TEST", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureBytesAt("TEST", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("foo", 2 << i)))); + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("OTHER", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureBytesAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("bar", 2 << i)))); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) { + SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag("FLOAT_FEATURE_TEST") + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + } runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = Adopt(input_sequence.release()); @@ -217,6 +285,32 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { const std::vector& output_packets = runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(output_packets[0].Timestamp().Value(), 0ll); +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { + SetUpCalculator( + {"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, + {}, false, true); + auto input_sequence = absl::make_unique(); + + auto vf_ptr = absl::make_unique>(2, 3); + runner_->MutableInputs() + ->Tag(kFloatContextFeatureTestTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + vf_ptr = absl::make_unique>(2, 4); + runner_->MutableInputs() + ->Tag(kFloatContextFeatureOtherTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -233,7 +327,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { SetUpCalculator({"IMAGE:images"}, context, false, true); auto input_sequence = ::absl::make_unique(); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; @@ -242,13 +336,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(0))); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -281,17 +375,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -345,17 +439,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -424,11 +518,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); auto status = runner_->Run(); @@ -472,7 +566,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); @@ -487,16 +581,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -538,18 +632,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { absl::flat_hash_map>> points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; runner_->MutableInputs() - ->Tag("KEYPOINTS_TEST") + ->Tag(kKeypointsTestTag) .packets.push_back(PointToForeign(&points).At(Timestamp(0))); runner_->MutableInputs() - ->Tag("KEYPOINTS_TEST") + ->Tag(kKeypointsTestTag) .packets.push_back(PointToForeign(&points).At(Timestamp(1))); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -589,17 +683,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("CLASS_SEGMENTATION") + ->Tag(kClassSegmentationTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -638,17 +732,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -684,11 +778,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); absl::Status status = runner_->Run(); @@ -705,13 +799,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) { mpms::AddImageTimestamp(1, input_sequence.get()); mpms::AddImageTimestamp(2, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -731,13 +825,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) { mpms::AddForwardFlowTimestamp(1, input_sequence.get()); mpms::AddForwardFlowTimestamp(2, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -768,13 +862,52 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) { mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("OTHER", *input_sequence)); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("OTHER", output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReplacingBytesVectors) { + SetUpCalculator({"BYTES_FEATURE_TEST:test", "BYTES_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("foo", 2 << i)); + mpms::AddFeatureBytes("TEST", *vs_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("TEST", i, input_sequence.get()); + vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("bar", 2 << i)); + mpms::AddFeatureBytes("OTHER", *vs_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); + } + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("OTHER", *input_sequence)); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -800,7 +933,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10))); } @@ -812,11 +945,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get()); mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -853,7 +986,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } @@ -867,7 +1000,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) .ConvertToProto(detection.mutable_location_data()); detections->push_back(detection); - runner_->MutableInputs()->Tag("BBOX").packets.push_back( + runner_->MutableInputs()->Tag(kBboxTag).packets.push_back( Adopt(detections.release()).At(Timestamp(i))); } @@ -883,7 +1016,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { mpms::AddBBoxTrackIndex({-1}, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); // If the all the previous values aren't cleared, this assert will fail. MP_ASSERT_OK(runner_->Run()); @@ -899,11 +1032,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) { for (int i = 0; i < num_timesteps; ++i) { auto vf_ptr = ::absl::make_unique>(1000000, i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_TEST") + ->Tag(kFloatFeatureTestTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); ASSERT_FALSE(runner_->Run().ok()); } diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc index fce24b9b9..67ba5e90a 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc @@ -26,6 +26,8 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { +constexpr char kReferenceTag[] = "REFERENCE"; + constexpr char kMatrix[] = "MATRIX"; constexpr char kTensor[] = "TENSOR"; @@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test { if (include_rate) { header->set_packet_rate(1.0); } - runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release()); + runner_->MutableInputs()->Tag(kReferenceTag).header = + Adopt(header.release()); } std::unique_ptr runner_; diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc new file mode 100644 index 000000000..2c9e14d4b --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc @@ -0,0 +1,118 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Calculator converts from one-dimensional Tensor of DT_STRING to +// vector OR from (batched) two-dimensional Tensor of DT_STRING to +// vector. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +class TensorToVectorStringCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + TensorToVectorStringCalculatorOptions options_; +}; +REGISTER_CALCULATOR(TensorToVectorStringCalculator); + +absl::Status TensorToVectorStringCalculator::GetContract( + CalculatorContract* cc) { + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // Input Tensor + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + const auto& options = cc->Options(); + if (options.tensor_is_2d()) { + RET_CHECK(!options.flatten_nd()); + cc->Outputs().Index(0).Set>>( + /* "Output vector>." */); + } else { + cc->Outputs().Index(0).Set>( + // Output vector. + ); + } + return absl::OkStatus(); +} + +absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + + // Inform mediapipe that this calculator produces an output at time t for + // each input received at time t (i.e. this calculator does not buffer + // inputs). This enables mediapipe to propagate time of arrival estimates in + // mediapipe graphs through this calculator. + cc->SetOffset(/*offset=*/0); + + return absl::OkStatus(); +} + +absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) { + const tf::Tensor& input_tensor = + cc->Inputs().Index(0).Value().Get(); + RET_CHECK(tf::DT_STRING == input_tensor.dtype()) + << "expected DT_STRING input but got " + << tensorflow::DataTypeString(input_tensor.dtype()); + + if (options_.tensor_is_2d()) { + RET_CHECK(2 == input_tensor.dims()) + << "Expected 2-dimensional Tensor, but the tensor shape is: " + << input_tensor.shape().DebugString(); + auto output = absl::make_unique>>( + input_tensor.dim_size(0), + std::vector(input_tensor.dim_size(1))); + for (int i = 0; i < input_tensor.dim_size(0); ++i) { + auto& instance_output = output->at(i); + const auto& slice = + input_tensor.Slice(i, i + 1).unaligned_flat(); + for (int j = 0; j < input_tensor.dim_size(1); ++j) { + instance_output.at(j) = slice(j); + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + if (!options_.flatten_nd()) { + RET_CHECK(1 == input_tensor.dims()) + << "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the " + << "tensor shape is: " << input_tensor.shape().DebugString(); + } + auto output = + absl::make_unique>(input_tensor.NumElements()); + const auto& tensor_values = input_tensor.flat(); + for (int i = 0; i < input_tensor.NumElements(); ++i) { + output->at(i) = tensor_values(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto new file mode 100644 index 000000000..74df1be69 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorToVectorStringCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorToVectorStringCalculatorOptions ext = 386534187; + } + + // If true, unpack a 2d tensor (matrix) into a vector>. If + // false, convert a 1d tensor (vector) into a vector. + optional bool tensor_is_2d = 1 [default = false]; + + // If true, an N-D tensor will be flattened to a vector. This is + // exclusive with tensor_is_2d. + optional bool flatten_nd = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc new file mode 100644 index 000000000..94dd9374d --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc @@ -0,0 +1,130 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class TensorToVectorStringCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToVectorStringCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + TensorToVectorStringCalculatorOptions::ext); + options->set_tensor_is_2d(tensor_is_2d); + options->set_flatten_nd(flatten_nd); + runner_ = absl::make_unique(config); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) { + SetUpRunner(false, false); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + tensor_vec(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) { + SetUpRunner(true, false); + const tf::TensorShape tensor_shape(std::vector{1, 5}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto slice = tensor->Slice(0, 1).flat(); + for (int i = 0; i < 5; ++i) { + slice(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector>& output_vectors = + output_packets[0].Get>>(); + ASSERT_EQ(1, output_vectors.size()); + const std::vector& output_vector = output_vectors[0]; + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) { + SetUpRunner(false, true); + const tf::TensorShape tensor_shape(std::vector{2, 2, 2}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto slice = tensor->flat(); + for (int i = 0; i < 2 * 2 * 2; ++i) { + slice(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + EXPECT_EQ(2 * 2 * 2, output_vector.size()); + for (int i = 0; i < 2 * 2 * 2; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 625612c17..a8ecb847d 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -49,6 +49,11 @@ namespace tf = ::tensorflow; namespace mediapipe { namespace { + +constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS"; +constexpr char kSessionTag[] = "SESSION"; +constexpr char kSessionBundleTag[] = "SESSION_BUNDLE"; + // This is a simple implementation of a semaphore using standard C++ libraries. // It is supposed to be used only by TensorflowInferenceCalculator to throttle // the concurrent calls of Tensorflow Session::Run. This is useful when multiple @@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase { } // A mediapipe::TensorFlowSession with a model loaded and ready for use. // For this calculator it must include a tag_to_tensor_map. - cc->InputSidePackets().Tag("SESSION").Set(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets().Tag(kSessionTag).Set(); + if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) { cc->InputSidePackets() - .Tag("RECURRENT_INIT_TENSORS") + .Tag(kRecurrentInitTensorsTag) .Set>>(); } return absl::OkStatus(); @@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase { ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { std::unique_ptr inference_state = absl::make_unique(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && - !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) && + !cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) { std::map* init_tensor_map; init_tensor_map = GetFromUniquePtr>( - cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + cc->InputSidePackets().Tag(kRecurrentInitTensorsTag)); for (const auto& p : *init_tensor_map) { inference_state->input_tensor_batches_[p.first].emplace_back(p.second); } @@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) override { options_ = cc->Options(); - RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag)); session_ = cc->InputSidePackets() - .Tag("SESSION") + .Tag(kSessionTag) .Get() .session.get(); tag_to_tensor_map_ = cc->InputSidePackets() - .Tag("SESSION") + .Tag(kSessionTag) .Get() .tag_to_tensor_map; @@ -490,7 +495,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { << keyed_tensors.first; } } else { - // Pad by replicating the first tens or, then ignore the values. + // Pad by replicating the first tensor, then ignore the values. keyed_tensors.second.resize(options_.batch_size()); std::fill(keyed_tensors.second.begin() + inference_state->batch_timestamps_.size(), diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index 6a931679d..cc1d15043 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -41,6 +41,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kMultipliedTag[] = "MULTIPLIED"; +constexpr char kBTag[] = "B"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { #ifdef __APPLE__ char path[1024]; @@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test { MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options, input_side_packets, &output_side_packets)); - runner_->MutableSidePackets()->Tag("SESSION") = - output_side_packets.Tag("SESSION"); + runner_->MutableSidePackets()->Tag(kSessionTag) = + output_side_packets.Tag(kSessionTag); } Packet CreateTensorPacket(const std::vector& input, int64 time) { @@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_b = - runner_->Outputs().Tag("B").packets; + runner_->Outputs().Tag(kBTag).packets; ASSERT_EQ(output_packets_b.size(), 1); const tf::Tensor& tensor_b = output_packets_b[0].Get(); tf::TensorShape expected_shape({1, 3}); @@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) { tf::test::ExpectTensorEqual(expected_tensor, tensor_b); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); expected_tensor = tf::test::AsTensor({0, 0, 0}, expected_shape); @@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); tf::TensorShape expected_shape({3}); @@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); tf::TensorShape expected_shape({3}); @@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(3, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(5, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); LOG(INFO) << "timestamp: " << 0; @@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); LOG(INFO) << "timestamp: " << 0; @@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(0, output_packets_mult.size()); } @@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({9, 12, 15}); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 2c1d169bc..794a8a732 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -47,6 +47,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kSessionTag[] = "SESSION"; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; + // Updates the graph nodes to use the device as specified by device_id. void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { for (auto& node : *graph_def->mutable_node()) { @@ -64,30 +69,32 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { cc->Options(); bool has_exactly_one_model = !options.graph_proto_path().empty() - ? !(cc->InputSidePackets().HasTag("STRING_MODEL") | - cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) - : (cc->InputSidePackets().HasTag("STRING_MODEL") ^ - cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")); + ? !(cc->InputSidePackets().HasTag(kStringModelTag) | + cc->InputSidePackets().HasTag(kStringModelFilePathTag)) + : (cc->InputSidePackets().HasTag(kStringModelTag) ^ + cc->InputSidePackets().HasTag(kStringModelFilePathTag)); RET_CHECK(has_exactly_one_model) << "Must have exactly one of graph_proto_path in options or " "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; - if (cc->InputSidePackets().HasTag("STRING_MODEL")) { + if (cc->InputSidePackets().HasTag(kStringModelTag)) { cc->InputSidePackets() - .Tag("STRING_MODEL") + .Tag(kStringModelTag) .Set( // String model from embedded path ); - } else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { + } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) { cc->InputSidePackets() - .Tag("STRING_MODEL_FILE_PATH") + .Tag(kStringModelFilePathTag) .Set( // Filename of std::string model. ); } - cc->OutputSidePackets().Tag("SESSION").Set( - // A TensorFlow model loaded and ready for use along with - // a map from tags to tensor names. - ); + cc->OutputSidePackets() + .Tag(kSessionTag) + .Set( + // A TensorFlow model loaded and ready for use along with + // a map from tags to tensor names. + ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); return absl::OkStatus(); } @@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { session->session.reset(tf::NewSession(session_options)); std::string graph_def_serialized; - if (cc->InputSidePackets().HasTag("STRING_MODEL")) { + if (cc->InputSidePackets().HasTag(kStringModelTag)) { graph_def_serialized = - cc->InputSidePackets().Tag("STRING_MODEL").Get(); - } else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { + cc->InputSidePackets().Tag(kStringModelTag).Get(); + } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) { const std::string& frozen_graph = cc->InputSidePackets() - .Tag("STRING_MODEL_FILE_PATH") + .Tag(kStringModelFilePathTag) .Get(); RET_CHECK_OK( mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); @@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); } - cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index bdf90dcbb..f0f8928db 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -37,6 +37,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { return mediapipe::file::JoinPath("./", "mediapipe/calculators/tensorflow/" @@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -213,12 +217,12 @@ TEST_F( } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); auto run_status = runner.Run(); EXPECT_THAT( @@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); auto run_status = runner.Run(); EXPECT_THAT( @@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); auto run_status = runner.Run(); EXPECT_THAT( @@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index 9f5b9e06b..09985bcf3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -43,6 +43,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kSessionTag[] = "SESSION"; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; + // Updates the graph nodes to use the device as specified by device_id. void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { for (auto& node : *graph_def->mutable_node()) { @@ -64,28 +69,29 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); bool has_exactly_one_model = !options.graph_proto_path().empty() - ? !(input_side_packets->HasTag("STRING_MODEL") | - input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) - : (input_side_packets->HasTag("STRING_MODEL") ^ - input_side_packets->HasTag("STRING_MODEL_FILE_PATH")); + ? !(input_side_packets->HasTag(kStringModelTag) | + input_side_packets->HasTag(kStringModelFilePathTag)) + : (input_side_packets->HasTag(kStringModelTag) ^ + input_side_packets->HasTag(kStringModelFilePathTag)); RET_CHECK(has_exactly_one_model) << "Must have exactly one of graph_proto_path in options or " "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; - if (input_side_packets->HasTag("STRING_MODEL")) { - input_side_packets->Tag("STRING_MODEL") + if (input_side_packets->HasTag(kStringModelTag)) { + input_side_packets->Tag(kStringModelTag) .Set( // String model from embedded path ); - } else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) { - input_side_packets->Tag("STRING_MODEL_FILE_PATH") + } else if (input_side_packets->HasTag(kStringModelFilePathTag)) { + input_side_packets->Tag(kStringModelFilePathTag) .Set( // Filename of std::string model. ); } - output_side_packets->Tag("SESSION").Set( - // A TensorFlow model loaded and ready for use along with - // a map from tags to tensor names. - ); + output_side_packets->Tag(kSessionTag) + .Set( + // A TensorFlow model loaded and ready for use along with + // a map from tags to tensor names. + ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); return absl::OkStatus(); } @@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { session->session.reset(tf::NewSession(session_options)); std::string graph_def_serialized; - if (input_side_packets.HasTag("STRING_MODEL")) { + if (input_side_packets.HasTag(kStringModelTag)) { graph_def_serialized = - input_side_packets.Tag("STRING_MODEL").Get(); - } else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) { + input_side_packets.Tag(kStringModelTag).Get(); + } else if (input_side_packets.HasTag(kStringModelFilePathTag)) { const std::string& frozen_graph = - input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get(); + input_side_packets.Tag(kStringModelFilePathTag).Get(); RET_CHECK_OK( mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); } else { @@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); } - output_side_packets->Tag("SESSION") = Adopt(session.release()); + output_side_packets->Tag(kSessionTag) = Adopt(session.release()); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index 34d7e8828..83f947a0c 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -37,6 +37,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { return mediapipe::file::JoinPath("./", "mediapipe/calculators/tensorflow/" @@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { void VerifySignatureMap(PacketSet* output_side_packets) { const TensorFlowSession& session = - output_side_packets->Tag("SESSION").Get(); + output_side_packets->Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); generator_options_->clear_graph_proto_path(); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -196,7 +200,7 @@ TEST_F( PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->clear_graph_proto_path(); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( @@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); generator_options_->clear_graph_proto_path(); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index c169c6b1e..de600de31 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -31,6 +31,9 @@ namespace mediapipe { namespace { + +constexpr char kSessionTag[] = "SESSION"; + static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models @@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { cc->InputSidePackets().Tag(kStringSavedModelPath).Set(); } // A TensorFlow model loaded and ready for use along with tensor - cc->OutputSidePackets().Tag("SESSION").Set(); + cc->OutputSidePackets().Tag(kSessionTag).Set(); return absl::OkStatus(); } @@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { output_signature.first, options)] = output_signature.second.name(); } - cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index 7016f14bb..52cd9e0bb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -35,6 +35,9 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetSavedModelDir() { std::string out_path = file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", @@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, } })", options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") = + runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) = MakePacket(GetSavedModelDir()); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 6489b0267..9b2e16a88 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -33,6 +33,9 @@ namespace mediapipe { namespace { + +constexpr char kSessionTag[] = "SESSION"; + static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models @@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { input_side_packets->Tag(kStringSavedModelPath).Set(); } // A TensorFlow model loaded and ready for use along with tensor - output_side_packets->Tag("SESSION").Set(); + output_side_packets->Tag(kSessionTag).Set(); return absl::OkStatus(); } @@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { output_signature.first, options)] = output_signature.second.name(); } - output_side_packets->Tag("SESSION") = Adopt(session.release()); + output_side_packets->Tag(kSessionTag) = Adopt(session.release()); return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index aca506f0b..46cbf41cb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -34,6 +34,9 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetSavedModelDir() { std::string out_path = file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", @@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, generator_options_->clear_saved_model_path(); PacketSet input_side_packets( tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value()); - input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = + input_side_packets.Tag(kStringSavedModelPathTag) = Adopt(new std::string(GetSavedModelDir())); PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); @@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index e8e40bad3..d12f91741 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -33,6 +33,33 @@ namespace { namespace tf = ::tensorflow; namespace mpms = mediapipe::mediasequence; +constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE"; +constexpr char kEncodedMediaStartTimestampTag[] = + "ENCODED_MEDIA_START_TIMESTAMP"; +constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA"; +constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS"; +constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS"; +constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS"; +constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS"; +constexpr char kDataPathTag[] = "DATA_PATH"; +constexpr char kDatasetRootTag[] = "DATASET_ROOT"; +constexpr char kMediaIdTag[] = "MEDIA_ID"; +constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX"; +constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG"; +constexpr char kAudioOtherTag[] = "AUDIO_OTHER"; +constexpr char kAudioTestTag[] = "AUDIO_TEST"; +constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; +constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; +constexpr char kBboxPrefixTag[] = "BBOX_PREFIX"; +constexpr char kKeypointsTag[] = "KEYPOINTS"; +constexpr char kBboxTag[] = "BBOX"; +constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; +constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; +constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; + class UnpackMediaSequenceCalculatorTest : public ::testing::Test { protected: void SetUpCalculator(const std::vector& output_streams, @@ -95,13 +122,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) { mpms::AddImageEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -124,13 +151,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) { mpms::AddImageEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -154,13 +181,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) { mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE_PREFIX").packets; + runner_->Outputs().Tag(kImagePrefixTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -182,12 +209,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) { mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + runner_->Outputs().Tag(kForwardFlowEncodedTag).packets; ASSERT_EQ(num_forward_flow_images, output_packets.size()); for (int i = 0; i < num_forward_flow_images; ++i) { @@ -211,12 +238,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) { mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + runner_->Outputs().Tag(kForwardFlowEncodedTag).packets; ASSERT_EQ(num_forward_flow_images, output_packets.size()); for (int i = 0; i < num_forward_flow_images; ++i) { @@ -240,13 +267,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) { mpms::AddBBoxTimestamp(i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("BBOX").packets; + runner_->Outputs().Tag(kBboxTag).packets; ASSERT_EQ(bboxes.size(), output_packets.size()); for (int i = 0; i < bboxes.size(); ++i) { @@ -274,13 +301,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) { mpms::AddBBoxTimestamp(prefix, i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("BBOX_PREFIX").packets; + runner_->Outputs().Tag(kBboxPrefixTag).packets; ASSERT_EQ(bboxes.size(), output_packets.size()); for (int i = 0; i < bboxes.size(); ++i) { @@ -306,13 +333,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) { mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets; + runner_->Outputs().Tag(kFloatFeatureTestTag).packets; ASSERT_EQ(num_float_lists, output_packets.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -322,7 +349,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) { } const std::vector& output_packets_other = - runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + runner_->Outputs().Tag(kFloatFeatureOtherTag).packets; ASSERT_EQ(num_float_lists, output_packets_other.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -352,12 +379,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) { mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -366,7 +393,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) { } const std::vector& output_packets_other = - runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + runner_->Outputs().Tag(kFloatFeatureOtherTag).packets; ASSERT_EQ(num_float_lists, output_packets_other.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -389,12 +416,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& fdense_avg_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets; ASSERT_EQ(fdense_avg_packets.size(), 1); const auto& fdense_avg_vector = fdense_avg_packets[0].Get>(); @@ -403,7 +430,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { ::testing::Eq(Timestamp::PostStream())); const std::vector& fdense_max_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets; ASSERT_EQ(fdense_max_packets.size(), 1); const auto& fdense_max_vector = fdense_max_packets[0].Get>(); @@ -430,13 +457,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -463,13 +490,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& fdense_max_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets; ASSERT_EQ(fdense_max_packets.size(), 1); const auto& fdense_max_vector = fdense_max_packets[0].Get>(); @@ -481,17 +508,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) { TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); std::string root = "test_root"; - runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root); + runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), root + "/" + data_path_); } @@ -501,28 +528,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) { options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_dataset_root_directory(root); SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), root + "/" + data_path_); } TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { SetUpCalculator({}, {"DATA_PATH:data_path"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), data_path_); } @@ -534,20 +561,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) { ->set_padding_after_label(2); SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .start_time(), 2.0, 1e-5); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .end_time(), 7.0, 1e-5); @@ -563,20 +590,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { ->set_force_decoding_from_start_of_media(true); SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .start_time(), 0.0, 1e-5); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .end_time(), 7.0, 1e-5); @@ -594,27 +621,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { ->mutable_base_packet_resampler_options() ->set_frame_rate(1.0); SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .start_time(), 2000000, 1); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .end_time(), 7000000, 1); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .frame_rate(), @@ -623,13 +650,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) { SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("IMAGE_FRAME_RATE") + .Tag(kImageFrameRateTag) .ValidateAsType()); - EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get(), + EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get(), image_frame_rate_); } diff --git a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc index 369c09660..a7f1a9e7f 100644 --- a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc @@ -26,6 +26,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kSingleIntTag[] = "SINGLE_INT"; +constexpr char kTensorOutTag[] = "TENSOR_OUT"; +constexpr char kVectorIntTag[] = "VECTOR_INT"; + class VectorIntToTensorCalculatorTest : public ::testing::Test { protected: void SetUpRunner( @@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test { const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) { tensorflow::DT_INT32, false, true); const int64 time = 1234; runner_->MutableInputs() - ->Tag("SINGLE_INT") + ->Tag(kSingleIntTag) .packets.push_back(MakePacket(1).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) { } const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) { tensorflow::DT_INT64, false, true); const int64 time = 1234; runner_->MutableInputs() - ->Tag("SINGLE_INT") + ->Tag(kSingleIntTag) .packets.push_back(MakePacket(1LL << 31).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) { } const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2d1037d20..55616bb83 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -162,6 +162,27 @@ selects.config_setting_group( ], ) +config_setting( + name = "edge_tpu_usb", + define_values = { + "MEDIAPIPE_EDGE_TPU": "usb", + }, +) + +config_setting( + name = "edge_tpu_pci", + define_values = { + "MEDIAPIPE_EDGE_TPU": "pci", + }, +) + +config_setting( + name = "edge_tpu_all", + define_values = { + "MEDIAPIPE_EDGE_TPU": "all", + }, +) + cc_library( name = "tflite_inference_calculator", srcs = ["tflite_inference_calculator.cc"], @@ -172,6 +193,12 @@ cc_library( ], "//conditions:default": [], }), + defines = select({ + "//conditions:default": [], + ":edge_tpu_usb": ["MEDIAPIPE_EDGE_TPU=usb"], + ":edge_tpu_pci": ["MEDIAPIPE_EDGE_TPU=pci"], + ":edge_tpu_all": ["MEDIAPIPE_EDGE_TPU=all"], + }), linkopts = select({ "//mediapipe:ios": [ "-framework CoreVideo", @@ -223,6 +250,20 @@ cc_library( "//conditions:default": [ "//mediapipe/util:cpu_util", ], + }) + select({ + "//conditions:default": [], + ":edge_tpu_usb": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_usb", + ], + ":edge_tpu_pci": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_pci", + ], + ":edge_tpu_all": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_all", + ], }), alwayslink = 1, ) diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index ef46460b1..8e83f3e44 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -85,7 +85,22 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU"; } // namespace #if defined(MEDIAPIPE_EDGE_TPU) -#include "edgetpu.h" +#include "tflite/public/edgetpu.h" + +// Checkes whether model contains Edge TPU custom op or not. +bool ContainsEdgeTpuCustomOp(const tflite::FlatBufferModel& model) { + const auto* opcodes = model.GetModel()->operator_codes(); + for (const auto* subgraph : *model.GetModel()->subgraphs()) { + for (const auto* op : *subgraph->operators()) { + const auto* opcode = opcodes->Get(op->opcode_index()); + if (opcode->custom_code() && + opcode->custom_code()->str() == edgetpu::kCustomOp) { + return true; + } + } + } + return false; +} // Creates and returns an Edge TPU interpreter to run the given edgetpu model. std::unique_ptr BuildEdgeTpuInterpreter( @@ -94,14 +109,9 @@ std::unique_ptr BuildEdgeTpuInterpreter( edgetpu::EdgeTpuContext* edgetpu_context) { resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); std::unique_ptr interpreter; - if (tflite::InterpreterBuilder(model, *resolver)(&interpreter) != kTfLiteOk) { - std::cerr << "Failed to build edge TPU interpreter." << std::endl; - } + CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter), + kTfLiteOk); interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context); - interpreter->SetNumThreads(1); - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate edge TPU tensors." << std::endl; - } return interpreter; } #endif // MEDIAPIPE_EDGE_TPU @@ -128,9 +138,23 @@ struct GPUData { } // namespace #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED +namespace { + +int GetXnnpackDefaultNumThreads() { +#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \ + defined(__EMSCRIPTEN_PTHREADS__) + constexpr int kMinNumThreadsByDefault = 1; + constexpr int kMaxNumThreadsByDefault = 4; + return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault, + kMaxNumThreadsByDefault); +#else + return 1; +#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__ +} + // Returns number of threads to configure XNNPACK delegate with. -// (Equal to user provided value if specified. Otherwise, it returns number of -// high cores (hard-coded to 1 for Emscripten without Threads extension)) +// Returns user provided value if specified. Otherwise, tries to choose optimal +// number of threads depending on the device. int GetXnnpackNumThreads( const mediapipe::TfLiteInferenceCalculatorOptions& opts) { static constexpr int kDefaultNumThreads = -1; @@ -138,13 +162,11 @@ int GetXnnpackNumThreads( opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { return opts.delegate().xnnpack().num_threads(); } -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) - return InferHigherCoreIds().size(); -#else - return 1; -#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ + return GetXnnpackDefaultNumThreads(); } +} // namespace + // Calculator Header Section // Runs inference on the provided input TFLite tensors and TFLite model. @@ -267,8 +289,7 @@ class TfLiteInferenceCalculator : public CalculatorBase { #endif // MEDIAPIPE_TFLITE_GL_INFERENCE #if defined(MEDIAPIPE_EDGE_TPU) - std::shared_ptr edgetpu_context_ = - edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); + std::shared_ptr edgetpu_context_; #endif bool gpu_inference_ = false; @@ -280,6 +301,8 @@ class TfLiteInferenceCalculator : public CalculatorBase { bool allow_precision_loss_ = false; mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; + mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::InferenceUsage + tflite_gpu_runner_usage_; bool use_kernel_caching_ = false; std::string cached_kernel_filename_; @@ -289,6 +312,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Calculator Core Section namespace { + +constexpr char kCustomOpResolverTag[] = "CUSTOM_OP_RESOLVER"; +constexpr char kModelTag[] = "MODEL"; + template bool ShouldUseGpu(CC* cc) { #if MEDIAPIPE_TFLITE_GPU_SUPPORTED @@ -313,7 +340,7 @@ absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { const auto& options = cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); RET_CHECK(!options.model_path().empty() ^ - cc->InputSidePackets().HasTag("MODEL")) + cc->InputSidePackets().HasTag(kModelTag)) << "Either model as side packet or model path in options is required."; if (cc->Inputs().HasTag(kTensorsTag)) @@ -326,13 +353,13 @@ absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kTensorsGpuTag)) cc->Outputs().Tag(kTensorsGpuTag).Set>(); - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") + .Tag(kCustomOpResolverTag) .Set(); } - if (cc->InputSidePackets().HasTag("MODEL")) { - cc->InputSidePackets().Tag("MODEL").Set(); + if (cc->InputSidePackets().HasTag(kModelTag)) { + cc->InputSidePackets().Tag(kModelTag).Set(); } if (ShouldUseGpu(cc)) { @@ -365,6 +392,7 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { options.delegate().gpu().use_advanced_gpu_api(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); tflite_gpu_runner_api_ = options.delegate().gpu().api(); + tflite_gpu_runner_usage_ = options.delegate().gpu().usage(); use_kernel_caching_ = use_advanced_gpu_api_ && options.delegate().gpu().has_cached_kernel_path(); @@ -471,8 +499,8 @@ absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { MP_RETURN_IF_ERROR(WriteKernelsToFile()); return RunInContextIfNeeded([this]() -> absl::Status { + interpreter_ = nullptr; if (delegate_) { - interpreter_ = nullptr; delegate_ = nullptr; #if MEDIAPIPE_TFLITE_GPU_SUPPORTED if (gpu_inference_) { @@ -486,7 +514,7 @@ absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED } #if defined(MEDIAPIPE_EDGE_TPU) - edgetpu_context_.reset(); + edgetpu_context_ = nullptr; #endif return absl::OkStatus(); }); @@ -708,9 +736,9 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( auto op_resolver_ptr = static_cast( &default_op_resolver); - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { op_resolver_ptr = &(cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") + .Tag(kCustomOpResolverTag) .Get()); } @@ -721,7 +749,23 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( : tflite::gpu::InferencePriority::MAX_PRECISION; options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + switch (tflite_gpu_runner_usage_) { + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + FAST_SINGLE_ANSWER: { + options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + SUSTAINED_SPEED: { + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + UNSPECIFIED: { + return absl::InternalError("inference usage need to be specified."); + } + } + tflite_gpu_runner_ = std::make_unique(options); switch (tflite_gpu_runner_api_) { case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: { @@ -737,8 +781,8 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( break; } } - MP_RETURN_IF_ERROR( - tflite_gpu_runner_->InitializeWithModel(model, *op_resolver_ptr)); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( + model, *op_resolver_ptr, /*allow_quant_ops=*/true)); // Allocate interpreter memory for cpu output. if (!gpu_output_) { @@ -794,21 +838,26 @@ absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates default_op_resolver; - auto op_resolver_ptr = - static_cast( - &default_op_resolver); - - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver_ptr = &(cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get()); - } - #if defined(MEDIAPIPE_EDGE_TPU) - interpreter_ = - BuildEdgeTpuInterpreter(model, op_resolver_ptr, edgetpu_context_.get()); -#else - tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_); + if (ContainsEdgeTpuCustomOp(model)) { + edgetpu_context_ = edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); + interpreter_ = BuildEdgeTpuInterpreter(model, &default_op_resolver, + edgetpu_context_.get()); + } else { +#endif // MEDIAPIPE_EDGE_TPU + auto op_resolver_ptr = + static_cast( + &default_op_resolver); + + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { + op_resolver_ptr = &(cc->InputSidePackets() + .Tag(kCustomOpResolverTag) + .Get()); + } + + tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_); +#if defined(MEDIAPIPE_EDGE_TPU) + } #endif // MEDIAPIPE_EDGE_TPU RET_CHECK(interpreter_); @@ -841,8 +890,8 @@ absl::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( if (!options.model_path().empty()) { return TfLiteModelLoader::LoadFromPath(options.model_path()); } - if (cc.InputSidePackets().HasTag("MODEL")) { - return cc.InputSidePackets().Tag("MODEL"); + if (cc.InputSidePackets().HasTag(kModelTag)) { + return cc.InputSidePackets().Tag(kModelTag); } return absl::Status(absl::StatusCode::kNotFound, "Must specify TFLite model as path or loaded model."); @@ -866,11 +915,15 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { // Attempt to use NNAPI. // If not supported, the default CPU delegate will be created and used. interpreter_->SetAllowFp16PrecisionForFp32(1); - delegate_ = - TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { - // No need to free according to tflite::NnApiDelegate() - // documentation. - }); + tflite::StatefulNnApiDelegate::Options options; + const auto& nnapi = calculator_opts.delegate().nnapi(); + // Set up cache_dir and model_token for NNAPI compilation cache. + if (nnapi.has_cache_dir() && nnapi.has_model_token()) { + options.cache_dir = nnapi.cache_dir().c_str(); + options.model_token = nnapi.model_token().c_str(); + } + delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), + [](TfLiteDelegate*) {}); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); return absl::OkStatus(); @@ -894,6 +947,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { kTfLiteOk); return absl::OkStatus(); } +#else + (void)use_xnnpack; #endif // !EDGETPU // Return and use default tflite infernece (on CPU). No need for GPU @@ -969,6 +1024,10 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { const int kHalfSize = 2; // sizeof(half) // Configure and create the delegate. TFLGpuDelegateOptions options; + // `enable_quantization` enables the run of sparse models i.e. the models with + // DENSIFY op preceding DEQUINTIZE op. Both ops get removed from the execution + // graph after the tensor of the weights is read. + options.enable_quantization = true; options.allow_precision_loss = allow_precision_loss_; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive; if (!delegate_) diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index 02dc20831..3b4d2896e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -67,9 +67,31 @@ message TfLiteInferenceCalculatorOptions { // Only available for OpenCL delegate on Android. // Kernel caching will only be enabled if this path is set. optional string cached_kernel_path = 2; + + // Encapsulated compilation/runtime tradeoffs. + enum InferenceUsage { + UNSPECIFIED = 0; + + // InferenceRunner will be used only once. Therefore, it is important to + // minimize bootstrap time as well. + FAST_SINGLE_ANSWER = 1; + + // Prefer maximizing the throughput. Same inference runner will be used + // repeatedly on different inputs. + SUSTAINED_SPEED = 2; + } + optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED]; } // Android only. - message Nnapi {} + message Nnapi { + // Directory to store compilation cache. If unspecified, NNAPI will not + // try caching the compilation. + optional string cache_dir = 1; + // Unique token identifying the model. It is the caller's responsibility + // to ensure there is no clash of the tokens. If unspecified, NNAPI will + // not try caching the compilation. + optional string model_token = 2; + } message Xnnpack { // Number of threads for XNNPACK delegate. (By default, calculator tries // to choose optimal number of threads depending on the device.) diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc index ef2946c32..94cfaece8 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc @@ -18,6 +18,10 @@ namespace mediapipe { +constexpr char kFloatsTag[] = "FLOATS"; +constexpr char kFloatTag[] = "FLOAT"; +constexpr char kTensorsTag[] = "TENSORS"; + // A calculator for converting TFLite tensors to to a float or a float vector. // // Input: @@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); absl::Status TfLiteTensorsToFloatsCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("TENSORS")); - RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); + RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); + RET_CHECK(cc->Outputs().HasTag(kFloatsTag) || + cc->Outputs().HasTag(kFloatTag)); - cc->Inputs().Tag("TENSORS").Set>(); - if (cc->Outputs().HasTag("FLOATS")) { - cc->Outputs().Tag("FLOATS").Set>(); + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Outputs().HasTag(kFloatsTag)) { + cc->Outputs().Tag(kFloatsTag).Set>(); } - if (cc->Outputs().HasTag("FLOAT")) { - cc->Outputs().Tag("FLOAT").Set(); + if (cc->Outputs().HasTag(kFloatTag)) { + cc->Outputs().Tag(kFloatTag).Set(); } return absl::OkStatus(); @@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) { } absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { - RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); + RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty()); const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); + cc->Inputs().Tag(kTensorsTag).Get>(); // TODO: Add option to specify which tensor to take from. const TfLiteTensor* raw_tensor = &input_tensors[0]; const float* raw_floats = raw_tensor->data.f; @@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { num_values *= raw_tensor->dims->data[i]; } - if (cc->Outputs().HasTag("FLOAT")) { + if (cc->Outputs().HasTag(kFloatTag)) { // TODO: Could add an index in the option to specifiy returning one // value of a float array. RET_CHECK_EQ(num_values, 1); - cc->Outputs().Tag("FLOAT").AddPacket( + cc->Outputs().Tag(kFloatTag).AddPacket( MakePacket(raw_floats[0]).At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("FLOATS")) { + if (cc->Outputs().HasTag(kFloatsTag)) { auto output_floats = absl::make_unique>( raw_floats, raw_floats + num_values); - cc->Outputs().Tag("FLOATS").Add(output_floats.release(), - cc->InputTimestamp()); + cc->Outputs() + .Tag(kFloatsTag) + .Add(output_floats.release(), cc->InputTimestamp()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1ee0fb9cc..eb8950510 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -57,6 +57,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "filter_detections_calculator_proto", + srcs = ["filter_detections_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], @@ -158,6 +168,21 @@ cc_test( ], ) +cc_test( + name = "filter_detections_calculator_test", + size = "small", + srcs = ["filter_detections_calculator_test.cc"], + deps = [ + ":filter_detections_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], @@ -372,6 +397,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "filter_detections_calculator", + srcs = ["filter_detections_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":filter_detections_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], @@ -840,6 +879,20 @@ cc_test( ], ) +cc_library( + name = "world_landmark_projection_calculator", + srcs = ["world_landmark_projection_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], @@ -859,6 +912,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/util/filtering:one_euro_filter", "//mediapipe/util/filtering:relative_velocity_filter", @@ -893,6 +947,31 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "visibility_copy_calculator_proto", + srcs = ["visibility_copy_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "visibility_copy_calculator", + srcs = ["visibility_copy_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":visibility_copy_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/algorithm:container", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], @@ -1313,3 +1392,34 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "inverse_matrix_calculator", + srcs = ["inverse_matrix_calculator.cc"], + hdrs = ["inverse_matrix_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "@com_google_absl//absl/status", + "@eigen_archive//:eigen3", + ], + alwayslink = True, +) + +cc_test( + name = "inverse_matrix_calculator_test", + srcs = ["inverse_matrix_calculator_test.cc"], + tags = ["desktop_only_test"], + deps = [ + ":inverse_matrix_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 2c0b25397..8af4a5de8 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -272,6 +272,15 @@ absl::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { } absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().HasTag(kGpuBufferTag) && + cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { + return absl::OkStatus(); + } + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { + return absl::OkStatus(); + } + // Initialize render target, drawn with OpenCV. std::unique_ptr image_mat; ImageFormat::Format target_format; diff --git a/mediapipe/calculators/util/clock_timestamp_calculator.cc b/mediapipe/calculators/util/clock_timestamp_calculator.cc index 4ba56cfd0..324bc4ac7 100644 --- a/mediapipe/calculators/util/clock_timestamp_calculator.cc +++ b/mediapipe/calculators/util/clock_timestamp_calculator.cc @@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { // Initialize the clock. if (cc->InputSidePackets().HasTag(kClockTag)) { clock_ = cc->InputSidePackets() - .Tag("CLOCK") + .Tag(kClockTag) .Get>(); } else { clock_.reset( diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc index 805ad495d..62eb1d8ae 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc +++ b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc @@ -27,6 +27,8 @@ namespace mediapipe { +constexpr char kIterableTag[] = "ITERABLE"; + typedef CollectionHasMinSizeCalculator> TestIntCollectionHasMinSizeCalculator; REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); @@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); void AddInputVector(const std::vector& input, int64 timestamp, CalculatorRunner* runner) { runner->MutableInputs() - ->Tag("ITERABLE") + ->Tag(kIterableTag) .packets.push_back( MakePacket>(input).At(Timestamp(timestamp))); } diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc index 8f8025576..ca85a267e 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { } cc->Outputs() - .Tag("DETECTIONS") + .Tag(kDetectionsTag) .Add(output_detections.release(), cc->InputTimestamp()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc index 343ccea4f..c4f084363 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + LocationData CreateRelativeLocationData(double xmin, double ymin, double width, double height) { LocationData location_data; @@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) { detections->push_back( CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.2f, 0.f, 0.3f, 0.f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; ASSERT_EQ(1, output.size()); const auto& output_detections = output[0].Get>(); @@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) { detections->push_back( CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.f, 0.2f, 0.f, 0.3f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; ASSERT_EQ(1, output.size()); const auto& output_detections = output[0].Get>(); diff --git a/mediapipe/calculators/util/detection_projection_calculator_test.cc b/mediapipe/calculators/util/detection_projection_calculator_test.cc index 4cc85acee..176054e43 100644 --- a/mediapipe/calculators/util/detection_projection_calculator_test.cc +++ b/mediapipe/calculators/util/detection_projection_calculator_test.cc @@ -31,6 +31,9 @@ namespace mediapipe { namespace { +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + using ::testing::ElementsAre; using ::testing::FloatNear; @@ -74,19 +77,19 @@ absl::StatusOr RunProjectionCalculator( )pb")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back(MakePacket>( std::vector({std::move(detection)})) .At(Timestamp::PostStream())); runner.MutableInputs() - ->Tag("PROJECTION_MATRIX") + ->Tag(kProjectionMatrixTag) .packets.push_back( MakePacket>(std::move(project_mat)) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; RET_CHECK_EQ(output.size(), 1); const auto& output_detections = output[0].Get>(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 29836cb59..73a67d322 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -203,6 +203,9 @@ absl::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { return absl::OkStatus(); } + if (rotate_ && !HasTagValue(cc, kImageSizeTag)) { + return absl::OkStatus(); + } std::vector detections; if (cc->Inputs().HasTag(kDetectionTag)) { @@ -323,7 +326,7 @@ absl::Status DetectionsToRectsCalculator::ComputeRotation( DetectionSpec DetectionsToRectsCalculator::GetDetectionSpec( const CalculatorContext* cc) { absl::optional> image_size; - if (cc->Inputs().HasTag(kImageSizeTag)) { + if (HasTagValue(cc->Inputs(), kImageSizeTag)) { image_size = cc->Inputs().Tag(kImageSizeTag).Get>(); } diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index f46640ab2..a45048d40 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -32,6 +32,14 @@ namespace mediapipe { namespace { +constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kRectsTag[] = "RECTS"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kRectTag[] = "RECT"; +constexpr char kDetectionTag[] = "DETECTION"; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && @@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) { DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); @@ -120,16 +128,16 @@ absl::StatusOr RunDetectionKeyPointsToRectCalculation( )pb")); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back(MakePacket(std::move(detection)) .At(Timestamp::PostStream())); runner.MutableInputs() - ->Tag("IMAGE_SIZE") + ->Tag(kImageSizeTag) .packets.push_back(MakePacket>(image_size) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; RET_CHECK_EQ(output.size(), 1); return output[0].Get(); } @@ -157,6 +165,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionKeyPointsToRect) { /*image_size=*/{640, 480}); MP_ASSERT_OK(status_or_value); EXPECT_THAT(status_or_value.value(), RectEq(480, 360, 320, 240)); + + status_or_value = RunDetectionKeyPointsToRectCalculation( + /*detection=*/DetectionWithKeyPoints({{0.25f, 0.25f}, {0.75f, 0.75f}}), + /*image_size=*/{0, 0}); + MP_ASSERT_OK(status_or_value); + EXPECT_THAT(status_or_value.value(), RectEq(0, 0, 0, 0)); } TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { @@ -170,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); @@ -195,12 +210,13 @@ absl::StatusOr RunDetectionKeyPointsToNormRectCalculation( )pb")); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back(MakePacket(std::move(detection)) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; RET_CHECK_EQ(output.size(), 1); return output[0].Get(); } @@ -242,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); @@ -265,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) { detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); @@ -288,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) { detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECTS").packets; + const std::vector& output = runner.Outputs().Tag(kRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 2); @@ -313,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) { detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("NORM_RECTS").packets; + runner.Outputs().Tag(kNormRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 2); @@ -338,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) { DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECTS").packets; + const std::vector& output = runner.Outputs().Tag(kRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); EXPECT_EQ(rects.size(), 1); @@ -361,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) { DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("NORM_RECTS").packets; + runner.Outputs().Tag(kNormRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 1); @@ -385,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) { detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); @@ -405,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) { detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc index ea4bfc484..04d8b5bcd 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -30,6 +30,10 @@ namespace mediapipe { +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; +constexpr char kDetectionListTag[] = "DETECTION_LIST"; + using ::testing::DoubleNear; // Error tolerance for pixels, distances, etc. @@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) { CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"); runner.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, output.size()); const auto& actual = output[0].Get(); EXPECT_EQ(actual.render_annotations_size(), 3); @@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) { CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, output.size()); const auto& actual = output[0].Get(); EXPECT_EQ(actual.render_annotations_size(), 3); @@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { *(detection_list->add_detection()) = CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1"); runner.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list.release()).At(Timestamp::PostStream())); @@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { detections->push_back( CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& actual = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, actual.size()); // Check the feature tag for item from detection list. EXPECT_EQ( @@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { auto detection_list1(absl::make_unique()); runner1.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list1.release()).At(Timestamp::PostStream())); auto detections1(absl::make_unique>()); runner1.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections1.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed."; const std::vector& exact1 = - runner1.Outputs().Tag("RENDER_DATA").packets; + runner1.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(0, exact1.size()); // Check when produce_empty_packet is true. @@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { auto detection_list2(absl::make_unique()); runner2.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list2.release()).At(Timestamp::PostStream())); auto detections2(absl::make_unique>()); runner2.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections2.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed."; const std::vector& exact2 = - runner2.Outputs().Tag("RENDER_DATA").packets; + runner2.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, exact2.size()); EXPECT_EQ(exact2[0].Get().render_annotations_size(), 0); } diff --git a/mediapipe/calculators/util/filter_collection_calculator.cc b/mediapipe/calculators/util/filter_collection_calculator.cc index 690ca2a93..ab361f450 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.cc +++ b/mediapipe/calculators/util/filter_collection_calculator.cc @@ -32,11 +32,15 @@ typedef FilterCollectionCalculator> FilterNormalizedRectCollectionCalculator; REGISTER_CALCULATOR(FilterNormalizedRectCollectionCalculator); -typedef FilterCollectionCalculator< - std::vector<::mediapipe::NormalizedLandmarkList>> +typedef FilterCollectionCalculator> FilterLandmarkListCollectionCalculator; REGISTER_CALCULATOR(FilterLandmarkListCollectionCalculator); +typedef FilterCollectionCalculator< + std::vector<::mediapipe::NormalizedLandmarkList>> + FilterNormalizedLandmarkListCollectionCalculator; +REGISTER_CALCULATOR(FilterNormalizedLandmarkListCollectionCalculator); + typedef FilterCollectionCalculator> FilterClassificationListCollectionCalculator; REGISTER_CALCULATOR(FilterClassificationListCollectionCalculator); diff --git a/mediapipe/calculators/util/filter_detections_calculator.cc b/mediapipe/calculators/util/filter_detections_calculator.cc new file mode 100644 index 000000000..a1f23ba83 --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator.cc @@ -0,0 +1,81 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/filter_detections_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +const char kInputDetectionsTag[] = "INPUT_DETECTIONS"; +const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS"; + +// +// Calculator to filter out detections that do not meet the criteria specified +// in options. +// +class FilterDetectionsCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kInputDetectionsTag)); + RET_CHECK(cc->Outputs().HasTag(kOutputDetectionsTag)); + + cc->Inputs().Tag(kInputDetectionsTag).Set>(); + cc->Outputs().Tag(kOutputDetectionsTag).Set>(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + const auto& input_detections = + cc->Inputs().Tag(kInputDetectionsTag).Get>(); + + auto output_detections = absl::make_unique>(); + + for (const Detection& detection : input_detections) { + RET_CHECK_GT(detection.score_size(), 0); + // Note: only score at index 0 supported. + if (detection.score(0) >= options_.min_score()) { + output_detections->push_back(detection); + } + } + + cc->Outputs() + .Tag(kOutputDetectionsTag) + .Add(output_detections.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + private: + mediapipe::FilterDetectionsCalculatorOptions options_; +}; + +REGISTER_CALCULATOR(FilterDetectionsCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/filter_detections_calculator.proto b/mediapipe/calculators/util/filter_detections_calculator.proto new file mode 100644 index 000000000..e16898c79 --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message FilterDetectionsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FilterDetectionsCalculatorOptions ext = 395478132; + } + + // Detections lower than this score get filtered out. + optional float min_score = 1; +} diff --git a/mediapipe/calculators/util/filter_detections_calculator_test.cc b/mediapipe/calculators/util/filter_detections_calculator_test.cc new file mode 100644 index 000000000..515a8b7df --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator_test.cc @@ -0,0 +1,100 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; + +absl::Status RunGraph(std::vector& input_detections, + std::vector* output_detections) { + CalculatorRunner runner(R"pb( + calculator: "FilterDetectionsCalculator" + input_stream: "INPUT_DETECTIONS:input_detections" + output_stream: "OUTPUT_DETECTIONS:output_detections" + options { + [mediapipe.FilterDetectionsCalculatorOptions.ext] { min_score: 0.5 } + } + )pb"); + + const Timestamp input_timestamp = Timestamp(0); + runner.MutableInputs() + ->Tag("INPUT_DETECTIONS") + .packets.push_back(MakePacket>(input_detections) + .At(input_timestamp)); + MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed."; + + const std::vector& output_packets = + runner.Outputs().Tag("OUTPUT_DETECTIONS").packets; + RET_CHECK_EQ(output_packets.size(), 1); + + *output_detections = output_packets[0].Get>(); + return absl::OkStatus(); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetections) { + std::vector input_detections; + Detection d1, d2; + d1.add_score(0.2); + d2.add_score(0.8); + input_detections.push_back(d1); + input_detections.push_back(d2); + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2))); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) { + std::vector input_detections; + Detection d1, d2, d3, d4; + d1.add_score(0.3); + d2.add_score(0.4); + d3.add_score(0.5); + d4.add_score(0.6); + input_detections.push_back(d1); + input_detections.push_back(d2); + input_detections.push_back(d3); + input_detections.push_back(d4); + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3), + mediapipe::EqualsProto(d4))); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) { + std::vector input_detections; + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_EQ(output_detections.size(), 0); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc index 7484d9257..0ddb342eb 100644 --- a/mediapipe/calculators/util/from_image_calculator.cc +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -33,6 +33,7 @@ namespace { constexpr char kImageFrameTag[] = "IMAGE_CPU"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kSourceOnGpuTag[] = "SOURCE_ON_GPU"; } // namespace // A calculator for converting the unified image container into @@ -46,6 +47,8 @@ constexpr char kImageTag[] = "IMAGE"; // IMAGE_CPU: An ImageFrame containing output image. // IMAGE_GPU: A GpuBuffer containing output image. // +// SOURCE_ON_GPU: The source Image is stored on GPU or CPU. +// // Note: // Data is automatically transferred to/from the CPU or GPU // depending on output type. @@ -66,6 +69,7 @@ class FromImageCalculator : public CalculatorBase { absl::Status RenderGpu(CalculatorContext* cc); absl::Status RenderCpu(CalculatorContext* cc); + bool check_image_source_ = false; bool gpu_output_ = false; bool gpu_initialized_ = false; #if !MEDIAPIPE_DISABLE_GPU @@ -102,6 +106,9 @@ absl::Status FromImageCalculator::GetContract(CalculatorContract* cc) { #endif // !MEDIAPIPE_DISABLE_GPU } + if (cc->Outputs().HasTag(kSourceOnGpuTag)) { + cc->Outputs().Tag(kSourceOnGpuTag).Set(); + } return absl::OkStatus(); } @@ -111,7 +118,9 @@ absl::Status FromImageCalculator::Open(CalculatorContext* cc) { if (cc->Outputs().HasTag(kGpuBufferTag)) { gpu_output_ = true; } - + if (cc->Outputs().HasTag(kSourceOnGpuTag)) { + check_image_source_ = true; + } if (gpu_output_) { #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); @@ -122,6 +131,13 @@ absl::Status FromImageCalculator::Open(CalculatorContext* cc) { } absl::Status FromImageCalculator::Process(CalculatorContext* cc) { + if (check_image_source_) { + auto& input = cc->Inputs().Tag(kImageTag).Get(); + cc->Outputs() + .Tag(kSourceOnGpuTag) + .AddPacket(MakePacket(input.UsesGpu()).At(cc->InputTimestamp())); + } + if (gpu_output_) { #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { diff --git a/mediapipe/calculators/util/inverse_matrix_calculator.cc b/mediapipe/calculators/util/inverse_matrix_calculator.cc new file mode 100644 index 000000000..5809623c0 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator.cc @@ -0,0 +1,50 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/inverse_matrix_calculator.h" + +#include "Eigen/Core" +#include "Eigen/Geometry" +#include "Eigen/LU" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { +namespace api2 { + +class InverseMatrixCalculatorImpl : public NodeImpl { + absl::Status Process(mediapipe::CalculatorContext* cc) override { + if (kInputMatrix(cc).IsEmpty()) { + return absl::OkStatus(); + } + Eigen::Matrix matrix( + kInputMatrix(cc).Get().data()); + + Eigen::Matrix inverse_matrix; + bool inverse_check; + matrix.computeInverseWithCheck(inverse_matrix, inverse_check); + RET_CHECK(inverse_check) << "Inverse matrix cannot be calculated."; + + std::array output; + Eigen::Map>( + output.data(), 4, 4) = inverse_matrix.matrix(); + kOutputMatrix(cc).Send(std::move(output)); + return absl::OkStatus(); + } +}; +MEDIAPIPE_NODE_IMPLEMENTATION(InverseMatrixCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/inverse_matrix_calculator.h b/mediapipe/calculators/util/inverse_matrix_calculator.h new file mode 100644 index 000000000..ba1657348 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator.h @@ -0,0 +1,51 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" + +namespace mediapipe { + +// Runs affine transformation. +// +// Input: +// MATRIX - std::array +// Row major 4x4 matrix to inverse. +// +// Output: +// MATRIX - std::array +// Row major 4x4 inversed matrix. +// +// Usage example: +// node { +// calculator: "dishti.aimatter.InverseMatrixCalculator" +// input_stream: "MATRIX:input_matrix" +// output_stream: "MATRIX:output_matrix" +// } +class InverseMatrixCalculator : public mediapipe::api2::NodeIntf { + public: + static constexpr mediapipe::api2::Input> kInputMatrix{ + "MATRIX"}; + static constexpr mediapipe::api2::Output> kOutputMatrix{ + "MATRIX"}; + MEDIAPIPE_NODE_INTERFACE(InverseMatrixCalculator, kInputMatrix, + kOutputMatrix); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/inverse_matrix_calculator_test.cc b/mediapipe/calculators/util/inverse_matrix_calculator_test.cc new file mode 100644 index 000000000..d3b629c78 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator_test.cc @@ -0,0 +1,126 @@ +#include "mediapipe/calculators/util/inverse_matrix_calculator.h" + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +void RunTest(const std::array& matrix, + const std::array& expected_inverse_matrix) { + auto graph_config = mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "matrix" + node { + calculator: "InverseMatrixCalculator" + input_stream: "MATRIX:matrix" + output_stream: "MATRIX:inverse_matrix" + } + )pb"); + + std::vector output_packets; + tool::AddVectorSink("inverse_matrix", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "matrix", + MakePacket>(std::move(matrix)).At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_packets, testing::SizeIs(1)); + + const auto& inverse_matrix = output_packets[0].Get>(); + + EXPECT_THAT(inverse_matrix, testing::Eq(expected_inverse_matrix)); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("matrix")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST(InverseMatrixCalculatorTest, Identity) { + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Translation) { + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 2.0f, + 0.0f, 1.0f, 0.0f, -5.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 1.0f, 0.0f, 0.0f, -2.0f, + 0.0f, 1.0f, 0.0f, 5.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Scale) { + // clang-format off + std::array matrix = { + 5.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 0.2f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.5f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Rotation90) { + // clang-format off + std::array matrix = { + 0.0f, -1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 0.0f, 1.0f, 0.0f, 0.0f, + -1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 099bdc7e6..4aab3b676 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -32,6 +32,12 @@ namespace mediapipe { +constexpr char kRenderDataTag[] = "RENDER_DATA"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kScoresTag[] = "SCORES"; +constexpr char kLabelsTag[] = "LABELS"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; + constexpr float kFontHeightScale = 1.25f; // A calculator takes in pairs of labels and scores or classifications, outputs @@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase { REGISTER_CALCULATOR(LabelsToRenderDataCalculator); absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("CLASSIFICATIONS")) { - cc->Inputs().Tag("CLASSIFICATIONS").Set(); + if (cc->Inputs().HasTag(kClassificationsTag)) { + cc->Inputs().Tag(kClassificationsTag).Set(); } else { - RET_CHECK(cc->Inputs().HasTag("LABELS")) + RET_CHECK(cc->Inputs().HasTag(kLabelsTag)) << "Must provide input stream \"LABELS\""; - cc->Inputs().Tag("LABELS").Set>(); - if (cc->Inputs().HasTag("SCORES")) { - cc->Inputs().Tag("SCORES").Set>(); + cc->Inputs().Tag(kLabelsTag).Set>(); + if (cc->Inputs().HasTag(kScoresTag)) { + cc->Inputs().Tag(kScoresTag).Set>(); } } - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } - cc->Outputs().Tag("RENDER_DATA").Set(); + cc->Outputs().Tag(kRenderDataTag).Set(); return absl::OkStatus(); } @@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { } absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().HasTag("VIDEO_PRESTREAM") && + if (cc->Inputs().HasTag(kVideoPrestreamTag) && cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = - cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + cc->Inputs().Tag(kVideoPrestreamTag).Get(); video_width_ = video_header.width; video_height_ = video_header.height; return absl::OkStatus(); @@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { std::vector labels; std::vector scores; - if (cc->Inputs().HasTag("CLASSIFICATIONS")) { + if (cc->Inputs().HasTag(kClassificationsTag)) { const ClassificationList& classifications = - cc->Inputs().Tag("CLASSIFICATIONS").Get(); + cc->Inputs().Tag(kClassificationsTag).Get(); labels.resize(classifications.classification_size()); scores.resize(classifications.classification_size()); for (int i = 0; i < classifications.classification_size(); ++i) { @@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { } } else { const std::vector& label_vector = - cc->Inputs().Tag("LABELS").Get>(); + cc->Inputs().Tag(kLabelsTag).Get>(); labels.resize(label_vector.size()); for (int i = 0; i < label_vector.size(); ++i) { labels[i] = label_vector[i]; } - if (cc->Inputs().HasTag("SCORES")) { + if (cc->Inputs().HasTag(kScoresTag)) { std::vector score_vector = - cc->Inputs().Tag("SCORES").Get>(); + cc->Inputs().Tag(kScoresTag).Get>(); CHECK_EQ(label_vector.size(), score_vector.size()); scores.resize(label_vector.size()); for (int i = 0; i < label_vector.size(); ++i) { @@ -169,7 +175,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { auto* text = label_annotation->mutable_text(); std::string display_text = labels[i]; - if (cc->Inputs().HasTag("SCORES")) { + if (cc->Inputs().HasTag(kScoresTag) || + options_.display_classification_score()) { absl::StrAppend(&display_text, ":", scores[i]); } text->set_display_text(display_text); @@ -179,7 +186,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { text->set_font_face(options_.font_face()); } cc->Outputs() - .Tag("RENDER_DATA") + .Tag(kRenderDataTag) .AddPacket(MakePacket(render_data).At(cc->InputTimestamp())); return absl::OkStatus(); diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index c5012ce85..cf0ada9c2 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -62,4 +62,7 @@ message LabelsToRenderDataCalculatorOptions { // Uses Classification.display_name field instead of Classification.label. optional bool use_display_name = 9 [default = false]; + + // Displays Classification score if enabled. + optional bool display_classification_score = 10 [default = false]; } diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc index 556d5673d..05827220e 100644 --- a/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc @@ -24,6 +24,9 @@ namespace mediapipe { +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLandmarksTag[] = "LANDMARKS"; + NormalizedLandmark CreateLandmark(float x, float y) { NormalizedLandmark landmark; landmark.set_x(x); @@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) { *landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f); *landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f); runner.MutableInputs() - ->Tag("LANDMARKS") + ->Tag(kLandmarksTag) .packets.push_back( Adopt(landmarks.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.2f, 0.f, 0.3f, 0.f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + const std::vector& output = + runner.Outputs().Tag(kLandmarksTag).packets; ASSERT_EQ(1, output.size()); const auto& output_landmarks = output[0].Get(); @@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) { landmark = landmarks->add_landmark(); *landmark = CreateLandmark(0.7f, 0.7f); runner.MutableInputs() - ->Tag("LANDMARKS") + ->Tag(kLandmarksTag) .packets.push_back( Adopt(landmarks.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.0f, 0.2f, 0.0f, 0.3f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + const std::vector& output = + runner.Outputs().Tag(kLandmarksTag).packets; ASSERT_EQ(1, output.size()); const auto& output_landmarks = output[0].Get(); diff --git a/mediapipe/calculators/util/landmark_projection_calculator_test.cc b/mediapipe/calculators/util/landmark_projection_calculator_test.cc index b15bb0f0c..2e919c30e 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator_test.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator_test.cc @@ -16,6 +16,10 @@ namespace mediapipe { namespace { +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; + absl::StatusOr RunCalculator( mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) { mediapipe::CalculatorRunner runner( @@ -26,17 +30,17 @@ absl::StatusOr RunCalculator( output_stream: "NORM_LANDMARKS:projected_landmarks" )pb")); runner.MutableInputs() - ->Tag("NORM_LANDMARKS") + ->Tag(kNormLandmarksTag) .packets.push_back( MakePacket(std::move(input)) .At(Timestamp(1))); runner.MutableInputs() - ->Tag("NORM_RECT") + ->Tag(kNormRectTag) .packets.push_back(MakePacket(std::move(rect)) .At(Timestamp(1))); MP_RETURN_IF_ERROR(runner.Run()); - const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets; RET_CHECK_EQ(output_packets.size(), 1); return output_packets[0].Get(); } @@ -104,17 +108,17 @@ absl::StatusOr RunCalculator( output_stream: "NORM_LANDMARKS:projected_landmarks" )pb")); runner.MutableInputs() - ->Tag("NORM_LANDMARKS") + ->Tag(kNormLandmarksTag) .packets.push_back( MakePacket(std::move(input)) .At(Timestamp(1))); runner.MutableInputs() - ->Tag("PROJECTION_MATRIX") + ->Tag(kProjectionMatrixTag) .packets.push_back(MakePacket>(std::move(matrix)) .At(Timestamp(1))); MP_RETURN_IF_ERROR(runner.Run()); - const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets; RET_CHECK_EQ(output_packets.size(), 1); return output_packets[0].Get(); } diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index fb2310610..6673816e7 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/util/filtering/one_euro_filter.h" @@ -30,6 +31,7 @@ namespace { constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; @@ -94,6 +96,18 @@ float GetObjectScale(const LandmarkList& landmarks) { return (object_width + object_height) / 2.0f; } +float GetObjectScale(const NormalizedRect& roi, const int image_width, + const int image_height) { + const float object_width = roi.width() * image_width; + const float object_height = roi.height() * image_height; + + return (object_width + object_height) / 2.0f; +} + +float GetObjectScale(const Rect& roi) { + return (roi.width() + roi.height()) / 2.0f; +} + // Abstract class for various landmarks filters. class LandmarksFilter { public: @@ -103,6 +117,7 @@ class LandmarksFilter { virtual absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, + const absl::optional object_scale_opt, LandmarkList* out_landmarks) = 0; }; @@ -111,6 +126,7 @@ class NoFilter : public LandmarksFilter { public: absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, + const absl::optional object_scale_opt, LandmarkList* out_landmarks) override { *out_landmarks = in_landmarks; return absl::OkStatus(); @@ -136,13 +152,15 @@ class VelocityFilter : public LandmarksFilter { absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, + const absl::optional object_scale_opt, LandmarkList* out_landmarks) override { // Get value scale as inverse value of the object scale. // If value is too small smoothing will be disabled and landmarks will be // returned as is. float value_scale = 1.0f; if (!disable_value_scaling_) { - const float object_scale = GetObjectScale(in_landmarks); + const float object_scale = + object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); if (object_scale < min_allowed_object_scale_) { *out_landmarks = in_landmarks; return absl::OkStatus(); @@ -205,12 +223,14 @@ class VelocityFilter : public LandmarksFilter { class OneEuroFilterImpl : public LandmarksFilter { public: OneEuroFilterImpl(double frequency, double min_cutoff, double beta, - double derivate_cutoff, float min_allowed_object_scale) + double derivate_cutoff, float min_allowed_object_scale, + bool disable_value_scaling) : frequency_(frequency), min_cutoff_(min_cutoff), beta_(beta), derivate_cutoff_(derivate_cutoff), - min_allowed_object_scale_(min_allowed_object_scale) {} + min_allowed_object_scale_(min_allowed_object_scale), + disable_value_scaling_(disable_value_scaling) {} absl::Status Reset() override { x_filters_.clear(); @@ -221,16 +241,24 @@ class OneEuroFilterImpl : public LandmarksFilter { absl::Status Apply(const LandmarkList& in_landmarks, const absl::Duration& timestamp, + const absl::optional object_scale_opt, LandmarkList* out_landmarks) override { // Initialize filters once. MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); - const float object_scale = GetObjectScale(in_landmarks); - if (object_scale < min_allowed_object_scale_) { - *out_landmarks = in_landmarks; - return absl::OkStatus(); + // Get value scale as inverse value of the object scale. + // If value is too small smoothing will be disabled and landmarks will be + // returned as is. + float value_scale = 1.0f; + if (!disable_value_scaling_) { + const float object_scale = + object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); + if (object_scale < min_allowed_object_scale_) { + *out_landmarks = in_landmarks; + return absl::OkStatus(); + } + value_scale = 1.0f / object_scale; } - const float value_scale = 1.0f / object_scale; // Filter landmarks. Every axis of every landmark is filtered separately. for (int i = 0; i < in_landmarks.landmark_size(); ++i) { @@ -277,6 +305,7 @@ class OneEuroFilterImpl : public LandmarksFilter { double beta_; double derivate_cutoff_; double min_allowed_object_scale_; + bool disable_value_scaling_; std::vector x_filters_; std::vector y_filters_; @@ -292,6 +321,10 @@ class OneEuroFilterImpl : public LandmarksFilter { // IMAGE_SIZE: A std::pair represention of image width and height. // Required to perform all computations in absolute coordinates to avoid any // influence of normalized values. +// OBJECT_SCALE_ROI (optional): A NormRect or Rect (depending on the format of +// input landmarks) used to determine the object scale for some of the +// filters. If not provided - object scale will be calculated from +// landmarks. // // Outputs: // NORM_FILTERED_LANDMARKS: A NormalizedLandmarkList of smoothed landmarks. @@ -301,6 +334,7 @@ class OneEuroFilterImpl : public LandmarksFilter { // calculator: "LandmarksSmoothingCalculator" // input_stream: "NORM_LANDMARKS:pose_landmarks" // input_stream: "IMAGE_SIZE:image_size" +// input_stream: "OBJECT_SCALE_ROI:roi" // output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered" // options: { // [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { @@ -330,9 +364,17 @@ absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { cc->Outputs() .Tag(kNormalizedFilteredLandmarksTag) .Set(); + + if (cc->Inputs().HasTag(kObjectScaleRoiTag)) { + cc->Inputs().Tag(kObjectScaleRoiTag).Set(); + } } else { cc->Inputs().Tag(kLandmarksTag).Set(); cc->Outputs().Tag(kFilteredLandmarksTag).Set(); + + if (cc->Inputs().HasTag(kObjectScaleRoiTag)) { + cc->Inputs().Tag(kObjectScaleRoiTag).Set(); + } } return absl::OkStatus(); @@ -357,7 +399,8 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { options.one_euro_filter().min_cutoff(), options.one_euro_filter().beta(), options.one_euro_filter().derivate_cutoff(), - options.one_euro_filter().min_allowed_object_scale()); + options.one_euro_filter().min_allowed_object_scale(), + options.one_euro_filter().disable_value_scaling()); } else { RET_CHECK_FAIL() << "Landmarks filter is either not specified or not supported"; @@ -389,13 +432,20 @@ absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { std::tie(image_width, image_height) = cc->Inputs().Tag(kImageSizeTag).Get>(); + absl::optional object_scale; + if (cc->Inputs().HasTag(kObjectScaleRoiTag) && + !cc->Inputs().Tag(kObjectScaleRoiTag).IsEmpty()) { + auto& roi = cc->Inputs().Tag(kObjectScaleRoiTag).Get(); + object_scale = GetObjectScale(roi, image_width, image_height); + } + auto in_landmarks = absl::make_unique(); NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, image_height, in_landmarks.get()); auto out_landmarks = absl::make_unique(); - MP_RETURN_IF_ERROR(landmarks_filter_->Apply(*in_landmarks, timestamp, - out_landmarks.get())); + MP_RETURN_IF_ERROR(landmarks_filter_->Apply( + *in_landmarks, timestamp, object_scale, out_landmarks.get())); auto out_norm_landmarks = absl::make_unique(); LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height, @@ -408,9 +458,16 @@ absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { const auto& in_landmarks = cc->Inputs().Tag(kLandmarksTag).Get(); + absl::optional object_scale; + if (cc->Inputs().HasTag(kObjectScaleRoiTag) && + !cc->Inputs().Tag(kObjectScaleRoiTag).IsEmpty()) { + auto& roi = cc->Inputs().Tag(kObjectScaleRoiTag).Get(); + object_scale = GetObjectScale(roi); + } + auto out_landmarks = absl::make_unique(); - MP_RETURN_IF_ERROR( - landmarks_filter_->Apply(in_landmarks, timestamp, out_landmarks.get())); + MP_RETURN_IF_ERROR(landmarks_filter_->Apply( + in_landmarks, timestamp, object_scale, out_landmarks.get())); cc->Outputs() .Tag(kFilteredLandmarksTag) diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.proto b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto index 7699287c9..017facb30 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.proto +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.proto @@ -41,9 +41,9 @@ message LandmarksSmoothingCalculatorOptions { optional float min_allowed_object_scale = 3 [default = 1e-6]; // Disable value scaling based on object size and use `1.0` instead. - // Value scale is calculated as inverse value of object size. Object size is - // calculated as maximum side of rectangular bounding box of the object in - // XY plane. + // If not disabled, value scale is calculated as inverse value of object + // size. Object size is calculated as maximum side of rectangular bounding + // box of the object in XY plane. optional bool disable_value_scaling = 4 [default = false]; } @@ -72,6 +72,12 @@ message LandmarksSmoothingCalculatorOptions { // If calculated object scale is less than given value smoothing will be // disabled and landmarks will be returned as is. optional float min_allowed_object_scale = 5 [default = 1e-6]; + + // Disable value scaling based on object size and use `1.0` instead. + // If not disabled, value scale is calculated as inverse value of object + // size. Object size is calculated as maximum side of rectangular bounding + // box of the object in XY plane. + optional bool disable_value_scaling = 6 [default = false]; } oneof filter_options { diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index fcba83a49..9cd460114 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -20,6 +20,11 @@ #include "mediapipe/framework/port/status.h" namespace mediapipe { + +constexpr char kContentsTag[] = "CONTENTS"; +constexpr char kFileSuffixTag[] = "FILE_SUFFIX"; +constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY"; + // The calculator takes the path to local directory and desired file suffix to // mach as input side packets, and outputs the contents of those files that // match the pattern. Those matched files will be sent sequentially through the @@ -35,16 +40,16 @@ namespace mediapipe { class LocalFilePatternContentsCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Tag("FILE_DIRECTORY").Set(); - cc->InputSidePackets().Tag("FILE_SUFFIX").Set(); - cc->Outputs().Tag("CONTENTS").Set(); + cc->InputSidePackets().Tag(kFileDirectoryTag).Set(); + cc->InputSidePackets().Tag(kFileSuffixTag).Set(); + cc->Outputs().Tag(kContentsTag).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory( - cc->InputSidePackets().Tag("FILE_DIRECTORY").Get(), - cc->InputSidePackets().Tag("FILE_SUFFIX").Get(), + cc->InputSidePackets().Tag(kFileDirectoryTag).Get(), + cc->InputSidePackets().Tag(kFileSuffixTag).Get(), &filenames_)); return absl::OkStatus(); } @@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { filenames_[current_output_], contents.get())); ++current_output_; cc->Outputs() - .Tag("CONTENTS") + .Tag(kContentsTag) .Add(contents.release(), Timestamp(current_output_)); } else { return tool::StatusStop(); diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 35e415505..0e5b2e885 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { // Initialize the clock. if (cc->InputSidePackets().HasTag(kClockTag)) { clock_ = cc->InputSidePackets() - .Tag("CLOCK") + .Tag(kClockTag) .Get>(); } else { clock_ = std::shared_ptr<::mediapipe::Clock>( diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 7c71dd5a1..e0a759bdb 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -130,8 +130,8 @@ absl::Status RectTransformationCalculator::Process(CalculatorContext* cc) { } cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp()); } - if (cc->Inputs().HasTag(kNormRectTag) && - !cc->Inputs().Tag(kNormRectTag).IsEmpty()) { + if (HasTagValue(cc->Inputs(), kNormRectTag) && + HasTagValue(cc->Inputs(), kImageSizeTag)) { auto rect = cc->Inputs().Tag(kNormRectTag).Get(); const auto& image_size = cc->Inputs().Tag(kImageSizeTag).Get>(); @@ -139,8 +139,8 @@ absl::Status RectTransformationCalculator::Process(CalculatorContext* cc) { cc->Outputs().Index(0).AddPacket( MakePacket(rect).At(cc->InputTimestamp())); } - if (cc->Inputs().HasTag(kNormRectsTag) && - !cc->Inputs().Tag(kNormRectsTag).IsEmpty()) { + if (HasTagValue(cc->Inputs(), kNormRectsTag) && + HasTagValue(cc->Inputs(), kImageSizeTag)) { auto rects = cc->Inputs().Tag(kNormRectsTag).Get>(); const auto& image_size = diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc index c86e6ca52..a89d8253f 100644 --- a/mediapipe/calculators/util/thresholding_calculator.cc +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -17,6 +17,12 @@ namespace mediapipe { +constexpr char kThresholdTag[] = "THRESHOLD"; +constexpr char kRejectTag[] = "REJECT"; +constexpr char kAcceptTag[] = "ACCEPT"; +constexpr char kFlagTag[] = "FLAG"; +constexpr char kFloatTag[] = "FLOAT"; + // Applies a threshold on a stream of numeric values and outputs a flag and/or // accept/reject stream. The threshold can be specified by one of the following: // 1) Input stream. @@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase { REGISTER_CALCULATOR(ThresholdingCalculator); absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("FLOAT")); - cc->Inputs().Tag("FLOAT").Set(); + RET_CHECK(cc->Inputs().HasTag(kFloatTag)); + cc->Inputs().Tag(kFloatTag).Set(); - if (cc->Outputs().HasTag("FLAG")) { - cc->Outputs().Tag("FLAG").Set(); + if (cc->Outputs().HasTag(kFlagTag)) { + cc->Outputs().Tag(kFlagTag).Set(); } - if (cc->Outputs().HasTag("ACCEPT")) { - cc->Outputs().Tag("ACCEPT").Set(); + if (cc->Outputs().HasTag(kAcceptTag)) { + cc->Outputs().Tag(kAcceptTag).Set(); } - if (cc->Outputs().HasTag("REJECT")) { - cc->Outputs().Tag("REJECT").Set(); + if (cc->Outputs().HasTag(kRejectTag)) { + cc->Outputs().Tag(kRejectTag).Set(); } - if (cc->Inputs().HasTag("THRESHOLD")) { - cc->Inputs().Tag("THRESHOLD").Set(); + if (cc->Inputs().HasTag(kThresholdTag)) { + cc->Inputs().Tag(kThresholdTag).Set(); } - if (cc->InputSidePackets().HasTag("THRESHOLD")) { - cc->InputSidePackets().Tag("THRESHOLD").Set(); - RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + if (cc->InputSidePackets().HasTag(kThresholdTag)) { + cc->InputSidePackets().Tag(kThresholdTag).Set(); + RET_CHECK(!cc->Inputs().HasTag(kThresholdTag)) << "Using both the threshold input side packet and input stream is not " "supported."; } @@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::ThresholdingCalculatorOptions>(); if (options.has_threshold()) { - RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + RET_CHECK(!cc->Inputs().HasTag(kThresholdTag)) << "Using both the threshold option and input stream is not supported."; - RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD")) + RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag)) << "Using both the threshold option and input side packet is not " "supported."; threshold_ = options.threshold(); } - if (cc->InputSidePackets().HasTag("THRESHOLD")) { - threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); + if (cc->InputSidePackets().HasTag(kThresholdTag)) { + threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get(); } return absl::OkStatus(); } absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().HasTag("THRESHOLD") && - !cc->Inputs().Tag("THRESHOLD").IsEmpty()) { - threshold_ = cc->Inputs().Tag("THRESHOLD").Get(); + if (cc->Inputs().HasTag(kThresholdTag) && + !cc->Inputs().Tag(kThresholdTag).IsEmpty()) { + threshold_ = cc->Inputs().Tag(kThresholdTag).Get(); } bool accept = false; - RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty()); - accept = - static_cast(cc->Inputs().Tag("FLOAT").Get()) > threshold_; + RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty()); + accept = static_cast(cc->Inputs().Tag(kFloatTag).Get()) > + threshold_; - if (cc->Outputs().HasTag("FLAG")) { - cc->Outputs().Tag("FLAG").AddPacket( + if (cc->Outputs().HasTag(kFlagTag)) { + cc->Outputs().Tag(kFlagTag).AddPacket( MakePacket(accept).At(cc->InputTimestamp())); } - if (accept && cc->Outputs().HasTag("ACCEPT")) { - cc->Outputs().Tag("ACCEPT").AddPacket( - MakePacket(true).At(cc->InputTimestamp())); + if (accept && cc->Outputs().HasTag(kAcceptTag)) { + cc->Outputs() + .Tag(kAcceptTag) + .AddPacket(MakePacket(true).At(cc->InputTimestamp())); } - if (!accept && cc->Outputs().HasTag("REJECT")) { - cc->Outputs().Tag("REJECT").AddPacket( - MakePacket(false).At(cc->InputTimestamp())); + if (!accept && cc->Outputs().HasTag(kRejectTag)) { + cc->Outputs() + .Tag(kRejectTag) + .AddPacket(MakePacket(false).At(cc->InputTimestamp())); } return absl::OkStatus(); diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index 37d1b2ab2..42ec5715e 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -39,6 +39,14 @@ namespace mediapipe { +constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION"; +constexpr char kSummaryTag[] = "SUMMARY"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kTopKLabelsTag[] = "TOP_K_LABELS"; +constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; +constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; +constexpr char kScoresTag[] = "SCORES"; + // A calculator that takes a vector of scores and returns the indexes, scores, // labels of the top k elements, classification protos, and summary std::string // (in csv format). @@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase { REGISTER_CALCULATOR(TopKScoresCalculator); absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("SCORES")); - cc->Inputs().Tag("SCORES").Set>(); - if (cc->Outputs().HasTag("TOP_K_INDEXES")) { - cc->Outputs().Tag("TOP_K_INDEXES").Set>(); + RET_CHECK(cc->Inputs().HasTag(kScoresTag)); + cc->Inputs().Tag(kScoresTag).Set>(); + if (cc->Outputs().HasTag(kTopKIndexesTag)) { + cc->Outputs().Tag(kTopKIndexesTag).Set>(); } - if (cc->Outputs().HasTag("TOP_K_SCORES")) { - cc->Outputs().Tag("TOP_K_SCORES").Set>(); + if (cc->Outputs().HasTag(kTopKScoresTag)) { + cc->Outputs().Tag(kTopKScoresTag).Set>(); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { - cc->Outputs().Tag("TOP_K_LABELS").Set>(); + if (cc->Outputs().HasTag(kTopKLabelsTag)) { + cc->Outputs().Tag(kTopKLabelsTag).Set>(); } - if (cc->Outputs().HasTag("CLASSIFICATIONS")) { - cc->Outputs().Tag("CLASSIFICATIONS").Set(); + if (cc->Outputs().HasTag(kClassificationsTag)) { + cc->Outputs().Tag(kClassificationsTag).Set(); } - if (cc->Outputs().HasTag("SUMMARY")) { - cc->Outputs().Tag("SUMMARY").Set(); + if (cc->Outputs().HasTag(kSummaryTag)) { + cc->Outputs().Tag(kSummaryTag).Set(); } return absl::OkStatus(); } @@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { if (options.has_label_map_path()) { MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path())); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { + if (cc->Outputs().HasTag(kTopKLabelsTag)) { RET_CHECK(!label_map_.empty()); } return absl::OkStatus(); @@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { const std::vector& input_vector = - cc->Inputs().Tag("SCORES").Get>(); + cc->Inputs().Tag(kScoresTag).Get>(); std::vector top_k_indexes; std::vector top_k_scores; @@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { top_k_labels.push_back(label_map_[index]); } } - if (cc->Outputs().HasTag("TOP_K_INDEXES")) { + if (cc->Outputs().HasTag(kTopKIndexesTag)) { cc->Outputs() - .Tag("TOP_K_INDEXES") + .Tag(kTopKIndexesTag) .AddPacket(MakePacket>(top_k_indexes) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_SCORES")) { + if (cc->Outputs().HasTag(kTopKScoresTag)) { cc->Outputs() - .Tag("TOP_K_SCORES") + .Tag(kTopKScoresTag) .AddPacket(MakePacket>(top_k_scores) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { + if (cc->Outputs().HasTag(kTopKLabelsTag)) { cc->Outputs() - .Tag("TOP_K_LABELS") + .Tag(kTopKLabelsTag) .AddPacket(MakePacket>(top_k_labels) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("SUMMARY")) { + if (cc->Outputs().HasTag(kSummaryTag)) { std::vector results; for (int index = 0; index < top_k_indexes.size(); ++index) { if (label_map_loaded_) { @@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { absl::StrCat(top_k_indexes[index], ":", top_k_scores[index])); } } - cc->Outputs().Tag("SUMMARY").AddPacket( - MakePacket(absl::StrJoin(results, ",")) - .At(cc->InputTimestamp())); + cc->Outputs() + .Tag(kSummaryTag) + .AddPacket(MakePacket(absl::StrJoin(results, ",")) + .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) { + if (cc->Outputs().HasTag(kTopKClassificationTag)) { auto classification_list = absl::make_unique(); for (int index = 0; index < top_k_indexes.size(); ++index) { Classification* classification = diff --git a/mediapipe/calculators/util/top_k_scores_calculator_test.cc b/mediapipe/calculators/util/top_k_scores_calculator_test.cc index 6e6a2ebad..e5a17af28 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator_test.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator_test.cc @@ -23,6 +23,10 @@ namespace mediapipe { +constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; +constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; +constexpr char kScoresTag[] = "SCORES"; + TEST(TopKScoresCalculatorTest, TestNodeConfig) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TopKScoresCalculator" @@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(2, indexes.size()); EXPECT_EQ(3, indexes[0]); EXPECT_EQ(0, indexes[1]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(2, scores.size()); @@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(4, indexes.size()); @@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) { EXPECT_EQ(2, indexes[2]); EXPECT_EQ(1, indexes[3]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(4, scores.size()); @@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(3, indexes.size()); @@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { EXPECT_EQ(0, indexes[1]); EXPECT_EQ(2, indexes[2]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(3, scores.size()); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index 28cf9498d..bcd7352a2 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -40,7 +40,7 @@ constexpr char kRectTag[] = "NORM_RECT"; // Input: // LANDMARKS: A LandmarkList representing world landmarks in the rectangle. // NORM_RECT: An NormalizedRect representing a normalized rectangle in image -// coordinates. +// coordinates. (Optional) // // Output: // LANDMARKS: A LandmarkList representing world landmarks projected (rotated @@ -59,7 +59,9 @@ class WorldLandmarkProjectionCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kLandmarksTag).Set(); - cc->Inputs().Tag(kRectTag).Set(); + if (cc->Inputs().HasTag(kRectTag)) { + cc->Inputs().Tag(kRectTag).Set(); + } cc->Outputs().Tag(kLandmarksTag).Set(); return absl::OkStatus(); @@ -74,13 +76,24 @@ class WorldLandmarkProjectionCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override { // Check that landmarks and rect are not empty. if (cc->Inputs().Tag(kLandmarksTag).IsEmpty() || - cc->Inputs().Tag(kRectTag).IsEmpty()) { + (cc->Inputs().HasTag(kRectTag) && + cc->Inputs().Tag(kRectTag).IsEmpty())) { return absl::OkStatus(); } const auto& in_landmarks = cc->Inputs().Tag(kLandmarksTag).Get(); - const auto& in_rect = cc->Inputs().Tag(kRectTag).Get(); + std::function rotate_fn; + if (cc->Inputs().HasTag(kRectTag)) { + const auto& in_rect = cc->Inputs().Tag(kRectTag).Get(); + const float cosa = std::cos(in_rect.rotation()); + const float sina = std::sin(in_rect.rotation()); + rotate_fn = [cosa, sina](const Landmark& in_landmark, + Landmark* out_landmark) { + out_landmark->set_x(cosa * in_landmark.x() - sina * in_landmark.y()); + out_landmark->set_y(sina * in_landmark.x() + cosa * in_landmark.y()); + }; + } auto out_landmarks = absl::make_unique(); for (int i = 0; i < in_landmarks.landmark_size(); ++i) { @@ -89,11 +102,9 @@ class WorldLandmarkProjectionCalculator : public CalculatorBase { Landmark* out_landmark = out_landmarks->add_landmark(); *out_landmark = in_landmark; - const float angle = in_rect.rotation(); - out_landmark->set_x(std::cos(angle) * in_landmark.x() - - std::sin(angle) * in_landmark.y()); - out_landmark->set_y(std::sin(angle) * in_landmark.x() + - std::cos(angle) * in_landmark.y()); + if (rotate_fn) { + rotate_fn(in_landmark, out_landmark); + } } cc->Outputs() diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index b7b91d253..55b5c458b 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -47,6 +47,21 @@ namespace mediapipe { +constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT"; +constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME"; +constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kBoxesTag[] = "BOXES"; +constexpr char kReacqSwitchTag[] = "REACQ_SWITCH"; +constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID"; +constexpr char kAddIndexTag[] = "ADD_INDEX"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kDescriptorsTag[] = "DESCRIPTORS"; +constexpr char kFeaturesTag[] = "FEATURES"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES"; +constexpr char kTrackingTag[] = "TRACKING"; + // A calculator to detect reappeared box positions from single frame. // // Input stream: @@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(BoxDetectorCalculator); absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("TRACKING")) { - cc->Inputs().Tag("TRACKING").Set(); + if (cc->Inputs().HasTag(kTrackingTag)) { + cc->Inputs().Tag(kTrackingTag).Set(); } - if (cc->Inputs().HasTag("TRACKED_BOXES")) { - cc->Inputs().Tag("TRACKED_BOXES").Set(); + if (cc->Inputs().HasTag(kTrackedBoxesTag)) { + cc->Inputs().Tag(kTrackedBoxesTag).Set(); } - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } - if (cc->Inputs().HasTag("FEATURES")) { - RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS")) + if (cc->Inputs().HasTag(kFeaturesTag)) { + RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag)) << "FEATURES and DESCRIPTORS need to be specified together."; - cc->Inputs().Tag("FEATURES").Set>(); + cc->Inputs().Tag(kFeaturesTag).Set>(); } - if (cc->Inputs().HasTag("DESCRIPTORS")) { - RET_CHECK(cc->Inputs().HasTag("FEATURES")) + if (cc->Inputs().HasTag(kDescriptorsTag)) { + RET_CHECK(cc->Inputs().HasTag(kFeaturesTag)) << "FEATURES and DESCRIPTORS need to be specified together."; - cc->Inputs().Tag("DESCRIPTORS").Set>(); + cc->Inputs().Tag(kDescriptorsTag).Set>(); } - if (cc->Inputs().HasTag("IMAGE_SIZE")) { - cc->Inputs().Tag("IMAGE_SIZE").Set>(); + if (cc->Inputs().HasTag(kImageSizeTag)) { + cc->Inputs().Tag(kImageSizeTag).Set>(); } - if (cc->Inputs().HasTag("ADD_INDEX")) { - cc->Inputs().Tag("ADD_INDEX").Set(); + if (cc->Inputs().HasTag(kAddIndexTag)) { + cc->Inputs().Tag(kAddIndexTag).Set(); } - if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { - cc->Inputs().Tag("CANCEL_OBJECT_ID").Set(); + if (cc->Inputs().HasTag(kCancelObjectIdTag)) { + cc->Inputs().Tag(kCancelObjectIdTag).Set(); } - if (cc->Inputs().HasTag("REACQ_SWITCH")) { - cc->Inputs().Tag("REACQ_SWITCH").Set(); + if (cc->Inputs().HasTag(kReacqSwitchTag)) { + cc->Inputs().Tag(kReacqSwitchTag).Set(); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Set(); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")) + if (cc->Outputs().HasTag(kVizTag)) { + RET_CHECK(cc->Inputs().HasTag(kVideoTag)) << "Output stream VIZ requires VIDEO to be present."; - cc->Outputs().Tag("VIZ").Set(); + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { - cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set(); + if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) { + cc->InputSidePackets().Tag(kIndexProtoStringTag).Set(); } - if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { - cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set(); + if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) { + cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set(); } - if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { - cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set(); + if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) { + cc->InputSidePackets().Tag(kFrameAlignmentTag).Set(); } return absl::OkStatus(); @@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); box_detector_ = BoxDetectorInterface::Create(options_.detector_options()); - if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { + if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) { BoxDetectorIndex predefined_index; if (!predefined_index.ParseFromString(cc->InputSidePackets() - .Tag("INDEX_PROTO_STRING") + .Tag(kIndexProtoStringTag) .Get())) { LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING"; } @@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { box_detector_->AddBoxDetectorIndex(predefined_index); } - if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { + if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) { write_index_ = true; } - if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { - frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get(); + if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) { + frame_alignment_ = + cc->InputSidePackets().Tag(kFrameAlignmentTag).Get(); } return absl::OkStatus(); @@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const int64 timestamp_msec = timestamp.Value() / 1000; InputStream* cancel_object_id_stream = - cc->Inputs().HasTag("CANCEL_OBJECT_ID") - ? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) + cc->Inputs().HasTag(kCancelObjectIdTag) + ? &(cc->Inputs().Tag(kCancelObjectIdTag)) : nullptr; if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { const int cancel_object_id = cancel_object_id_stream->Get(); box_detector_->CancelBoxDetection(cancel_object_id); } - InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX") - ? &(cc->Inputs().Tag("ADD_INDEX")) + InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag) + ? &(cc->Inputs().Tag(kAddIndexTag)) : nullptr; if (add_index_stream && !add_index_stream->IsEmpty()) { BoxDetectorIndex predefined_index; @@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { box_detector_->AddBoxDetectorIndex(predefined_index); } - InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH") - ? &(cc->Inputs().Tag("REACQ_SWITCH")) + InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag) + ? &(cc->Inputs().Tag(kReacqSwitchTag)) : nullptr; if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) { detector_switch_ = reacq_switch_stream->Get(); @@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } - InputStream* track_stream = cc->Inputs().HasTag("TRACKING") - ? &(cc->Inputs().Tag("TRACKING")) + InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag) + ? &(cc->Inputs().Tag(kTrackingTag)) : nullptr; InputStream* video_stream = - cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr; - InputStream* feature_stream = cc->Inputs().HasTag("FEATURES") - ? &(cc->Inputs().Tag("FEATURES")) + cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr; + InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag) + ? &(cc->Inputs().Tag(kFeaturesTag)) : nullptr; - InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS") - ? &(cc->Inputs().Tag("DESCRIPTORS")) + InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag) + ? &(cc->Inputs().Tag(kDescriptorsTag)) : nullptr; CHECK(track_stream != nullptr || video_stream != nullptr || @@ -266,9 +282,10 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { << "One and only one of {tracking_data, input image frame, " "feature/descriptor} need to be valid."; - InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES") - ? &(cc->Inputs().Tag("TRACKED_BOXES")) - : nullptr; + InputStream* tracked_boxes_stream = + cc->Inputs().HasTag(kTrackedBoxesTag) + ? &(cc->Inputs().Tag(kTrackedBoxesTag)) + : nullptr; std::unique_ptr detected_boxes(new TimedBoxProtoList()); if (track_stream != nullptr) { @@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { } const auto& image_size = - cc->Inputs().Tag("IMAGE_SIZE").Get>(); + cc->Inputs().Tag(kImageSizeTag).Get>(); float inv_scale = 1.0f / std::max(image_size.first, image_size.second); TimedBoxProtoList tracked_boxes; @@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { detected_boxes.get()); } - if (cc->Outputs().HasTag("VIZ")) { + if (cc->Outputs().HasTag(kVizTag)) { cv::Mat viz_view; std::unique_ptr viz_frame; if (video_stream != nullptr && !video_stream->IsEmpty()) { @@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { for (const auto& box : detected_boxes->box()) { RenderBox(box, &viz_view); } - cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp); } return absl::OkStatus(); @@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { if (write_index_) { BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex(); MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents( - cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get(), + cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get(), index.SerializeAsString())); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index 7d04d9765..d3acc322a 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2; namespace { +constexpr char kCacheDirTag[] = "CACHE_DIR"; +constexpr char kInitialPosTag[] = "INITIAL_POS"; +constexpr char kRaBoxesTag[] = "RA_BOXES"; +constexpr char kBoxesTag[] = "BOXES"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING"; +constexpr char kRaTrackTag[] = "RA_TRACK"; +constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID"; +constexpr char kRestartPosTag[] = "RESTART_POS"; +constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING"; +constexpr char kStartPosTag[] = "START_POS"; +constexpr char kStartTag[] = "START"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kTrackTimeTag[] = "TRACK_TIME"; +constexpr char kTrackingTag[] = "TRACKING"; + // Convert box position according to rotation angle in degrees. void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, float in_right, int rotation, float* out_top, @@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, } // namespace. absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("TRACKING")) { - cc->Inputs().Tag("TRACKING").Set(); + if (cc->Inputs().HasTag(kTrackingTag)) { + cc->Inputs().Tag(kTrackingTag).Set(); } - if (cc->Inputs().HasTag("TRACK_TIME")) { - RET_CHECK(cc->Inputs().HasTag("TRACKING")) + if (cc->Inputs().HasTag(kTrackTimeTag)) { + RET_CHECK(cc->Inputs().HasTag(kTrackingTag)) << "TRACK_TIME needs TRACKING input"; - cc->Inputs().Tag("TRACK_TIME").SetAny(); + cc->Inputs().Tag(kTrackTimeTag).SetAny(); } - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } - if (cc->Inputs().HasTag("START")) { + if (cc->Inputs().HasTag(kStartTag)) { // Actual packet content does not matter. - cc->Inputs().Tag("START").SetAny(); + cc->Inputs().Tag(kStartTag).SetAny(); } - if (cc->Inputs().HasTag("START_POS")) { - cc->Inputs().Tag("START_POS").Set(); + if (cc->Inputs().HasTag(kStartPosTag)) { + cc->Inputs().Tag(kStartPosTag).Set(); } - if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) { - cc->Inputs().Tag("START_POS_PROTO_STRING").Set(); + if (cc->Inputs().HasTag(kStartPosProtoStringTag)) { + cc->Inputs().Tag(kStartPosProtoStringTag).Set(); } - if (cc->Inputs().HasTag("RESTART_POS")) { - cc->Inputs().Tag("RESTART_POS").Set(); + if (cc->Inputs().HasTag(kRestartPosTag)) { + cc->Inputs().Tag(kRestartPosTag).Set(); } - if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { - cc->Inputs().Tag("CANCEL_OBJECT_ID").Set(); + if (cc->Inputs().HasTag(kCancelObjectIdTag)) { + cc->Inputs().Tag(kCancelObjectIdTag).Set(); } - if (cc->Inputs().HasTag("RA_TRACK")) { - cc->Inputs().Tag("RA_TRACK").Set(); + if (cc->Inputs().HasTag(kRaTrackTag)) { + cc->Inputs().Tag(kRaTrackTag).Set(); } - if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) { - cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set(); + if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) { + cc->Inputs().Tag(kRaTrackProtoStringTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")) + if (cc->Outputs().HasTag(kVizTag)) { + RET_CHECK(cc->Inputs().HasTag(kVideoTag)) << "Output stream VIZ requires VIDEO to be present."; - cc->Outputs().Tag("VIZ").Set(); + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Set(); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Set(); } - if (cc->Outputs().HasTag("RA_BOXES")) { - cc->Outputs().Tag("RA_BOXES").Set(); + if (cc->Outputs().HasTag(kRaBoxesTag)) { + cc->Outputs().Tag(kRaBoxesTag).Set(); } #if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) - RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS")) + RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag)) << "Unsupported on mobile"; #else - if (cc->InputSidePackets().HasTag("INITIAL_POS")) { - cc->InputSidePackets().Tag("INITIAL_POS").Set(); + if (cc->InputSidePackets().HasTag(kInitialPosTag)) { + cc->InputSidePackets().Tag(kInitialPosTag).Set(); } #endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cc->InputSidePackets().Tag("CACHE_DIR").Set(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cc->InputSidePackets().Tag(kCacheDirTag).Set(); } - RET_CHECK(cc->Inputs().HasTag("TRACKING") != - cc->InputSidePackets().HasTag("CACHE_DIR")) + RET_CHECK(cc->Inputs().HasTag(kTrackingTag) != + cc->InputSidePackets().HasTag(kCacheDirTag)) << "Either TRACKING or CACHE_DIR needs to be specified."; if (cc->InputSidePackets().HasTag(kOptionsTag)) { @@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); - RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") || + RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) || !options_.has_initial_position()) << "Can not specify initial position as side packet and via options"; @@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } #if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__) - if (cc->InputSidePackets().HasTag("INITIAL_POS")) { + if (cc->InputSidePackets().HasTag(kInitialPosTag)) { LOG(INFO) << "Parsing: " - << cc->InputSidePackets().Tag("INITIAL_POS").Get(); + << cc->InputSidePackets().Tag(kInitialPosTag).Get(); initial_pos_ = ParseTextProtoOrDie( - cc->InputSidePackets().Tag("INITIAL_POS").Get()); + cc->InputSidePackets().Tag(kInitialPosTag).Get()); } #endif // !defined(__ANDROID__) && !defined(__APPLE__) && // !defined(__EMSCRIPTEN__) @@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } visualize_tracking_data_ = - options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ"); - visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ"); + options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag); + visualize_state_ = + options_.visualize_state() && cc->Outputs().HasTag(kVizTag); visualize_internal_state_ = - options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ"); + options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag); // Force recording of internal state for rendering. if (visualize_internal_state_) { @@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_.mutable_tracker_options()->set_record_path_states(true); } - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get(); RET_CHECK(!cache_dir_.empty()); box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options())); } else { @@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } if (options_.streaming_track_data_cache_size() > 0) { - RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR")) + RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag)) << "Streaming mode not compatible with cache dir."; } @@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } - InputStream* track_stream = cc->Inputs().HasTag("TRACKING") - ? &(cc->Inputs().Tag("TRACKING")) + InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag) + ? &(cc->Inputs().Tag(kTrackingTag)) : nullptr; - InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME") - ? &(cc->Inputs().Tag("TRACK_TIME")) + InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag) + ? &(cc->Inputs().Tag(kTrackTimeTag)) : nullptr; // Cache tracking data if possible. @@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } } - InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS") - ? &(cc->Inputs().Tag("START_POS")) + InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag) + ? &(cc->Inputs().Tag(kStartPosTag)) : nullptr; MotionBoxMap fast_forward_boxes; @@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* start_pos_proto_string_stream = - cc->Inputs().HasTag("START_POS_PROTO_STRING") - ? &(cc->Inputs().Tag("START_POS_PROTO_STRING")) + cc->Inputs().HasTag(kStartPosProtoStringTag) + ? &(cc->Inputs().Tag(kStartPosProtoStringTag)) : nullptr; if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) { if (start_pos_proto_string_stream && @@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } } - InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS") - ? &(cc->Inputs().Tag("RESTART_POS")) + InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag) + ? &(cc->Inputs().Tag(kRestartPosTag)) : nullptr; if (restart_pos_stream && !restart_pos_stream->IsEmpty()) { @@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* cancel_object_id_stream = - cc->Inputs().HasTag("CANCEL_OBJECT_ID") - ? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) + cc->Inputs().HasTag(kCancelObjectIdTag) + ? &(cc->Inputs().Tag(kCancelObjectIdTag)) : nullptr; if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { const int cancel_object_id = cancel_object_id_stream->Get(); @@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { TrackingData track_data_to_render; - if (cc->Outputs().HasTag("VIZ")) { - InputStream* video_stream = &(cc->Inputs().Tag("VIDEO")); + if (cc->Outputs().HasTag(kVizTag)) { + InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag)); if (!video_stream->IsEmpty()) { input_view = formats::MatView(&video_stream->Get()); @@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { ++frame_num_since_reset_; // Generate results for queued up request. - if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) { + if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) { for (int j = 0; j < queued_track_requests_.size(); ++j) { const Timestamp& past_time = queued_track_requests_[j]; RET_CHECK(past_time.Value() < timestamp.Value()) @@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } // Output for every time. - cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time); + cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time); } queued_track_requests_.clear(); @@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } // Handle random access track requests. - InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK") - ? &(cc->Inputs().Tag("RA_TRACK")) + InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag) + ? &(cc->Inputs().Tag(kRaTrackTag)) : nullptr; if (ra_track_stream && !ra_track_stream->IsEmpty()) { @@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* ra_track_proto_string_stream = - cc->Inputs().HasTag("RA_TRACK_PROTO_STRING") - ? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING")) + cc->Inputs().HasTag(kRaTrackProtoStringTag) + ? &(cc->Inputs().Tag(kRaTrackProtoStringTag)) : nullptr; if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) { if (ra_track_proto_string_stream && @@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { // Always output in batch, only output in streaming if tracking data // is present (might be in fast forward mode instead). - if (cc->Outputs().HasTag("BOXES") && + if (cc->Outputs().HasTag(kBoxesTag) && (box_tracker_ || !track_stream->IsEmpty())) { std::unique_ptr boxes(new TimedBoxProtoList()); *boxes = std::move(box_track_list); - cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp); + cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp); } if (viz_frame) { - cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp); } return absl::OkStatus(); @@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( } cc->Outputs() - .Tag("RA_BOXES") + .Tag(kRaBoxesTag) .Add(result_list.release(), cc->InputTimestamp()); } diff --git a/mediapipe/calculators/video/flow_packager_calculator.cc b/mediapipe/calculators/video/flow_packager_calculator.cc index a57105928..2965cd8e6 100644 --- a/mediapipe/calculators/video/flow_packager_calculator.cc +++ b/mediapipe/calculators/video/flow_packager_calculator.cc @@ -29,6 +29,13 @@ namespace mediapipe { +constexpr char kCacheDirTag[] = "CACHE_DIR"; +constexpr char kCompleteTag[] = "COMPLETE"; +constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK"; +constexpr char kTrackingTag[] = "TRACKING"; +constexpr char kCameraTag[] = "CAMERA"; +constexpr char kFlowTag[] = "FLOW"; + using mediapipe::CameraMotion; using mediapipe::FlowPackager; using mediapipe::RegionFlowFeatureList; @@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase { REGISTER_CALCULATOR(FlowPackagerCalculator); absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { - if (!cc->Inputs().HasTag("FLOW")) { + if (!cc->Inputs().HasTag(kFlowTag)) { return tool::StatusFail("No input flow was specified."); } - cc->Inputs().Tag("FLOW").Set(); + cc->Inputs().Tag(kFlowTag).Set(); - if (cc->Inputs().HasTag("CAMERA")) { - cc->Inputs().Tag("CAMERA").Set(); + if (cc->Inputs().HasTag(kCameraTag)) { + cc->Inputs().Tag(kCameraTag).Set(); } - if (cc->Outputs().HasTag("TRACKING")) { - cc->Outputs().Tag("TRACKING").Set(); + if (cc->Outputs().HasTag(kTrackingTag)) { + cc->Outputs().Tag(kTrackingTag).Set(); } - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { - cc->Outputs().Tag("TRACKING_CHUNK").Set(); + if (cc->Outputs().HasTag(kTrackingChunkTag)) { + cc->Outputs().Tag(kTrackingChunkTag).Set(); } - if (cc->Outputs().HasTag("COMPLETE")) { - cc->Outputs().Tag("COMPLETE").Set(); + if (cc->Outputs().HasTag(kCompleteTag)) { + cc->Outputs().Tag(kCompleteTag).Set(); } - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cc->InputSidePackets().Tag("CACHE_DIR").Set(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cc->InputSidePackets().Tag(kCacheDirTag).Set(); } return absl::OkStatus(); @@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { flow_packager_.reset(new FlowPackager(options_.flow_packager_options())); - use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR"); - build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK"); + use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag); + build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag); if (use_caching_) { - cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); + cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get(); } return absl::OkStatus(); } absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { - InputStream* flow_stream = &(cc->Inputs().Tag("FLOW")); + InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag)); const RegionFlowFeatureList& flow = flow_stream->Get(); const Timestamp timestamp = flow_stream->Value().Timestamp(); const CameraMotion* camera_motion = nullptr; - if (cc->Inputs().HasTag("CAMERA")) { - InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA")); + if (cc->Inputs().HasTag(kCameraTag)) { + InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag)); camera_motion = &camera_stream->Get(); } @@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { if (frame_idx_ > 0) { item->set_prev_timestamp_usec(prev_timestamp_.Value()); } - if (cc->Outputs().HasTag("TRACKING")) { + if (cc->Outputs().HasTag(kTrackingTag)) { // Need to copy as output is requested. *item->mutable_tracking_data() = *tracking_data; } else { @@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { options_.caching_chunk_size_msec() * (chunk_idx_ + 1); if (timestamp.Value() / 1000 >= next_chunk_msec) { - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { + if (cc->Outputs().HasTag(kTrackingChunkTag)) { cc->Outputs() - .Tag("TRACKING_CHUNK") + .Tag(kTrackingChunkTag) .Add(new TrackingDataChunk(tracking_chunk_), Timestamp(tracking_chunk_.item(0).timestamp_usec())); } @@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { } } - if (cc->Outputs().HasTag("TRACKING")) { + if (cc->Outputs().HasTag(kTrackingTag)) { cc->Outputs() - .Tag("TRACKING") + .Tag(kTrackingTag) .Add(tracking_data.release(), flow_stream->Value().Timestamp()); } @@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { if (frame_idx_ > 0) { tracking_chunk_.set_last_chunk(true); - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { + if (cc->Outputs().HasTag(kTrackingChunkTag)) { cc->Outputs() - .Tag("TRACKING_CHUNK") + .Tag(kTrackingChunkTag) .Add(new TrackingDataChunk(tracking_chunk_), Timestamp(tracking_chunk_.item(0).timestamp_usec())); } @@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { } } - if (cc->Outputs().HasTag("COMPLETE")) { - cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream()); + if (cc->Outputs().HasTag(kCompleteTag)) { + cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index 59673108c..6217d3be9 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -38,6 +38,18 @@ namespace mediapipe { +constexpr char kDownsampleTag[] = "DOWNSAMPLE"; +constexpr char kCsvFileTag[] = "CSV_FILE"; +constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT"; +constexpr char kVideoOutTag[] = "VIDEO_OUT"; +constexpr char kDenseFgTag[] = "DENSE_FG"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kSaliencyTag[] = "SALIENCY"; +constexpr char kCameraTag[] = "CAMERA"; +constexpr char kFlowTag[] = "FLOW"; +constexpr char kSelectionTag[] = "SELECTION"; +constexpr char kVideoTag[] = "VIDEO"; + using mediapipe::AffineAdapter; using mediapipe::CameraMotion; using mediapipe::FrameSelectionResult; @@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase { REGISTER_CALCULATOR(MotionAnalysisCalculator); absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } // Optional input stream from frame selection calculator. - if (cc->Inputs().HasTag("SELECTION")) { - cc->Inputs().Tag("SELECTION").Set(); + if (cc->Inputs().HasTag(kSelectionTag)) { + cc->Inputs().Tag(kSelectionTag).Set(); } - RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION")) + RET_CHECK(cc->Inputs().HasTag(kVideoTag) || + cc->Inputs().HasTag(kSelectionTag)) << "Either VIDEO, SELECTION must be specified."; - if (cc->Outputs().HasTag("FLOW")) { - cc->Outputs().Tag("FLOW").Set(); + if (cc->Outputs().HasTag(kFlowTag)) { + cc->Outputs().Tag(kFlowTag).Set(); } - if (cc->Outputs().HasTag("CAMERA")) { - cc->Outputs().Tag("CAMERA").Set(); + if (cc->Outputs().HasTag(kCameraTag)) { + cc->Outputs().Tag(kCameraTag).Set(); } - if (cc->Outputs().HasTag("SALIENCY")) { - cc->Outputs().Tag("SALIENCY").Set(); + if (cc->Outputs().HasTag(kSaliencyTag)) { + cc->Outputs().Tag(kSaliencyTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - cc->Outputs().Tag("VIZ").Set(); + if (cc->Outputs().HasTag(kVizTag)) { + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->Outputs().HasTag("DENSE_FG")) { - cc->Outputs().Tag("DENSE_FG").Set(); + if (cc->Outputs().HasTag(kDenseFgTag)) { + cc->Outputs().Tag(kDenseFgTag).Set(); } - if (cc->Outputs().HasTag("VIDEO_OUT")) { - cc->Outputs().Tag("VIDEO_OUT").Set(); + if (cc->Outputs().HasTag(kVideoOutTag)) { + cc->Outputs().Tag(kVideoOutTag).Set(); } - if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) { + if (cc->Outputs().HasTag(kGrayVideoOutTag)) { // We only output grayscale video if we're actually performing full region- // flow analysis on the video. - RET_CHECK(cc->Inputs().HasTag("VIDEO") && - !cc->Inputs().HasTag("SELECTION")); - cc->Outputs().Tag("GRAY_VIDEO_OUT").Set(); + RET_CHECK(cc->Inputs().HasTag(kVideoTag) && + !cc->Inputs().HasTag(kSelectionTag)); + cc->Outputs().Tag(kGrayVideoOutTag).Set(); } - if (cc->InputSidePackets().HasTag("CSV_FILE")) { - cc->InputSidePackets().Tag("CSV_FILE").Set(); + if (cc->InputSidePackets().HasTag(kCsvFileTag)) { + cc->InputSidePackets().Tag(kCsvFileTag).Set(); } - if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { - cc->InputSidePackets().Tag("DOWNSAMPLE").Set(); + if (cc->InputSidePackets().HasTag(kDownsampleTag)) { + cc->InputSidePackets().Tag(kDownsampleTag).Set(); } if (cc->InputSidePackets().HasTag(kOptionsTag)) { @@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); - video_input_ = cc->Inputs().HasTag("VIDEO"); - selection_input_ = cc->Inputs().HasTag("SELECTION"); - region_flow_feature_output_ = cc->Outputs().HasTag("FLOW"); - camera_motion_output_ = cc->Outputs().HasTag("CAMERA"); - saliency_output_ = cc->Outputs().HasTag("SALIENCY"); - visualize_output_ = cc->Outputs().HasTag("VIZ"); - dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG"); - video_output_ = cc->Outputs().HasTag("VIDEO_OUT"); - grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT"); - csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE"); + video_input_ = cc->Inputs().HasTag(kVideoTag); + selection_input_ = cc->Inputs().HasTag(kSelectionTag); + region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag); + camera_motion_output_ = cc->Outputs().HasTag(kCameraTag); + saliency_output_ = cc->Outputs().HasTag(kSaliencyTag); + visualize_output_ = cc->Outputs().HasTag(kVizTag); + dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag); + video_output_ = cc->Outputs().HasTag(kVideoOutTag); + grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag); + csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag); hybrid_meta_analysis_ = options_.meta_analysis() == MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID; @@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { if (csv_file_input_) { // Read from file and parse. const std::string filename = - cc->InputSidePackets().Tag("CSV_FILE").Get(); + cc->InputSidePackets().Tag(kCsvFileTag).Get(); std::string file_contents; std::ifstream input_file(filename, std::ios::in); @@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { // Get video header from video or selection input if present. const VideoHeader* video_header = nullptr; - if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) { - video_header = &(cc->Inputs().Tag("VIDEO").Header().Get()); + if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) { + video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get()); } else if (selection_input_ && - !cc->Inputs().Tag("SELECTION").Header().IsEmpty()) { - video_header = &(cc->Inputs().Tag("SELECTION").Header().Get()); + !cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) { + video_header = + &(cc->Inputs().Tag(kSelectionTag).Header().Get()); } else { LOG(WARNING) << "No input video header found. Downstream calculators " "expecting video headers are likely to fail."; @@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { with_saliency_ = options_.analysis_options().compute_motion_saliency(); // Force computation of saliency if requested as output. - if (cc->Outputs().HasTag("SALIENCY")) { + if (cc->Outputs().HasTag(kSaliencyTag)) { with_saliency_ = true; if (!options_.analysis_options().compute_motion_saliency()) { LOG(WARNING) << "Enable saliency computation. Set " @@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); } - if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { + if (cc->InputSidePackets().HasTag(kDownsampleTag)) { options_.mutable_analysis_options() ->mutable_flow_options() ->set_downsample_factor( - cc->InputSidePackets().Tag("DOWNSAMPLE").Get()); + cc->InputSidePackets().Tag(kDownsampleTag).Get()); } // If no video header is provided, just return and initialize on the first @@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { ////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE /////////////// if (visualize_output_) { - cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header))); + cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header))); } if (video_output_) { cc->Outputs() - .Tag("VIDEO_OUT") + .Tag(kVideoOutTag) .SetHeader(Adopt(new VideoHeader(*video_header))); } - if (cc->Outputs().HasTag("DENSE_FG")) { + if (cc->Outputs().HasTag(kDenseFgTag)) { std::unique_ptr foreground_header( new VideoHeader(*video_header)); foreground_header->format = ImageFormat::GRAY8; - cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release())); - } - - if (cc->Outputs().HasTag("CAMERA")) { - cc->Outputs().Tag("CAMERA").SetHeader( - Adopt(new VideoHeader(*video_header))); - } - - if (cc->Outputs().HasTag("SALIENCY")) { cc->Outputs() - .Tag("SALIENCY") + .Tag(kDenseFgTag) + .SetHeader(Adopt(foreground_header.release())); + } + + if (cc->Outputs().HasTag(kCameraTag)) { + cc->Outputs() + .Tag(kCameraTag) + .SetHeader(Adopt(new VideoHeader(*video_header))); + } + + if (cc->Outputs().HasTag(kSaliencyTag)) { + cc->Outputs() + .Tag(kSaliencyTag) .SetHeader(Adopt(new VideoHeader(*video_header))); } @@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { } InputStream* video_stream = - video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr; + video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr; InputStream* selection_stream = - selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr; + selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr; // Checked on Open. CHECK(video_stream || selection_stream); @@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { CameraMotion output_motion = meta_motions_.front(); meta_motions_.pop_front(); output_motion.set_timestamp_usec(timestamp.Value()); - cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion), - timestamp); + cc->Outputs() + .Tag(kCameraTag) + .Add(new CameraMotion(output_motion), timestamp); } if (region_flow_feature_output_) { @@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { meta_features_.pop_front(); output_features.set_timestamp_usec(timestamp.Value()); - cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features), - timestamp); + cc->Outputs().Tag(kFlowTag).Add( + new RegionFlowFeatureList(output_features), timestamp); } ++frame_idx_; @@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) { // Output concatenated results, nothing to compute here. if (camera_motion_output_) { - cc->Outputs().Tag("CAMERA").Add( - frame_selection_result->release_camera_motion(), timestamp); + cc->Outputs() + .Tag(kCameraTag) + .Add(frame_selection_result->release_camera_motion(), timestamp); } if (region_flow_feature_output_) { - cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(), - timestamp); + cc->Outputs().Tag(kFlowTag).Add( + frame_selection_result->release_features(), timestamp); } if (video_output_) { - cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value()); + cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value()); } return absl::OkStatus(); @@ -549,7 +568,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { timestamp_buffer_.push_back(timestamp); ++frame_idx_; - VLOG_EVERY_N(0, 100) << "Analyzed frame " << frame_idx_; + VLOG_EVERY_N(1, 100) << "Analyzed frame " << frame_idx_; // Buffer input frames only if visualization is requested. if (visualize_output_ || video_output_) { @@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { grayscale_mat.copyTo(image_frame_mat); cc->Outputs() - .Tag("GRAY_VIDEO_OUT") + .Tag(kGrayVideoOutTag) .Add(grayscale_image.release(), timestamp); } @@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( *feature_list, *camera_motion, with_saliency_ ? saliency[k].get() : nullptr, &visualization); - cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp); } // Output dense foreground mask. @@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( cv::Mat foreground = formats::MatView(foreground_frame.get()); motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion, &foreground); - cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp); + cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp); } // Output flow features if requested. if (region_flow_feature_output_) { - cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp); + cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp); } // Output camera motion. if (camera_motion_output_) { - cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp); + cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp); } if (video_output_) { - cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]); + cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]); } // Output saliency. if (saliency_output_) { - cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp); + cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp); } } diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index bf7ed3e8a..94ddbb836 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -27,6 +27,12 @@ namespace mediapipe { namespace { + +constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; + // cv::VideoCapture set data type to unsigned char by default. Therefore, the // image format is only related to the number of channles the cv::Mat has. ImageFormat::Format GetImageFormat(int num_channels) { @@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) { class OpenCvVideoDecoderCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); - cc->Outputs().Tag("VIDEO").Set(); - if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { - cc->Outputs().Tag("VIDEO_PRESTREAM").Set(); + cc->InputSidePackets().Tag(kInputFilePathTag).Set(); + cc->Outputs().Tag(kVideoTag).Set(); + if (cc->Outputs().HasTag(kVideoPrestreamTag)) { + cc->Outputs().Tag(kVideoPrestreamTag).Set(); } - if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { - cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set(); + if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) { + cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { const std::string& input_file_path = - cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kInputFilePathTag).Get(); cap_ = absl::make_unique(input_file_path); if (!cap_->isOpened()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { header->frame_rate = fps; header->duration = frame_count_ / fps; - if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { + if (cc->Outputs().HasTag(kVideoPrestreamTag)) { cc->Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .Add(header.release(), Timestamp::PreStream()); - cc->Outputs().Tag("VIDEO_PRESTREAM").Close(); + cc->Outputs().Tag(kVideoPrestreamTag).Close(); } // Rewind to the very first frame. cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0); - if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { + if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) { #ifdef HAVE_FFMPEG std::string saved_audio_path = std::tmpnam(nullptr); std::string ffmpeg_command = @@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str()); if (status_code == 0) { cc->OutputSidePackets() - .Tag("SAVED_AUDIO_PATH") + .Tag(kSavedAudioPathTag) .Set(MakePacket(saved_audio_path)); } else { LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path << " by executing the following command: " << ffmpeg_command; cc->OutputSidePackets() - .Tag("SAVED_AUDIO_PATH") + .Tag(kSavedAudioPathTag) .Set(MakePacket(std::string())); } #else @@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { // If the timestamp of the current frame is not greater than the one of the // previous frame, the new frame will be discarded. if (prev_timestamp_ < timestamp) { - cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp); + cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp); prev_timestamp_ = timestamp; decoded_frames_++; } diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc index 03d27b6fe..035e5a8c9 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc @@ -29,6 +29,10 @@ namespace mediapipe { namespace { +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; + TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_MP4_AVC720P_AAC.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(1280, header.width); EXPECT_EQ(640, header.height); @@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { // The number of the output packets should be 180. // Some OpenCV version returns the first two frames with the same timestamp on // macos and we might miss one frame here. - int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); + int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size(); EXPECT_GE(num_of_packets, 179); for (int i = 0; i < num_of_packets; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(1280, output_mat.size().width); @@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_FLV_H264_AAC.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(640, header.width); EXPECT_EQ(320, header.height); @@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { // can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4). // EXPECT_FLOAT_EQ(6.0f, header.duration); // EXPECT_FLOAT_EQ(30.0f, header.frame_rate); - EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); + EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size()); for (int i = 0; i < 180; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(640, output_mat.size().width); @@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_MKV_VP8_VORBIS.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(640, header.width); EXPECT_EQ(320, header.height); @@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { // The number of the output packets should be 180. // Some OpenCV version returns the first two frames with the same timestamp on // macos and we might miss one frame here. - int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); + int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size(); EXPECT_GE(num_of_packets, 179); for (int i = 0; i < num_of_packets; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(640, output_mat.size().width); diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index 9a74fb710..4af8c5955 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -36,6 +36,11 @@ namespace mediapipe { +constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH"; +constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kVideoTag[] = "VIDEO"; + // Encodes the input video stream and produces a media file. // The media file can be output to the output_file_path specified as a side // packet. Currently, the calculator only supports one video stream (in @@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { }; absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")); - cc->Inputs().Tag("VIDEO").Set(); - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + RET_CHECK(cc->Inputs().HasTag(kVideoTag)); + cc->Inputs().Tag(kVideoTag).Set(); + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } - RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH")); - cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set(); - if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { - cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag)); + cc->InputSidePackets().Tag(kOutputFilePathTag).Set(); + if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) { + cc->InputSidePackets().Tag(kAudioFilePathTag).Set(); } return absl::OkStatus(); } @@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { << "Video format must be specified in " "OpenCvVideoEncoderCalculatorOptions"; output_file_path_ = - cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kOutputFilePathTag).Get(); std::vector splited_file_path = absl::StrSplit(output_file_path_, '.'); RET_CHECK(splited_file_path.size() >= 2 && @@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { // If the video header will be available, the video metadata will be fetched // from the video header directly. The calculator will receive the video // header packet at timestamp prestream. - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { return absl::OkStatus(); } return SetUpVideoWriter(options.fps(), options.width(), options.height()); @@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = - cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + cc->Inputs().Tag(kVideoPrestreamTag).Get(); return SetUpVideoWriter(video_header.frame_rate, video_header.width, video_header.height); } const ImageFrame& image_frame = - cc->Inputs().Tag("VIDEO").Value().Get(); + cc->Inputs().Tag(kVideoTag).Value().Get(); ImageFormat::Format format = image_frame.Format(); cv::Mat frame; if (format == ImageFormat::GRAY8) { @@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (frame.empty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " - << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << cc->Inputs().Tag(kVideoTag).Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; } } else { @@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (tmp_frame.empty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " - << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << cc->Inputs().Tag(kVideoTag).Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; } if (format == ImageFormat::SRGB) { @@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { if (writer_ && writer_->isOpened()) { writer_->release(); } - if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { + if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) { #ifdef HAVE_FFMPEG const std::string& audio_file_path = - cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kAudioFilePathTag).Get(); if (audio_file_path.empty()) { LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the " "audio tracks to the generated video because the audio " diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc index cf00da1f7..56f3253e2 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -23,6 +23,11 @@ namespace mediapipe { namespace { +constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW"; +constexpr char kForwardFlowTag[] = "FORWARD_FLOW"; +constexpr char kSecondFrameTag[] = "SECOND_FRAME"; +constexpr char kFirstFrameTag[] = "FIRST_FRAME"; + // Checks that img1 and img2 have the same dimensions. bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) { return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height()); @@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { }; absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) { - if (!cc->Inputs().HasTag("FIRST_FRAME") || - !cc->Inputs().HasTag("SECOND_FRAME")) { + if (!cc->Inputs().HasTag(kFirstFrameTag) || + !cc->Inputs().HasTag(kSecondFrameTag)) { return absl::InvalidArgumentError( "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " "must be specified."); } - cc->Inputs().Tag("FIRST_FRAME").Set(); - cc->Inputs().Tag("SECOND_FRAME").Set(); - if (cc->Outputs().HasTag("FORWARD_FLOW")) { - cc->Outputs().Tag("FORWARD_FLOW").Set(); + cc->Inputs().Tag(kFirstFrameTag).Set(); + cc->Inputs().Tag(kSecondFrameTag).Set(); + if (cc->Outputs().HasTag(kForwardFlowTag)) { + cc->Outputs().Tag(kForwardFlowTag).Set(); } - if (cc->Outputs().HasTag("BACKWARD_FLOW")) { - cc->Outputs().Tag("BACKWARD_FLOW").Set(); + if (cc->Outputs().HasTag(kBackwardFlowTag)) { + cc->Outputs().Tag(kBackwardFlowTag).Set(); } return absl::OkStatus(); } @@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { absl::MutexLock lock(&mutex_); tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); } - if (cc->Outputs().HasTag("FORWARD_FLOW")) { + if (cc->Outputs().HasTag(kForwardFlowTag)) { forward_requested_ = true; } - if (cc->Outputs().HasTag("BACKWARD_FLOW")) { + if (cc->Outputs().HasTag(kBackwardFlowTag)) { backward_requested_ = true; } @@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { const ImageFrame& first_frame = - cc->Inputs().Tag("FIRST_FRAME").Value().Get(); + cc->Inputs().Tag(kFirstFrameTag).Value().Get(); const ImageFrame& second_frame = - cc->Inputs().Tag("SECOND_FRAME").Value().Get(); + cc->Inputs().Tag(kSecondFrameTag).Value().Get(); if (forward_requested_) { auto forward_optical_flow_field = absl::make_unique(); MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame, forward_optical_flow_field.get())); cc->Outputs() - .Tag("FORWARD_FLOW") + .Tag(kForwardFlowTag) .Add(forward_optical_flow_field.release(), cc->InputTimestamp()); } if (backward_requested_) { @@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame, backward_optical_flow_field.get())); cc->Outputs() - .Tag("BACKWARD_FLOW") + .Tag(kBackwardFlowTag) .Add(backward_optical_flow_field.release(), cc->InputTimestamp()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/video_pre_stream_calculator.cc b/mediapipe/calculators/video/video_pre_stream_calculator.cc index ab9cd22a4..317d4baad 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator.cc @@ -19,6 +19,9 @@ namespace mediapipe { +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kFrameTag[] = "FRAME"; + // Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp // PreStream. Note that this calculator only fills in format, width, and height, // i.e. frame_rate and duration will not be filled, unless: @@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().UsesTags()) { cc->Inputs().Index(0).Set(); } else { - cc->Inputs().Tag("FRAME").Set(); - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + cc->Inputs().Tag(kFrameTag).Set(); + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } cc->Outputs().Index(0).Set(); return absl::OkStatus(); @@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { frame_rate_in_prestream_ = cc->Inputs().UsesTags() && - cc->Inputs().HasTag("FRAME") && - cc->Inputs().HasTag("VIDEO_PRESTREAM"); + cc->Inputs().HasTag(kFrameTag) && + cc->Inputs().HasTag(kVideoPrestreamTag); header_ = absl::make_unique(); return absl::OkStatus(); } @@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment(); if (cc->InputTimestamp() == Timestamp::PreStream()) { - RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty()); - RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()); - *header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty()); + RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty()); + *header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get(); RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero"; } else { - RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()) + RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty()) << "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream()."; - RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty()); - const auto& frame = cc->Inputs().Tag("FRAME").Get(); + RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty()); + const auto& frame = cc->Inputs().Tag(kFrameTag).Get(); header_->format = frame.Format(); header_->width = frame.Width(); header_->height = frame.Height(); diff --git a/mediapipe/examples/android/solutions/BUILD b/mediapipe/examples/android/solutions/BUILD new file mode 100644 index 000000000..1ba23afe6 --- /dev/null +++ b/mediapipe/examples/android/solutions/BUILD @@ -0,0 +1,21 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), + visibility = ["//mediapipe/examples/android/solutions:__subpackages__"], +) diff --git a/mediapipe/examples/android/solutions/build.gradle b/mediapipe/examples/android/solutions/build.gradle new file mode 100644 index 000000000..691e41013 --- /dev/null +++ b/mediapipe/examples/android/solutions/build.gradle @@ -0,0 +1,24 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +buildscript { + repositories { + google() + mavenCentral() + } + dependencies { + classpath "com.android.tools.build:gradle:4.2.0" + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + mavenCentral() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/mediapipe/examples/android/solutions/create_win_symlinks.bat b/mediapipe/examples/android/solutions/create_win_symlinks.bat new file mode 100644 index 000000000..ea641b6e9 --- /dev/null +++ b/mediapipe/examples/android/solutions/create_win_symlinks.bat @@ -0,0 +1,16 @@ +@rem Remove the current res dir symlinks that are for Linux and macOS and recreate res dir symlinks for Windows. +@rem This script needs administrator permission. Must run this script as administrator. + +@rem for hands example app. +cd /d %~dp0 +cd hands\src\main +rm res +mklink /d res ..\..\..\res + +@rem for facemesh example app. +cd /d %~dp0 +cd facemesh\src\main +rm res +mklink /d res ..\..\..\res +dir +pause diff --git a/mediapipe/examples/android/solutions/facemesh/build.gradle b/mediapipe/examples/android/solutions/facemesh/build.gradle new file mode 100644 index 000000000..74aedf095 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/build.gradle @@ -0,0 +1,50 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.hands" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe hands solution API and solution-core. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:facemesh:latest.release' + // MediaPipe deps + implementation 'com.google.flogger:flogger:latest.release' + implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.guava:guava:27.0.1-android' + implementation 'com.google.protobuf:protobuf-java:3.11.4' + // CameraX core library + def camerax_version = "1.0.0-beta10" + implementation "androidx.camera:camera-core:$camerax_version" + implementation "androidx.camera:camera-camera2:$camerax_version" + implementation "androidx.camera:camera-lifecycle:$camerax_version" +} diff --git a/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro b/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml new file mode 100644 index 000000000..de062995a --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/BUILD b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD new file mode 100644 index 000000000..591102c3e --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD @@ -0,0 +1,44 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "facemesh", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.facemesh", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.facemesh", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/facemesh", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java new file mode 100644 index 000000000..fd6c533d3 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java @@ -0,0 +1,186 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.opengl.GLES20; +import android.opengl.Matrix; +import com.google.common.collect.ImmutableSet; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.ResultGlBoundary; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.facemesh.FaceMeshConnections; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.List; + +/** A custom implementation of {@link ResultGlRenderer} to render MediaPope FaceMesh results. */ +public class FaceMeshResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "FaceMeshResultGlRenderer"; + + private static final float[] TESSELATION_COLOR = new float[] {0.75f, 0.75f, 0.75f, 0.5f}; + private static final int TESSELATION_THICKNESS = 5; + private static final float[] RIGHT_EYE_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final int RIGHT_EYE_THICKNESS = 8; + private static final float[] RIGHT_EYEBROW_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final int RIGHT_EYEBROW_THICKNESS = 8; + private static final float[] LEFT_EYE_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final int LEFT_EYE_THICKNESS = 8; + private static final float[] LEFT_EYEBROW_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final int LEFT_EYEBROW_THICKNESS = 8; + private static final float[] FACE_OVAL_COLOR = new float[] {0.9f, 0.9f, 0.9f, 1f}; + private static final int FACE_OVAL_THICKNESS = 8; + private static final float[] LIPS_COLOR = new float[] {0.9f, 0.9f, 0.9f, 1f}; + private static final int LIPS_THICKNESS = 8; + private static final String VERTEX_SHADER = + "uniform mat4 uTransformMatrix;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uTransformMatrix * vPosition;\n" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "uniform vec4 uColor;\n" + + "void main() {\n" + + " gl_FragColor = uColor;\n" + + "}"; + private int program; + private int positionHandle; + private int transformMatrixHandle; + private int colorHandle; + private final float[] transformMatrix = new float[16]; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + transformMatrixHandle = GLES20.glGetUniformLocation(program, "uTransformMatrix"); + colorHandle = GLES20.glGetUniformLocation(program, "uColor"); + } + + @Override + public void renderResult(FaceMeshResult result, ResultGlBoundary boundary) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + // Sets the transform matrix to align the result rendering with the scaled output texture. + // Also flips the rendering vertically since OpenGL assumes the coordinate origin is at the + // bottom-left corner, whereas MediaPipe landmark data assumes the coordinate origin is at the + // top-left corner. + Matrix.setIdentityM(transformMatrix, 0); + Matrix.scaleM( + transformMatrix, + 0, + 2 / (boundary.right() - boundary.left()), + -2 / (boundary.top() - boundary.bottom()), + 1.0f); + GLES20.glUniformMatrix4fv(transformMatrixHandle, 1, false, transformMatrix, 0); + + int numFaces = result.multiFaceLandmarks().size(); + for (int i = 0; i < numFaces; ++i) { + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_TESSELATION, + TESSELATION_COLOR, + TESSELATION_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYE, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYEBROW, + RIGHT_EYEBROW_COLOR, + RIGHT_EYEBROW_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYE, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYEBR0W, + LEFT_EYEBROW_COLOR, + LEFT_EYEBROW_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_FACE_OVAL, + FACE_OVAL_COLOR, + FACE_OVAL_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LIPS, + LIPS_COLOR, + LIPS_THICKNESS); + } + } + + /** + * Calls this to delete the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + private void drawLandmarks( + List faceLandmarkList, + ImmutableSet connections, + float[] colorArray, + int thickness) { + GLES20.glUniform4fv(colorHandle, 1, colorArray, 0); + GLES20.glLineWidth(thickness); + for (FaceMeshConnections.Connection c : connections) { + float[] vertex = new float[4]; + NormalizedLandmark start = faceLandmarkList.get(c.start()); + vertex[0] = normalizedLandmarkValue(start.getX()); + vertex[1] = normalizedLandmarkValue(start.getY()); + NormalizedLandmark end = faceLandmarkList.get(c.end()); + vertex[2] = normalizedLandmarkValue(end.getX()); + vertex[3] = normalizedLandmarkValue(end.getY()); + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } + } + + // Normalizes the value from the landmark value range:[0, 1] to the standard OpenGL coordinate + // value range: [-1, 1]. + private float normalizedLandmarkValue(float value) { + return value * 2 - 1; + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java new file mode 100644 index 000000000..9db91a8e3 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java @@ -0,0 +1,158 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import android.util.Size; +import com.google.common.collect.ImmutableSet; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutions.facemesh.FaceMeshConnections; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.util.List; + +/** An ImageView implementation for displaying MediaPipe FaceMesh results. */ +public class FaceMeshResultImageView extends AppCompatImageView { + private static final String TAG = "FaceMeshResultImageView"; + + private static final int TESSELATION_COLOR = Color.parseColor("#70C0C0C0"); + private static final int TESSELATION_THICKNESS = 5; + private static final int RIGHT_EYE_COLOR = Color.parseColor("#FF3030"); + private static final int RIGHT_EYE_THICKNESS = 8; + private static final int RIGHT_EYEBROW_COLOR = Color.parseColor("#FF3030"); + private static final int RIGHT_EYEBROW_THICKNESS = 8; + private static final int LEFT_EYE_COLOR = Color.parseColor("#30FF30"); + private static final int LEFT_EYE_THICKNESS = 8; + private static final int LEFT_EYEBROW_COLOR = Color.parseColor("#30FF30"); + private static final int LEFT_EYEBROW_THICKNESS = 8; + private static final int FACE_OVAL_COLOR = Color.parseColor("#E0E0E0"); + private static final int FACE_OVAL_THICKNESS = 8; + private static final int LIPS_COLOR = Color.parseColor("#E0E0E0"); + private static final int LIPS_THICKNESS = 8; + private Bitmap latest; + + public FaceMeshResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link FaceMeshResult} to render. + * + * @param result a {@link FaceMeshResult} object that contains the solution outputs and the input + * {@link Bitmap}. + */ + public void setFaceMeshResult(FaceMeshResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + Size imageSize = new Size(width, height); + canvas.drawBitmap(bmInput, new Matrix(), null); + int numFaces = result.multiFaceLandmarks().size(); + for (int i = 0; i < numFaces; ++i) { + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_TESSELATION, + imageSize, + TESSELATION_COLOR, + TESSELATION_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYE, + imageSize, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYEBROW, + imageSize, + RIGHT_EYEBROW_COLOR, + RIGHT_EYEBROW_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYE, + imageSize, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYEBR0W, + imageSize, + LEFT_EYEBROW_COLOR, + LEFT_EYEBROW_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_FACE_OVAL, + imageSize, + FACE_OVAL_COLOR, + FACE_OVAL_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LIPS, + imageSize, + LIPS_COLOR, + LIPS_THICKNESS); + } + } + + /** Updates the image view with the latest facemesh result. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + // TODO: Better hand landmark and hand connection drawing. + private void drawLandmarksOnCanvas( + Canvas canvas, + List faceLandmarkList, + ImmutableSet connections, + Size imageSize, + int color, + int thickness) { + // Draw connections. + for (FaceMeshConnections.Connection c : connections) { + Paint connectionPaint = new Paint(); + connectionPaint.setColor(color); + connectionPaint.setStrokeWidth(thickness); + NormalizedLandmark start = faceLandmarkList.get(c.start()); + NormalizedLandmark end = faceLandmarkList.get(c.end()); + canvas.drawLine( + start.getX() * imageSize.getWidth(), + start.getY() * imageSize.getHeight(), + end.getX() * imageSize.getWidth(), + end.getY() * imageSize.getHeight(), + connectionPaint); + } + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java new file mode 100644 index 000000000..27c89a93e --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java @@ -0,0 +1,308 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.facemesh.FaceMesh; +import com.google.mediapipe.solutions.facemesh.FaceMeshOptions; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.io.IOException; + +/** Main activity of MediaPipe FaceMesh app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private FaceMesh facemesh; + // Run the pipeline and the model inference on GPU or CPU. + private static final boolean RUN_ON_GPU = true; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private FaceMeshResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + facemesh.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); + imageGetter.launch(gallery); + }); + imageView = new FaceMeshResultImageView(this); + } + + /** The core MediaPipe FaceMesh setup workflow for its static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe FaceMesh instance in the static image mode. + facemesh = + new FaceMesh( + this, + FaceMeshOptions.builder() + .setMode(FaceMeshOptions.STATIC_IMAGE_MODE) + .setRunOnGpu(RUN_ON_GPU) + .build()); + + // Connects MediaPipe FaceMesh to the user-defined FaceMeshResultImageView. + facemesh.setResultListener( + faceMeshResult -> { + logNoseLandmark(faceMeshResult, /*showPixelValues=*/ true); + imageView.setFaceMeshResult(faceMeshResult); + runOnUiThread(() -> imageView.update()); + }); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + facemesh.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); + videoGetter.launch(gallery); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** The core MediaPipe FaceMesh setup workflow for its streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe FaceMesh instance in the streaming mode. + facemesh = + new FaceMesh( + this, + FaceMeshOptions.builder() + .setMode(FaceMeshOptions.STREAMING_MODE) + .setRunOnGpu(RUN_ON_GPU) + .build()); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + + if (inputSource == InputSource.CAMERA) { + // Initializes a new CameraInput instance and connects it to MediaPipe FaceMesh. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + // Initializes a new VideoInput instance and connects it to MediaPipe FaceMesh. + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined FaceMeshResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>(this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + facemesh.setResultListener( + faceMeshResult -> { + logNoseLandmark(faceMeshResult, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + facemesh.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (facemesh != null) { + facemesh.close(); + } + } + + private void logNoseLandmark(FaceMeshResult result, boolean showPixelValues) { + if (result == null || result.multiFaceLandmarks().isEmpty()) { + return; + } + NormalizedLandmark noseLandmark = result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe FaceMesh nose coordinates (pixel values): x=%f, y=%f", + noseLandmark.getX() * width, noseLandmark.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + } + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/res b/mediapipe/examples/android/solutions/facemesh/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/gradle.properties b/mediapipe/examples/android/solutions/gradle.properties new file mode 100644 index 000000000..c09e1e3b0 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app"s APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..e708b1c02 Binary files /dev/null and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..442d9132e --- /dev/null +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/solutions/gradlew b/mediapipe/examples/android/solutions/gradlew new file mode 100755 index 000000000..4f906e0c8 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradlew @@ -0,0 +1,185 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/mediapipe/examples/android/solutions/gradlew.bat b/mediapipe/examples/android/solutions/gradlew.bat new file mode 100755 index 000000000..ac1b06f93 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradlew.bat @@ -0,0 +1,89 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/mediapipe/examples/android/solutions/hands/build.gradle b/mediapipe/examples/android/solutions/hands/build.gradle new file mode 100644 index 000000000..27629fd5d --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/build.gradle @@ -0,0 +1,50 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.hands" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe hands solution API and solution-core. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:hands:latest.release' + // MediaPipe deps + implementation 'com.google.flogger:flogger:latest.release' + implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.guava:guava:27.0.1-android' + implementation 'com.google.protobuf:protobuf-java:3.11.4' + // CameraX core library + def camerax_version = "1.0.0-beta10" + implementation "androidx.camera:camera-core:$camerax_version" + implementation "androidx.camera:camera-camera2:$camerax_version" + implementation "androidx.camera:camera-lifecycle:$camerax_version" +} diff --git a/mediapipe/examples/android/solutions/hands/proguard-rules.pro b/mediapipe/examples/android/solutions/hands/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml new file mode 100644 index 000000000..4537a2537 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/hands/src/main/BUILD b/mediapipe/examples/android/solutions/hands/src/main/BUILD new file mode 100644 index 000000000..0d71e4a95 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/BUILD @@ -0,0 +1,44 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "hands", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.hands", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.hands", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/hands", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java new file mode 100644 index 000000000..720ae5509 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java @@ -0,0 +1,131 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.opengl.GLES20; +import android.opengl.Matrix; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.ResultGlBoundary; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.List; + +/** A custom implementation of {@link ResultGlRenderer} to render MediaPope Hands results. */ +public class HandsResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "HandsResultGlRenderer"; + + private static final float CONNECTION_THICKNESS = 20.0f; + private static final String VERTEX_SHADER = + "uniform mat4 uTransformMatrix;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uTransformMatrix * vPosition;\n" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "void main() {\n" + + " gl_FragColor = vec4(0, 1, 0, 1);\n" + + "}"; + private int program; + private int positionHandle; + private int transformMatrixHandle; + private final float[] transformMatrix = new float[16]; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + transformMatrixHandle = GLES20.glGetUniformLocation(program, "uTransformMatrix"); + } + + @Override + public void renderResult(HandsResult result, ResultGlBoundary boundary) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + // Sets the transform matrix to align the result rendering with the scaled output texture. + // Also flips the rendering vertically since OpenGL assumes the coordinate origin is at the + // bottom-left corner, whereas MediaPipe landmark data assumes the coordinate origin is at the + // top-left corner. + Matrix.setIdentityM(transformMatrix, 0); + Matrix.scaleM( + transformMatrix, + 0, + 2 / (boundary.right() - boundary.left()), + -2 / (boundary.top() - boundary.bottom()), + 1.0f); + GLES20.glUniformMatrix4fv(transformMatrixHandle, 1, false, transformMatrix, 0); + GLES20.glLineWidth(CONNECTION_THICKNESS); + + int numHands = result.multiHandLandmarks().size(); + for (int i = 0; i < numHands; ++i) { + drawLandmarks(result.multiHandLandmarks().get(i).getLandmarkList()); + } + } + + /** + * Calls this to delete the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + // TODO: Better hand landmark and hand connection drawing. + private void drawLandmarks(List handLandmarkList) { + for (Hands.Connection c : Hands.HAND_CONNECTIONS) { + float[] vertex = new float[4]; + NormalizedLandmark start = handLandmarkList.get(c.start()); + vertex[0] = normalizedLandmarkValue(start.getX()); + vertex[1] = normalizedLandmarkValue(start.getY()); + NormalizedLandmark end = handLandmarkList.get(c.end()); + vertex[2] = normalizedLandmarkValue(end.getX()); + vertex[3] = normalizedLandmarkValue(end.getY()); + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } + } + + // Normalizes the value from the landmark value range:[0, 1] to the standard OpenGL coordinate + // value range: [-1, 1]. + private float normalizedLandmarkValue(float value) { + return value * 2 - 1; + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java new file mode 100644 index 000000000..d4052d4e9 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java @@ -0,0 +1,102 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.formats.proto.LandmarkProto; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.util.List; + +/** An ImageView implementation for displaying MediaPipe Hands results. */ +public class HandsResultImageView extends AppCompatImageView { + private static final String TAG = "HandsResultImageView"; + + private static final int LANDMARK_COLOR = Color.RED; + private static final int LANDMARK_RADIUS = 15; + private static final int CONNECTION_COLOR = Color.GREEN; + private static final int CONNECTION_THICKNESS = 10; + private Bitmap latest; + + public HandsResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link HandsResult} to render. + * + * @param result a {@link HandsResult} object that contains the solution outputs and the input + * {@link Bitmap}. + */ + public void setHandsResult(HandsResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + + canvas.drawBitmap(bmInput, new Matrix(), null); + int numHands = result.multiHandLandmarks().size(); + for (int i = 0; i < numHands; ++i) { + drawLandmarksOnCanvas( + result.multiHandLandmarks().get(i).getLandmarkList(), canvas, width, height); + } + } + + /** Updates the image view with the latest hands result. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + // TODO: Better hand landmark and hand connection drawing. + private void drawLandmarksOnCanvas( + List handLandmarkList, Canvas canvas, int width, int height) { + // Draw connections. + for (Hands.Connection c : Hands.HAND_CONNECTIONS) { + Paint connectionPaint = new Paint(); + connectionPaint.setColor(CONNECTION_COLOR); + connectionPaint.setStrokeWidth(CONNECTION_THICKNESS); + NormalizedLandmark start = handLandmarkList.get(c.start()); + NormalizedLandmark end = handLandmarkList.get(c.end()); + canvas.drawLine( + start.getX() * width, + start.getY() * height, + end.getX() * width, + end.getY() * height, + connectionPaint); + } + Paint landmarkPaint = new Paint(); + landmarkPaint.setColor(LANDMARK_COLOR); + // Draw landmarks. + for (LandmarkProto.NormalizedLandmark landmark : handLandmarkList) { + canvas.drawCircle( + landmark.getX() * width, landmark.getY() * height, LANDMARK_RADIUS, landmarkPaint); + } + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java new file mode 100644 index 000000000..379219942 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java @@ -0,0 +1,309 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.hands.HandLandmark; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsOptions; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.io.IOException; + +/** Main activity of MediaPipe Hands app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private Hands hands; + // Run the pipeline and the model inference on GPU or CPU. + private static final boolean RUN_ON_GPU = true; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private HandsResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + hands.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); + imageGetter.launch(gallery); + }); + imageView = new HandsResultImageView(this); + } + + /** The core MediaPipe Hands setup workflow for its static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe Hands instance in the static image mode. + hands = + new Hands( + this, + HandsOptions.builder() + .setMode(HandsOptions.STATIC_IMAGE_MODE) + .setMaxNumHands(1) + .setRunOnGpu(RUN_ON_GPU) + .build()); + + // Connects MediaPipe Hands to the user-defined HandsResultImageView. + hands.setResultListener( + handsResult -> { + logWristLandmark(handsResult, /*showPixelValues=*/ true); + imageView.setHandsResult(handsResult); + runOnUiThread(() -> imageView.update()); + }); + hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + hands.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); + videoGetter.launch(gallery); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** The core MediaPipe Hands setup workflow for its streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe Hands instance in the streaming mode. + hands = + new Hands( + this, + HandsOptions.builder() + .setMode(HandsOptions.STREAMING_MODE) + .setMaxNumHands(1) + .setRunOnGpu(RUN_ON_GPU) + .build()); + hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + + if (inputSource == InputSource.CAMERA) { + // Initializes a new CameraInput instance and connects it to MediaPipe Hands. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + // Initializes a new VideoInput instance and connects it to MediaPipe Hands. + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined HandsResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>(this, hands.getGlContext(), hands.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + hands.setResultListener( + handsResult -> { + logWristLandmark(handsResult, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + hands.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (hands != null) { + hands.close(); + } + } + + private void logWristLandmark(HandsResult result, boolean showPixelValues) { + NormalizedLandmark wristLandmark = Hands.getHandLandmark(result, 0, HandLandmark.WRIST); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist coordinates (pixel values): x=%f, y=%f", + wristLandmark.getX() * width, wristLandmark.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + } + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/res b/mediapipe/examples/android/solutions/hands/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..c7bd21dbd --- /dev/null +++ b/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml b/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..01f0af0ad --- /dev/null +++ b/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/res/layout/activity_main.xml b/mediapipe/examples/android/solutions/res/layout/activity_main.xml new file mode 100644 index 000000000..834e9a3e6 --- /dev/null +++ b/mediapipe/examples/android/solutions/res/layout/activity_main.xml @@ -0,0 +1,40 @@ + + + +