diff --git a/README.md b/README.md index 8cd42fa7b..2d9550d37 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | @@ -79,6 +79,13 @@ run code search using ## Publications +* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) + in Google Developers Blog +* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html) + in Google Developers Blog +* [SignAll SDK: Sign language interface using MediaPipe is now available for + developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) + in Google Developers Blog * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) diff --git a/WORKSPACE b/WORKSPACE index a1dcb4724..c2aaca658 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,11 +16,11 @@ bazel_skylib_workspace() load("@bazel_skylib//lib:versions.bzl", "versions") versions.check(minimum_bazel_version = "3.7.2") -# ABSL cpp library lts_2020_09_23 +# ABSL cpp library lts_2021_03_24, patch 2. http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/20200923.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -29,8 +29,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20200923", - sha256 = "b3744a4f7a249d5eaf2309daad597631ce77ea62e0fc6abffbab4b4c3dc0fc08" + strip_prefix = "abseil-cpp-20210324.2", + sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" ) http_archive( @@ -53,19 +53,12 @@ rules_foreign_cc_dependencies() all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" # GoogleTest/GoogleMock framework. Used by most unit-tests. -# Last updated 2020-06-30. +# Last updated 2021-07-02. http_archive( name = "com_google_googletest", - urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"], - patches = [ - # fix for https://github.com/google/googletest/issues/2817 - "@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff" - ], - patch_args = [ - "-p1", - ], - strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e", - sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895", + urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"], + strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3", + sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049", ) # Google Benchmark library. @@ -164,11 +157,11 @@ http_archive( http_archive( name = "pybind11", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", - "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz", ], - sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", - strip_prefix = "pybind11-2.4.3", + sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020", + strip_prefix = "pybind11-2.7.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) @@ -338,7 +331,10 @@ load("@rules_jvm_external//:defs.bzl", "maven_install") maven_install( artifacts = [ "androidx.concurrent:concurrent-futures:1.0.0-alpha03", - "androidx.lifecycle:lifecycle-common:2.2.0", + "androidx.lifecycle:lifecycle-common:2.3.1", + "androidx.activity:activity:1.2.2", + "androidx.exifinterface:exifinterface:1.3.3", + "androidx.fragment:fragment:1.3.4", "androidx.annotation:annotation:aar:1.1.0", "androidx.appcompat:appcompat:aar:1.1.0-rc01", "androidx.camera:camera-core:1.0.0-beta10", @@ -353,9 +349,12 @@ maven_install( "com.google.android.material:material:aar:1.0.0-rc01", "com.google.auto.value:auto-value:1.8.1", "com.google.auto.value:auto-value-annotations:1.8.1", - "com.google.code.findbugs:jsr305:3.0.2", - "com.google.flogger:flogger-system-backend:0.3.1", - "com.google.flogger:flogger:0.3.1", + "com.google.code.findbugs:jsr305:latest.release", + "com.google.android.datatransport:transport-api:3.0.0", + "com.google.android.datatransport:transport-backend-cct:3.1.0", + "com.google.android.datatransport:transport-runtime:3.1.0", + "com.google.flogger:flogger-system-backend:0.6", + "com.google.flogger:flogger:0.6", "com.google.guava:guava:27.0.1-android", "com.google.guava:listenablefuture:1.0", "junit:junit:4.12", @@ -383,9 +382,9 @@ http_archive( ) # Tensorflow repo should always go after the other external dependencies. -# 2021-06-07 -_TENSORFLOW_GIT_COMMIT = "700533808e6016dc458bb2eeecfca4babfc482ec" -_TENSORFLOW_SHA256 = "b6edd7f4039bfc19f3e77594ecff558ba620091d0dc48181484b3d9085026126" +# 2021-07-29 +_TENSORFLOW_GIT_COMMIT = "52a2905cbc21034766c08041933053178c5d10e3" +_TENSORFLOW_SHA256 = "06d4691bcdb700f3275fa0971a1585221c2b9f3dffe867963be565a6643d7f56" http_archive( name = "org_tensorflow", urls = [ @@ -394,6 +393,8 @@ http_archive( patches = [ "@//third_party:org_tensorflow_compatibility_fixes.diff", "@//third_party:org_tensorflow_objc_cxx17.diff", + # Diff is generated with a script, don't update it manually. + "@//third_party:org_tensorflow_custom_ops.diff", ], patch_args = [ "-p1", @@ -406,3 +407,18 @@ load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") tf_workspace3() load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") tf_workspace2() + +# Edge TPU +http_archive( + name = "libedgetpu", + sha256 = "14d5527a943a25bc648c28a9961f954f70ba4d79c0a9ca5ae226e1831d72fe80", + strip_prefix = "libedgetpu-3164995622300286ef2bb14d7fdc2792dae045b7", + urls = [ + "https://github.com/google-coral/libedgetpu/archive/3164995622300286ef2bb14d7fdc2792dae045b7.tar.gz" + ], +) +load("@libedgetpu//:workspace.bzl", "libedgetpu_dependencies") +libedgetpu_dependencies() + +load("@coral_crosstool//:configure.bzl", "cc_crosstool") +cc_crosstool(name = "crosstool") diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index 71224a258..c3c6506ee 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -16,12 +16,14 @@ nav_order: 1 Please follow instructions below to build Android example apps in the supported MediaPipe [solutions](../solutions/solutions.md). To learn more about these -example apps, start from [Hello World! on Android](./hello_world_android.md). To -incorporate MediaPipe into an existing Android Studio project, see these -[instructions](./android_archive_library.md) that use Android Archive (AAR) and -Gradle. +example apps, start from [Hello World! on Android](./hello_world_android.md). -## Building Android example apps +To incorporate MediaPipe into Android Studio projects, see these +[instructions](./android_solutions.md) to use the MediaPipe Android Solution +APIs (currently in alpha) that are now available in +[Google's Maven Repository](https://maven.google.com/web/index.html?#com.google.mediapipe). + +## Building Android example apps with Bazel ### Prerequisite @@ -51,16 +53,6 @@ $YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or android_sdk_repository() in the [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) file. -Please verify all the necessary packages are installed. - -* Android SDK Platform API Level 28 or 29 -* Android SDK Build-Tools 28 or 29 -* Android SDK Platform-Tools 28 or 29 -* Android SDK Tools 26.1.1 -* Android NDK 19c or above - -### Option 1: Build with Bazel in Command Line - Tip: You can run this [script](https://github.com/google/mediapipe/blob/master/build_android_examples.sh) to build (and install) all MediaPipe Android example apps. @@ -84,108 +76,3 @@ to build (and install) all MediaPipe Android example apps. ```bash adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/handtrackinggpu.apk ``` - -### Option 2: Build with Bazel in Android Studio - -The MediaPipe project can be imported into Android Studio using the Bazel -plugins. This allows the MediaPipe examples to be built and modified in Android -Studio. - -To incorporate MediaPipe into an existing Android Studio project, see these -[instructions](./android_archive_library.md) that use Android Archive (AAR) and -Gradle. - -The steps below use Android Studio 3.5 to build and install a MediaPipe example -app: - -1. Install and launch Android Studio 3.5. - -2. Select `Configure` -> `SDK Manager` -> `SDK Platforms`. - - * Verify that Android SDK Platform API Level 28 or 29 is installed. - * Take note of the Android SDK Location, e.g., - `/usr/local/home/Android/Sdk`. - -3. Select `Configure` -> `SDK Manager` -> `SDK Tools`. - - * Verify that Android SDK Build-Tools 28 or 29 is installed. - * Verify that Android SDK Platform-Tools 28 or 29 is installed. - * Verify that Android SDK Tools 26.1.1 is installed. - * Verify that Android NDK 19c or above is installed. - * Take note of the Android NDK Location, e.g., - `/usr/local/home/Android/Sdk/ndk-bundle` or - `/usr/local/home/Android/Sdk/ndk/20.0.5594570`. - -4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point - to the installed SDK and NDK. - - ```bash - export ANDROID_HOME=/usr/local/home/Android/Sdk - - # If the NDK libraries are installed by a previous version of Android Studio, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk-bundle - # If the NDK libraries are installed by Android Studio 3.5, do - export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk/ - ``` - -5. Select `Configure` -> `Plugins` to install `Bazel`. - -6. On Linux, select `File` -> `Settings` -> `Bazel settings`. On macos, select - `Android Studio` -> `Preferences` -> `Bazel settings`. Then, modify `Bazel - binary location` to be the same as the output of `$ which bazel`. - -7. Select `Import Bazel Project`. - - * Select `Workspace`: `/path/to/mediapipe` and select `Next`. - * Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select - `Next`. - * Modify `Project View` to be the following and select `Finish`. - - ``` - directories: - # read project settings, e.g., .bazelrc - . - -mediapipe/objc - -mediapipe/examples/ios - - targets: - //mediapipe/examples/android/...:all - //mediapipe/java/...:all - - android_sdk_platform: android-29 - - sync_flags: - --host_crosstool_top=@bazel_tools//tools/cpp:toolchain - ``` - -8. Select `Bazel` -> `Sync` -> `Sync project with Build files`. - - Note: Even after doing step 4, if you still see the error: `"no such package - '@androidsdk//': Either the path attribute of android_sdk_repository or the - ANDROID_HOME environment variable must be set."`, please modify the - [`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) - file to point to your SDK and NDK library locations, as below: - - ``` - android_sdk_repository( - name = "androidsdk", - path = "/path/to/android/sdk" - ) - - android_ndk_repository( - name = "androidndk", - path = "/path/to/android/ndk" - ) - ``` - -9. Connect an Android device to the workstation. - -10. Select `Run...` -> `Edit Configurations...`. - - * Select `Templates` -> `Bazel Command`. - * Enter Target Expression: - `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu:handtrackinggpu` - * Enter Bazel command: `mobile-install`. - * Enter Bazel flags: `-c opt --config=android_arm64`. - * Press the `[+]` button to add the new configuration. - * Select `Run` to run the example app on the connected Android device. diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index ec34a8352..d2f25213f 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -3,7 +3,7 @@ layout: default title: MediaPipe Android Archive parent: MediaPipe on Android grand_parent: Getting Started -nav_order: 2 +nav_order: 3 --- # MediaPipe Android Archive @@ -113,9 +113,9 @@ each project. androidTestImplementation 'androidx.test.ext:junit:1.1.0' androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' // MediaPipe deps - implementation 'com.google.flogger:flogger:0.3.1' - implementation 'com.google.flogger:flogger-system-backend:0.3.1' - implementation 'com.google.code.findbugs:jsr305:3.0.2' + implementation 'com.google.flogger:flogger:latest.release' + implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.code.findbugs:jsr305:latest.release' implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library diff --git a/docs/getting_started/android_solutions.md b/docs/getting_started/android_solutions.md new file mode 100644 index 000000000..9df98043f --- /dev/null +++ b/docs/getting_started/android_solutions.md @@ -0,0 +1,131 @@ +--- +layout: default +title: MediaPipe Android Solutions +parent: MediaPipe on Android +grand_parent: Getting Started +nav_order: 2 +--- + +# MediaPipe Android Solutions +{: .no_toc } + +1. TOC +{:toc} +--- + +MediaPipe Android Solution APIs (currently in alpha) are available in: + +* [MediaPipe Face Detection](../solutions/face_detection#android-solution-api) +* [MediaPipe Face Mesh](../solutions/face_mesh#android-solution-api) +* [MediaPipe Hands](../solutions/hands#android-solution-api) + +## Incorporation in Android Studio + +Prebuilt packages of Android Solution APIs can be found in +[Google's Maven Repository](https://maven.google.com/web/index.html?#com.google.mediapipe). +To incorporate them into an Android Studio project, add the following into the +project's Gradle dependencies: + +``` +dependencies { + // MediaPipe solution-core is the foundation of any MediaPipe Solutions. + implementation 'com.google.mediapipe:solution-core:latest.release' + // Optional: MediaPipe Face Detection Solution. + implementation 'com.google.mediapipe:facedetection:latest.release' + // Optional: MediaPipe Face Mesh Solution. + implementation 'com.google.mediapipe:facemesh:latest.release' + // Optional: MediaPipe Hands Solution. + implementation 'com.google.mediapipe:hands:latest.release' +} +``` + +If you need further customization, instead of using the prebuilt maven packages +consider building a MediaPipe Android Archive library locally from source by +following these [instructions](./android_archive_library.md). + +## Building solution example apps + +Detailed usage examples of the Android Solution APIs can be found in the +[source code](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions) +of the solution example apps. + +To build these apps: + +1. Open Android Studio Arctic Fox on Linux, macOS, or Windows. + +2. Import mediapipe/examples/android/solutions directory into Android Studio. + + ![Screenshot](../images/import_mp_android_studio_project.png) + +3. For Windows users, run `create_win_symlinks.bat` as administrator to create + res directory symlinks. + + ![Screenshot](../images/run_create_win_symlinks.png) + +4. Select "File" -> "Sync Project with Gradle Files" to sync project. + +5. Run solution example app in Android Studio. + + ![Screenshot](../images/run_android_solution_app.png) + +6. (Optional) Run solutions on CPU. + + MediaPipe solution example apps run the pipeline and model inference on GPU + by default. If needed, for example to run the apps on Android Emulator, set + the `RUN_ON_GPU` boolean variable to `false` in the app's + `MainActivity.java` to run the pipeline and model inference on CPU. + +## MediaPipe Solution APIs Terms of Service + +Last modified: November 12, 2021 + +Use of MediaPipe Solution APIs is subject to the +[Google APIs Terms of Service](https://developers.google.com/terms), +[Google API Services User Data Policy](https://developers.google.com/terms/api-services-user-data-policy), +and the terms below. Please check back from time to time as these terms and +policies are occasionally updated. + +**Privacy** + +When you use MediaPipe Solution APIs, processing of the input data (e.g. images, +video, text) fully happens on-device, and **MediaPipe does not send that input +data to Google servers**. As a result, you can use our APIs for processing data +that should not leave the device. + +MediaPipe Android Solution APIs will contact Google servers from time to time in +order to receive things like bug fixes, updated models, and hardware accelerator +compatibility information. MediaPipe Android Solution APIs also send metrics +about the performance and utilization of the APIs in your app to Google. Google +uses this metrics data to measure performance, API usage, debug, maintain and +improve the APIs, and detect misuse or abuse, as further described in our +[Privacy Policy](https://policies.google.com/privacy). + +**You are responsible for obtaining informed consent from your app users about +Google’s processing of MediaPipe metrics data as required by applicable law.** + +Data we collect may include the following, across all MediaPipe Android Solution +APIs: + +- Device information (such as manufacturer, model, OS version and build) and + available ML hardware accelerators (GPU and DSP). Used for diagnostics and + usage analytics. + +- App identification information (package name / bundle id, app version). Used + for diagnostics and usage analytics. + +- API configuration (such as image format, resolution, and MediaPipe version + used). Used for diagnostics and usage analytics. + +- Event type (such as initialize, download model, update, run, and detection). + Used for diagnostics and usage analytics. + +- Error codes. Used for diagnostics. + +- Performance metrics. Used for diagnostics. + +- Per-installation identifiers that do not uniquely identify a user or + physical device. Used for operation of remote configuration and usage + analytics. + +- Network request sender IP addresses. Used for remote configuration + diagnostics. Collected IP addresses are retained temporarily. diff --git a/docs/getting_started/faq.md b/docs/getting_started/faq.md index 75bf8ad97..c42ef898c 100644 --- a/docs/getting_started/faq.md +++ b/docs/getting_started/faq.md @@ -103,7 +103,7 @@ monotonically increasing timestamps. By convention, realtime calculators and graphs use the recording time or the presentation time as the timestamp for each packet, with each timestamp representing microseconds since `Jan/1/1970:00:00:00`. This allows packets from various sources to be processed -in a gloablly consistent order. +in a globally consistent order. Normally for offline processing, every input packet is processed and processing continues as long as necessary. For online processing, it is often necessary to diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index 9f277f799..6674d4023 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -31,8 +31,8 @@ stream on an Android device. ## Setup -1. Install MediaPipe on your system, see [MediaPipe installation guide] for - details. +1. Install MediaPipe on your system, see + [MediaPipe installation guide](./install.md) for details. 2. Install Android Development SDK and Android NDK. See how to do so also in [MediaPipe installation guide]. 3. Enable [developer options] on your Android device. @@ -770,7 +770,6 @@ If you ran into any issues, please see the full code of the tutorial [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java [`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout [`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java -[MediaPipe installation guide]:./install.md [`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java [`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html [`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 06d79c67d..4591b5f33 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -31,8 +31,8 @@ stream on an iOS device. ## Setup -1. Install MediaPipe on your system, see [MediaPipe installation guide] for - details. +1. Install MediaPipe on your system, see + [MediaPipe installation guide](./install.md) for details. 2. Setup your iOS device for development. 3. Setup [Bazel] on your system to build and deploy the iOS app. @@ -113,6 +113,10 @@ bazel to build the iOS application. The content of the 5. `Main.storyboard` and `Launch.storyboard` 6. `Assets.xcassets` directory. +Note: In newer versions of Xcode, you may see additional files `SceneDelegate.h` +and `SceneDelegate.m`. Make sure to copy them too and add them to the `BUILD` +file mentioned below. + Copy these files to a directory named `HelloWorld` to a location that can access the MediaPipe source code. For example, the source code of the application that we will build in this tutorial is located in @@ -247,6 +251,12 @@ We need to get frames from the `_cameraSource` into our application `MPPInputSourceDelegate`. So our application `ViewController` can be a delegate of `_cameraSource`. +Update the interface definition of `ViewController` accordingly: + +``` +@interface ViewController () +``` + To handle camera setup and process incoming frames, we should use a queue different from the main queue. Add the following to the implementation block of the `ViewController`: @@ -288,6 +298,12 @@ utility called `MPPLayerRenderer` to display images on the screen. This utility can be used to display `CVPixelBufferRef` objects, which is the type of the images provided by `MPPCameraInputSource` to its delegates. +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPLayerRenderer.h" +``` + To display images of the screen, we need to add a new `UIView` object called `_liveView` to the `ViewController`. @@ -411,6 +427,12 @@ Objective-C++. ### Use the graph in `ViewController` +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPGraph.h" +``` + Declare a static constant with the name of the graph, the input stream and the output stream: @@ -549,6 +571,12 @@ method to receive packets on this output stream and display them on the screen: } ``` +Update the interface definition of `ViewController` with `MPPGraphDelegate`: + +``` +@interface ViewController () +``` + And that is all! Build and run the app on your iOS device. You should see the results of running the edge detection graph on a live video feed. Congrats! @@ -560,6 +588,5 @@ appropriate `BUILD` file dependencies for the edge detection graph. [Bazel]:https://bazel.build/ [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt -[MediaPipe installation guide]:./install.md -[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common) -[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld) +[common]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common +[helloworld]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 95dce1d17..bb2539d33 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -43,104 +43,189 @@ install --user six`. 3. Install OpenCV and FFmpeg. - Option 1. Use package manager tool to install the pre-compiled OpenCV - libraries. FFmpeg will be installed via libopencv-video-dev. + **Option 1**. Use package manager tool to install the pre-compiled OpenCV + libraries. FFmpeg will be installed via `libopencv-video-dev`. - Note: Debian 9 and Ubuntu 16.04 provide OpenCV 2.4.9. You may want to take - option 2 or 3 to install OpenCV 3 or above. + OS | OpenCV + -------------------- | ------ + Debian 9 (stretch) | 2.4 + Debian 10 (buster) | 3.2 + Debian 11 (bullseye) | 4.5 + Ubuntu 16.04 LTS | 2.4 + Ubuntu 18.04 LTS | 3.2 + Ubuntu 20.04 LTS | 4.2 + Ubuntu 20.04 LTS | 4.2 + Ubuntu 21.04 | 4.5 ```bash - $ sudo apt-get install libopencv-core-dev libopencv-highgui-dev \ - libopencv-calib3d-dev libopencv-features2d-dev \ - libopencv-imgproc-dev libopencv-video-dev + $ sudo apt-get install -y \ + libopencv-core-dev \ + libopencv-highgui-dev \ + libopencv-calib3d-dev \ + libopencv-features2d-dev \ + libopencv-imgproc-dev \ + libopencv-video-dev ``` - Debian 9 and Ubuntu 18.04 install the packages in - `/usr/lib/x86_64-linux-gnu`. MediaPipe's [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] are configured for this library path. Ubuntu 20.04 - may install the OpenCV and FFmpeg packages in `/usr/local`, Please follow - the option 3 below to modify the [`WORKSPACE`], [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] files accordingly. - - Moreover, for Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, the - library path needs to be modified like the following: + MediaPipe's [`opencv_linux.BUILD`] and [`WORKSPACE`] are already configured + for OpenCV 2/3 and should work correctly on any architecture: ```bash - sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD + # WORKSPACE + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr", + ) + + # opencv_linux.BUILD for OpenCV 2/3 installed from Debian package + cc_library( + name = "opencv", + linkopts = [ + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) ``` - Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source - and modify MediaPipe's OpenCV config. + For OpenCV 4 you need to modify [`opencv_linux.BUILD`] taking into account + current architecture: - Option 3. Follow OpenCV's + ```bash + # WORKSPACE + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr", + ) + + # opencv_linux.BUILD for OpenCV 4 installed from Debian package + cc_library( + name = "opencv", + hdrs = glob([ + # Uncomment according to your multiarch value (gcc -print-multiarch): + # "include/aarch64-linux-gnu/opencv4/opencv2/cvconfig.h", + # "include/arm-linux-gnueabihf/opencv4/opencv2/cvconfig.h", + # "include/x86_64-linux-gnu/opencv4/opencv2/cvconfig.h", + "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + # Uncomment according to your multiarch value (gcc -print-multiarch): + # "include/aarch64-linux-gnu/opencv4/", + # "include/arm-linux-gnueabihf/opencv4/", + # "include/x86_64-linux-gnu/opencv4/", + "include/opencv4/", + ], + linkopts = [ + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + **Option 2**. Run [`setup_opencv.sh`] to automatically build OpenCV from + source and modify MediaPipe's OpenCV config. This option will do all steps + defined in Option 3 automatically. + + **Option 3**. Follow OpenCV's [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) to manually build OpenCV from source code. - Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and - [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg - libraries. For example if OpenCV and FFmpeg are both manually installed in - "/usr/local/", you will need to update: (1) the "linux_opencv" and - "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv" - cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg" - cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below: + You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to point + MediaPipe to your own OpenCV libraries. Assume OpenCV would be installed to + `/usr/local/` which is recommended by default. + + OpenCV 2/3 setup: ```bash + # WORKSPACE new_local_repository( - name = "linux_opencv", - build_file = "@//third_party:opencv_linux.BUILD", - path = "/usr/local", + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", ) + # opencv_linux.BUILD for OpenCV 2/3 installed to /usr/local + cc_library( + name = "opencv", + linkopts = [ + "-L/usr/local/lib", + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + OpenCV 4 setup: + + ```bash + # WORKSPACE new_local_repository( - name = "linux_ffmpeg", - build_file = "@//third_party:ffmpeg_linux.BUILD", - path = "/usr/local", + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", ) + # opencv_linux.BUILD for OpenCV 4 installed to /usr/local cc_library( - name = "opencv", - srcs = glob( - [ - "lib/libopencv_core.so", - "lib/libopencv_highgui.so", - "lib/libopencv_imgcodecs.so", - "lib/libopencv_imgproc.so", - "lib/libopencv_video.so", - "lib/libopencv_videoio.so", - ], - ), - hdrs = glob([ - # For OpenCV 3.x - "include/opencv2/**/*.h*", - # For OpenCV 4.x - # "include/opencv4/opencv2/**/*.h*", - ]), - includes = [ - # For OpenCV 3.x - "include/", - # For OpenCV 4.x - # "include/opencv4/", - ], - linkstatic = 1, - visibility = ["//visibility:public"], + name = "opencv", + hdrs = glob([ + "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + "include/opencv4/", + ], + linkopts = [ + "-L/usr/local/lib", + "-l:libopencv_core.so", + "-l:libopencv_calib3d.so", + "-l:libopencv_features2d.so", + "-l:libopencv_highgui.so", + "-l:libopencv_imgcodecs.so", + "-l:libopencv_imgproc.so", + "-l:libopencv_video.so", + "-l:libopencv_videoio.so", + ], + ) + ``` + + Current FFmpeg setup is defined in [`ffmpeg_linux.BUILD`] and should work + for any architecture: + + ```bash + # WORKSPACE + new_local_repository( + name = "linux_ffmpeg", + build_file = "@//third_party:ffmpeg_linux.BUILD", + path = "/usr" ) + # ffmpeg_linux.BUILD for FFmpeg installed from Debian package cc_library( - name = "libffmpeg", - srcs = glob( - [ - "lib/libav*.so", - ], - ), - hdrs = glob(["include/libav*/*.h"]), - includes = ["include"], - linkopts = [ - "-lavcodec", - "-lavformat", - "-lavutil", - ], - linkstatic = 1, - visibility = ["//visibility:public"], + name = "libffmpeg", + linkopts = [ + "-l:libavcodec.so", + "-l:libavformat.so", + "-l:libavutil.so", + ], ) ``` @@ -711,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the ```bash $ docker run -it --name mediapipe mediapipe:latest - root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world + root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world # Should print: # Hello World! diff --git a/docs/getting_started/javascript.md b/docs/getting_started/javascript.md index 98a4f19bc..f56abcd6e 100644 --- a/docs/getting_started/javascript.md +++ b/docs/getting_started/javascript.md @@ -22,12 +22,23 @@ Solution | NPM Package | Example [Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] [Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] [Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] +[Objectron][Ob-pg] | [@mediapipe/objectron][Ob-npm] | [mediapipe.dev/demo/objectron][Ob-demo] [Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] [Selfie Segmentation][S-pg] | [@mediapipe/selfie_segmentation][S-npm] | [mediapipe.dev/demo/selfie_segmentation][S-demo] Click on a solution link above for more information, including API and code snippets. +### Supported plaforms: + +| Browser | Platform | Notes | +| ------- | ----------------------- | -------------------------------------- | +| Chrome | Android / Windows / Mac | Pixel 4 and older unsupported. Fuschia | +| | | unsupported. | +| Chrome | iOS | Camera unavailable in Chrome on iOS. | +| Safari | iPad/iPhone/Mac | iOS and Safari on iPad / iPhone / | +| | | MacBook | + The quickest way to get acclimated is to look at the examples above. Each demo has a link to a [CodePen][codepen] so that you can edit the code and try it yourself. We have included a number of utility packages to help you get started: @@ -67,33 +78,24 @@ affecting your work, restrict your request to a `` number. e.g., [F-pg]: ../solutions/face_mesh#javascript-solution-api [Fd-pg]: ../solutions/face_detection#javascript-solution-api [H-pg]: ../solutions/hands#javascript-solution-api +[Ob-pg]: ../solutions/objectron#javascript-solution-api [P-pg]: ../solutions/pose#javascript-solution-api [S-pg]: ../solutions/selfie_segmentation#javascript-solution-api [Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic [F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh [Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection [H-npm]: https://www.npmjs.com/package/@mediapipe/hands +[Ob-npm]: https://www.npmjs.com/package/@mediapipe/objectron [P-npm]: https://www.npmjs.com/package/@mediapipe/pose [S-npm]: https://www.npmjs.com/package/@mediapipe/selfie_segmentation [draw-npm]: https://www.npmjs.com/package/@mediapipe/drawing_utils [cam-npm]: https://www.npmjs.com/package/@mediapipe/camera_utils [ctrl-npm]: https://www.npmjs.com/package/@mediapipe/control_utils -[Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic -[F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh -[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection -[H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands -[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose -[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/selfie_segmentation -[Ho-pen]: https://code.mediapipe.dev/codepen/holistic -[F-pen]: https://code.mediapipe.dev/codepen/face_mesh -[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection -[H-pen]: https://code.mediapipe.dev/codepen/hands -[P-pen]: https://code.mediapipe.dev/codepen/pose -[S-pen]: https://code.mediapipe.dev/codepen/selfie_segmentation [Ho-demo]: https://mediapipe.dev/demo/holistic [F-demo]: https://mediapipe.dev/demo/face_mesh [Fd-demo]: https://mediapipe.dev/demo/face_detection [H-demo]: https://mediapipe.dev/demo/hands +[Ob-demo]: https://mediapipe.dev/demo/objectron [P-demo]: https://mediapipe.dev/demo/pose [S-demo]: https://mediapipe.dev/demo/selfie_segmentation [npm]: https://www.npmjs.com/package/@mediapipe diff --git a/docs/getting_started/python_framework.md b/docs/getting_started/python_framework.md index ece14bc91..688285d87 100644 --- a/docs/getting_started/python_framework.md +++ b/docs/getting_started/python_framework.md @@ -74,7 +74,7 @@ Mapping\[str, Packet\] | std::map | create_st np.ndarray
(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(
        format=ImageFormat.SRGB,
        data=mat) | get_image_frame(packet) np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet) Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet) -List\[Proto\] | std::vector\ | create_proto_vector(proto_list) | get_proto_list(packet) +List\[Proto\] | std::vector\ | n/a | get_proto_list(packet) It's not uncommon that users create custom C++ classes and and send those into the graphs and calculators. To allow the custom classes to be used in Python diff --git a/docs/images/attention_mesh_architecture.png b/docs/images/attention_mesh_architecture.png new file mode 100644 index 000000000..3a38de5c9 Binary files /dev/null and b/docs/images/attention_mesh_architecture.png differ diff --git a/docs/images/import_mp_android_studio_project.png b/docs/images/import_mp_android_studio_project.png new file mode 100644 index 000000000..aa02b95ce Binary files /dev/null and b/docs/images/import_mp_android_studio_project.png differ diff --git a/docs/images/mobile/pose_segmentation.mp4 b/docs/images/mobile/pose_segmentation.mp4 new file mode 100644 index 000000000..e0a68da70 Binary files /dev/null and b/docs/images/mobile/pose_segmentation.mp4 differ diff --git a/docs/images/mobile/pose_tracking_pck_chart.png b/docs/images/mobile/pose_tracking_pck_chart.png index 8b781e630..1fa4bf97d 100644 Binary files a/docs/images/mobile/pose_tracking_pck_chart.png and b/docs/images/mobile/pose_tracking_pck_chart.png differ diff --git a/docs/images/run_android_solution_app.png b/docs/images/run_android_solution_app.png new file mode 100644 index 000000000..aa21f3c24 Binary files /dev/null and b/docs/images/run_android_solution_app.png differ diff --git a/docs/images/run_create_win_symlinks.png b/docs/images/run_create_win_symlinks.png new file mode 100644 index 000000000..69b94b75f Binary files /dev/null and b/docs/images/run_create_win_symlinks.png differ diff --git a/docs/index.md b/docs/index.md index d3a22bd22..86d6ddc5e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -45,7 +45,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | @@ -79,6 +79,13 @@ run code search using ## Publications +* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) + in Google Developers Blog +* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html) + in Google Developers Blog +* [SignAll SDK: Sign language interface using MediaPipe is now available for + developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) + in Google Developers Blog * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 9d08ee482..04d429987 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -121,12 +121,10 @@ with mp_face_detection.FaceDetection( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_detection.process(image) # Draw the face detection annotations on the image. @@ -135,7 +133,8 @@ with mp_face_detection.FaceDetection( if results.detections: for detection in results.detections: mp_drawing.draw_detection(image, detection) - cv2.imshow('MediaPipe Face Detection', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Face Detection', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -200,7 +199,7 @@ const faceDetection = new FaceDetection({locateFile: (file) => { return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`; }}); faceDetection.setOptions({ - modelSelection: 0 + modelSelection: 0, minDetectionConfidence: 0.5 }); faceDetection.onResults(onResults); @@ -216,6 +215,214 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle +dependencies and try the Android Solution API in the companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facedetection), +and learn more in the usage example below. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [modelSelection](#model_selection) + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(false) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe Face Detection Solution. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> faceDetection.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); +faceDetection.setResultListener( + faceDetectionResult -> { + if (faceDetectionResult.multiFaceDetections().isEmpty()) { + return; + } + RelativeKeypoint noseTip = + faceDetectionResult + .multiFaceDetections() + .get(0) + .getLocationData() + .getRelativeKeypoints(FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + faceDetection.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(true) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); + +// Connects MediaPipe Face Detection Solution to the user-defined ImageView +// instance that allows users to have the custom drawing of the output landmarks +// on it. See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java +// as an example. +FaceDetectionResultImageView imageView = new FaceDetectionResultImageView(this); +faceDetection.setResultListener( + faceDetectionResult -> { + if (faceDetectionResult.multiFaceDetections().isEmpty()) { + return; + } + int width = faceDetectionResult.inputBitmap().getWidth(); + int height = faceDetectionResult.inputBitmap().getHeight(); + RelativeKeypoint noseTip = + faceDetectionResult + .multiFaceDetections() + .get(0) + .getLocationData() + .getRelativeKeypoints(FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip coordinates (pixel values): x=%f, y=%f", + noseTip.getX() * width, noseTip.getY() * height)); + // Request canvas drawing. + imageView.setFaceDetectionResult(faceDetectionResult); + runOnUiThread(() -> imageView.update()); + }); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + faceDetection.send(bitmap); + } + } + }); +Intent pickImageIntent = new Intent(Intent.ACTION_PICK); +pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); +imageGetter.launch(pickImageIntent); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(false) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe Face Detection Solution. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> faceDetection.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +faceDetection.setResultListener( + faceDetectionResult -> { + if (faceDetectionResult.multiFaceDetections().isEmpty()) { + return; + } + RelativeKeypoint noseTip = + faceDetectionResult + .multiFaceDetections() + .get(0) + .getLocationData() + .getRelativeKeypoints(FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + faceDetection.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); +pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); +videoGetter.launch(pickVideoIntent); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 5de1b41d3..57bf4de5b 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -111,6 +111,23 @@ You can find more information about the face landmark model in this :------------------------------------------------------------------------: | *Fig 2. Face landmarks: the red box indicates the cropped area as input to the landmark model, the red dots represent the 468 landmarks in 3D, and the green lines connecting landmarks illustrate the contours around the eyes, eyebrows, lips and the entire face.* | +#### Attention Mesh Model + +In addition to the [Face Landmark Model](#face-landmark-model) we provide +another model that applies +[attention](https://en.wikipedia.org/wiki/Attention_(machine_learning)) to +semantically meaningful face regions, and therefore predicting landmarks more +accurately around lips, eyes and irises, at the expense of more compute. It +enables applications like AR makeup and AR puppeteering. + +The attention mesh model can be selected in the Solution APIs via the +[refine_landmarks](#refine_landmarks) option. You can also find more information +about the model in this [paper](https://arxiv.org/abs/2006.10962). + +![attention_mesh_architecture.png](../images/attention_mesh_architecture.png) | +:---------------------------------------------------------------------------: | +*Fig 3. Attention Mesh: Overview of model architecture.* | + ## Face Geometry Module The [Face Landmark Model](#face-landmark-model) performs a single-camera face landmark @@ -145,8 +162,8 @@ be set freely, however for better results it is advised to set them as close to the *real physical camera parameters* as possible. ![face_geometry_metric_3d_space.gif](../images/face_geometry_metric_3d_space.gif) | -:----------------------------------------------------------------------------: | -*Fig 3. A visualization of multiple key elements in the Metric 3D space.* | +:-------------------------------------------------------------------------------: | +*Fig 4. A visualization of multiple key elements in the Metric 3D space.* | #### Canonical Face Model @@ -210,7 +227,7 @@ The effect renderer is implemented as a MediaPipe | ![face_geometry_renderer.gif](../images/face_geometry_renderer.gif) | | :---------------------------------------------------------------------: | -| *Fig 4. An example of face effects rendered by the Face Geometry Effect Renderer.* | +| *Fig 5. An example of face effects rendered by the Face Geometry Effect Renderer.* | ## Solution APIs @@ -234,6 +251,12 @@ unrelated, images. Default to `false`. Maximum number of faces to detect. Default to `1`. +#### refine_landmarks + +Whether to further refine the landmark coordinates around the eyes and lips, and +output additional landmarks around the irises by applying the +[Attention Mesh Model](#attention-mesh-model). Default to `false`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the @@ -271,6 +294,7 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [max_num_faces](#max_num_faces) +* [refine_landmarks](#refine_landmarks) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -278,6 +302,7 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_face_mesh = mp.solutions.face_mesh # For static images: @@ -286,6 +311,7 @@ drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, + refine_landmarks=True, min_detection_confidence=0.5) as face_mesh: for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) @@ -301,15 +327,32 @@ with mp_face_mesh.FaceMesh( mp_drawing.draw_landmarks( image=annotated_image, landmark_list=face_landmarks, - connections=mp_face_mesh.FACE_CONNECTIONS, - landmark_drawing_spec=drawing_spec, - connection_drawing_spec=drawing_spec) + connections=mp_face_mesh.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) + mp_drawing.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) + mp_drawing.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_iris_connections_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # For webcam input: drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) cap = cv2.VideoCapture(0) with mp_face_mesh.FaceMesh( + max_num_faces=1, + refine_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5) as face_mesh: while cap.isOpened(): @@ -319,12 +362,10 @@ with mp_face_mesh.FaceMesh( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_mesh.process(image) # Draw the face mesh annotations on the image. @@ -335,10 +376,26 @@ with mp_face_mesh.FaceMesh( mp_drawing.draw_landmarks( image=image, landmark_list=face_landmarks, - connections=mp_face_mesh.FACE_CONNECTIONS, - landmark_drawing_spec=drawing_spec, - connection_drawing_spec=drawing_spec) - cv2.imshow('MediaPipe FaceMesh', image) + connections=mp_face_mesh.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_iris_connections_style()) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Face Mesh', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -353,6 +410,7 @@ and the following usage example. Supported configuration options: * [maxNumFaces](#max_num_faces) +* [refineLandmarks](#refine_landmarks) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -393,8 +451,10 @@ function onResults(results) { {color: '#C0C0C070', lineWidth: 1}); drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_EYE, {color: '#FF3030'}); drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_EYEBROW, {color: '#FF3030'}); + drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_IRIS, {color: '#FF3030'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_EYE, {color: '#30FF30'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_EYEBROW, {color: '#30FF30'}); + drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_IRIS, {color: '#30FF30'}); drawConnectors(canvasCtx, landmarks, FACEMESH_FACE_OVAL, {color: '#E0E0E0'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LIPS, {color: '#E0E0E0'}); } @@ -407,6 +467,7 @@ const faceMesh = new FaceMesh({locateFile: (file) => { }}); faceMesh.setOptions({ maxNumFaces: 1, + refineLandmarks: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); @@ -423,6 +484,202 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle +dependencies and try the Android Solution API in the companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facemesh), +and learn more in the usage example below. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumFaces](#max_num_faces) +* [refineLandmarks](#refine_landmarks) +* runOnGpu: Run the pipeline and the model inference on GPU or CPU. + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setStaticImageMode(false) + .setRefineLandmarks(true) + .setMaxNumFaces(1) + .setRunOnGpu(true).build(); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe Face Mesh Solution. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> faceMesh.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceMesh.getGlContext(), faceMesh.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +faceMesh.setResultListener( + faceMeshResult -> { + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + faceMesh.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setStaticImageMode(true) + .setRefineLandmarks(true) + .setMaxNumFaces(1) + .setRunOnGpu(true).build(); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); + +// Connects MediaPipe Face Mesh Solution to the user-defined ImageView instance +// that allows users to have the custom drawing of the output landmarks on it. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java +// as an example. +FaceMeshResultImageView imageView = new FaceMeshResultImageView(this); +faceMesh.setResultListener( + faceMeshResult -> { + int width = faceMeshResult.inputBitmap().getWidth(); + int height = faceMeshResult.inputBitmap().getHeight(); + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe Face Mesh nose coordinates (pixel values): x=%f, y=%f", + noseLandmark.getX() * width, noseLandmark.getY() * height)); + // Request canvas drawing. + imageView.setFaceMeshResult(faceMeshResult); + runOnUiThread(() -> imageView.update()); + }); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + faceMesh.send(bitmap); + } + } + }); +Intent pickImageIntent = new Intent(Intent.ACTION_PICK); +pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); +imageGetter.launch(pickImageIntent); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +FaceMeshOptions faceMeshOptions = + FaceMeshOptions.builder() + .setStaticImageMode(false) + .setRefineLandmarks(true) + .setMaxNumFaces(1) + .setRunOnGpu(true).build(); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe Face Mesh Solution. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> faceMesh.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceMesh.getGlContext(), faceMesh.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +faceMesh.setResultListener( + faceMeshResult -> { + NormalizedLandmark noseLandmark = + result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + Log.i( + TAG, + String.format( + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + faceMesh.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); +pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); +videoGetter.launch(pickVideoIntent); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 9dd2898ba..e8a75ee16 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -91,8 +91,10 @@ To detect initial hand locations, we designed a mobile real-time uses in a manner similar to the face detection model in [MediaPipe Face Mesh](./face_mesh.md). Detecting hands is a decidedly complex task: our -[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite) -has to work across a variety of hand sizes with a large scale span (~20x) +[lite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection_lite.tflite) +and +[full model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection_full.tflite) +have to work across a variety of hand sizes with a large scale span (~20x) relative to the image frame and be able to detect occluded and self-occluded hands. Whereas faces have high contrast patterns, e.g., in the eye and mouth region, the lack of such features in hands makes it comparatively difficult to @@ -120,7 +122,7 @@ just 86.22%. ### Hand Landmark Model After the palm detection over the whole image our subsequent hand landmark -[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite) +[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite) performs precise keypoint localization of 21 3D hand-knuckle coordinates inside the detected hand regions via regression, that is direct coordinate prediction. The model learns a consistent internal hand pose representation and is robust @@ -163,6 +165,11 @@ unrelated, images. Default to `false`. Maximum number of hands to detect. Default to `2`. +#### model_complexity + +Complexity of the hand landmark model: `0` or `1`. Landmark accuracy as well as +inference latency generally go up with the model complexity. Default to `1`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the @@ -190,6 +197,17 @@ of 21 hand landmarks and each landmark is composed of `x`, `y` and `z`. `x` and and the smaller the value the closer the landmark is to the camera. The magnitude of `z` uses roughly the same scale as `x`. +#### multi_hand_world_landmarks + +Collection of detected/tracked hands, where each hand is represented as a list +of 21 hand landmarks in world coordinates. Each landmark consists of the +following: + +* `x`, `y` and `z`: Real-world 3D coordinates in meters with the origin at the + hand's approximate geometric center. +* `visibility`: Identical to that defined in the corresponding + [multi_hand_landmarks](#multi_hand_landmarks). + #### multi_handedness Collection of handedness of the detected/tracked hands (i.e. is it a left or @@ -212,6 +230,7 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [max_num_hands](#max_num_hands) +* [model_complexity](#model_complexity) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -219,6 +238,7 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_hands = mp.solutions.hands # For static images: @@ -248,13 +268,24 @@ with mp_hands.Hands( f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})' ) mp_drawing.draw_landmarks( - annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + annotated_image, + hand_landmarks, + mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) cv2.imwrite( '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) + # Draw hand world landmarks. + if not results.multi_hand_world_landmarks: + continue + for hand_world_landmarks in results.multi_hand_world_landmarks: + mp_drawing.plot_landmarks( + hand_world_landmarks, mp_hands.HAND_CONNECTIONS, azimuth=5) # For webcam input: cap = cv2.VideoCapture(0) with mp_hands.Hands( + model_complexity=0, min_detection_confidence=0.5, min_tracking_confidence=0.5) as hands: while cap.isOpened(): @@ -264,12 +295,10 @@ with mp_hands.Hands( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = hands.process(image) # Draw the hand annotations on the image. @@ -278,8 +307,13 @@ with mp_hands.Hands( if results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks: mp_drawing.draw_landmarks( - image, hand_landmarks, mp_hands.HAND_CONNECTIONS) - cv2.imshow('MediaPipe Hands', image) + image, + hand_landmarks, + mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Hands', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -294,6 +328,7 @@ and a [fun application], and the following usage example. Supported configuration options: * [maxNumHands](#max_num_hands) +* [modelComplexity](#model_complexity) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -343,6 +378,7 @@ const hands = new Hands({locateFile: (file) => { }}); hands.setOptions({ maxNumHands: 2, + modelComplexity: 1, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); @@ -359,6 +395,207 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle +dependencies and try the Android Solution API in the companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/hands), +and learn more in the usage example below. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumHands](#max_num_hands) +* runOnGpu: Run the pipeline and the model inference on GPU or CPU. + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +HandsOptions handsOptions = + HandsOptions.builder() + .setStaticImageMode(false) + .setMaxNumHands(2) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe Hands Solution. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> hands.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, hands.getGlContext(), hands.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +hands.setResultListener( + handsResult -> { + if (result.multiHandLandmarks().isEmpty()) { + return; + } + NormalizedLandmark wristLandmark = + handsResult.multiHandLandmarks().get(0).getLandmarkList().get(HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + hands.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +HandsOptions handsOptions = + HandsOptions.builder() + .setStaticImageMode(true) + .setMaxNumHands(2) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); + +// Connects MediaPipe Hands Solution to the user-defined ImageView instance that +// allows users to have the custom drawing of the output landmarks on it. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java +// as an example. +HandsResultImageView imageView = new HandsResultImageView(this); +hands.setResultListener( + handsResult -> { + if (result.multiHandLandmarks().isEmpty()) { + return; + } + int width = handsResult.inputBitmap().getWidth(); + int height = handsResult.inputBitmap().getHeight(); + NormalizedLandmark wristLandmark = + handsResult.multiHandLandmarks().get(0).getLandmarkList().get(HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist coordinates (pixel values): x=%f, y=%f", + wristLandmark.getX() * width, wristLandmark.getY() * height)); + // Request canvas drawing. + imageView.setHandsResult(handsResult); + runOnUiThread(() -> imageView.update()); + }); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + hands.send(bitmap); + } + } + }); +Intent pickImageIntent = new Intent(Intent.ACTION_PICK); +pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); +imageGetter.launch(pickImageIntent); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +HandsOptions handsOptions = + HandsOptions.builder() + .setStaticImageMode(false) + .setMaxNumHands(2) + .setRunOnGpu(true).build(); +Hands hands = new Hands(this, handsOptions); +hands.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe Hands Solution. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> hands.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, hands.getGlContext(), hands.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +hands.setResultListener( + handsResult -> { + if (result.multiHandLandmarks().isEmpty()) { + return; + } + NormalizedLandmark wristLandmark = + handsResult.multiHandLandmarks().get(0).getLandmarkList().get(HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + hands.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); +pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); +videoGetter.launch(pickVideoIntent); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 1ae8034bf..d0ab0b801 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -147,6 +147,23 @@ If set to `true`, the solution filters pose landmarks across different input images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) is also set to `true`. Default to `true`. +#### enable_segmentation + +If set to `true`, in addition to the pose, face and hand landmarks the solution +also generates the segmentation mask. Default to `false`. + +#### smooth_segmentation + +If set to `true`, the solution filters segmentation masks across different input +images to reduce jitter. Ignored if [enable_segmentation](#enable_segmentation) +is `false` or [static_image_mode](#static_image_mode) is `true`. Default to +`true`. + +#### refine_face_landmarks + +Whether to further refine the landmark coordinates around the eyes and lips, and +output additional landmarks around the irises. Default to `false`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the @@ -207,6 +224,15 @@ the camera. The magnitude of `z` uses roughly the same scale as `x`. A list of 21 hand landmarks on the right hand, in the same representation as [left_hand_landmarks](#left_hand_landmarks). +#### segmentation_mask + +The output segmentation mask, predicted only when +[enable_segmentation](#enable_segmentation) is set to `true`. The mask has the +same width and height as the input image, and contains values in `[0.0, 1.0]` +where `1.0` and `0.0` indicate high certainty of a "human" and "background" +pixel respectively. Please refer to the platform-specific usage examples below +for usage details. + ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to @@ -218,6 +244,9 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) +* [enable_segmentation](#enable_segmentation) +* [smooth_segmentation](#smooth_segmentation) +* [refine_face_landmarks](#refine_face_landmarks) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -225,13 +254,16 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_holistic = mp.solutions.holistic # For static images: IMAGE_FILES = [] with mp_holistic.Holistic( static_image_mode=True, - model_complexity=2) as holistic: + model_complexity=2, + enable_segmentation=True, + refine_face_landmarks=True) as holistic: for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) image_height, image_width, _ = image.shape @@ -244,16 +276,29 @@ with mp_holistic.Holistic( f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' ) - # Draw pose, left and right hands, and face landmarks on the image. + annotated_image = image.copy() + # Draw segmentation on the image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + annotated_image = np.where(condition, annotated_image, bg_image) + # Draw pose, left and right hands, and face landmarks on the image. mp_drawing.draw_landmarks( - annotated_image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + annotated_image, + results.face_landmarks, + mp_holistic.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_tesselation_style()) mp_drawing.draw_landmarks( - annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + annotated_image, + results.pose_landmarks, + mp_holistic.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles. + get_default_pose_landmarks_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # Plot pose world landmarks. mp_drawing.plot_landmarks( @@ -271,26 +316,30 @@ with mp_holistic.Holistic( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = holistic.process(image) # Draw landmark annotation on the image. image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) mp_drawing.draw_landmarks( - image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + image, + results.face_landmarks, + mp_holistic.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_contours_style()) mp_drawing.draw_landmarks( - image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) - cv2.imshow('MediaPipe Holistic', image) + image, + results.pose_landmarks, + mp_holistic.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles + .get_default_pose_landmarks_style()) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Holistic', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -306,6 +355,9 @@ Supported configuration options: * [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) +* [enableSegmentation](#enable_segmentation) +* [smoothSegmentation](#smooth_segmentation) +* [refineFaceLandmarks](#refineFaceLandmarks) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -338,8 +390,20 @@ const canvasCtx = canvasElement.getContext('2d'); function onResults(results) { canvasCtx.save(); canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); + canvasCtx.drawImage(results.segmentationMask, 0, 0, + canvasElement.width, canvasElement.height); + + // Only overwrite existing pixels. + canvasCtx.globalCompositeOperation = 'source-in'; + canvasCtx.fillStyle = '#00FF00'; + canvasCtx.fillRect(0, 0, canvasElement.width, canvasElement.height); + + // Only overwrite missing pixels. + canvasCtx.globalCompositeOperation = 'destination-atop'; canvasCtx.drawImage( results.image, 0, 0, canvasElement.width, canvasElement.height); + + canvasCtx.globalCompositeOperation = 'source-over'; drawConnectors(canvasCtx, results.poseLandmarks, POSE_CONNECTIONS, {color: '#00FF00', lineWidth: 4}); drawLandmarks(canvasCtx, results.poseLandmarks, @@ -363,6 +427,9 @@ const holistic = new Holistic({locateFile: (file) => { holistic.setOptions({ modelComplexity: 1, smoothLandmarks: true, + enableSegmentation: true, + smoothSegmentation: true, + refineFaceLandmarks: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 2f3001722..b2f59a9c8 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -41,7 +41,10 @@ one over the other. * Face landmark model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark.tflite), [TF.js model](https://tfhub.dev/mediapipe/facemesh/1) -* [Model card](https://mediapipe.page.link/facemesh-mc) +* Face landmark model w/ attention (aka Attention Mesh): + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark_with_attention.tflite) +* [Model card](https://mediapipe.page.link/facemesh-mc), + [Model card (w/ attention)](https://mediapipe.page.link/attentionmesh-mc) ### [Iris](https://google.github.io/mediapipe/solutions/iris) @@ -52,13 +55,14 @@ one over the other. ### [Hands](https://google.github.io/mediapipe/solutions/hands) * Palm detection model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite), + [TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection_lite.tflite), + [TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection_full.tflite), [TF.js model](https://tfhub.dev/mediapipe/handdetector/1) * Hand landmark model: - [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite), - [TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite), + [TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_lite.tflite), + [TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite), [TF.js model](https://tfhub.dev/mediapipe/handskeleton/1) -* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse) +* [Model card](https://mediapipe.page.link/handmc) ### [Pose](https://google.github.io/mediapipe/solutions/pose) diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 20dc3cace..25259d678 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -224,29 +224,33 @@ where object detection simply runs on every image. Default to `0.99`. #### model_name -Name of the model to use for predicting 3D bounding box landmarks. Currently supports -`{'Shoe', 'Chair', 'Cup', 'Camera'}`. +Name of the model to use for predicting 3D bounding box landmarks. Currently +supports `{'Shoe', 'Chair', 'Cup', 'Camera'}`. Default to `Shoe`. #### focal_length -Camera focal length `(fx, fy)`, by default is defined in -[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in -[pixel space](#pixel-space), users should provide `image_size` = `(image_width, -image_height)` to enable conversions inside the API. For further details about -NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). +By default, camera focal length defined in [NDC space](#ndc-space), i.e., `(fx, +fy)`. Default to `(1.0, 1.0)`. To specify focal length in +[pixel space](#pixel-space) instead, i.e., `(fx_pixel, fy_pixel)`, users should +provide [`image_size`](#image_size) = `(image_width, image_height)` to enable +conversions inside the API. For further details about NDC and pixel space, +please see [Coordinate Systems](#coordinate-systems). #### principal_point -Camera principal point `(px, py)`, by default is defined in -[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in -[pixel space](#pixel-space), users should provide `image_size` = `(image_width, -image_height)` to enable conversions inside the API. For further details about -NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). +By default, camera principal point defined in [NDC space](#ndc-space), i.e., +`(px, py)`. Default to `(0.0, 0.0)`. To specify principal point in +[pixel space](#pixel-space), i.e.,`(px_pixel, py_pixel)`, users should provide +[`image_size`](#image_size) = `(image_width, image_height)` to enable +conversions inside the API. For further details about NDC and pixel space, +please see [Coordinate Systems](#coordinate-systems). #### image_size -(**Optional**) size `(image_width, image_height)` of the input image, **ONLY** -needed when use `focal_length` and `principal_point` in pixel space. +**Specify only when [`focal_length`](#focal_length) and +[`principal_point`](#principal_point) are specified in pixel space.** + +Size of the input image, i.e., `(image_width, image_height)`. ### Output @@ -334,11 +338,10 @@ with mp_objectron.Objectron(static_image_mode=False, # If loading a video, use 'break' instead of 'continue'. continue - # Convert the BGR image to RGB. - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = objectron.process(image) # Draw the box landmarks on the image. @@ -350,12 +353,96 @@ with mp_objectron.Objectron(static_image_mode=False, image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) mp_drawing.draw_axis(image, detected_object.rotation, detected_object.translation) - cv2.imshow('MediaPipe Objectron', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Objectron', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() ``` +## JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [staticImageMode](#static_image_mode) +* [maxNumObjects](#max_num_objects) +* [minDetectionConfidence](#min_detection_confidence) +* [minTrackingConfidence](#min_tracking_confidence) +* [modelName](#model_name) +* [focalLength](#focal_length) +* [principalPoint](#principal_point) +* [imageSize](#image_size) + +```html + + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + ## Example Apps Please first see general instructions for @@ -442,7 +529,7 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http > ``` > and then run > -> ```build +> ```bash > bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] > ``` > INPUT_DIR should be the folder with initial asset .obj files to be processed, @@ -561,11 +648,15 @@ py = -py_pixel * 2.0 / image_height + 1.0 [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) * Google AI Blog: [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) -* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021 +* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the + Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in + CVPR 2021 * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak Shape Supervision](https://arxiv.org/abs/2003.03522) * Paper: [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) - ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020 + ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth + Workshop on Computer Vision for AR/VR, CVPR 2020 * [Models and model cards](./models.md#objectron) +* [Web demo](https://code.mediapipe.dev/codepen/objectron) * [Python Colab](https://mediapipe.page.link/objectron_py_colab) diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 0ae81a858..2ec1b4f4e 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -30,7 +30,8 @@ overlay of digital content and information on top of the physical world in augmented reality. MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring -33 3D landmarks on the whole body from RGB video frames utilizing our +33 3D landmarks and background segmentation mask on the whole body from RGB +video frames utilizing our [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) research that also powers the [ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection). @@ -49,11 +50,11 @@ The solution utilizes a two-step detector-tracker ML pipeline, proven to be effective in our [MediaPipe Hands](./hands.md) and [MediaPipe Face Mesh](./face_mesh.md) solutions. Using a detector, the pipeline first locates the person/pose region-of-interest (ROI) within the frame. The -tracker subsequently predicts the pose landmarks within the ROI using the -ROI-cropped frame as input. Note that for video use cases the detector is -invoked only as needed, i.e., for the very first frame and when the tracker -could no longer identify body pose presence in the previous frame. For other -frames the pipeline simply derives the ROI from the previous frame’s pose +tracker subsequently predicts the pose landmarks and segmentation mask within +the ROI using the ROI-cropped frame as input. Note that for video use cases the +detector is invoked only as needed, i.e., for the very first frame and when the +tracker could no longer identify body pose presence in the previous frame. For +other frames the pipeline simply derives the ROI from the previous frame’s pose landmarks. The pipeline is implemented as a MediaPipe @@ -87,11 +88,11 @@ from [COCO topology](https://cocodataset.org/#keypoints-2020). Method | Yoga
[`mAP`] | Yoga
[`PCK@0.2`] | Dance
[`mAP`] | Dance
[`PCK@0.2`] | HIIT
[`mAP`] | HIIT
[`PCK@0.2`] ----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------: -BlazePose.Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5** -BlazePose.Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7** -BlazePose.Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5** -[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0** -[Apple.Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6** +BlazePose GHUM Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5** +BlazePose GHUM Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7** +BlazePose GHUM Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5** +[AlphaPose ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0** +[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6** ![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) | :--------------------------------------------------------------------------: | @@ -100,11 +101,11 @@ BlazePose.Lite We designed our models specifically for live perception use cases, so all of them work in real-time on the majority of modern devices. -Method | Latency
Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency
MacBook Pro (15-inch 2017) ---------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------: -BlazePose.Heavy | 53 ms | 38 ms -BlazePose.Full | 25 ms | 27 ms -BlazePose.Lite | 20 ms | 25 ms +Method | Latency
Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency
MacBook Pro (15-inch 2017) +-------------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------: +BlazePose GHUM Heavy | 53 ms | 38 ms +BlazePose GHUM Full | 25 ms | 27 ms +BlazePose GHUM Lite | 20 ms | 25 ms ## Models @@ -124,21 +125,24 @@ hip midpoints. :----------------------------------------------------------------------------------------------------: | *Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* | -### Pose Landmark Model (BlazePose GHUM 3D) +### Pose Landmark Model (BlazePose [GHUM](https://github.com/google-research/google-research/tree/master/ghum) 3D) The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks (see figure below). -Please find more detail in the -[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), -this [paper](https://arxiv.org/abs/2006.10204) and -[the model card](./models.md#pose), and the attributes in each landmark -[below](#pose_landmarks). - ![pose_tracking_full_body_landmarks.png](../images/mobile/pose_tracking_full_body_landmarks.png) | :----------------------------------------------------------------------------------------------: | *Fig 4. 33 pose landmarks.* | +Optionally, MediaPipe Pose can predicts a full-body +[segmentation mask](#segmentation_mask) represented as a two-class segmentation +(human or background). + +Please find more detail in the +[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), +this [paper](https://arxiv.org/abs/2006.10204), +[the model card](./models.md#pose) and the [Output](#output) section below. + ## Solution APIs ### Cross-platform Configuration Options @@ -167,6 +171,18 @@ If set to `true`, the solution filters pose landmarks across different input images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) is also set to `true`. Default to `true`. +#### enable_segmentation + +If set to `true`, in addition to the pose landmarks the solution also generates +the segmentation mask. Default to `false`. + +#### smooth_segmentation + +If set to `true`, the solution filters segmentation masks across different input +images to reduce jitter. Ignored if [enable_segmentation](#enable_segmentation) +is `false` or [static_image_mode](#static_image_mode) is `true`. Default to +`true`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the @@ -211,6 +227,19 @@ the following: * `visibility`: Identical to that defined in the corresponding [pose_landmarks](#pose_landmarks). +#### segmentation_mask + +The output segmentation mask, predicted only when +[enable_segmentation](#enable_segmentation) is set to `true`. The mask has the +same width and height as the input image, and contains values in `[0.0, 1.0]` +where `1.0` and `0.0` indicate high certainty of a "human" and "background" +pixel respectively. Please refer to the platform-specific usage examples below +for usage details. + +*Fig 6. Example of MediaPipe Pose segmentation mask.* | +:---------------------------------------------------: | + | + ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to @@ -222,6 +251,8 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) +* [enable_segmentation](#enable_segmentation) +* [smooth_segmentation](#smooth_segmentation) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -229,13 +260,16 @@ Supported configuration options: import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles mp_pose = mp.solutions.pose # For static images: IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_pose.Pose( static_image_mode=True, model_complexity=2, + enable_segmentation=True, min_detection_confidence=0.5) as pose: for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) @@ -247,13 +281,24 @@ with mp_pose.Pose( continue print( f'Nose coordinates: (' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].y * image_height})' ) - # Draw pose landmarks on the image. + annotated_image = image.copy() + # Draw segmentation on the image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + annotated_image = np.where(condition, annotated_image, bg_image) + # Draw pose landmarks on the image. mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + annotated_image, + results.pose_landmarks, + mp_pose.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # Plot pose world landmarks. mp_drawing.plot_landmarks( @@ -271,20 +316,22 @@ with mp_pose.Pose( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = pose.process(image) # Draw the pose annotation on the image. image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) - cv2.imshow('MediaPipe Pose', image) + image, + results.pose_landmarks, + mp_pose.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style()) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Pose', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -300,6 +347,8 @@ Supported configuration options: * [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) +* [enableSegmentation](#enable_segmentation) +* [smoothSegmentation](#smooth_segmentation) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -319,6 +368,7 @@ Supported configuration options:
+
@@ -340,8 +390,20 @@ function onResults(results) { canvasCtx.save(); canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); + canvasCtx.drawImage(results.segmentationMask, 0, 0, + canvasElement.width, canvasElement.height); + + // Only overwrite existing pixels. + canvasCtx.globalCompositeOperation = 'source-in'; + canvasCtx.fillStyle = '#00FF00'; + canvasCtx.fillRect(0, 0, canvasElement.width, canvasElement.height); + + // Only overwrite missing pixels. + canvasCtx.globalCompositeOperation = 'destination-atop'; canvasCtx.drawImage( results.image, 0, 0, canvasElement.width, canvasElement.height); + + canvasCtx.globalCompositeOperation = 'source-over'; drawConnectors(canvasCtx, results.poseLandmarks, POSE_CONNECTIONS, {color: '#00FF00', lineWidth: 4}); drawLandmarks(canvasCtx, results.poseLandmarks, @@ -357,6 +419,8 @@ const pose = new Pose({locateFile: (file) => { pose.setOptions({ modelComplexity: 1, smoothLandmarks: true, + enableSegmentation: true, + smoothSegmentation: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); @@ -422,6 +486,7 @@ on how to build MediaPipe examples. [BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204) ([presentation](https://youtu.be/YPpUOTRn5tA)) * [Models and model cards](./models.md#pose) +* [GHUM & GHUML: Generative 3D Human Shape and Articulated Pose Models](https://github.com/google-research/google-research/tree/master/ghum) * [Web demo](https://code.mediapipe.dev/codepen/pose) * [Python Colab](https://mediapipe.page.link/pose_py_colab) diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md index f649bee72..2cb155fb3 100644 --- a/docs/solutions/selfie_segmentation.md +++ b/docs/solutions/selfie_segmentation.md @@ -96,6 +96,7 @@ Supported configuration options: ```python import cv2 import mediapipe as mp +import numpy as np mp_drawing = mp.solutions.drawing_utils mp_selfie_segmentation = mp.solutions.selfie_segmentation @@ -261,7 +262,7 @@ to visualize its associated subgraphs, please see [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing) [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD) * iOS target: - [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](http:/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) + [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) ### Desktop diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index 98bafe30e..e9e4cdc38 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -13,6 +13,9 @@ has_toc: false {:toc} --- +MediaPipe offers open source cross-platform, customizable ML solutions for live +and streaming media. + @@ -29,7 +32,7 @@ has_toc: false [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index f1d5805ef..43979d93e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -140,6 +140,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "graph_profile_calculator_proto", + srcs = ["graph_profile_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], @@ -521,9 +531,13 @@ cc_test( ":split_vector_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", ], ) @@ -1200,3 +1214,45 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "graph_profile_calculator", + srcs = ["graph_profile_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_profile_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "graph_profile_calculator_test", + srcs = ["graph_profile_calculator_test.cc"], + deps = [ + ":graph_profile_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:test_calculators", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework/tool:simulation_clock_executor", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) diff --git a/mediapipe/calculators/core/add_header_calculator_test.cc b/mediapipe/calculators/core/add_header_calculator_test.cc index 4e197918d..bbe9bdd30 100644 --- a/mediapipe/calculators/core/add_header_calculator_test.cc +++ b/mediapipe/calculators/core/add_header_calculator_test.cc @@ -24,6 +24,9 @@ namespace mediapipe { +constexpr char kDataTag[] = "DATA"; +constexpr char kHeaderTag[] = "HEADER"; + class AddHeaderCalculatorTest : public ::testing::Test {}; TEST_F(AddHeaderCalculatorTest, HeaderStream) { @@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableInputs()->Tag("HEADER").header = + runner.MutableInputs()->Tag(kHeaderTag).header = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableInputs()->Tag("HEADER").header = + runner.MutableInputs()->Tag(kHeaderTag).header = Adopt(new std::string("my_header")); - runner.MutableInputs()->Tag("HEADER").packets.push_back( - Adopt(new std::string("not allowed"))); + runner.MutableInputs() + ->Tag(kHeaderTag) + .packets.push_back(Adopt(new std::string("not allowed"))); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) { CalculatorRunner runner(node); // Set header and add 5 packets. - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run calculator. @@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) { CalculatorRunner runner(node); // Set both headers and add 5 packets. - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); - runner.MutableSidePackets()->Tag("HEADER") = + runner.MutableSidePackets()->Tag(kHeaderTag) = Adopt(new std::string("my_header")); for (int i = 0; i < 5; ++i) { Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); - runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet); } // Run should fail because header can only be provided one way. diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 1d0f7824d..e698e194c 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -42,4 +42,13 @@ REGISTER_CALCULATOR(BeginLoopDetectionCalculator); typedef BeginLoopCalculator> BeginLoopMatrixCalculator; REGISTER_CALCULATOR(BeginLoopMatrixCalculator); +// A calculator to process std::vector>. +typedef BeginLoopCalculator>> + BeginLoopMatrixVectorCalculator; +REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator); + +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopUint64tCalculator; +REGISTER_CALCULATOR(BeginLoopUint64tCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/counting_source_calculator.cc b/mediapipe/calculators/core/counting_source_calculator.cc index 0b731d9ce..fb75669e9 100644 --- a/mediapipe/calculators/core/counting_source_calculator.cc +++ b/mediapipe/calculators/core/counting_source_calculator.cc @@ -19,6 +19,13 @@ namespace mediapipe { +constexpr char kIncrementTag[] = "INCREMENT"; +constexpr char kInitialValueTag[] = "INITIAL_VALUE"; +constexpr char kBatchSizeTag[] = "BATCH_SIZE"; +constexpr char kErrorCountTag[] = "ERROR_COUNT"; +constexpr char kMaxCountTag[] = "MAX_COUNT"; +constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN"; + // Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of // sequential numbers from INITIAL_VALUE (default 0) with a common // difference of INCREMENT (default 1) between successive numbers (with @@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase { static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { - cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set(); + if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) { + cc->InputSidePackets().Tag(kErrorOnOpenTag).Set(); } - RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") || - cc->InputSidePackets().HasTag("ERROR_COUNT")); - if (cc->InputSidePackets().HasTag("MAX_COUNT")) { - cc->InputSidePackets().Tag("MAX_COUNT").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) || + cc->InputSidePackets().HasTag(kErrorCountTag)); + if (cc->InputSidePackets().HasTag(kMaxCountTag)) { + cc->InputSidePackets().Tag(kMaxCountTag).Set(); } - if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { - cc->InputSidePackets().Tag("ERROR_COUNT").Set(); + if (cc->InputSidePackets().HasTag(kErrorCountTag)) { + cc->InputSidePackets().Tag(kErrorCountTag).Set(); } - if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { - cc->InputSidePackets().Tag("BATCH_SIZE").Set(); + if (cc->InputSidePackets().HasTag(kBatchSizeTag)) { + cc->InputSidePackets().Tag(kBatchSizeTag).Set(); } - if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { - cc->InputSidePackets().Tag("INITIAL_VALUE").Set(); + if (cc->InputSidePackets().HasTag(kInitialValueTag)) { + cc->InputSidePackets().Tag(kInitialValueTag).Set(); } - if (cc->InputSidePackets().HasTag("INCREMENT")) { - cc->InputSidePackets().Tag("INCREMENT").Set(); + if (cc->InputSidePackets().HasTag(kIncrementTag)) { + cc->InputSidePackets().Tag(kIncrementTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { - if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && - cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { + if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) && + cc->InputSidePackets().Tag(kErrorOnOpenTag).Get()) { return absl::NotFoundError("expected error"); } - if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { - error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get(); + if (cc->InputSidePackets().HasTag(kErrorCountTag)) { + error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get(); RET_CHECK_LE(0, error_count_); } - if (cc->InputSidePackets().HasTag("MAX_COUNT")) { - max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get(); + if (cc->InputSidePackets().HasTag(kMaxCountTag)) { + max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get(); RET_CHECK_LE(0, max_count_); } - if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { - batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get(); + if (cc->InputSidePackets().HasTag(kBatchSizeTag)) { + batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get(); RET_CHECK_LT(0, batch_size_); } - if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { - counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get(); + if (cc->InputSidePackets().HasTag(kInitialValueTag)) { + counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get(); } - if (cc->InputSidePackets().HasTag("INCREMENT")) { - increment_ = cc->InputSidePackets().Tag("INCREMENT").Get(); + if (cc->InputSidePackets().HasTag(kIncrementTag)) { + increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get(); RET_CHECK_LT(0, increment_); } RET_CHECK(error_count_ >= 0 || max_count_ >= 0); diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc index 04a7e55a0..a8adefc63 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc @@ -35,11 +35,14 @@ // } namespace mediapipe { +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; +constexpr char kEncodedTag[] = "ENCODED"; + class DequantizeByteArrayCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("ENCODED").Set(); - cc->Outputs().Tag("FLOAT_VECTOR").Set>(); + cc->Inputs().Tag(kEncodedTag).Set(); + cc->Outputs().Tag(kFloatVectorTag).Set>(); return absl::OkStatus(); } @@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { const std::string& encoded = - cc->Inputs().Tag("ENCODED").Value().Get(); + cc->Inputs().Tag(kEncodedTag).Value().Get(); std::vector float_vector; float_vector.reserve(encoded.length()); for (int i = 0; i < encoded.length(); ++i) { @@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { static_cast(encoded.at(i)) * scalar_ + bias_); } cc->Outputs() - .Tag("FLOAT_VECTOR") + .Tag(kFloatVectorTag) .AddPacket(MakePacket>(float_vector) .At(cc->InputTimestamp())); return absl::OkStatus(); diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc index cf0a8dc15..81b9e5562 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; +constexpr char kEncodedTag[] = "ENCODED"; + TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -39,8 +42,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -64,8 +69,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -89,8 +96,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) { )pb"); CalculatorRunner runner(node_config); std::string empty_string; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket(empty_string).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket(empty_string).At(Timestamp(0))); auto status = runner.Run(); EXPECT_FALSE(status.ok()); EXPECT_THAT( @@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) { )pb"); CalculatorRunner runner(node_config); unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01}; - runner.MutableInputs()->Tag("ENCODED").packets.push_back( - MakePacket( - std::string(reinterpret_cast(input), 4)) - .At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kEncodedTag) + .packets.push_back( + MakePacket( + std::string(reinterpret_cast(input), 4)) + .At(Timestamp(0))); auto status = runner.Run(); MP_ASSERT_OK(runner.Run()); const std::vector& outputs = - runner.Outputs().Tag("FLOAT_VECTOR").packets; + runner.Outputs().Tag(kFloatVectorTag).packets; EXPECT_EQ(1, outputs.size()); const std::vector& result = outputs[0].Get>(); ASSERT_FALSE(result.empty()); diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index eba621ce3..b365121bc 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -24,6 +24,11 @@ namespace mediapipe { +constexpr char kFinishedTag[] = "FINISHED"; +constexpr char kAllowTag[] = "ALLOW"; +constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; +constexpr char kOptionsTag[] = "OPTIONS"; + // FlowLimiterCalculator is used to limit the number of frames in flight // by dropping input frames when necessary. // @@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { auto& side_inputs = cc->InputSidePackets(); - side_inputs.Tag("OPTIONS").Set().Optional(); - cc->Inputs().Tag("OPTIONS").Set().Optional(); + side_inputs.Tag(kOptionsTag).Set().Optional(); + cc->Inputs() + .Tag(kOptionsTag) + .Set() + .Optional(); RET_CHECK_GE(cc->Inputs().NumEntries(""), 1); for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); } cc->Inputs().Get("FINISHED", 0).SetAny(); - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set().Optional(); - cc->Outputs().Tag("ALLOW").Set().Optional(); + cc->InputSidePackets().Tag(kMaxInFlightTag).Set().Optional(); + cc->Outputs().Tag(kAllowTag).Set().Optional(); cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); return absl::OkStatus(); @@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { options_ = cc->Options(); options_ = tool::RetrieveOptions(options_, cc->InputSidePackets()); - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { options_.set_max_in_flight( - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get()); + cc->InputSidePackets().Tag(kMaxInFlightTag).Get()); } input_queues_.resize(cc->Inputs().NumEntries("")); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); @@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase { // Outputs a packet indicating whether a frame was sent or dropped. void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { - if (cc->Outputs().HasTag("ALLOW")) { - cc->Outputs().Tag("ALLOW").AddPacket(MakePacket(allow).At(ts)); + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); } } @@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase { options_ = tool::RetrieveOptions(options_, cc->Inputs()); // Process the FINISHED input stream. - Packet finished_packet = cc->Inputs().Tag("FINISHED").Value(); + Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value(); if (finished_packet.Timestamp() == cc->InputTimestamp()) { while (!frames_in_flight_.empty() && frames_in_flight_.front() <= finished_packet.Timestamp()) { @@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase { Timestamp bound = cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream(); SetNextTimestampBound(bound, &cc->Outputs().Get("", 0)); - if (cc->Outputs().HasTag("ALLOW")) { - SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW")); + if (cc->Outputs().HasTag(kAllowTag)) { + SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag)); } } diff --git a/mediapipe/calculators/core/flow_limiter_calculator.proto b/mediapipe/calculators/core/flow_limiter_calculator.proto index 0f7c925ae..a3a71a294 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.proto +++ b/mediapipe/calculators/core/flow_limiter_calculator.proto @@ -30,7 +30,7 @@ message FlowLimiterCalculatorOptions { optional int32 max_in_flight = 1 [default = 1]; // The maximum number of frames queued waiting for processing. - // The default value limits to 1 frame awaiting processing. + // The default value limits to 0 frames awaiting processing. optional int32 max_in_queue = 2 [default = 0]; // The maximum time in microseconds to wait for a frame to finish processing. diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index d2294dd48..962b1c81a 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -36,6 +36,13 @@ namespace mediapipe { namespace { + +constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS"; +constexpr char kClockTag[] = "CLOCK"; +constexpr char kWarmupTimeTag[] = "WARMUP_TIME"; +constexpr char kSleepTimeTag[] = "SLEEP_TIME"; +constexpr char kPacketTag[] = "PACKET"; + // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: @@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) { class SleepCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); - cc->InputSidePackets().Tag("SLEEP_TIME").Set(); - cc->InputSidePackets().Tag("WARMUP_TIME").Set(); - cc->InputSidePackets().Tag("CLOCK").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag)); + cc->InputSidePackets().Tag(kSleepTimeTag).Set(); + cc->InputSidePackets().Tag(kWarmupTimeTag).Set(); + cc->InputSidePackets().Tag(kClockTag).Set(); cc->SetTimestampOffset(0); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { - clock_ = cc->InputSidePackets().Tag("CLOCK").Get(); + clock_ = cc->InputSidePackets().Tag(kClockTag).Get(); return absl::OkStatus(); } @@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase { ++packet_count; absl::Duration sleep_time = absl::Microseconds( packet_count == 1 - ? cc->InputSidePackets().Tag("WARMUP_TIME").Get() - : cc->InputSidePackets().Tag("SLEEP_TIME").Get()); + ? cc->InputSidePackets().Tag(kWarmupTimeTag).Get() + : cc->InputSidePackets().Tag(kSleepTimeTag).Get()); clock_->Sleep(sleep_time); - cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + cc->Outputs() + .Tag(kPacketTag) + .AddPacket(cc->Inputs().Tag(kPacketTag).Value()); return absl::OkStatus(); } @@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator); class DropCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); - cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag)); + cc->InputSidePackets().Tag(kDropTimestampsTag).Set(); cc->SetProcessTimestampBounds(true); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { - if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) { + if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) { ++packet_count; } bool drop = (packet_count == 3); - if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) { - cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); + if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) { + cc->Outputs() + .Tag(kPacketTag) + .AddPacket(cc->Inputs().Tag(kPacketTag).Value()); } - if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get()) { - cc->Outputs().Tag("PACKET").SetNextTimestampBound( - cc->InputTimestamp().NextAllowedInStream()); + if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get()) { + cc->Outputs() + .Tag(kPacketTag) + .SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream()); } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index 189671860..8fdb9e0a3 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -21,6 +21,11 @@ namespace mediapipe { namespace { + +constexpr char kStateChangeTag[] = "STATE_CHANGE"; +constexpr char kDisallowTag[] = "DISALLOW"; +constexpr char kAllowTag[] = "ALLOW"; + enum GateState { GATE_UNINITIALIZED, GATE_ALLOW, @@ -59,8 +64,9 @@ std::string ToString(GateState state) { // ALLOW or DISALLOW can also be specified as an input side packet. The rules // for evaluation remain the same as above. // -// ALLOW/DISALLOW inputs must be specified either using input stream or -// via input side packet but not both. +// ALLOW/DISALLOW inputs must be specified either using input stream or via +// input side packet but not both. If neither is specified, the behavior is then +// determined by the "allow" field in the calculator options. // // Intended to be used with the default input stream handler, which synchronizes // all data input streams with the ALLOW/DISALLOW control input stream. @@ -83,30 +89,33 @@ class GateCalculator : public CalculatorBase { GateCalculator() {} static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) { - bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || - cc->InputSidePackets().HasTag("DISALLOW"); + bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) || + cc->InputSidePackets().HasTag(kDisallowTag); bool input_via_stream = - cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW"); - // Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW - // input. - RET_CHECK(input_via_side_packet ^ input_via_stream); + cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag); + // Only one of input_side_packet or input_stream may specify + // ALLOW/DISALLOW input. if (input_via_side_packet) { - RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^ - cc->InputSidePackets().HasTag("DISALLOW")); + RET_CHECK(!input_via_stream); + RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^ + cc->InputSidePackets().HasTag(kDisallowTag)); - if (cc->InputSidePackets().HasTag("ALLOW")) { - cc->InputSidePackets().Tag("ALLOW").Set(); + if (cc->InputSidePackets().HasTag(kAllowTag)) { + cc->InputSidePackets().Tag(kAllowTag).Set().Optional(); } else { - cc->InputSidePackets().Tag("DISALLOW").Set(); + cc->InputSidePackets().Tag(kDisallowTag).Set().Optional(); } - } else { - RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); + } + if (input_via_stream) { + RET_CHECK(!input_via_side_packet); + RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^ + cc->Inputs().HasTag(kDisallowTag)); - if (cc->Inputs().HasTag("ALLOW")) { - cc->Inputs().Tag("ALLOW").Set(); + if (cc->Inputs().HasTag(kAllowTag)) { + cc->Inputs().Tag(kAllowTag).Set(); } else { - cc->Inputs().Tag("DISALLOW").Set(); + cc->Inputs().Tag(kDisallowTag).Set(); } } return absl::OkStatus(); @@ -125,23 +134,22 @@ class GateCalculator : public CalculatorBase { cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); } - if (cc->Outputs().HasTag("STATE_CHANGE")) { - cc->Outputs().Tag("STATE_CHANGE").Set(); + if (cc->Outputs().HasTag(kStateChangeTag)) { + cc->Outputs().Tag(kStateChangeTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { - use_side_packet_for_allow_disallow_ = false; - if (cc->InputSidePackets().HasTag("ALLOW")) { + if (cc->InputSidePackets().HasTag(kAllowTag)) { use_side_packet_for_allow_disallow_ = true; allow_by_side_packet_decision_ = - cc->InputSidePackets().Tag("ALLOW").Get(); - } else if (cc->InputSidePackets().HasTag("DISALLOW")) { + cc->InputSidePackets().Tag(kAllowTag).Get(); + } else if (cc->InputSidePackets().HasTag(kDisallowTag)) { use_side_packet_for_allow_disallow_ = true; allow_by_side_packet_decision_ = - !cc->InputSidePackets().Tag("DISALLOW").Get(); + !cc->InputSidePackets().Tag(kDisallowTag).Get(); } cc->SetOffset(TimestampDiff(0)); @@ -152,26 +160,34 @@ class GateCalculator : public CalculatorBase { const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); empty_packets_as_allow_ = options.empty_packets_as_allow(); + if (!use_side_packet_for_allow_disallow_ && + !cc->Inputs().HasTag(kAllowTag) && !cc->Inputs().HasTag(kDisallowTag)) { + use_option_for_allow_disallow_ = true; + allow_by_option_decision_ = options.allow(); + } + return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { bool allow = empty_packets_as_allow_; - if (use_side_packet_for_allow_disallow_) { + if (use_option_for_allow_disallow_) { + allow = allow_by_option_decision_; + } else if (use_side_packet_for_allow_disallow_) { allow = allow_by_side_packet_decision_; } else { - if (cc->Inputs().HasTag("ALLOW") && - !cc->Inputs().Tag("ALLOW").IsEmpty()) { - allow = cc->Inputs().Tag("ALLOW").Get(); + if (cc->Inputs().HasTag(kAllowTag) && + !cc->Inputs().Tag(kAllowTag).IsEmpty()) { + allow = cc->Inputs().Tag(kAllowTag).Get(); } - if (cc->Inputs().HasTag("DISALLOW") && - !cc->Inputs().Tag("DISALLOW").IsEmpty()) { - allow = !cc->Inputs().Tag("DISALLOW").Get(); + if (cc->Inputs().HasTag(kDisallowTag) && + !cc->Inputs().Tag(kDisallowTag).IsEmpty()) { + allow = !cc->Inputs().Tag(kDisallowTag).Get(); } } const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; - if (cc->Outputs().HasTag("STATE_CHANGE")) { + if (cc->Outputs().HasTag(kStateChangeTag)) { if (last_gate_state_ != GATE_UNINITIALIZED && last_gate_state_ != new_gate_state) { VLOG(2) << "State transition in " << cc->NodeName() << " @ " @@ -179,7 +195,7 @@ class GateCalculator : public CalculatorBase { << ToString(last_gate_state_) << " to " << ToString(new_gate_state); cc->Outputs() - .Tag("STATE_CHANGE") + .Tag(kStateChangeTag) .AddPacket(MakePacket(allow).At(cc->InputTimestamp())); } } @@ -211,8 +227,10 @@ class GateCalculator : public CalculatorBase { GateState last_gate_state_ = GATE_UNINITIALIZED; int num_data_streams_; bool empty_packets_as_allow_; - bool use_side_packet_for_allow_disallow_; + bool use_side_packet_for_allow_disallow_ = false; bool allow_by_side_packet_decision_; + bool use_option_for_allow_disallow_ = false; + bool allow_by_option_decision_; }; REGISTER_CALCULATOR(GateCalculator); diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto index 76bacc74e..32402bf28 100644 --- a/mediapipe/calculators/core/gate_calculator.proto +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -29,4 +29,8 @@ message GateCalculatorOptions { // disallowing the corresponding packets in the data input streams. Setting // this option to true inverts that, allowing the data packets to go through. optional bool empty_packets_as_allow = 1; + + // Whether to allow or disallow the input streams to pass when no + // ALLOW/DISALLOW input or side input is specified. + optional bool allow = 2 [default = false]; } diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index 0b78b9b75..c523bce28 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -22,6 +22,9 @@ namespace mediapipe { namespace { +constexpr char kDisallowTag[] = "DISALLOW"; +constexpr char kAllowTag[] = "ALLOW"; + class GateCalculatorTest : public ::testing::Test { protected: // Helper to run a graph and return status. @@ -110,6 +113,68 @@ TEST_F(GateCalculatorTest, InvalidInputs) { )"))); } +TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(false, output[1].Get()); +} + +TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: false + } + } + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + +TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + output_stream: "test_output" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { SetRunner(R"( calculator: "GateCalculator" @@ -117,7 +182,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true)); + runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -139,7 +204,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false)); + runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -161,7 +226,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false)); + runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); @@ -179,7 +244,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { input_stream: "test_input" output_stream: "test_output" )"); - runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true)); + runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); constexpr int64 kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); diff --git a/mediapipe/calculators/core/graph_profile_calculator.cc b/mediapipe/calculators/core/graph_profile_calculator.cc new file mode 100644 index 000000000..9b9aa3bb7 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator.cc @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/core/graph_profile_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_profile.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +// This calculator periodically copies the GraphProfile from +// mediapipe::GraphProfiler::CaptureProfile to the "PROFILE" output stream. +// +// Example config: +// node { +// calculator: "GraphProfileCalculator" +// output_stream: "FRAME:any_frame" +// output_stream: "PROFILE:graph_profile" +// } +// +class GraphProfileCalculator : public Node { + public: + static constexpr Input::Multiple kFrameIn{"FRAME"}; + static constexpr Output kProfileOut{"PROFILE"}; + + MEDIAPIPE_NODE_CONTRACT(kFrameIn, kProfileOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + auto options = cc->Options<::mediapipe::GraphProfileCalculatorOptions>(); + + if (prev_profile_ts_ == Timestamp::Unset() || + cc->InputTimestamp() - prev_profile_ts_ >= options.profile_interval()) { + prev_profile_ts_ = cc->InputTimestamp(); + GraphProfile result; + MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile(&result)); + kProfileOut(cc).Send(result); + } + return absl::OkStatus(); + } + + private: + Timestamp prev_profile_ts_; +}; + +MEDIAPIPE_REGISTER_NODE(GraphProfileCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/graph_profile_calculator.proto b/mediapipe/calculators/core/graph_profile_calculator.proto new file mode 100644 index 000000000..2bcc480c8 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator.proto @@ -0,0 +1,30 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option objc_class_prefix = "MediaPipe"; + +message GraphProfileCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional GraphProfileCalculatorOptions ext = 367481815; + } + + // The interval in microseconds between successive reported GraphProfiles. + optional int64 profile_interval = 1 [default = 1000000]; +} diff --git a/mediapipe/calculators/core/graph_profile_calculator_test.cc b/mediapipe/calculators/core/graph_profile_calculator_test.cc new file mode 100644 index 000000000..8a7845b19 --- /dev/null +++ b/mediapipe/calculators/core/graph_profile_calculator_test.cc @@ -0,0 +1,211 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_profile.pb.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/port/threadpool.h" +#include "mediapipe/framework/tool/simulation_clock_executor.h" + +// Tests for GraphProfileCalculator. +using testing::ElementsAre; + +namespace mediapipe { +namespace { + +constexpr char kClockTag[] = "CLOCK"; + +using mediapipe::Clock; + +// A Calculator with a fixed Process call latency. +class SleepCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Tag(kClockTag).Set>(); + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->SetTimestampOffset(TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Open(CalculatorContext* cc) final { + clock_ = + cc->InputSidePackets().Tag(kClockTag).Get>(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + clock_->Sleep(absl::Milliseconds(5)); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return absl::OkStatus(); + } + std::shared_ptr<::mediapipe::Clock> clock_ = nullptr; +}; +REGISTER_CALCULATOR(SleepCalculator); + +// Tests showing GraphProfileCalculator reporting GraphProfile output packets. +class GraphProfileCalculatorTest : public ::testing::Test { + protected: + void SetUpProfileGraph() { + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: "input_packets_0" + node { + calculator: 'SleepCalculator' + input_side_packet: 'CLOCK:sync_clock' + input_stream: 'input_packets_0' + output_stream: 'output_packets_1' + } + node { + calculator: "GraphProfileCalculator" + options: { + [mediapipe.GraphProfileCalculatorOptions.ext]: { + profile_interval: 25000 + } + } + input_stream: "FRAME:output_packets_1" + output_stream: "PROFILE:output_packets_0" + } + )", + &graph_config_)); + } + + static Packet PacketAt(int64 ts) { + return Adopt(new int64(999)).At(Timestamp(ts)); + } + static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); } + static bool IsNone(const Packet& packet) { + return packet.Timestamp() == Timestamp::OneOverPostStream(); + } + // Return the values of the timestamps of a vector of Packets. + static std::vector TimestampValues( + const std::vector& packets) { + std::vector result; + for (const Packet& p : packets) { + result.push_back(p.Timestamp().Value()); + } + return result; + } + + // Runs a CalculatorGraph with a series of packet sets. + // Returns a vector of packets from each graph output stream. + void RunGraph(const std::vector>& input_sets, + std::vector* output_packets) { + // Register output packet observers. + tool::AddVectorSink("output_packets_0", &graph_config_, output_packets); + + // Start running the graph. + std::shared_ptr executor( + new SimulationClockExecutor(3 /*num_threads*/)); + CalculatorGraph graph; + MP_ASSERT_OK(graph.SetExecutor("", executor)); + graph.profiler()->SetClock(executor->GetClock()); + MP_ASSERT_OK(graph.Initialize(graph_config_)); + executor->GetClock()->ThreadStart(); + MP_ASSERT_OK(graph.StartRun({ + {"sync_clock", + Adopt(new std::shared_ptr<::mediapipe::Clock>(executor->GetClock()))}, + })); + + // Send each packet to the graph in the specified order. + for (int t = 0; t < input_sets.size(); t++) { + const std::vector& input_set = input_sets[t]; + for (int i = 0; i < input_set.size(); i++) { + const Packet& packet = input_set[i]; + if (!IsNone(packet)) { + MP_EXPECT_OK(graph.AddPacketToInputStream( + absl::StrCat("input_packets_", i), packet)); + } + executor->GetClock()->Sleep(absl::Milliseconds(10)); + } + } + MP_ASSERT_OK(graph.CloseAllInputStreams()); + executor->GetClock()->Sleep(absl::Milliseconds(100)); + executor->GetClock()->ThreadFinish(); + MP_ASSERT_OK(graph.WaitUntilDone()); + } + + CalculatorGraphConfig graph_config_; +}; + +TEST_F(GraphProfileCalculatorTest, GraphProfile) { + SetUpProfileGraph(); + auto profiler_config = graph_config_.mutable_profiler_config(); + profiler_config->set_enable_profiler(true); + profiler_config->set_trace_enabled(false); + profiler_config->set_trace_log_disabled(true); + profiler_config->set_enable_stream_latency(true); + profiler_config->set_calculator_filter(".*Calculator"); + + // Run the graph with a series of packet sets. + std::vector> input_sets = { + {PacketAt(10000)}, // + {PacketAt(20000)}, // + {PacketAt(30000)}, // + {PacketAt(40000)}, + }; + std::vector output_packets; + RunGraph(input_sets, &output_packets); + + // Validate the output packets. + EXPECT_THAT(TimestampValues(output_packets), // + ElementsAre(10000, 40000)); + + GraphProfile expected_profile = + mediapipe::ParseTextProtoOrDie(R"pb( + calculator_profiles { + name: "GraphProfileCalculator" + open_runtime: 0 + process_runtime { total: 0 count: 3 } + process_input_latency { total: 15000 count: 3 } + process_output_latency { total: 15000 count: 3 } + input_stream_profiles { + name: "output_packets_1" + back_edge: false + latency { total: 0 count: 3 } + } + } + calculator_profiles { + name: "SleepCalculator" + open_runtime: 0 + process_runtime { total: 15000 count: 3 } + process_input_latency { total: 0 count: 3 } + process_output_latency { total: 15000 count: 3 } + input_stream_profiles { + name: "input_packets_0" + back_edge: false + latency { total: 0 count: 3 } + } + })pb"); + + EXPECT_THAT(output_packets[1].Get(), + mediapipe::EqualsProto(expected_profile)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc index 0bbf94dc8..45a0e1cd3 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc @@ -29,6 +29,9 @@ namespace mediapipe { namespace { +constexpr char kMinuendTag[] = "MINUEND"; +constexpr char kSubtrahendTag[] = "SUBTRAHEND"; + // A 3x4 Matrix of random integers in [0,1000). const char kMatrixText[] = "rows: 3\n" @@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) { CalculatorRunner runner(node_config); Matrix* side_matrix = new Matrix(); MatrixFromTextProto(kMatrixText, side_matrix); - runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix); + runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix); Matrix* input_matrix = new Matrix(); MatrixFromTextProto(kMatrixText2, input_matrix); - runner.MutableInputs()->Tag("MINUEND").packets.push_back( - Adopt(input_matrix).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kMinuendTag) + .packets.push_back(Adopt(input_matrix).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); EXPECT_EQ(1, runner.Outputs().Index(0).packets.size()); @@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) { CalculatorRunner runner(node_config); Matrix* side_matrix = new Matrix(); MatrixFromTextProto(kMatrixText, side_matrix); - runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix); + runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix); Matrix* input_matrix = new Matrix(); MatrixFromTextProto(kMatrixText2, input_matrix); runner.MutableInputs() - ->Tag("SUBTRAHEND") + ->Tag(kSubtrahendTag) .packets.push_back(Adopt(input_matrix).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index e2dd74553..86d2fab42 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -14,7 +14,11 @@ #include +#include "absl/status/status.h" +#include "absl/types/optional.h" #include "mediapipe/calculators/core/split_vector_calculator.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/port/gtest.h" @@ -301,4 +305,99 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) { } } // namespace + +class PassThroughAndTsBoundUpdateNode : public mediapipe::api2::Node { + public: + static constexpr mediapipe::api2::Input kInValue{"VALUE"}; + static constexpr mediapipe::api2::Output kOutValue{"VALUE"}; + static constexpr mediapipe::api2::Output kOutTsBoundUpdate{ + "TS_BOUND_UPDATE"}; + MEDIAPIPE_NODE_CONTRACT(kInValue, kOutValue, kOutTsBoundUpdate); + + absl::Status Process(CalculatorContext* cc) override { + kOutValue(cc).Send(kInValue(cc)); + kOutTsBoundUpdate(cc).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(PassThroughAndTsBoundUpdateNode); + +class ToOptionalNode : public mediapipe::api2::Node { + public: + static constexpr mediapipe::api2::Input kTick{"TICK"}; + static constexpr mediapipe::api2::Input kInValue{"VALUE"}; + static constexpr mediapipe::api2::Output> kOutValue{ + "OUTPUT"}; + MEDIAPIPE_NODE_CONTRACT(kTick, kInValue, kOutValue); + + absl::Status Process(CalculatorContext* cc) override { + if (kInValue(cc).IsEmpty()) { + kOutValue(cc).Send(absl::nullopt); + } else { + kOutValue(cc).Send({kInValue(cc).Get()}); + } + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(ToOptionalNode); + +namespace { + +TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "select" + node { + calculator: "PassThroughAndTsBoundUpdateNode" + input_stream: "VALUE:select" + output_stream: "VALUE:select_ps" + output_stream: "TS_BOUND_UPDATE:ts_bound_update" + } + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:select_ps" + input_stream: "INPUT:1:ts_bound_update" + input_stream: "SELECT:select" + output_stream: "OUTPUT:select_or_ts_bound_update" + } + node { + calculator: "ToOptionalNode" + input_stream: "TICK:select" + input_stream: "VALUE:select_or_ts_bound_update" + output_stream: "OUTPUT:output" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto send_value_fn = [&](int value, Timestamp ts) -> absl::Status { + MP_RETURN_IF_ERROR( + graph.AddPacketToInputStream("select", MakePacket(value).At(ts))); + return graph.WaitUntilIdle(); + }; + + MP_ASSERT_OK(send_value_fn(0, Timestamp(1))); + ASSERT_EQ(output_packets.size(), 1); + EXPECT_EQ(output_packets[0].Get>(), 0); + + MP_ASSERT_OK(send_value_fn(1, Timestamp(2))); + ASSERT_EQ(output_packets.size(), 2); + EXPECT_EQ(output_packets[1].Get>(), absl::nullopt); + + MP_ASSERT_OK(send_value_fn(0, Timestamp(3))); + ASSERT_EQ(output_packets.size(), 3); + EXPECT_EQ(output_packets[2].Get>(), 0); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace + } // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc index 41bddbfa7..ff55a87e7 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.cc +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -60,7 +60,10 @@ class PacketClonerCalculator : public CalculatorBase { const auto calculator_options = cc->Options(); output_only_when_all_inputs_received_ = - calculator_options.output_only_when_all_inputs_received(); + calculator_options.output_only_when_all_inputs_received() || + calculator_options.output_packets_only_when_all_inputs_received(); + output_empty_packets_before_all_inputs_received_ = + calculator_options.output_packets_only_when_all_inputs_received(); // Parse input streams. tick_signal_index_ = cc->Inputs().NumEntries() - 1; @@ -88,6 +91,9 @@ class PacketClonerCalculator : public CalculatorBase { // Return if one of the input is null. for (int i = 0; i < tick_signal_index_; ++i) { if (current_[i].IsEmpty()) { + if (output_empty_packets_before_all_inputs_received_) { + SetAllNextTimestampBounds(cc); + } return absl::OkStatus(); } } @@ -107,9 +113,17 @@ class PacketClonerCalculator : public CalculatorBase { } private: + void SetAllNextTimestampBounds(CalculatorContext* cc) { + for (int j = 0; j < tick_signal_index_; ++j) { + cc->Outputs().Index(j).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + } + std::vector current_; int tick_signal_index_; bool output_only_when_all_inputs_received_; + bool output_empty_packets_before_all_inputs_received_; }; REGISTER_CALCULATOR(PacketClonerCalculator); diff --git a/mediapipe/calculators/core/packet_cloner_calculator.proto b/mediapipe/calculators/core/packet_cloner_calculator.proto index e30672fab..82bfa9c7a 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.proto +++ b/mediapipe/calculators/core/packet_cloner_calculator.proto @@ -28,4 +28,9 @@ message PacketClonerCalculatorOptions { // When true, this calculator will drop received TICK packets if any input // stream hasn't received a packet yet. optional bool output_only_when_all_inputs_received = 1 [default = false]; + + // Similar with above, but also transmit empty packet for all streams before + // all inputs are received. + optional bool output_packets_only_when_all_inputs_received = 2 + [default = false]; } diff --git a/mediapipe/calculators/core/packet_presence_calculator.cc b/mediapipe/calculators/core/packet_presence_calculator.cc index cb119a76d..ac723ad8a 100644 --- a/mediapipe/calculators/core/packet_presence_calculator.cc +++ b/mediapipe/calculators/core/packet_presence_calculator.cc @@ -17,6 +17,9 @@ namespace mediapipe { +constexpr char kPresenceTag[] = "PRESENCE"; +constexpr char kPacketTag[] = "PACKET"; + // For each non empty input packet, emits a single output packet containing a // boolean value "true", "false" in response to empty packets (a.k.a. timestamp // bound updates) This can be used to "flag" the presence of an arbitrary packet @@ -58,8 +61,8 @@ namespace mediapipe { class PacketPresenceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("PACKET").SetAny(); - cc->Outputs().Tag("PRESENCE").Set(); + cc->Inputs().Tag(kPacketTag).SetAny(); + cc->Outputs().Tag(kPresenceTag).Set(); // Process() function is invoked in response to input stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); @@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { cc->Outputs() - .Tag("PRESENCE") - .AddPacket(MakePacket(!cc->Inputs().Tag("PACKET").IsEmpty()) + .Tag(kPresenceTag) + .AddPacket(MakePacket(!cc->Inputs().Tag(kPacketTag).IsEmpty()) .At(cc->InputTimestamp())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 43253520a..81ccdbe65 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -39,6 +39,11 @@ namespace mediapipe { REGISTER_CALCULATOR(PacketResamplerCalculator); namespace { + +constexpr char kSeedTag[] = "SEED"; +constexpr char kVideoHeaderTag[] = "VIDEO_HEADER"; +constexpr char kOptionsTag[] = "OPTIONS"; + // Returns a TimestampDiff (assuming microseconds) corresponding to the // given time in seconds. TimestampDiff TimestampDiffFromSeconds(double seconds) { @@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { const auto& resampler_options = cc->Options(); - if (cc->InputSidePackets().HasTag("OPTIONS")) { - cc->InputSidePackets().Tag("OPTIONS").Set(); + if (cc->InputSidePackets().HasTag(kOptionsTag)) { + cc->InputSidePackets().Tag(kOptionsTag).Set(); } CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0); if (!input_data_id.IsValid()) { input_data_id = cc->Inputs().GetId("", 0); } cc->Inputs().Get(input_data_id).SetAny(); - if (cc->Inputs().HasTag("VIDEO_HEADER")) { - cc->Inputs().Tag("VIDEO_HEADER").Set(); + if (cc->Inputs().HasTag(kVideoHeaderTag)) { + cc->Inputs().Tag(kVideoHeaderTag).Set(); } CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0); @@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { output_data_id = cc->Outputs().GetId("", 0); } cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id)); - if (cc->Outputs().HasTag("VIDEO_HEADER")) { - cc->Outputs().Tag("VIDEO_HEADER").Set(); + if (cc->Outputs().HasTag(kVideoHeaderTag)) { + cc->Outputs().Tag(kVideoHeaderTag).Set(); } if (resampler_options.jitter() != 0.0) { RET_CHECK_GT(resampler_options.jitter(), 0.0); RET_CHECK_LE(resampler_options.jitter(), 1.0); - RET_CHECK(cc->InputSidePackets().HasTag("SEED")); - cc->InputSidePackets().Tag("SEED").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag)); + cc->InputSidePackets().Tag(kSeedTag).Set(); } return absl::OkStatus(); } @@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream() && - cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && - !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { - video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); + cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) && + !cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) { + video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get(); video_header_.frame_rate = frame_rate_; if (cc->Inputs().Get(input_data_id_).IsEmpty()) { return absl::OkStatus(); @@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) { "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open( "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) { "ignored, because we are adding jitter."; } - const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { return absl::InvalidArgumentError( @@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) { base_timestamp_ + TimestampDiffFromSeconds(first_index / calculator_->frame_rate_); } - if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { + if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) { cc->Outputs() - .Tag("VIDEO_HEADER") + .Tag(kVideoHeaderTag) .Add(new VideoHeader(calculator_->video_header_), Timestamp::PreStream()); } diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index 191e1d842..f02da0d18 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -32,6 +32,12 @@ namespace mediapipe { using ::testing::ElementsAre; namespace { + +constexpr char kOptionsTag[] = "OPTIONS"; +constexpr char kSeedTag[] = "SEED"; +constexpr char kVideoHeaderTag[] = "VIDEO_HEADER"; +constexpr char kDataTag[] = "DATA"; + // A simple version of CalculatorRunner with built-in convenience // methods for setting inputs from a vector and checking outputs // against expected outputs (both timestamps and contents). @@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { )pb")); for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { - runner.MutableInputs()->Tag("DATA").packets.push_back( + runner.MutableInputs()->Tag(kDataTag).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); } VideoHeader video_header_in; @@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { video_header_in.duration = 1.0; video_header_in.format = ImageFormat::SRGB; runner.MutableInputs() - ->Tag("VIDEO_HEADER") + ->Tag(kVideoHeaderTag) .packets.push_back( Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); MP_ASSERT_OK(runner.Run()); - ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size()); + ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size()); EXPECT_EQ(Timestamp::PreStream(), - runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp()); + runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp()); const VideoHeader& video_header_out = - runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get(); + runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get(); EXPECT_EQ(video_header_in.width, video_header_out.width); EXPECT_EQ(video_header_in.height, video_header_out.height); EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate); @@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) { [mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 30 })pb")); - runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options); runner.SetInput({-222, 15000, 32000, 49999, 150000}); MP_ASSERT_OK(runner.Run()); EXPECT_EQ(6, runner.Outputs().Index(0).packets.size()); @@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) { frame_rate: 30 base_timestamp: 0 })pb")); - runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options); runner.SetInput({-222, 15000, 32000, 49999, 150000}); MP_ASSERT_OK(runner.Run()); diff --git a/mediapipe/calculators/core/packet_thinner_calculator.cc b/mediapipe/calculators/core/packet_thinner_calculator.cc index d3d391b61..1d94d886b 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator.cc @@ -217,6 +217,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { header->format = video_header.format; header->width = video_header.width; header->height = video_header.height; + header->duration = video_header.duration; header->frame_rate = new_frame_rate; cc->Outputs().Index(0).SetHeader(Adopt(header.release())); } else { diff --git a/mediapipe/calculators/core/packet_thinner_calculator_test.cc b/mediapipe/calculators/core/packet_thinner_calculator_test.cc index 86fcc00f9..3522488e7 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator_test.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator_test.cc @@ -29,6 +29,8 @@ namespace mediapipe { namespace { +constexpr char kPeriodTag[] = "PERIOD"; + // A simple version of CalculatorRunner with built-in convenience methods for // setting inputs from a vector and checking outputs against a vector of // expected outputs. @@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) { SimpleRunner runner(node); runner.SetInput({2, 4, 6, 8, 10, 12, 14}); - runner.MutableSidePackets()->Tag("PERIOD") = MakePacket(5); + runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket(5); MP_ASSERT_OK(runner.Run()); const std::vector expected_timestamps = {2, 8, 14}; @@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) { SimpleRunner runner(node); runner.SetInput({2, 4, 6, 8, 10, 12, 14}); - runner.MutableSidePackets()->Tag("PERIOD") = MakePacket(5); + runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket(5); MP_ASSERT_OK(runner.Run()); const std::vector expected_timestamps = {2, 6, 10, 14}; diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index c9d431d1c..563417669 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -39,6 +39,8 @@ using ::testing::Pair; using ::testing::Value; namespace { +constexpr char kDisallowTag[] = "DISALLOW"; + // Returns the timestamp values for a vector of Packets. // TODO: puth this kind of test util in a common place. std::vector TimestampValues(const std::vector& packets) { @@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); - cc->Inputs().Tag("DISALLOW").Set(); + cc->Inputs().Tag(kDisallowTag).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { if (!cc->Inputs().Index(0).IsEmpty() && - !cc->Inputs().Tag("DISALLOW").Get()) { + !cc->Inputs().Tag(kDisallowTag).Get()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc index e95509298..e1df66c1a 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -41,11 +41,14 @@ // } namespace mediapipe { +constexpr char kEncodedTag[] = "ENCODED"; +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; + class QuantizeFloatVectorCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("FLOAT_VECTOR").Set>(); - cc->Outputs().Tag("ENCODED").Set(); + cc->Inputs().Tag(kFloatVectorTag).Set>(); + cc->Outputs().Tag(kEncodedTag).Set(); return absl::OkStatus(); } @@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { const std::vector& float_vector = - cc->Inputs().Tag("FLOAT_VECTOR").Value().Get>(); + cc->Inputs().Tag(kFloatVectorTag).Value().Get>(); int feature_size = float_vector.size(); std::string encoded_features; encoded_features.reserve(feature_size); @@ -86,8 +89,10 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { (old_value - min_quantized_value_) * (255.0 / range_)); encoded_features += encoded; } - cc->Outputs().Tag("ENCODED").AddPacket( - MakePacket(encoded_features).At(cc->InputTimestamp())); + cc->Outputs() + .Tag(kEncodedTag) + .AddPacket( + MakePacket(encoded_features).At(cc->InputTimestamp())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc index 8f23437b6..a3a410565 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kEncodedTag[] = "ENCODED"; +constexpr char kFloatVectorTag[] = "FLOAT_VECTOR"; + TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); auto status = runner.Run(); @@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) { CalculatorRunner runner(node_config); std::vector empty_vector; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(empty_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].Get().empty()); EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); @@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) { CalculatorRunner runner(node_config); std::vector vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f}; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); const std::string& result = outputs[0].Get(); ASSERT_FALSE(result.empty()); @@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) { CalculatorRunner runner(node_config); std::vector vector = {-65.0f, 65.0f}; runner.MutableInputs() - ->Tag("FLOAT_VECTOR") + ->Tag(kFloatVectorTag) .packets.push_back( MakePacket>(vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); - const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + const std::vector& outputs = + runner.Outputs().Tag(kEncodedTag).packets; EXPECT_EQ(1, outputs.size()); const std::string& result = outputs[0].Get(); ASSERT_FALSE(result.empty()); diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index 277f83fe2..ef3cb9896 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -23,6 +23,9 @@ namespace mediapipe { +constexpr char kAllowTag[] = "ALLOW"; +constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; + // RealTimeFlowLimiterCalculator is used to limit the number of pipelined // processing operations in a section of the graph. // @@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); } cc->Inputs().Get("FINISHED", 0).SetAny(); - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set(); + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { + cc->InputSidePackets().Tag(kMaxInFlightTag).Set(); } - if (cc->Outputs().HasTag("ALLOW")) { - cc->Outputs().Tag("ALLOW").Set(); + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).Set(); } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); @@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { finished_id_ = cc->Inputs().GetId("FINISHED", 0); max_in_flight_ = 1; - if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { - max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get(); + if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) { + max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get(); } RET_CHECK_GE(max_in_flight_, 1); num_in_flight_ = 0; diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc index fe4785860..7fddd7fdf 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc @@ -33,6 +33,9 @@ namespace mediapipe { namespace { + +constexpr char kFinishedTag[] = "FINISHED"; + // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: @@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) { Timestamp timestamp = Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); runner.MutableInputs() - ->Tag("FINISHED") + ->Tag(kFinishedTag) .packets.push_back(MakePacket(true).At(timestamp)); } diff --git a/mediapipe/calculators/core/sequence_shift_calculator_test.cc b/mediapipe/calculators/core/sequence_shift_calculator_test.cc index 23ad57225..8c749904c 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator_test.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator_test.cc @@ -22,6 +22,8 @@ namespace mediapipe { namespace { +constexpr char kPacketOffsetTag[] = "PACKET_OFFSET"; + // Adds packets containing integers equal to their original timestamp. void AddPackets(CalculatorRunner* runner) { for (int i = 0; i < 10; ++i) { @@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) { CalculatorRunner runner(node); AddPackets(&runner); - runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2)); + runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2)); MP_ASSERT_OK(runner.Run()); const std::vector& input_packets = runner.MutableInputs()->Index(0).packets; diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index c8f1177d5..a80136be7 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -80,4 +80,7 @@ typedef SplitVectorCalculator SplitClassificationListVectorCalculator; REGISTER_CALCULATOR(SplitClassificationListVectorCalculator); +typedef SplitVectorCalculator SplitUint64tVectorCalculator; +REGISTER_CALCULATOR(SplitUint64tVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 507b6f0ff..0bbfadd05 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -661,3 +661,138 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", ], ) + +cc_library( + name = "affine_transformation", + hdrs = ["affine_transformation.h"], + deps = ["@com_google_absl//absl/status:statusor"], +) + +cc_library( + name = "affine_transformation_runner_gl", + srcs = ["affine_transformation_runner_gl.cc"], + hdrs = ["affine_transformation_runner_gl.h"], + deps = [ + ":affine_transformation", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/gpu:shader_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + ], +) + +cc_library( + name = "affine_transformation_runner_opencv", + srcs = ["affine_transformation_runner_opencv.cc"], + hdrs = ["affine_transformation_runner_opencv.h"], + deps = [ + ":affine_transformation", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + ], +) + +mediapipe_proto_library( + name = "warp_affine_calculator_proto", + srcs = ["warp_affine_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:gpu_origin_proto", + ], +) + +cc_library( + name = "warp_affine_calculator", + srcs = ["warp_affine_calculator.cc"], + hdrs = ["warp_affine_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":affine_transformation", + ":affine_transformation_runner_opencv", + ":warp_affine_calculator_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + ":affine_transformation_runner_gl", + ], + }), + alwayslink = 1, +) + +cc_test( + name = "warp_affine_calculator_test", + srcs = ["warp_affine_calculator_test.cc"], + data = [ + "//mediapipe/calculators/tensor:testdata/image_to_tensor/input.jpg", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/large_sub_rect_keep_aspect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", + ], + tags = ["desktop_only_test"], + deps = [ + ":affine_transformation", + ":warp_affine_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_converter", + "//mediapipe/calculators/tensor:image_to_tensor_utils", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/calculators/image/affine_transformation.h b/mediapipe/calculators/image/affine_transformation.h new file mode 100644 index 000000000..40793e7a1 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation.h @@ -0,0 +1,55 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ + +#include + +#include "absl/status/statusor.h" + +namespace mediapipe { + +class AffineTransformation { + public: + // Pixel extrapolation method. + // When converting image to tensor it may happen that tensor needs to read + // pixels outside image boundaries. Border mode helps to specify how such + // pixels will be calculated. + enum class BorderMode { kZero, kReplicate }; + + struct Size { + int width; + int height; + }; + + template + class Runner { + public: + virtual ~Runner() = default; + + // Transforms input into output using @matrix as following: + // output(x, y) = input(matrix[0] * x + matrix[1] * y + matrix[3], + // matrix[4] * x + matrix[5] * y + matrix[7]) + // where x and y ranges are defined by @output_size. + virtual absl::StatusOr Run(const InputT& input, + const std::array& matrix, + const Size& output_size, + BorderMode border_mode) = 0; + }; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_H_ diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc new file mode 100644 index 000000000..c38fc8e07 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -0,0 +1,354 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/affine_transformation_runner_gl.h" + +#include +#include + +#include "Eigen/Core" +#include "Eigen/Geometry" +#include "Eigen/LU" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_origin.pb.h" +#include "mediapipe/gpu/shader_util.h" + +namespace mediapipe { + +namespace { + +using mediapipe::GlCalculatorHelper; +using mediapipe::GlhCreateProgram; +using mediapipe::GlTexture; +using mediapipe::GpuBuffer; +using mediapipe::GpuOrigin; + +bool IsMatrixVerticalFlipNeeded(GpuOrigin::Mode gpu_origin) { + switch (gpu_origin) { + case GpuOrigin::DEFAULT: + case GpuOrigin::CONVENTIONAL: +#ifdef __APPLE__ + return false; +#else + return true; +#endif // __APPLE__ + case GpuOrigin::TOP_LEFT: + return false; + } +} + +#ifdef __APPLE__ +#define GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED 0 +#else +#define GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED 1 +#endif // __APPLE__ + +bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context) { + return gl_context.gl_major_version() > 3 || + (gl_context.gl_major_version() == 3 && + gl_context.gl_minor_version() >= 2); +} + +constexpr int kAttribVertex = 0; +constexpr int kAttribTexturePosition = 1; +constexpr int kNumAttributes = 2; + +class GlTextureWarpAffineRunner + : public AffineTransformation::Runner> { + public: + GlTextureWarpAffineRunner(std::shared_ptr gl_helper, + GpuOrigin::Mode gpu_origin) + : gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} + absl::Status Init() { + return gl_helper_->RunInGlContext([this]() -> absl::Status { + const GLint attr_location[kNumAttributes] = { + kAttribVertex, + kAttribTexturePosition, + }; + const GLchar* attr_name[kNumAttributes] = { + "position", + "texture_coordinate", + }; + + constexpr GLchar kVertShader[] = R"( + in vec4 position; + in mediump vec4 texture_coordinate; + out mediump vec2 sample_coordinate; + uniform mat4 transform_matrix; + + void main() { + gl_Position = position; + vec4 tc = transform_matrix * texture_coordinate; + sample_coordinate = tc.xy; + } + )"; + + constexpr GLchar kFragShader[] = R"( + DEFAULT_PRECISION(mediump, float) + in vec2 sample_coordinate; + uniform sampler2D input_texture; + + #ifdef GL_ES + #define fragColor gl_FragColor + #else + out vec4 fragColor; + #endif // defined(GL_ES); + + void main() { + vec4 color = texture2D(input_texture, sample_coordinate); + #ifdef CUSTOM_ZERO_BORDER_MODE + float out_of_bounds = + float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || + sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); + color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); + #endif // defined(CUSTOM_ZERO_BORDER_MODE) + fragColor = color; + } + )"; + + // Create program and set parameters. + auto create_fn = [&](const std::string& vs, + const std::string& fs) -> absl::StatusOr { + GLuint program = 0; + GlhCreateProgram(vs.c_str(), fs.c_str(), kNumAttributes, &attr_name[0], + attr_location, &program); + + RET_CHECK(program) << "Problem initializing warp affine program."; + glUseProgram(program); + glUniform1i(glGetUniformLocation(program, "input_texture"), 1); + GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); + return Program{.id = program, .matrix_id = matrix_id}; + }; + + const std::string vert_src = + absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); + + const std::string frag_src = absl::StrCat( + mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); + + ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); + + auto create_custom_zero_fn = [&]() -> absl::StatusOr { + std::string custom_zero_border_mode_def = R"( + #define CUSTOM_ZERO_BORDER_MODE + )"; + const std::string frag_custom_zero_src = + absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, + custom_zero_border_mode_def, kFragShader); + return create_fn(vert_src, frag_custom_zero_src); + }; +#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + if (!IsGlClampToBorderSupported(gl_helper_->GetGlContext())) { + ASSIGN_OR_RETURN(program_custom_zero_, create_custom_zero_fn()); + } +#else + ASSIGN_OR_RETURN(program_custom_zero_, create_custom_zero_fn()); +#endif // GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + + glGenFramebuffers(1, &framebuffer_); + + // vertex storage + glGenBuffers(2, vbo_); + glGenVertexArrays(1, &vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicSquareVertices), + mediapipe::kBasicSquareVertices, GL_STATIC_DRAW); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glBufferData(GL_ARRAY_BUFFER, sizeof(mediapipe::kBasicTextureVertices), + mediapipe::kBasicTextureVertices, GL_STATIC_DRAW); + + glBindBuffer(GL_ARRAY_BUFFER, 0); + + return absl::OkStatus(); + }); + } + + absl::StatusOr> Run( + const GpuBuffer& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + std::unique_ptr gpu_buffer; + MP_RETURN_IF_ERROR( + gl_helper_->RunInGlContext([this, &input, &matrix, &size, &border_mode, + &gpu_buffer]() -> absl::Status { + auto input_texture = gl_helper_->CreateSourceTexture(input); + auto output_texture = gl_helper_->CreateDestinationTexture( + size.width, size.height, input.format()); + + MP_RETURN_IF_ERROR( + RunInternal(input_texture, matrix, border_mode, &output_texture)); + gpu_buffer = output_texture.GetFrame(); + return absl::OkStatus(); + })); + + return gpu_buffer; + } + + absl::Status RunInternal(const GlTexture& texture, + const std::array& matrix, + AffineTransformation::BorderMode border_mode, + GlTexture* output) { + glDisable(GL_DEPTH_TEST); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, output->width(), output->height()); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, output->name()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + output->name(), 0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(texture.target(), texture.name()); + + // a) Filtering. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + + // b) Clamping. + std::optional program = program_; + switch (border_mode) { + case AffineTransformation::BorderMode::kReplicate: { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + break; + } + case AffineTransformation::BorderMode::kZero: { +#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + if (program_custom_zero_) { + program = program_custom_zero_; + } else { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_BORDER); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_BORDER); + glTexParameterfv(GL_TEXTURE_2D, GL_TEXTURE_BORDER_COLOR, + std::array{0.0f, 0.0f, 0.0f, 0.0f}.data()); + } +#else + RET_CHECK(program_custom_zero_) + << "Program must have been initialized."; + program = program_custom_zero_; +#endif // GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + break; + } + } + glUseProgram(program->id); + + Eigen::Matrix eigen_mat(matrix.data()); + if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { + // @matrix describes affine transformation in terms of TOP LEFT origin, so + // in some cases/on some platforms an extra flipping should be done before + // and after. + const Eigen::Matrix flip_y( + {{1.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, -1.0f, 0.0f, 1.0f}, + {0.0f, 0.0f, 1.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f}}); + eigen_mat = flip_y * eigen_mat * flip_y; + } + + // If GL context is ES2, then GL_FALSE must be used for 'transpose' + // GLboolean in glUniformMatrix4fv, or else INVALID_VALUE error is reported. + // Hence, transposing the matrix and always passing transposed. + eigen_mat.transposeInPlace(); + glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); + + // vao + glBindVertexArray(vao_); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[0]); + glEnableVertexAttribArray(kAttribVertex); + glVertexAttribPointer(kAttribVertex, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo_[1]); + glEnableVertexAttribArray(kAttribTexturePosition); + glVertexAttribPointer(kAttribTexturePosition, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // Resetting to MediaPipe texture param defaults. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + + glDisableVertexAttribArray(kAttribVertex); + glDisableVertexAttribArray(kAttribTexturePosition); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, 0); + + return absl::OkStatus(); + } + + ~GlTextureWarpAffineRunner() override { + gl_helper_->RunInGlContext([this]() { + // Release OpenGL resources. + if (framebuffer_ != 0) glDeleteFramebuffers(1, &framebuffer_); + if (program_.id != 0) glDeleteProgram(program_.id); + if (program_custom_zero_ && program_custom_zero_->id != 0) { + glDeleteProgram(program_custom_zero_->id); + } + if (vao_ != 0) glDeleteVertexArrays(1, &vao_); + glDeleteBuffers(2, vbo_); + }); + } + + private: + struct Program { + GLuint id; + GLint matrix_id; + }; + std::shared_ptr gl_helper_; + GpuOrigin::Mode gpu_origin_; + GLuint vao_ = 0; + GLuint vbo_[2] = {0, 0}; + Program program_; + std::optional program_custom_zero_; + GLuint framebuffer_ = 0; +}; + +#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED + +} // namespace + +absl::StatusOr>>> +CreateAffineTransformationGlRunner( + std::shared_ptr gl_helper, GpuOrigin::Mode gpu_origin) { + auto runner = + absl::make_unique(gl_helper, gpu_origin); + MP_RETURN_IF_ERROR(runner->Init()); + return runner; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.h b/mediapipe/calculators/image/affine_transformation_runner_gl.h new file mode 100644 index 000000000..677e0720d --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.h @@ -0,0 +1,36 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_origin.pb.h" + +namespace mediapipe { + +absl::StatusOr>>> +CreateAffineTransformationGlRunner( + std::shared_ptr gl_helper, + mediapipe::GpuOrigin::Mode gpu_origin); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_GL_H_ diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.cc b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc new file mode 100644 index 000000000..46026a987 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc @@ -0,0 +1,160 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +cv::BorderTypes GetBorderModeForOpenCv( + AffineTransformation::BorderMode border_mode) { + switch (border_mode) { + case AffineTransformation::BorderMode::kZero: + return cv::BORDER_CONSTANT; + case AffineTransformation::BorderMode::kReplicate: + return cv::BORDER_REPLICATE; + } +} + +class OpenCvRunner + : public AffineTransformation::Runner { + public: + absl::StatusOr Run( + const ImageFrame& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + // OpenCV warpAffine works in absolute coordinates, so the transfom (which + // accepts and produces relative coordinates) should be adjusted to first + // normalize coordinates and then scale them. + // clang-format off + cv::Matx44f normalize_dst_coordinate({ + 1.0f / size.width, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f / size.height, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + cv::Matx44f scale_src_coordinate({ + 1.0f * input.Width(), 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f * input.Height(), 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + // clang-format on + cv::Matx44f adjust_dst_coordinate; + cv::Matx44f adjust_src_coordinate; + // TODO: update to always use accurate implementation. + constexpr bool kOpenCvCompatibility = true; + if (kOpenCvCompatibility) { + adjust_dst_coordinate = normalize_dst_coordinate; + adjust_src_coordinate = scale_src_coordinate; + } else { + // To do an accurate affine image transformation and make "on-cpu" and + // "on-gpu" calculations aligned - extra offset is required to select + // correct pixels. + // + // Each destination pixel corresponds to some pixels region from source + // image.(In case of downscaling there can be more than one pixel.) The + // offset for x and y is calculated in the way, so pixel in the middle of + // the region is selected. + // + // For simplicity sake, let's consider downscaling from 100x50 to 10x10 + // without a rotation: + // 1. Each destination pixel corresponds to 10x5 region + // X range: [0, .. , 9] + // Y range: [0, .. , 4] + // 2. Considering we have __discrete__ pixels, the center of the region is + // between (4, 2) and (5, 2) pixels, let's assume it's a "pixel" + // (4.5, 2). + // 3. When using the above as an offset for every pixel select while + // downscaling, resulting pixels are: + // (4.5, 2), (14.5, 2), .. , (94.5, 2) + // (4.5, 7), (14.5, 7), .. , (94.5, 7) + // .. + // (4.5, 47), (14.5, 47), .., (94.5, 47) + // instead of: + // (0, 0), (10, 0), .. , (90, 0) + // (0, 5), (10, 7), .. , (90, 5) + // .. + // (0, 45), (10, 45), .., (90, 45) + // The latter looks shifted. + // + // Offsets are needed, so that __discrete__ pixel at (0, 0) corresponds to + // the same pixel as would __non discrete__ pixel at (0.5, 0.5). Hence, + // transformation matrix should shift coordinates by (0.5, 0.5) as the + // very first step. + // + // Due to the above shift, transformed coordinates would be valid for + // float coordinates where pixel (0, 0) spans [0.0, 1.0) x [0.0, 1.0). + // T0 make it valid for __discrete__ pixels, transformation matrix should + // shift coordinate by (-0.5f, -0.5f) as the very last step. (E.g. if we + // get (0.5f, 0.5f), then it's (0, 0) __discrete__ pixel.) + // clang-format off + cv::Matx44f shift_dst({1.0f, 0.0f, 0.0f, 0.5f, + 0.0f, 1.0f, 0.0f, 0.5f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + cv::Matx44f shift_src({1.0f, 0.0f, 0.0f, -0.5f, + 0.0f, 1.0f, 0.0f, -0.5f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}); + // clang-format on + adjust_dst_coordinate = normalize_dst_coordinate * shift_dst; + adjust_src_coordinate = shift_src * scale_src_coordinate; + } + + cv::Matx44f transform(matrix.data()); + cv::Matx44f transform_absolute = + adjust_src_coordinate * transform * adjust_dst_coordinate; + + cv::Mat in_mat = formats::MatView(&input); + + cv::Mat cv_affine_transform(2, 3, CV_32F); + cv_affine_transform.at(0, 0) = transform_absolute.val[0]; + cv_affine_transform.at(0, 1) = transform_absolute.val[1]; + cv_affine_transform.at(0, 2) = transform_absolute.val[3]; + cv_affine_transform.at(1, 0) = transform_absolute.val[4]; + cv_affine_transform.at(1, 1) = transform_absolute.val[5]; + cv_affine_transform.at(1, 2) = transform_absolute.val[7]; + + ImageFrame out_image(input.Format(), size.width, size.height); + cv::Mat out_mat = formats::MatView(&out_image); + + cv::warpAffine(in_mat, out_mat, cv_affine_transform, + cv::Size(out_mat.cols, out_mat.rows), + /*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, + GetBorderModeForOpenCv(border_mode)); + + return out_image; + } +}; + +} // namespace + +absl::StatusOr< + std::unique_ptr>> +CreateAffineTransformationOpenCvRunner() { + return absl::make_unique(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.h b/mediapipe/calculators/image/affine_transformation_runner_opencv.h new file mode 100644 index 000000000..200281c95 --- /dev/null +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.h @@ -0,0 +1,32 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/framework/formats/image_frame.h" + +namespace mediapipe { + +absl::StatusOr< + std::unique_ptr>> +CreateAffineTransformationOpenCvRunner(); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_AFFINE_TRANSFORMATION_RUNNER_OPENCV_H_ diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 3d878bffc..6bb43dc00 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -240,7 +240,7 @@ absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { auto input_mat = mediapipe::formats::MatView(&input_frame); // Only 1 or 3 channel images supported by OpenCV. - if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { + if (!(input_mat.channels() == 1 || input_mat.channels() == 3)) { return absl::InternalError( "CPU filtering supports only 1 or 3 channel input images."); } diff --git a/mediapipe/calculators/image/image_clone_calculator.cc b/mediapipe/calculators/image/image_clone_calculator.cc index 107c42b92..1e76848b1 100644 --- a/mediapipe/calculators/image/image_clone_calculator.cc +++ b/mediapipe/calculators/image/image_clone_calculator.cc @@ -36,7 +36,7 @@ using GpuBuffer = mediapipe::GpuBuffer; // stored on the target storage (CPU vs GPU) specified in the calculator option. // // The clone shares ownership of the input pixel data on the existing storage. -// If the target storage is diffrent from the existing one, then the data is +// If the target storage is different from the existing one, then the data is // further copied there. // // Example usage: diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 07f7d5f46..8c9305ffb 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -480,8 +480,7 @@ RectSpec ImageCroppingCalculator::GetCropSpecs(const CalculatorContext* cc, if (cc->Inputs().HasTag(kRectTag)) { const auto& rect = cc->Inputs().Tag(kRectTag).Get(); // Only use the rect if it is valid. - if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && - rect.y_center() >= 0) { + if (rect.width() > 0 && rect.height() > 0) { x_center = rect.x_center(); y_center = rect.y_center(); crop_width = rect.width(); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 60873ae9f..76cc845e2 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode( // IMAGE: ImageFrame representing the input image. // IMAGE_GPU: GpuBuffer representing the input image. // +// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as +// pair. If set, it will override corresponding field in calculator +// options and input side packet. +// // ROTATION_DEGREES (optional): The counterclockwise rotation angle in // degrees. This allows different rotation angles for different frames. It has // to be a multiple of 90 degrees. If provided, it overrides the @@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract( } #endif // !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set>(); + } + if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); } @@ -329,6 +337,16 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { !cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) { flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get(); } + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { + return absl::OkStatus(); + } else { + const auto& image_size = + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get>(); + output_width_ = image_size.first; + output_height_ = image_size.second; + } + } if (use_gpu_) { #if !MEDIAPIPE_DISABLE_GPU @@ -491,6 +509,14 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { ComputeOutputDimensions(input_width, input_height, &output_width, &output_height); + if (scale_mode_ == mediapipe::ScaleMode_Mode_FILL_AND_CROP) { + const float scale = + std::min(static_cast(output_width_) / input_width, + static_cast(output_height_) / input_height); + output_width = std::round(input_width * scale); + output_height = std::round(input_height * scale); + } + if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { auto padding = absl::make_unique>(); ComputeOutputLetterboxPadding(input_width, input_height, output_width, diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index 575268da5..0669f5322 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -262,6 +262,7 @@ absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) { scale_image::FindOutputDimensions(crop_width_, crop_height_, // options_.target_width(), // options_.target_height(), // + options_.target_max_area(), // options_.preserve_aspect_ratio(), // options_.scale_to_multiple_of(), // &output_width_, &output_height_)); diff --git a/mediapipe/calculators/image/scale_image_calculator.proto b/mediapipe/calculators/image/scale_image_calculator.proto index e51ccafaa..2b7572d56 100644 --- a/mediapipe/calculators/image/scale_image_calculator.proto +++ b/mediapipe/calculators/image/scale_image_calculator.proto @@ -28,6 +28,11 @@ message ScaleImageCalculatorOptions { optional int32 target_width = 1; optional int32 target_height = 2; + // If set, then automatically calculates a target_width and target_height that + // has an area below the target max area. Aspect ratio preservation cannot be + // disabled. + optional int32 target_max_area = 15; + // If true, the image is scaled up or down proportionally so that it // fits inside the box represented by target_width and target_height. // Otherwise it is scaled to fit target_width and target_height diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 738e83da0..490d0336a 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -92,12 +92,21 @@ absl::Status FindOutputDimensions(int input_width, // int input_height, // int target_width, // int target_height, // + int target_max_area, // bool preserve_aspect_ratio, // int scale_to_multiple_of, // int* output_width, int* output_height) { CHECK(output_width); CHECK(output_height); + if (target_max_area > 0 && input_width * input_height > target_max_area) { + preserve_aspect_ratio = true; + target_height = static_cast(sqrt(static_cast(target_max_area) / + (static_cast(input_width) / + static_cast(input_height)))); + target_width = -1; // Resize width to preserve aspect ratio. + } + if (preserve_aspect_ratio) { RET_CHECK(scale_to_multiple_of == 2) << "FindOutputDimensions always outputs width and height that are " @@ -164,5 +173,17 @@ absl::Status FindOutputDimensions(int input_width, // << "Unable to set output dimensions based on target dimensions."; } +absl::Status FindOutputDimensions(int input_width, // + int input_height, // + int target_width, // + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height) { + return FindOutputDimensions( + input_width, input_height, target_width, target_height, -1, + preserve_aspect_ratio, scale_to_multiple_of, output_width, output_height); +} + } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_utils.h b/mediapipe/calculators/image/scale_image_utils.h index c2c0b8f7c..e7fccd8dc 100644 --- a/mediapipe/calculators/image/scale_image_utils.h +++ b/mediapipe/calculators/image/scale_image_utils.h @@ -34,15 +34,25 @@ absl::Status FindCropDimensions(int input_width, int input_height, // int* crop_width, int* crop_height, // int* col_start, int* row_start); -// Given an input width and height, a target width and height, whether to -// preserve the aspect ratio, and whether to round-down to the multiple of a -// given number nearest to the targets, determine the output width and height. -// If target_width or target_height is non-positive, then they will be set to -// the input_width and input_height respectively. If scale_to_multiple_of is -// less than 1, it will be treated like 1. The output_width and -// output_height will be reduced as necessary to preserve_aspect_ratio if the -// option is specified. If preserving the aspect ratio is desired, you must set -// scale_to_multiple_of to 2. +// Given an input width and height, a target width and height or max area, +// whether to preserve the aspect ratio, and whether to round-down to the +// multiple of a given number nearest to the targets, determine the output width +// and height. If target_width or target_height is non-positive, then they will +// be set to the input_width and input_height respectively. If target_area is +// non-positive, then it will be ignored. If scale_to_multiple_of is less than +// 1, it will be treated like 1. The output_width and output_height will be +// reduced as necessary to preserve_aspect_ratio if the option is specified. If +// preserving the aspect ratio is desired, you must set scale_to_multiple_of +// to 2. +absl::Status FindOutputDimensions(int input_width, int input_height, // + int target_width, + int target_height, // + int target_max_area, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height); + +// Backwards compatible helper. absl::Status FindOutputDimensions(int input_width, int input_height, // int target_width, int target_height, // diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc index 14a58e762..bda1fa4d6 100644 --- a/mediapipe/calculators/image/scale_image_utils_test.cc +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -79,49 +79,49 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { int output_width; int output_height; // Not scale. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(100, output_height); // Not scale with odd input size. - MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(201, output_width); EXPECT_EQ(101, output_height); // Scale down by 1/2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale up, doubling dimensions. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(400, output_width); EXPECT_EQ(200, output_height); // Fits a 2:1 image into a 150 x 150 box. Output dimensions are always // visible by 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(150, output_width); EXPECT_EQ(74, output_height); // Fits a 2:1 image into a 400 x 50 box. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to multiple number with odd targe size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to multiple number with odd targe size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(50, output_height); // Scale to odd size. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(151, output_width); EXPECT_EQ(101, output_height); } @@ -131,18 +131,18 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsNoAspectRatio) { int output_width; int output_height; // Scale width only. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(100, output_width); EXPECT_EQ(100, output_height); // Scale height only. - MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(200, output_height); // Scale both dimensions. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, -1, false, 2, + &output_width, &output_height)); EXPECT_EQ(150, output_width); EXPECT_EQ(200, output_height); } @@ -152,41 +152,78 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsDownScaleToMultipleOf) { int output_width; int output_height; // Set no targets, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(100, 100, -1, -1, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(100, 100, -1, -1, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(96, output_width); EXPECT_EQ(96, output_height); // Set width target, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(96, output_width); EXPECT_EQ(96, output_height); // Set height target, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, 201, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(201, 101, -1, 201, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(200, output_width); EXPECT_EQ(200, output_height); // Set both targets, downscale to a multiple of 8. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, 8, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, -1, false, 8, + &output_width, &output_height)); EXPECT_EQ(144, output_width); EXPECT_EQ(200, output_height); // Doesn't throw error if keep aspect is true and downscale multiple is 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 200, true, 2, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 400, 200, -1, true, 2, + &output_width, &output_height)); EXPECT_EQ(400, output_width); EXPECT_EQ(200, output_height); // Throws error if keep aspect is true, but downscale multiple is not 2. - ASSERT_THAT(FindOutputDimensions(200, 100, 400, 200, true, 4, &output_width, - &output_height), + ASSERT_THAT(FindOutputDimensions(200, 100, 400, 200, -1, true, 4, + &output_width, &output_height), testing::Not(testing::status::IsOk())); // Downscaling to multiple ignored if multiple is less than 2. - MP_ASSERT_OK(FindOutputDimensions(200, 100, 401, 201, false, 1, &output_width, - &output_height)); + MP_ASSERT_OK(FindOutputDimensions(200, 100, 401, 201, -1, false, 1, + &output_width, &output_height)); EXPECT_EQ(401, output_width); EXPECT_EQ(201, output_height); } +// Tests scaling without keeping the aspect ratio fixed. +TEST(ScaleImageUtilsTest, FindOutputDimensionsMaxArea) { + int output_width; + int output_height; + // Smaller area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 9000, false, 2, + &output_width, &output_height)); + EXPECT_NEAR( + 200 / 100, + static_cast(output_width) / static_cast(output_height), + 0.1f); + EXPECT_LE(output_width * output_height, 9000); + // Close to original area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 19999, false, 2, + &output_width, &output_height)); + EXPECT_NEAR( + 200.0 / 100.0, + static_cast(output_width) / static_cast(output_height), + 0.1f); + EXPECT_LE(output_width * output_height, 19999); + // Don't scale with larger area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 20001, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); + // Don't scale with equal area. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, 20000, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); + // Don't scale at all. + MP_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, -1, false, 2, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); +} + } // namespace } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 87a661be6..04c3b2cf6 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -53,7 +53,7 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; // The alpha channel can be set to a single value, or come from an image mask. // If the input image has an alpha channel, it will be updated. // If the input image doesn't have an alpha channel, one will be added. -// Adding alpha channel to a Grayscale (single channel) input is not suported. +// Adding alpha channel to a Grayscale (single channel) input is not supported. // // Inputs: // One of the following two IMAGE tags: diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc new file mode 100644 index 000000000..e3d017a35 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -0,0 +1,211 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/warp_affine_calculator.h" + +#include +#include +#include + +#include "mediapipe/calculators/image/affine_transformation.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/calculators/image/affine_transformation_runner_gl.h" +#endif // !MEDIAPIPE_DISABLE_GPU +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" +#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { + +AffineTransformation::BorderMode GetBorderMode( + mediapipe::WarpAffineCalculatorOptions::BorderMode border_mode) { + switch (border_mode) { + case mediapipe::WarpAffineCalculatorOptions::BORDER_ZERO: + return AffineTransformation::BorderMode::kZero; + case mediapipe::WarpAffineCalculatorOptions::BORDER_UNSPECIFIED: + case mediapipe::WarpAffineCalculatorOptions::BORDER_REPLICATE: + return AffineTransformation::BorderMode::kReplicate; + } +} + +template +class WarpAffineRunnerHolder {}; + +template <> +class WarpAffineRunnerHolder { + public: + using RunnerType = AffineTransformation::Runner; + absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } + absl::StatusOr GetRunner() { + if (!runner_) { + ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); + } + return runner_.get(); + } + + private: + std::unique_ptr runner_; +}; + +#if !MEDIAPIPE_DISABLE_GPU +template <> +class WarpAffineRunnerHolder { + public: + using RunnerType = + AffineTransformation::Runner>; + absl::Status Open(CalculatorContext* cc) { + gpu_origin_ = + cc->Options().gpu_origin(); + gl_helper_ = std::make_shared(); + return gl_helper_->Open(cc); + } + absl::StatusOr GetRunner() { + if (!runner_) { + ASSIGN_OR_RETURN( + runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); + } + return runner_.get(); + } + + private: + mediapipe::GpuOrigin::Mode gpu_origin_; + std::shared_ptr gl_helper_; + std::unique_ptr runner_; +}; +#endif // !MEDIAPIPE_DISABLE_GPU + +template <> +class WarpAffineRunnerHolder { + public: + absl::Status Open(CalculatorContext* cc) { return runner_.Open(cc); } + absl::StatusOr< + AffineTransformation::Runner*> + GetRunner() { + return &runner_; + } + + private: + class Runner : public AffineTransformation::Runner { + public: + absl::Status Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(cpu_holder_.Open(cc)); +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_holder_.Open(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); + } + absl::StatusOr Run( + const mediapipe::Image& input, const std::array& matrix, + const AffineTransformation::Size& size, + AffineTransformation::BorderMode border_mode) override { + if (input.UsesGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + ASSIGN_OR_RETURN(auto* runner, gpu_holder_.GetRunner()); + ASSIGN_OR_RETURN(auto result, runner->Run(input.GetGpuBuffer(), matrix, + size, border_mode)); + return mediapipe::Image(*result); +#else + return absl::UnavailableError("GPU support is disabled"); +#endif // !MEDIAPIPE_DISABLE_GPU + } + ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner()); + const auto& frame_ptr = input.GetImageFrameSharedPtr(); + // Wrap image into image frame. + const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), + frame_ptr->Height(), frame_ptr->WidthStep(), + const_cast(frame_ptr->PixelData()), + [](uint8* data) {}); + ASSIGN_OR_RETURN(auto result, + runner->Run(image_frame, matrix, size, border_mode)); + return mediapipe::Image(std::make_shared(std::move(result))); + } + + private: + WarpAffineRunnerHolder cpu_holder_; +#if !MEDIAPIPE_DISABLE_GPU + WarpAffineRunnerHolder gpu_holder_; +#endif // !MEDIAPIPE_DISABLE_GPU + }; + + Runner runner_; +}; + +template +class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl { + public: +#if !MEDIAPIPE_DISABLE_GPU + static absl::Status UpdateContract(CalculatorContract* cc) { + if constexpr (std::is_same_v || + std::is_same_v) { + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + } + return absl::OkStatus(); + } +#endif // !MEDIAPIPE_DISABLE_GPU + + absl::Status Open(CalculatorContext* cc) override { return holder_.Open(cc); } + + absl::Status Process(CalculatorContext* cc) override { + if (InterfaceT::kInImage(cc).IsEmpty() || + InterfaceT::kMatrix(cc).IsEmpty() || + InterfaceT::kOutputSize(cc).IsEmpty()) { + return absl::OkStatus(); + } + const std::array& transform = *InterfaceT::kMatrix(cc); + auto [out_width, out_height] = *InterfaceT::kOutputSize(cc); + AffineTransformation::Size output_size; + output_size.width = out_width; + output_size.height = out_height; + ASSIGN_OR_RETURN(auto* runner, holder_.GetRunner()); + ASSIGN_OR_RETURN( + auto result, + runner->Run( + *InterfaceT::kInImage(cc), transform, output_size, + GetBorderMode(cc->Options() + .border_mode()))); + InterfaceT::kOutImage(cc).Send(std::move(result)); + + return absl::OkStatus(); + } + + private: + WarpAffineRunnerHolder + holder_; +}; + +} // namespace + +MEDIAPIPE_NODE_IMPLEMENTATION( + WarpAffineCalculatorImpl); +#if !MEDIAPIPE_DISABLE_GPU +MEDIAPIPE_NODE_IMPLEMENTATION( + WarpAffineCalculatorImpl); +#endif // !MEDIAPIPE_DISABLE_GPU +MEDIAPIPE_NODE_IMPLEMENTATION(WarpAffineCalculatorImpl); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.h b/mediapipe/calculators/image/warp_affine_calculator.h new file mode 100644 index 000000000..4a1b07030 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.h @@ -0,0 +1,94 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_buffer.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +// Runs affine transformation. +// +// Input: +// IMAGE - Image/ImageFrame/GpuBuffer +// +// MATRIX - std::array +// Used as following: +// output(x, y) = input(matrix[0] * x + matrix[1] * y + matrix[3], +// matrix[4] * x + matrix[5] * y + matrix[7]) +// where x and y ranges are defined by @OUTPUT_SIZE. +// +// OUTPUT_SIZE - std::pair +// Size of the output image. +// +// Output: +// IMAGE - Image/ImageFrame/GpuBuffer +// +// Note: +// - Output image type and format are the same as the input one. +// +// Usage example: +// node { +// calculator: "WarpAffineCalculator(Cpu|Gpu)" +// input_stream: "IMAGE:image" +// input_stream: "MATRIX:matrix" +// input_stream: "OUTPUT_SIZE:size" +// output_stream: "IMAGE:transformed_image" +// options: { +// [mediapipe.WarpAffineCalculatorOptions.ext] { +// border_mode: BORDER_ZERO +// } +// } +// } +template +class WarpAffineCalculatorIntf : public mediapipe::api2::NodeIntf { + public: + static constexpr mediapipe::api2::Input kInImage{"IMAGE"}; + static constexpr mediapipe::api2::Input> kMatrix{ + "MATRIX"}; + static constexpr mediapipe::api2::Input> kOutputSize{ + "OUTPUT_SIZE"}; + static constexpr mediapipe::api2::Output kOutImage{"IMAGE"}; +}; + +class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix, + kOutputSize, kOutImage); +}; +#if !MEDIAPIPE_DISABLE_GPU +class WarpAffineCalculatorGpu + : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorGpu, kInImage, kMatrix, + kOutputSize, kOutImage); +}; +#endif // !MEDIAPIPE_DISABLE_GPU +class WarpAffineCalculator : public WarpAffineCalculatorIntf { + public: + MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculator, kInImage, kMatrix, kOutputSize, + kOutImage); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_IMAGE_WARP_AFFINE_CALCULATOR_H_ diff --git a/mediapipe/calculators/image/warp_affine_calculator.proto b/mediapipe/calculators/image/warp_affine_calculator.proto new file mode 100644 index 000000000..20e6c1b07 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/gpu/gpu_origin.proto"; + +message WarpAffineCalculatorOptions { + extend CalculatorOptions { + optional WarpAffineCalculatorOptions ext = 373693895; + } + + // Pixel extrapolation methods. See @border_mode. + enum BorderMode { + BORDER_UNSPECIFIED = 0; + BORDER_ZERO = 1; + BORDER_REPLICATE = 2; + } + + // Pixel extrapolation method. + // When converting image to tensor it may happen that tensor needs to read + // pixels outside image boundaries. Border mode helps to specify how such + // pixels will be calculated. + // + // BORDER_REPLICATE is used by default. + optional BorderMode border_mode = 1; + + // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs + // to be flipped vertically as tensors are expected to start at top. + // (DEFAULT or unset interpreted as CONVENTIONAL.) + optional GpuOrigin.Mode gpu_origin = 2; +} diff --git a/mediapipe/calculators/image/warp_affine_calculator_test.cc b/mediapipe/calculators/image/warp_affine_calculator_test.cc new file mode 100644 index 000000000..959912cc9 --- /dev/null +++ b/mediapipe/calculators/image/warp_affine_calculator_test.cc @@ -0,0 +1,615 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/image/affine_transformation.h" +#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" +#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +cv::Mat GetRgb(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgb(bgr.rows, bgr.cols, CV_8UC3); + int from_to[] = {0, 2, 1, 1, 2, 0}; + cv::mixChannels(&bgr, 1, &rgb, 1, from_to, 3); + return rgb; +} + +cv::Mat GetRgba(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat rgba(bgr.rows, bgr.cols, CV_8UC4, cv::Scalar(0, 0, 0, 0)); + int from_to[] = {0, 2, 1, 1, 2, 0}; + cv::mixChannels(&bgr, 1, &bgr, 1, from_to, 3); + return bgr; +} + +// Test template. +// No processing/assertions should be done after the function is invoked. +void RunTest(const std::string& graph_text, const std::string& tag, + const cv::Mat& input, cv::Mat expected_result, + float similarity_threshold, std::array matrix, + int out_width, int out_height, + absl::optional border_mode) { + std::string border_mode_str; + if (border_mode) { + switch (*border_mode) { + case AffineTransformation::BorderMode::kReplicate: + border_mode_str = "border_mode: BORDER_REPLICATE"; + break; + case AffineTransformation::BorderMode::kZero: + border_mode_str = "border_mode: BORDER_ZERO"; + break; + } + } + auto graph_config = mediapipe::ParseTextProtoOrDie( + absl::Substitute(graph_text, /*$0=*/border_mode_str)); + + std::vector output_packets; + tool::AddVectorSink("output_image", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + ImageFrame input_image( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {}); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", + MakePacket(std::move(input_image)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "matrix", + MakePacket>(std::move(matrix)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "output_size", MakePacket>( + std::pair(out_width, out_height)) + .At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_packets, testing::SizeIs(1)); + + // Get and process results. + const ImageFrame& out_frame = output_packets[0].Get(); + cv::Mat result = formats::MatView(&out_frame); + double similarity = + 1.0 - cv::norm(result, expected_result, cv::NORM_RELATIVE | cv::NORM_L2); + EXPECT_GE(similarity, similarity_threshold); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.CloseInputStream("matrix")); + MP_ASSERT_OK(graph.CloseInputStream("output_size")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +enum class InputType { kImageFrame, kImage }; + +// Similarity is checked against OpenCV results always, and due to differences +// on how OpenCV and GL treats pixels there are two thresholds. +// TODO: update to have just one threshold when OpenCV +// implementation is updated. +struct SimilarityConfig { + double threshold_on_cpu; + double threshold_on_gpu; +}; + +void RunTest(cv::Mat input, cv::Mat expected_result, + const SimilarityConfig& similarity, std::array matrix, + int out_width, int out_height, + absl::optional border_mode) { + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "WarpAffineCalculatorCpu" + input_stream: "IMAGE:input_image" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + } + } + } + )", + "cpu", input, expected_result, similarity.threshold_on_cpu, matrix, + out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:input_image" + output_stream: "IMAGE:input_image_unified" + } + node { + calculator: "WarpAffineCalculator" + input_stream: "IMAGE:input_image_unified" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_unified" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + } + } + } + node { + calculator: "FromImageCalculator" + input_stream: "IMAGE:output_image_unified" + output_stream: "IMAGE_CPU:output_image" + } + )", + "cpu_image", input, expected_result, similarity.threshold_on_cpu, + matrix, out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + node { + calculator: "WarpAffineCalculatorGpu" + input_stream: "IMAGE:input_image_gpu" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_gpu" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + gpu_origin: TOP_LEFT + } + } + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + } + )", + "gpu", input, expected_result, similarity.threshold_on_gpu, matrix, + out_width, out_height, border_mode); + + RunTest(R"( + input_stream: "input_image" + input_stream: "output_size" + input_stream: "matrix" + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + node { + calculator: "ToImageCalculator" + input_stream: "IMAGE_GPU:input_image_gpu" + output_stream: "IMAGE:input_image_unified" + } + node { + calculator: "WarpAffineCalculator" + input_stream: "IMAGE:input_image_unified" + input_stream: "MATRIX:matrix" + input_stream: "OUTPUT_SIZE:output_size" + output_stream: "IMAGE:output_image_unified" + options { + [mediapipe.WarpAffineCalculatorOptions.ext] { + $0 # border mode + gpu_origin: TOP_LEFT + } + } + } + node { + calculator: "FromImageCalculator" + input_stream: "IMAGE:output_image_unified" + output_stream: "IMAGE_GPU:output_image_gpu" + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + } + )", + "gpu_image", input, expected_result, similarity.threshold_on_gpu, + matrix, out_width, out_height, border_mode); +} + +std::array GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, + bool keep_aspect_ratio, int out_width, + int out_height) { + std::array transform_mat; + mediapipe::RotatedRect roi_absolute = + mediapipe::GetRoi(input.cols, input.rows, roi); + mediapipe::PadRoi(out_width, out_height, keep_aspect_ratio, &roi_absolute) + .IgnoreError(); + mediapipe::GetRotatedSubRectToRectTransformMatrix( + roi_absolute, input.cols, input.rows, + /*flip_horizontaly=*/false, &transform_mat); + return transform_mat; +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = {}; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * 90.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_with_rotation_border_zero.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(0); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = {}; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"); + int out_width = 128; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, NoOp) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"); + int out_width = 64; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kReplicate; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +TEST(WarpAffineCalculatorTest, NoOpBorderZero) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.0f); + roi.set_height(1.0f); + roi.set_rotation(0); + auto input = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgba( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/noop_except_range.png"); + int out_width = 64; + int out_height = 128; + bool keep_aspect_ratio = true; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 71be05f6c..72c2f5181 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -26,6 +26,11 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +exports_files( + glob(["testdata/image_to_tensor/*"]), + visibility = ["//mediapipe/calculators/image:__subpackages__"], +) + selects.config_setting_group( name = "compute_shader_unavailable", match_any = [ @@ -351,6 +356,57 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "landmarks_to_tensor_calculator_proto", + srcs = ["landmarks_to_tensor_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "landmarks_to_tensor_calculator", + srcs = ["landmarks_to_tensor_calculator.cc"], + hdrs = ["landmarks_to_tensor_calculator.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":landmarks_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_test( + name = "landmarks_to_tensor_calculator_test", + srcs = ["landmarks_to_tensor_calculator_test.cc"], + deps = [ + ":landmarks_to_tensor_calculator", + ":landmarks_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 5c22d734b..b579f0474 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -87,9 +87,9 @@ using GpuBuffer = mediapipe::GpuBuffer; // TENSORS - std::vector // Vector containing a single Tensor populated with an extrated RGB image. // MATRIX - std::array @Optional -// An std::array representing a 4x4 row-major-order matrix which -// can be used to map a point on the output tensor to a point on the input -// image. +// An std::array representing a 4x4 row-major-order matrix that +// maps a point on the input image to a point on the output tensor, and +// can be used to reverse the mapping by inverting the matrix. // LETTERBOX_PADDING - std::array @Optional // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 11256a338..46e0f928c 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -33,7 +33,7 @@ class InferenceCalculatorSelectorImpl absl::StatusOr GetConfig( const CalculatorGraphConfig::Node& subgraph_node) { const auto& options = - Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>( + Subgraph::GetOptions( subgraph_node); std::vector impls; const bool should_use_gpu = diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index 9fe06181c..1c54bc46e 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -99,8 +99,11 @@ class InferenceCalculator : public NodeIntf { kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; static constexpr SideInput::Optional kSideInModel{"MODEL"}; static constexpr Output> kOutTensors{"TENSORS"}; + static constexpr SideInput< + mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{ + "DELEGATE"}; MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, - kOutTensors); + kOutTensors, kDelegate); protected: using TfLiteDelegatePtr = diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index e0b538a91..04f8d141d 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -18,6 +18,9 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "InferenceCalculatorProto"; + // Full Example: // // node { @@ -31,7 +34,6 @@ import "mediapipe/framework/calculator.proto"; // } // } // } -// message InferenceCalculatorOptions { extend mediapipe.CalculatorOptions { optional InferenceCalculatorOptions ext = 336783863; @@ -66,10 +68,55 @@ message InferenceCalculatorOptions { // Load pre-compiled serialized binary cache to accelerate init process. // Only available for OpenCL delegate on Android. // Kernel caching will only be enabled if this path is set. + // + // NOTE: binary cache usage may be skipped if valid serialized model, + // specified by "serialized_model_dir", exists. + // + // TODO: update to cached_kernel_dir optional string cached_kernel_path = 2; + + // A dir to load from and save to a pre-compiled serialized model used to + // accelerate init process. + // + // NOTE: available for OpenCL delegate on Android only when + // "use_advanced_gpu_api" is set to true and "model_token" is set + // properly. + // + // NOTE: serialized model takes precedence over binary cache + // specified by "cached_kernel_path", which still can be used if + // serialized model is invalid or missing. + optional string serialized_model_dir = 7; + + // Unique token identifying the model. Used in conjunction with + // "serialized_model_dir". It is the caller's responsibility to ensure + // there is no clash of the tokens. + optional string model_token = 8; + + // Encapsulated compilation/runtime tradeoffs. + enum InferenceUsage { + UNSPECIFIED = 0; + + // InferenceRunner will be used only once. Therefore, it is important to + // minimize bootstrap time as well. + FAST_SINGLE_ANSWER = 1; + + // Prefer maximizing the throughput. Same inference runner will be used + // repeatedly on different inputs. + SUSTAINED_SPEED = 2; + } + optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED]; } + // Android only. - message Nnapi {} + message Nnapi { + // Directory to store compilation cache. If unspecified, NNAPI will not + // try caching the compilation. + optional string cache_dir = 1; + // Unique token identifying the model. It is the caller's responsibility + // to ensure there is no clash of the tokens. If unspecified, NNAPI will + // not try caching the compilation. + optional string model_token = 2; + } message Xnnpack { // Number of threads for XNNPACK delegate. (By default, calculator tries // to choose optimal number of threads depending on the device.) diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index 0299ab526..7d695ad9b 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -50,11 +50,13 @@ int GetXnnpackDefaultNumThreads() { // Returns number of threads to configure XNNPACK delegate with. // Returns user provided value if specified. Otherwise, tries to choose optimal // number of threads depending on the device. -int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { +int GetXnnpackNumThreads( + const bool opts_has_delegate, + const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) { static constexpr int kDefaultNumThreads = -1; - if (opts.has_delegate() && opts.delegate().has_xnnpack() && - opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { - return opts.delegate().xnnpack().num_threads(); + if (opts_has_delegate && opts_delegate.has_xnnpack() && + opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) { + return opts_delegate.xnnpack().num_threads(); } return GetXnnpackDefaultNumThreads(); } @@ -73,6 +75,7 @@ class InferenceCalculatorCpuImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -91,8 +94,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract( absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadModel(cc)); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { @@ -156,34 +158,61 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { const auto& calculator_opts = cc->Options(); - if (calculator_opts.has_delegate() && - calculator_opts.delegate().has_tflite()) { + auto opts_delegate = calculator_opts.delegate(); + if (!kDelegate(cc).IsEmpty()) { + mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate = + kDelegate(cc).Get(); + CHECK(input_side_packet_delegate.has_tflite() || + input_side_packet_delegate.has_xnnpack() || + input_side_packet_delegate.has_nnapi() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_cpu only supports delegate input side packet " + << "for TFLite, XNNPack and Nnapi"; + opts_delegate.MergeFrom(input_side_packet_delegate); + } + const bool opts_has_delegate = + calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty(); + if (opts_has_delegate && opts_delegate.has_tflite()) { // Default tflite inference requeqsted - no need to modify graph. return absl::OkStatus(); } #if defined(MEDIAPIPE_ANDROID) - const bool nnapi_requested = calculator_opts.has_delegate() - ? calculator_opts.delegate().has_nnapi() - : calculator_opts.use_nnapi(); + const bool nnapi_requested = opts_has_delegate ? opts_delegate.has_nnapi() + : calculator_opts.use_nnapi(); if (nnapi_requested) { // Attempt to use NNAPI. // If not supported, the default CPU delegate will be created and used. interpreter_->SetAllowFp16PrecisionForFp32(1); - delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { - // No need to free according to tflite::NnApiDelegate() documentation. - }); + tflite::StatefulNnApiDelegate::Options options; + const auto& nnapi = opts_delegate.nnapi(); + // Set up cache_dir and model_token for NNAPI compilation cache. + options.cache_dir = + nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr; + options.model_token = + nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr; + delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), + [](TfLiteDelegate*) {}); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); return absl::OkStatus(); @@ -193,13 +222,13 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { #if defined(__EMSCRIPTEN__) const bool use_xnnpack = true; #else - const bool use_xnnpack = calculator_opts.has_delegate() && - calculator_opts.delegate().has_xnnpack(); + const bool use_xnnpack = opts_has_delegate && opts_delegate.has_xnnpack(); #endif // defined(__EMSCRIPTEN__) if (use_xnnpack) { TfLiteXNNPackDelegateOptions xnnpack_opts{}; - xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); + xnnpack_opts.num_threads = + GetXnnpackNumThreads(opts_has_delegate, opts_delegate); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 5769df20e..dda9a1fa1 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -18,7 +18,9 @@ #include #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/util/tflite/config.h" #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -48,10 +50,11 @@ class InferenceCalculatorGlImpl absl::Status Close(CalculatorContext* cc) override; private: - absl::Status ReadKernelsFromFile(); - absl::Status WriteKernelsToFile(); + absl::Status ReadGpuCaches(); + absl::Status SaveGpuCaches(); absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. @@ -65,6 +68,8 @@ class InferenceCalculatorGlImpl bool allow_precision_loss_ = false; mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage + tflite_gpu_runner_usage_; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GPU_SUPPORTED @@ -78,6 +83,8 @@ class InferenceCalculatorGlImpl bool use_kernel_caching_ = false; std::string cached_kernel_filename_; + bool use_serialized_model_ = false; + std::string serialized_model_path_; }; absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { @@ -91,22 +98,43 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); - use_advanced_gpu_api_ = options.has_delegate() && - options.delegate().has_gpu() && - options.delegate().gpu().use_advanced_gpu_api(); - allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); - tflite_gpu_runner_api_ = options.delegate().gpu().api(); - use_kernel_caching_ = use_advanced_gpu_api_ && - options.delegate().gpu().has_cached_kernel_path(); + mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate(); + if (!kDelegate(cc).IsEmpty()) { + mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate = + kDelegate(cc).Get(); + CHECK(input_side_packet_delegate.has_gpu() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_gl only supports delegate input side packet " + << "for Gpu"; + delegate.MergeFrom(input_side_packet_delegate); + } + const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty(); + use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() && + delegate.gpu().use_advanced_gpu_api(); + allow_precision_loss_ = delegate.gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = delegate.gpu().api(); + tflite_gpu_runner_usage_ = delegate.gpu().usage(); + use_kernel_caching_ = + use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path(); + use_serialized_model_ = use_advanced_gpu_api_ && + delegate.gpu().has_serialized_model_dir() && + delegate.gpu().has_model_token(); use_gpu_delegate_ = !use_advanced_gpu_api_; if (use_kernel_caching_) { #ifdef MEDIAPIPE_ANDROID - cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + cached_kernel_filename_ = delegate.gpu().cached_kernel_path() + mediapipe::File::Basename(options.model_path()) + ".ker"; #endif // MEDIAPIPE_ANDROID } + if (use_serialized_model_) { +#ifdef MEDIAPIPE_ANDROID + serialized_model_path_ = mediapipe::file::JoinPath( + delegate.gpu().serialized_model_dir(), delegate.gpu().model_token()); +#endif // MEDIAPIPE_ANDROID + } // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner // for everything. @@ -115,10 +143,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { } MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, - &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); - })); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) + : LoadDelegateAndAllocateTensors(cc); + })); return absl::OkStatus(); } @@ -193,7 +222,7 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() { +absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() { #ifdef MEDIAPIPE_ANDROID if (use_kernel_caching_) { // Save kernel file. @@ -203,12 +232,22 @@ absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() { MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } + if (use_serialized_model_) { + // Save serialized model file. + ASSIGN_OR_RETURN(std::vector serialized_model_vec, + tflite_gpu_runner_->GetSerializedModel()); + absl::string_view serialized_model( + reinterpret_cast(serialized_model_vec.data()), + serialized_model_vec.size()); + MP_RETURN_IF_ERROR( + mediapipe::file::SetContents(serialized_model_path_, serialized_model)); + } #endif // MEDIAPIPE_ANDROID return absl::OkStatus(); } absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { - MP_RETURN_IF_ERROR(WriteKernelsToFile()); + MP_RETURN_IF_ERROR(SaveGpuCaches()); if (use_gpu_delegate_) { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { gpu_buffers_in_.clear(); @@ -222,17 +261,24 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() { +absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() { #ifdef MEDIAPIPE_ANDROID - if (use_kernel_caching_) { + if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) { // Load pre-compiled kernel file. - if (mediapipe::File::Exists(cached_kernel_filename_)) { - std::string cache_str; - MP_RETURN_IF_ERROR( - mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); - std::vector cache_vec(cache_str.begin(), cache_str.end()); - tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); - } + std::string cache_str; + MP_RETURN_IF_ERROR( + mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); + std::vector cache_vec(cache_str.begin(), cache_str.end()); + tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); + } + if (use_serialized_model_ && File::Exists(serialized_model_path_)) { + // Load serialized model file. + std::string serialized_model_str; + MP_RETURN_IF_ERROR( + file::GetContents(serialized_model_path_, &serialized_model_str)); + std::vector serialized_model_vec(serialized_model_str.begin(), + serialized_model_str.end()); + tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec)); } #endif // MEDIAPIPE_ANDROID return absl::OkStatus(); @@ -253,9 +299,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( : tflite::gpu::InferencePriority::MAX_PRECISION; options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + switch (tflite_gpu_runner_usage_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + FAST_SINGLE_ANSWER: { + options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu:: + SUSTAINED_SPEED: { + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: { + return absl::InternalError("inference usage need to be specified."); + } + } tflite_gpu_runner_ = std::make_unique(options); switch (tflite_gpu_runner_api_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { tflite_gpu_runner_->ForceOpenGL(); break; @@ -264,10 +328,6 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( tflite_gpu_runner_->ForceOpenCL(); break; } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { - // Do not need to force any specific API. - break; - } } MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( model, op_resolver, /*allow_quant_ops=*/true)); @@ -282,7 +342,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( tflite_gpu_runner_->GetOutputShapes()[i].c}; } - MP_RETURN_IF_ERROR(ReadKernelsFromFile()); + MP_RETURN_IF_ERROR(ReadGpuCaches()); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); @@ -306,11 +366,19 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 4bf3525e4..49e042290 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -92,6 +92,7 @@ class InferenceCalculatorMetalImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -130,8 +131,7 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { @@ -212,11 +212,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { interpreter_->SetNumThreads( cc->Options().cpu_num_thread()); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } @@ -236,6 +244,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); + id device = gpu_helper_.mtlDevice; // Get input image sizes. diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc new file mode 100644 index 000000000..8f9323818 --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc @@ -0,0 +1,101 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h" + +#include + +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +float GetAttribute( + const Landmark& landmark, + const LandmarksToTensorCalculatorOptions::Attribute& attribute) { + switch (attribute) { + case LandmarksToTensorCalculatorOptions::X: + return landmark.x(); + case LandmarksToTensorCalculatorOptions::Y: + return landmark.y(); + case LandmarksToTensorCalculatorOptions::Z: + return landmark.z(); + case LandmarksToTensorCalculatorOptions::VISIBILITY: + return landmark.visibility(); + case LandmarksToTensorCalculatorOptions::PRESENCE: + return landmark.presence(); + } +} + +} // namespace + +class LandmarksToTensorCalculatorImpl + : public NodeImpl { + public: + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + RET_CHECK(options_.attributes_size() > 0) + << "At least one attribute must be specified"; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (kInLandmarkList(cc).IsEmpty()) { + return absl::OkStatus(); + } + + // Get input landmarks. + const auto& in_landmarks = *kInLandmarkList(cc); + + // Determine tensor shape. + const int n_landmarks = in_landmarks.landmark_size(); + const int n_attributes = options_.attributes_size(); + auto tensor_shape = options_.flatten() + ? Tensor::Shape{1, n_landmarks * n_attributes} + : Tensor::Shape{1, n_landmarks, n_attributes}; + + // Create empty tesnor. + Tensor tensor(Tensor::ElementType::kFloat32, tensor_shape); + auto* buffer = tensor.GetCpuWriteView().buffer(); + + // Fill tensor with landmark attributes. + for (int i = 0; i < n_landmarks; ++i) { + for (int j = 0; j < n_attributes; ++j) { + buffer[i * n_attributes + j] = + GetAttribute(in_landmarks.landmark(i), options_.attributes(j)); + } + } + + // Return vector with a single tensor. + auto result = std::vector(); + result.push_back(std::move(tensor)); + kOutTensors(cc).Send(std::move(result)); + + return absl::OkStatus(); + } + + private: + LandmarksToTensorCalculatorOptions options_; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksToTensorCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h new file mode 100644 index 000000000..662f1b05f --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h @@ -0,0 +1,61 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +// A calculator for converting landmars into a Tensor. +// +// Input: +// LANDMARKS - LandmarkList +// Landmarks to be converted into a Tensor. +// +// Output: +// TENSORS - std::vector +// Vector containing a single Tensor populated with landmark values. +// +// Example: +// node { +// calculator: "LandmarksToTensorCalculator" +// input_stream: "LANDMARKS:landmarks" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.LandmarksToTensorCalculatorOptions.ext] { +// attributes: [X, Y, Z, VISIBILITY, PRESENCE] +// # flatten: true +// } +// } +// } +class LandmarksToTensorCalculator : public NodeIntf { + public: + static constexpr Input::Optional kInLandmarkList{"LANDMARKS"}; + static constexpr Output> kOutTensors{"TENSORS"}; + MEDIAPIPE_NODE_INTERFACE(LandmarksToTensorCalculator, kInLandmarkList, + kOutTensors); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto new file mode 100644 index 000000000..6ef1c8d4e --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator.proto @@ -0,0 +1,44 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the LandmarksToTensorCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarksToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional LandmarksToTensorCalculatorOptions ext = 394810235; + } + + enum Attribute { + X = 0; + Y = 1; + Z = 2; + VISIBILITY = 3; + PRESENCE = 4; + } + + // Subset and order of attributes as they should appear in the output Tensor. + // Should contain at least one attribute. + repeated Attribute attributes = 1; + + // Collapses all landmark attributes into a one dimensional tensor (i.e. + // switches from (n_landmarks, n_attributes) to (n_landmarks * n_attributes) + // representation). + optional bool flatten = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc new file mode 100644 index 000000000..dfda71b55 --- /dev/null +++ b/mediapipe/calculators/tensor/landmarks_to_tensor_calculator_test.cc @@ -0,0 +1,155 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +void RunLandmarks(mediapipe::CalculatorRunner* runner, + const LandmarkList& landmarks) { + runner->MutableInputs() + ->Tag("LANDMARKS") + .packets.push_back(MakePacket(landmarks).At(Timestamp(0))); + MP_ASSERT_OK(runner->Run()); +} + +const Tensor& GetOutputTensor(mediapipe::CalculatorRunner* runner) { + const auto& output_packets = runner->Outputs().Tag("TENSORS").packets; + EXPECT_EQ(output_packets.size(), 1); + + const auto& tensors = output_packets[0].Get>(); + EXPECT_EQ(tensors.size(), 1); + + return tensors[0]; +} + +void ValidateTensor(const Tensor& tensor, + const std::vector& expected_shape, + const std::vector& expected_values) { + EXPECT_EQ(tensor.shape().dims, expected_shape); + EXPECT_EQ(tensor.shape().num_elements(), expected_values.size()); + + auto* tensor_buffer = tensor.GetCpuReadView().buffer(); + const std::vector tensor_values( + tensor_buffer, tensor_buffer + tensor.shape().num_elements()); + EXPECT_THAT(tensor_values, testing::ElementsAreArray(expected_values)); +} + +TEST(LandmarksToTensorCalculatorTest, AllAttributes) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z, VISIBILITY, PRESENCE ] + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + landmark1->set_visibility(4.0f); + landmark1->set_presence(5.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + landmark2->set_visibility(9.0f); + landmark2->set_presence(10.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 2, 5}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}); +} + +TEST(LandmarksToTensorCalculatorTest, XYZAttributes) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z ] + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 2, 3}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f}); +} + +TEST(LandmarksToTensorCalculatorTest, XYZAttributes_Flatten) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "LandmarksToTensorCalculator" + input_stream: "LANDMARKS:landmarks" + output_stream: "TENSORS:tensors" + options: { + [mediapipe.LandmarksToTensorCalculatorOptions.ext] { + attributes: [ X, Y, Z ] + flatten: true + } + } + )pb")); + + LandmarkList landmarks; + auto* landmark1 = landmarks.add_landmark(); + landmark1->set_x(1.0f); + landmark1->set_y(2.0f); + landmark1->set_z(3.0f); + auto* landmark2 = landmarks.add_landmark(); + landmark2->set_x(6.0f); + landmark2->set_y(7.0f); + landmark2->set_z(8.0f); + + RunLandmarks(&runner, landmarks); + const auto& tensor = GetOutputTensor(&runner); + ValidateTensor(tensor, /*expected_shape=*/{1, 6}, /*expected_values=*/ + {1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f}); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 82180fe52..f3c7c7b09 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -517,8 +517,8 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { uniform sampler2D frame; void main() { - $1 // flip - vec4 pixel = texture2D(frame, sample_coordinate); + vec2 coord = $1 + vec4 pixel = texture2D(frame, coord); $2 // normalize [-1,1] fragColor.r = pixel.r; // r channel $3 // g & b channels @@ -526,8 +526,9 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { })", /*$0=*/single_channel ? "vec1" : "vec4", /*$1=*/ - flip_vertically_ ? "sample_coordinate.y = 1.0 - sample_coordinate.y;" - : "", + flip_vertically_ + ? "vec2(sample_coordinate.x, 1.0 - sample_coordinate.y);" + : "sample_coordinate;", /*$2=*/output_range_.has_value() ? absl::Substitute("pixel = pixel * float($0) + float($1);", (output_range_->second - output_range_->first), diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index f161127f5..498036c12 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -670,7 +670,8 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_scores[i], detection_classes[i], options_.flip_vertically()); const auto& bbox = detection.location_data().relative_bounding_box(); - if (bbox.width() < 0 || bbox.height() < 0) { + if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || + std::isnan(bbox.height())) { // Decoded detection boxes could have negative values for width/height due // to model prediction. Filter out those boxes since some downstream // calculators may assume non-negative values. (b/171391719) diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 45e242f3c..ffc96b2e4 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -138,7 +138,6 @@ using ::tflite::gpu::gl::GlShader; // } // } // -// Currently only OpenGLES 3.1 and CPU backends supported. // TODO Refactor and add support for other backends/platforms. // class TensorsToSegmentationCalculator : public CalculatorBase { diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 0dbbd57da..ac058610a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -88,6 +88,13 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "tensor_to_vector_string_calculator_options_proto", + srcs = ["tensor_to_vector_string_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], @@ -257,6 +264,14 @@ mediapipe_cc_proto_library( deps = [":tensor_to_vector_float_calculator_options_proto"], ) +mediapipe_cc_proto_library( + name = "tensor_to_vector_string_calculator_options_cc_proto", + srcs = ["tensor_to_vector_string_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":tensor_to_vector_string_calculator_options_proto"], +) + mediapipe_cc_proto_library( name = "unpack_media_sequence_calculator_cc_proto", srcs = ["unpack_media_sequence_calculator.proto"], @@ -572,9 +587,21 @@ cc_library( "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [ - "//mediapipe/framework/port:file_helpers", ], - }), + "//mediapipe:android": [], + }) + select( + { + "//conditions:default": [ + ], + }, + ) + select( + { + "//conditions:default": [ + ], + "//mediapipe:android": [ + ], + }, + ), alwayslink = 1, ) @@ -694,6 +721,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensor_to_vector_string_calculator", + srcs = ["tensor_to_vector_string_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ":tensor_to_vector_string_calculator_options_cc_proto", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", + ], + }), + alwayslink = 1, +) + cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], @@ -864,6 +911,7 @@ cc_test( "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -1058,6 +1106,20 @@ cc_test( ], ) +cc_test( + name = "tensor_to_vector_string_calculator_test", + srcs = ["tensor_to_vector_string_calculator_test.cc"], + deps = [ + ":tensor_to_vector_string_calculator", + ":tensor_to_vector_string_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + cc_test( name = "unpack_media_sequence_calculator_test", srcs = ["unpack_media_sequence_calculator_test.cc"], diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index ddb042e6a..3991f645d 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -37,6 +37,7 @@ const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kImageTag[] = "IMAGE"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; +const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_"; const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; const char kBBoxTag[] = "BBOX"; const char kKeypointsTag[] = "KEYPOINTS"; @@ -153,6 +154,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } } CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || @@ -231,6 +235,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearFeatureFloats(key, sequence_.get()); mpms::ClearFeatureTimestamp(key, sequence_.get()); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) { + std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) / + sizeof(*kBytesFeaturePrefixTag) - + 1); + mpms::ClearFeatureBytes(key, sequence_.get()); + mpms::ClearFeatureTimestamp(key, sequence_.get()); + } if (absl::StartsWith(tag, kKeypointsTag)) { std::string key = tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1); @@ -243,11 +254,6 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } - if (cc->Outputs().HasTag(kSequenceExampleTag)) { - cc->Outputs() - .Tag(kSequenceExampleTag) - .SetNextTimestampBound(Timestamp::Max()); - } return absl::OkStatus(); } @@ -305,7 +311,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (cc->Outputs().HasTag(kSequenceExampleTag)) { cc->Outputs() .Tag(kSequenceExampleTag) - .Add(sequence_.release(), Timestamp::PostStream()); + .Add(sequence_.release(), options.output_as_zero_timestamp() + ? Timestamp(0ll) + : Timestamp::PostStream()); } sequence_.reset(); @@ -408,6 +416,17 @@ class PackMediaSequenceCalculator : public CalculatorBase { cc->Inputs().Tag(tag).Get>(), sequence_.get()); } + if (absl::StartsWith(tag, kBytesFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) / + sizeof(*kBytesFeaturePrefixTag) - + 1); + mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddFeatureBytes( + key, cc->Inputs().Tag(tag).Get>(), + sequence_.get()); + } if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; if (tag != kBBoxTag) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto index 695eb6b5e..6ba09fb16 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto @@ -65,4 +65,7 @@ message PackMediaSequenceCalculatorOptions { // If true, will return an error status if an output sequence would be too // many bytes to serialize. optional bool skip_large_sequences = 7 [default = true]; + + // If true/false, outputs the SequenceExample at timestamp 0/PostStream. + optional bool output_as_zero_timestamp = 8 [default = false]; } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index c163cebcd..b39a0bac0 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -29,6 +29,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/util/sequence/media_sequence.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" @@ -39,12 +40,33 @@ namespace { namespace tf = ::tensorflow; namespace mpms = mediapipe::mediasequence; +constexpr char kBboxTag[] = "BBOX"; +constexpr char kEncodedMediaStartTimestampTag[] = + "ENCODED_MEDIA_START_TIMESTAMP"; +constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA"; +constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION"; +constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST"; +constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED"; +constexpr char kAudioOtherTag[] = "AUDIO_OTHER"; +constexpr char kAudioTestTag[] = "AUDIO_TEST"; +constexpr char kBytesFeatureOtherTag[] = "BYTES_FEATURE_OTHER"; +constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST"; +constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; +constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; +constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; +constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; +constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; +constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; +constexpr char kImageTag[] = "IMAGE"; + class PackMediaSequenceCalculatorTest : public ::testing::Test { protected: void SetUpCalculator(const std::vector& input_streams, const tf::Features& features, - bool output_only_if_all_present, - bool replace_instead_of_append) { + const bool output_only_if_all_present, + const bool replace_instead_of_append, + const bool output_as_zero_timestamp = false) { CalculatorGraphConfig::Node config; config.set_calculator("PackMediaSequenceCalculator"); config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); @@ -57,6 +79,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test { *options->mutable_context_feature_map() = features; options->set_output_only_if_all_present(output_only_if_all_present); options->set_replace_data_instead_of_append(replace_instead_of_append); + options->set_output_as_zero_timestamp(output_as_zero_timestamp); runner_ = ::absl::make_unique(config); } @@ -80,17 +103,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -124,17 +147,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { auto image_ptr = ::absl::make_unique(encoded_image); runner_->MutableInputs() - ->Tag("IMAGE_PREFIX") + ->Tag(kImagePrefixTag) .packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -158,21 +181,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) { for (int i = 0; i < num_timesteps; ++i) { auto vf_ptr = ::absl::make_unique>(2, 2 << i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_TEST") + ->Tag(kFloatFeatureTestTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); vf_ptr = ::absl::make_unique>(2, 2 << i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_OTHER") + ->Tag(kFloatFeatureOtherTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -194,20 +217,65 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) { } } -TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { - SetUpCalculator( - {"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, - {}, false, true); - auto input_sequence = absl::make_unique(); +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) { + SetUpCalculator({"BYTES_FEATURE_TEST:test", "BYTES_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); - auto vf_ptr = absl::make_unique>(2, 3); - runner_->MutableInputs() - ->Tag("FLOAT_CONTEXT_FEATURE_TEST") - .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); - vf_ptr = absl::make_unique>(2, 4); - runner_->MutableInputs() - ->Tag("FLOAT_CONTEXT_FEATURE_OTHER") - .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("foo", 2 << i)); + runner_->MutableInputs() + ->Tag(kBytesFeatureTestTag) + .packets.push_back(Adopt(vs_ptr.release()).At(Timestamp(i))); + vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("bar", 2 << i)); + runner_->MutableInputs() + ->Tag(kBytesFeatureOtherTag) + .packets.push_back(Adopt(vs_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("OTHER", output_sequence)); + for (int i = 0; i < num_timesteps; ++i) { + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("TEST", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureBytesAt("TEST", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("foo", 2 << i)))); + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("OTHER", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureBytesAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("bar", 2 << i)))); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) { + SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag("FLOAT_FEATURE_TEST") + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + } runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = Adopt(input_sequence.release()); @@ -217,6 +285,32 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { const std::vector& output_packets = runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(output_packets[0].Timestamp().Value(), 0ll); +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { + SetUpCalculator( + {"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, + {}, false, true); + auto input_sequence = absl::make_unique(); + + auto vf_ptr = absl::make_unique>(2, 3); + runner_->MutableInputs() + ->Tag(kFloatContextFeatureTestTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + vf_ptr = absl::make_unique>(2, 4); + runner_->MutableInputs() + ->Tag(kFloatContextFeatureOtherTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -233,7 +327,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { SetUpCalculator({"IMAGE:images"}, context, false, true); auto input_sequence = ::absl::make_unique(); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; @@ -242,13 +336,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(0))); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -281,17 +375,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -345,17 +439,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -424,11 +518,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); auto status = runner_->Run(); @@ -472,7 +566,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("BBOX_PREDICTED") + ->Tag(kBboxPredictedTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); @@ -487,16 +581,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -538,18 +632,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { absl::flat_hash_map>> points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; runner_->MutableInputs() - ->Tag("KEYPOINTS_TEST") + ->Tag(kKeypointsTestTag) .packets.push_back(PointToForeign(&points).At(Timestamp(0))); runner_->MutableInputs() - ->Tag("KEYPOINTS_TEST") + ->Tag(kKeypointsTestTag) .packets.push_back(PointToForeign(&points).At(Timestamp(1))); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -589,17 +683,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { detections->push_back(detection); runner_->MutableInputs() - ->Tag("CLASS_SEGMENTATION") + ->Tag(kClassSegmentationTag) .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -638,17 +732,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -684,11 +778,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { auto flow_ptr = ::absl::make_unique(encoded_flow); runner_->MutableInputs() - ->Tag("FORWARD_FLOW_ENCODED") + ->Tag(kForwardFlowEncodedTag) .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); absl::Status status = runner_->Run(); @@ -705,13 +799,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) { mpms::AddImageTimestamp(1, input_sequence.get()); mpms::AddImageTimestamp(2, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -731,13 +825,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) { mpms::AddForwardFlowTimestamp(1, input_sequence.get()); mpms::AddForwardFlowTimestamp(2, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -768,13 +862,52 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) { mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("OTHER", *input_sequence)); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("OTHER", output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReplacingBytesVectors) { + SetUpCalculator({"BYTES_FEATURE_TEST:test", "BYTES_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("foo", 2 << i)); + mpms::AddFeatureBytes("TEST", *vs_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("TEST", i, input_sequence.get()); + vs_ptr = ::absl::make_unique>( + 2, absl::StrCat("bar", 2 << i)); + mpms::AddFeatureBytes("OTHER", *vs_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); + } + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureBytesSize("OTHER", *input_sequence)); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -800,7 +933,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10))); } @@ -812,11 +945,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get()); mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + runner_->Outputs().Tag(kSequenceExampleTag).packets; ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); @@ -853,7 +986,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { for (int i = 0; i < num_images; ++i) { auto image_ptr = ::absl::make_unique(encoded_image); - runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( Adopt(image_ptr.release()).At(Timestamp(i))); } @@ -867,7 +1000,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) .ConvertToProto(detection.mutable_location_data()); detections->push_back(detection); - runner_->MutableInputs()->Tag("BBOX").packets.push_back( + runner_->MutableInputs()->Tag(kBboxTag).packets.push_back( Adopt(detections.release()).At(Timestamp(i))); } @@ -883,7 +1016,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { mpms::AddBBoxTrackIndex({-1}, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); // If the all the previous values aren't cleared, this assert will fail. MP_ASSERT_OK(runner_->Run()); @@ -899,11 +1032,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) { for (int i = 0; i < num_timesteps; ++i) { auto vf_ptr = ::absl::make_unique>(1000000, i); runner_->MutableInputs() - ->Tag("FLOAT_FEATURE_TEST") + ->Tag(kFloatFeatureTestTag) .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); ASSERT_FALSE(runner_->Run().ok()); } diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc index fce24b9b9..67ba5e90a 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc @@ -26,6 +26,8 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { +constexpr char kReferenceTag[] = "REFERENCE"; + constexpr char kMatrix[] = "MATRIX"; constexpr char kTensor[] = "TENSOR"; @@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test { if (include_rate) { header->set_packet_rate(1.0); } - runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release()); + runner_->MutableInputs()->Tag(kReferenceTag).header = + Adopt(header.release()); } std::unique_ptr runner_; diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc new file mode 100644 index 000000000..2c9e14d4b --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator.cc @@ -0,0 +1,118 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Calculator converts from one-dimensional Tensor of DT_STRING to +// vector OR from (batched) two-dimensional Tensor of DT_STRING to +// vector. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +class TensorToVectorStringCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + TensorToVectorStringCalculatorOptions options_; +}; +REGISTER_CALCULATOR(TensorToVectorStringCalculator); + +absl::Status TensorToVectorStringCalculator::GetContract( + CalculatorContract* cc) { + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // Input Tensor + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + const auto& options = cc->Options(); + if (options.tensor_is_2d()) { + RET_CHECK(!options.flatten_nd()); + cc->Outputs().Index(0).Set>>( + /* "Output vector>." */); + } else { + cc->Outputs().Index(0).Set>( + // Output vector. + ); + } + return absl::OkStatus(); +} + +absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + + // Inform mediapipe that this calculator produces an output at time t for + // each input received at time t (i.e. this calculator does not buffer + // inputs). This enables mediapipe to propagate time of arrival estimates in + // mediapipe graphs through this calculator. + cc->SetOffset(/*offset=*/0); + + return absl::OkStatus(); +} + +absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) { + const tf::Tensor& input_tensor = + cc->Inputs().Index(0).Value().Get(); + RET_CHECK(tf::DT_STRING == input_tensor.dtype()) + << "expected DT_STRING input but got " + << tensorflow::DataTypeString(input_tensor.dtype()); + + if (options_.tensor_is_2d()) { + RET_CHECK(2 == input_tensor.dims()) + << "Expected 2-dimensional Tensor, but the tensor shape is: " + << input_tensor.shape().DebugString(); + auto output = absl::make_unique>>( + input_tensor.dim_size(0), + std::vector(input_tensor.dim_size(1))); + for (int i = 0; i < input_tensor.dim_size(0); ++i) { + auto& instance_output = output->at(i); + const auto& slice = + input_tensor.Slice(i, i + 1).unaligned_flat(); + for (int j = 0; j < input_tensor.dim_size(1); ++j) { + instance_output.at(j) = slice(j); + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + if (!options_.flatten_nd()) { + RET_CHECK(1 == input_tensor.dims()) + << "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the " + << "tensor shape is: " << input_tensor.shape().DebugString(); + } + auto output = + absl::make_unique>(input_tensor.NumElements()); + const auto& tensor_values = input_tensor.flat(); + for (int i = 0; i < input_tensor.NumElements(); ++i) { + output->at(i) = tensor_values(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto new file mode 100644 index 000000000..74df1be69 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorToVectorStringCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorToVectorStringCalculatorOptions ext = 386534187; + } + + // If true, unpack a 2d tensor (matrix) into a vector>. If + // false, convert a 1d tensor (vector) into a vector. + optional bool tensor_is_2d = 1 [default = false]; + + // If true, an N-D tensor will be flattened to a vector. This is + // exclusive with tensor_is_2d. + optional bool flatten_nd = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc new file mode 100644 index 000000000..94dd9374d --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_test.cc @@ -0,0 +1,130 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class TensorToVectorStringCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToVectorStringCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + TensorToVectorStringCalculatorOptions::ext); + options->set_tensor_is_2d(tensor_is_2d); + options->set_flatten_nd(flatten_nd); + runner_ = absl::make_unique(config); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) { + SetUpRunner(false, false); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + tensor_vec(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) { + SetUpRunner(true, false); + const tf::TensorShape tensor_shape(std::vector{1, 5}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto slice = tensor->Slice(0, 1).flat(); + for (int i = 0; i < 5; ++i) { + slice(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector>& output_vectors = + output_packets[0].Get>>(); + ASSERT_EQ(1, output_vectors.size()); + const std::vector& output_vector = output_vectors[0]; + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) { + SetUpRunner(false, true); + const tf::TensorShape tensor_shape(std::vector{2, 2, 2}); + auto tensor = absl::make_unique(tf::DT_STRING, tensor_shape); + auto slice = tensor->flat(); + for (int i = 0; i < 2 * 2 * 2; ++i) { + slice(i) = absl::StrCat("foo", i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + EXPECT_EQ(2 * 2 * 2, output_vector.size()); + for (int i = 0; i < 2 * 2 * 2; ++i) { + const std::string expected = absl::StrCat("foo", i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index fc81207a0..a8ecb847d 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -49,6 +49,11 @@ namespace tf = ::tensorflow; namespace mediapipe { namespace { + +constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS"; +constexpr char kSessionTag[] = "SESSION"; +constexpr char kSessionBundleTag[] = "SESSION_BUNDLE"; + // This is a simple implementation of a semaphore using standard C++ libraries. // It is supposed to be used only by TensorflowInferenceCalculator to throttle // the concurrent calls of Tensorflow Session::Run. This is useful when multiple @@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase { } // A mediapipe::TensorFlowSession with a model loaded and ready for use. // For this calculator it must include a tag_to_tensor_map. - cc->InputSidePackets().Tag("SESSION").Set(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets().Tag(kSessionTag).Set(); + if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) { cc->InputSidePackets() - .Tag("RECURRENT_INIT_TENSORS") + .Tag(kRecurrentInitTensorsTag) .Set>>(); } return absl::OkStatus(); @@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase { ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { std::unique_ptr inference_state = absl::make_unique(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && - !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) && + !cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) { std::map* init_tensor_map; init_tensor_map = GetFromUniquePtr>( - cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + cc->InputSidePackets().Tag(kRecurrentInitTensorsTag)); for (const auto& p : *init_tensor_map) { inference_state->input_tensor_batches_[p.first].emplace_back(p.second); } @@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) override { options_ = cc->Options(); - RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag)); session_ = cc->InputSidePackets() - .Tag("SESSION") + .Tag(kSessionTag) .Get() .session.get(); tag_to_tensor_map_ = cc->InputSidePackets() - .Tag("SESSION") + .Tag(kSessionTag) .Get() .tag_to_tensor_map; diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index 6a931679d..cc1d15043 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -41,6 +41,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kMultipliedTag[] = "MULTIPLIED"; +constexpr char kBTag[] = "B"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { #ifdef __APPLE__ char path[1024]; @@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test { MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options, input_side_packets, &output_side_packets)); - runner_->MutableSidePackets()->Tag("SESSION") = - output_side_packets.Tag("SESSION"); + runner_->MutableSidePackets()->Tag(kSessionTag) = + output_side_packets.Tag(kSessionTag); } Packet CreateTensorPacket(const std::vector& input, int64 time) { @@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_b = - runner_->Outputs().Tag("B").packets; + runner_->Outputs().Tag(kBTag).packets; ASSERT_EQ(output_packets_b.size(), 1); const tf::Tensor& tensor_b = output_packets_b[0].Get(); tf::TensorShape expected_shape({1, 3}); @@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) { tf::test::ExpectTensorEqual(expected_tensor, tensor_b); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); expected_tensor = tf::test::AsTensor({0, 0, 0}, expected_shape); @@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); tf::TensorShape expected_shape({3}); @@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); tf::TensorShape expected_shape({3}); @@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(3, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(5, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({6, 8, 10}); @@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); LOG(INFO) << "timestamp: " << 0; @@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); LOG(INFO) << "timestamp: " << 0; @@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) { MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(0, output_packets_mult.size()); } @@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets_mult = - runner_->Outputs().Tag("MULTIPLIED").packets; + runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(1, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); auto expected_tensor = tf::test::AsTensor({9, 12, 15}); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 2c1d169bc..794a8a732 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -47,6 +47,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kSessionTag[] = "SESSION"; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; + // Updates the graph nodes to use the device as specified by device_id. void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { for (auto& node : *graph_def->mutable_node()) { @@ -64,30 +69,32 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { cc->Options(); bool has_exactly_one_model = !options.graph_proto_path().empty() - ? !(cc->InputSidePackets().HasTag("STRING_MODEL") | - cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) - : (cc->InputSidePackets().HasTag("STRING_MODEL") ^ - cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")); + ? !(cc->InputSidePackets().HasTag(kStringModelTag) | + cc->InputSidePackets().HasTag(kStringModelFilePathTag)) + : (cc->InputSidePackets().HasTag(kStringModelTag) ^ + cc->InputSidePackets().HasTag(kStringModelFilePathTag)); RET_CHECK(has_exactly_one_model) << "Must have exactly one of graph_proto_path in options or " "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; - if (cc->InputSidePackets().HasTag("STRING_MODEL")) { + if (cc->InputSidePackets().HasTag(kStringModelTag)) { cc->InputSidePackets() - .Tag("STRING_MODEL") + .Tag(kStringModelTag) .Set( // String model from embedded path ); - } else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { + } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) { cc->InputSidePackets() - .Tag("STRING_MODEL_FILE_PATH") + .Tag(kStringModelFilePathTag) .Set( // Filename of std::string model. ); } - cc->OutputSidePackets().Tag("SESSION").Set( - // A TensorFlow model loaded and ready for use along with - // a map from tags to tensor names. - ); + cc->OutputSidePackets() + .Tag(kSessionTag) + .Set( + // A TensorFlow model loaded and ready for use along with + // a map from tags to tensor names. + ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); return absl::OkStatus(); } @@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { session->session.reset(tf::NewSession(session_options)); std::string graph_def_serialized; - if (cc->InputSidePackets().HasTag("STRING_MODEL")) { + if (cc->InputSidePackets().HasTag(kStringModelTag)) { graph_def_serialized = - cc->InputSidePackets().Tag("STRING_MODEL").Get(); - } else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { + cc->InputSidePackets().Tag(kStringModelTag).Get(); + } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) { const std::string& frozen_graph = cc->InputSidePackets() - .Tag("STRING_MODEL_FILE_PATH") + .Tag(kStringModelFilePathTag) .Get(); RET_CHECK_OK( mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); @@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); } - cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index bdf90dcbb..f0f8928db 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -37,6 +37,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { return mediapipe::file::JoinPath("./", "mediapipe/calculators/tensorflow/" @@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -213,12 +217,12 @@ TEST_F( } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } @@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); auto run_status = runner.Run(); EXPECT_THAT( @@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); auto run_status = runner.Run(); EXPECT_THAT( @@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, } })", calculator_options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = + runner.MutableSidePackets()->Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - runner.MutableSidePackets()->Tag("STRING_MODEL") = + runner.MutableSidePackets()->Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); auto run_status = runner.Run(); EXPECT_THAT( @@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); VerifySignatureMap(session); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index 9f5b9e06b..09985bcf3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -43,6 +43,11 @@ namespace mediapipe { namespace tf = ::tensorflow; namespace { + +constexpr char kSessionTag[] = "SESSION"; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; + // Updates the graph nodes to use the device as specified by device_id. void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { for (auto& node : *graph_def->mutable_node()) { @@ -64,28 +69,29 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); bool has_exactly_one_model = !options.graph_proto_path().empty() - ? !(input_side_packets->HasTag("STRING_MODEL") | - input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) - : (input_side_packets->HasTag("STRING_MODEL") ^ - input_side_packets->HasTag("STRING_MODEL_FILE_PATH")); + ? !(input_side_packets->HasTag(kStringModelTag) | + input_side_packets->HasTag(kStringModelFilePathTag)) + : (input_side_packets->HasTag(kStringModelTag) ^ + input_side_packets->HasTag(kStringModelFilePathTag)); RET_CHECK(has_exactly_one_model) << "Must have exactly one of graph_proto_path in options or " "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; - if (input_side_packets->HasTag("STRING_MODEL")) { - input_side_packets->Tag("STRING_MODEL") + if (input_side_packets->HasTag(kStringModelTag)) { + input_side_packets->Tag(kStringModelTag) .Set( // String model from embedded path ); - } else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) { - input_side_packets->Tag("STRING_MODEL_FILE_PATH") + } else if (input_side_packets->HasTag(kStringModelFilePathTag)) { + input_side_packets->Tag(kStringModelFilePathTag) .Set( // Filename of std::string model. ); } - output_side_packets->Tag("SESSION").Set( - // A TensorFlow model loaded and ready for use along with - // a map from tags to tensor names. - ); + output_side_packets->Tag(kSessionTag) + .Set( + // A TensorFlow model loaded and ready for use along with + // a map from tags to tensor names. + ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); return absl::OkStatus(); } @@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { session->session.reset(tf::NewSession(session_options)); std::string graph_def_serialized; - if (input_side_packets.HasTag("STRING_MODEL")) { + if (input_side_packets.HasTag(kStringModelTag)) { graph_def_serialized = - input_side_packets.Tag("STRING_MODEL").Get(); - } else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) { + input_side_packets.Tag(kStringModelTag).Get(); + } else if (input_side_packets.HasTag(kStringModelFilePathTag)) { const std::string& frozen_graph = - input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get(); + input_side_packets.Tag(kStringModelFilePathTag).Get(); RET_CHECK_OK( mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); } else { @@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); } - output_side_packets->Tag("SESSION") = Adopt(session.release()); + output_side_packets->Tag(kSessionTag) = Adopt(session.release()); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index 34d7e8828..83f947a0c 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -37,6 +37,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH"; +constexpr char kStringModelTag[] = "STRING_MODEL"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetGraphDefPath() { return mediapipe::file::JoinPath("./", "mediapipe/calculators/tensorflow/" @@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { void VerifySignatureMap(PacketSet* output_side_packets) { const TensorFlowSession& session = - output_side_packets->Tag("SESSION").Get(); + output_side_packets->Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); generator_options_->clear_graph_proto_path(); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -196,7 +200,7 @@ TEST_F( PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->clear_graph_proto_path(); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, @@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); absl::Status run_status = tool::RunGenerateAndValidateTypes( @@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL") = + input_side_packets.Tag(kStringModelTag) = Adopt(new std::string(serialized_graph_contents)); - input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + input_side_packets.Tag(kStringModelFilePathTag) = Adopt(new std::string(GetGraphDefPath())); generator_options_->clear_graph_proto_path(); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index c169c6b1e..de600de31 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -31,6 +31,9 @@ namespace mediapipe { namespace { + +constexpr char kSessionTag[] = "SESSION"; + static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models @@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { cc->InputSidePackets().Tag(kStringSavedModelPath).Set(); } // A TensorFlow model loaded and ready for use along with tensor - cc->OutputSidePackets().Tag("SESSION").Set(); + cc->OutputSidePackets().Tag(kSessionTag).Set(); return absl::OkStatus(); } @@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { output_signature.first, options)] = output_signature.second.name(); } - cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index 7016f14bb..52cd9e0bb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -35,6 +35,9 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetSavedModelDir() { std::string out_path = file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", @@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, } })", options_->DebugString())); - runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") = + runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) = MakePacket(GetSavedModelDir()); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, options_->DebugString())); MP_ASSERT_OK(runner.Run()); const TensorFlowSession& session = - runner.OutputSidePackets().Tag("SESSION").Get(); + runner.OutputSidePackets().Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 6489b0267..9b2e16a88 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -33,6 +33,9 @@ namespace mediapipe { namespace { + +constexpr char kSessionTag[] = "SESSION"; + static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models @@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { input_side_packets->Tag(kStringSavedModelPath).Set(); } // A TensorFlow model loaded and ready for use along with tensor - output_side_packets->Tag("SESSION").Set(); + output_side_packets->Tag(kSessionTag).Set(); return absl::OkStatus(); } @@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { output_signature.first, options)] = output_signature.second.name(); } - output_side_packets->Tag("SESSION") = Adopt(session.release()); + output_side_packets->Tag(kSessionTag) = Adopt(session.release()); return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index aca506f0b..46cbf41cb 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -34,6 +34,9 @@ namespace { namespace tf = ::tensorflow; +constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH"; +constexpr char kSessionTag[] = "SESSION"; + std::string GetSavedModelDir() { std::string out_path = file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", @@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); @@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, generator_options_->clear_saved_model_path(); PacketSet input_side_packets( tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value()); - input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = + input_side_packets.Tag(kStringSavedModelPathTag) = Adopt(new std::string(GetSavedModelDir())); PacketSet output_side_packets( tool::CreateTagMap({"SESSION:session"}).value()); @@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); } @@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); const TensorFlowSession& session = - output_side_packets.Tag("SESSION").Get(); + output_side_packets.Tag(kSessionTag).Get(); // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index e8e40bad3..d12f91741 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -33,6 +33,33 @@ namespace { namespace tf = ::tensorflow; namespace mpms = mediapipe::mediasequence; +constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE"; +constexpr char kEncodedMediaStartTimestampTag[] = + "ENCODED_MEDIA_START_TIMESTAMP"; +constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA"; +constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS"; +constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS"; +constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS"; +constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS"; +constexpr char kDataPathTag[] = "DATA_PATH"; +constexpr char kDatasetRootTag[] = "DATASET_ROOT"; +constexpr char kMediaIdTag[] = "MEDIA_ID"; +constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX"; +constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG"; +constexpr char kAudioOtherTag[] = "AUDIO_OTHER"; +constexpr char kAudioTestTag[] = "AUDIO_TEST"; +constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; +constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; +constexpr char kBboxPrefixTag[] = "BBOX_PREFIX"; +constexpr char kKeypointsTag[] = "KEYPOINTS"; +constexpr char kBboxTag[] = "BBOX"; +constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; +constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; +constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; + class UnpackMediaSequenceCalculatorTest : public ::testing::Test { protected: void SetUpCalculator(const std::vector& output_streams, @@ -95,13 +122,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) { mpms::AddImageEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -124,13 +151,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) { mpms::AddImageEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -154,13 +181,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) { mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE_PREFIX").packets; + runner_->Outputs().Tag(kImagePrefixTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -182,12 +209,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) { mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + runner_->Outputs().Tag(kForwardFlowEncodedTag).packets; ASSERT_EQ(num_forward_flow_images, output_packets.size()); for (int i = 0; i < num_forward_flow_images; ++i) { @@ -211,12 +238,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) { mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + runner_->Outputs().Tag(kForwardFlowEncodedTag).packets; ASSERT_EQ(num_forward_flow_images, output_packets.size()); for (int i = 0; i < num_forward_flow_images; ++i) { @@ -240,13 +267,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) { mpms::AddBBoxTimestamp(i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("BBOX").packets; + runner_->Outputs().Tag(kBboxTag).packets; ASSERT_EQ(bboxes.size(), output_packets.size()); for (int i = 0; i < bboxes.size(); ++i) { @@ -274,13 +301,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) { mpms::AddBBoxTimestamp(prefix, i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("BBOX_PREFIX").packets; + runner_->Outputs().Tag(kBboxPrefixTag).packets; ASSERT_EQ(bboxes.size(), output_packets.size()); for (int i = 0; i < bboxes.size(); ++i) { @@ -306,13 +333,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) { mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets; + runner_->Outputs().Tag(kFloatFeatureTestTag).packets; ASSERT_EQ(num_float_lists, output_packets.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -322,7 +349,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) { } const std::vector& output_packets_other = - runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + runner_->Outputs().Tag(kFloatFeatureOtherTag).packets; ASSERT_EQ(num_float_lists, output_packets_other.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -352,12 +379,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) { mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get()); } - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -366,7 +393,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) { } const std::vector& output_packets_other = - runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + runner_->Outputs().Tag(kFloatFeatureOtherTag).packets; ASSERT_EQ(num_float_lists, output_packets_other.size()); for (int i = 0; i < num_float_lists; ++i) { @@ -389,12 +416,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& fdense_avg_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets; ASSERT_EQ(fdense_avg_packets.size(), 1); const auto& fdense_avg_vector = fdense_avg_packets[0].Get>(); @@ -403,7 +430,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { ::testing::Eq(Timestamp::PostStream())); const std::vector& fdense_max_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets; ASSERT_EQ(fdense_max_packets.size(), 1); const auto& fdense_max_vector = fdense_max_packets[0].Get>(); @@ -430,13 +457,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& output_packets = - runner_->Outputs().Tag("IMAGE").packets; + runner_->Outputs().Tag(kImageTag).packets; ASSERT_EQ(num_images, output_packets.size()); for (int i = 0; i < num_images; ++i) { @@ -463,13 +490,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) { mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), input_sequence.get()); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); MP_ASSERT_OK(runner_->Run()); const std::vector& fdense_max_packets = - runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets; ASSERT_EQ(fdense_max_packets.size(), 1); const auto& fdense_max_vector = fdense_max_packets[0].Get>(); @@ -481,17 +508,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) { TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); std::string root = "test_root"; - runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root); + runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), root + "/" + data_path_); } @@ -501,28 +528,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) { options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_dataset_root_directory(root); SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), root + "/" + data_path_); } TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { SetUpCalculator({}, {"DATA_PATH:data_path"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->OutputSidePackets() - .Tag("DATA_PATH") + .Tag(kDataPathTag) .ValidateAsType()); - ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get(), data_path_); } @@ -534,20 +561,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) { ->set_padding_after_label(2); SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .start_time(), 2.0, 1e-5); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .end_time(), 7.0, 1e-5); @@ -563,20 +590,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { ->set_force_decoding_from_start_of_media(true); SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .start_time(), 0.0, 1e-5); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("AUDIO_DECODER_OPTIONS") + .Tag(kAudioDecoderOptionsTag) .Get() .end_time(), 7.0, 1e-5); @@ -594,27 +621,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { ->mutable_base_packet_resampler_options() ->set_frame_rate(1.0); SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .ValidateAsType()); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .start_time(), 2000000, 1); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .end_time(), 7000000, 1); EXPECT_NEAR(runner_->OutputSidePackets() - .Tag("RESAMPLER_OPTIONS") + .Tag(kResamplerOptionsTag) .Get() .GetExtension(PacketResamplerCalculatorOptions::ext) .frame_rate(), @@ -623,13 +650,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) { SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"}); - runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(sequence_.release()); MP_ASSERT_OK(runner_->Run()); MP_EXPECT_OK(runner_->OutputSidePackets() - .Tag("IMAGE_FRAME_RATE") + .Tag(kImageFrameRateTag) .ValidateAsType()); - EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get(), + EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get(), image_frame_rate_); } diff --git a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc index 369c09660..a7f1a9e7f 100644 --- a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_test.cc @@ -26,6 +26,10 @@ namespace { namespace tf = ::tensorflow; +constexpr char kSingleIntTag[] = "SINGLE_INT"; +constexpr char kTensorOutTag[] = "TENSOR_OUT"; +constexpr char kVectorIntTag[] = "VECTOR_INT"; + class VectorIntToTensorCalculatorTest : public ::testing::Test { protected: void SetUpRunner( @@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test { const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) { tensorflow::DT_INT32, false, true); const int64 time = 1234; runner_->MutableInputs() - ->Tag("SINGLE_INT") + ->Tag(kSingleIntTag) .packets.push_back(MakePacket(1).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) { } const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) { tensorflow::DT_INT64, false, true); const int64 time = 1234; runner_->MutableInputs() - ->Tag("SINGLE_INT") + ->Tag(kSingleIntTag) .packets.push_back(MakePacket(1LL << 31).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); @@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) { } const int64 time = 1234; runner_->MutableInputs() - ->Tag("VECTOR_INT") + ->Tag(kVectorIntTag) .packets.push_back(Adopt(input.release()).At(Timestamp(time))); EXPECT_TRUE(runner_->Run().ok()); const std::vector& output_packets = - runner_->Outputs().Tag("TENSOR_OUT").packets; + runner_->Outputs().Tag(kTensorOutTag).packets; EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(time, output_packets[0].Timestamp().Value()); const tf::Tensor& output_tensor = output_packets[0].Get(); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2d1037d20..55616bb83 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -162,6 +162,27 @@ selects.config_setting_group( ], ) +config_setting( + name = "edge_tpu_usb", + define_values = { + "MEDIAPIPE_EDGE_TPU": "usb", + }, +) + +config_setting( + name = "edge_tpu_pci", + define_values = { + "MEDIAPIPE_EDGE_TPU": "pci", + }, +) + +config_setting( + name = "edge_tpu_all", + define_values = { + "MEDIAPIPE_EDGE_TPU": "all", + }, +) + cc_library( name = "tflite_inference_calculator", srcs = ["tflite_inference_calculator.cc"], @@ -172,6 +193,12 @@ cc_library( ], "//conditions:default": [], }), + defines = select({ + "//conditions:default": [], + ":edge_tpu_usb": ["MEDIAPIPE_EDGE_TPU=usb"], + ":edge_tpu_pci": ["MEDIAPIPE_EDGE_TPU=pci"], + ":edge_tpu_all": ["MEDIAPIPE_EDGE_TPU=all"], + }), linkopts = select({ "//mediapipe:ios": [ "-framework CoreVideo", @@ -223,6 +250,20 @@ cc_library( "//conditions:default": [ "//mediapipe/util:cpu_util", ], + }) + select({ + "//conditions:default": [], + ":edge_tpu_usb": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_usb", + ], + ":edge_tpu_pci": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_pci", + ], + ":edge_tpu_all": [ + "@libedgetpu//tflite/public:edgetpu", + "@libedgetpu//tflite/public:oss_edgetpu_direct_all", + ], }), alwayslink = 1, ) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.proto b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto index c89248822..911e4ac92 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.proto +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto @@ -24,20 +24,20 @@ message SsdAnchorsCalculatorOptions { optional SsdAnchorsCalculatorOptions ext = 247258239; } // Size of input images. - required int32 input_size_width = 1; - required int32 input_size_height = 2; + optional int32 input_size_width = 1; // required + optional int32 input_size_height = 2; // required // Min and max scales for generating anchor boxes on feature maps. - required float min_scale = 3; - required float max_scale = 4; + optional float min_scale = 3; // required + optional float max_scale = 4; // required // The offset for the center of anchors. The value is in the scale of stride. // E.g. 0.5 meaning 0.5 * |current_stride| in pixels. - required float anchor_offset_x = 5 [default = 0.5]; - required float anchor_offset_y = 6 [default = 0.5]; + optional float anchor_offset_x = 5 [default = 0.5]; // required + optional float anchor_offset_y = 6 [default = 0.5]; // required // Number of output feature maps to generate the anchors on. - required int32 num_layers = 7; + optional int32 num_layers = 7; // required // Sizes of output feature maps to create anchors. Either feature_map size or // stride should be provided. repeated int32 feature_map_width = 8; diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 9ec556987..8e83f3e44 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -85,7 +85,22 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU"; } // namespace #if defined(MEDIAPIPE_EDGE_TPU) -#include "edgetpu.h" +#include "tflite/public/edgetpu.h" + +// Checkes whether model contains Edge TPU custom op or not. +bool ContainsEdgeTpuCustomOp(const tflite::FlatBufferModel& model) { + const auto* opcodes = model.GetModel()->operator_codes(); + for (const auto* subgraph : *model.GetModel()->subgraphs()) { + for (const auto* op : *subgraph->operators()) { + const auto* opcode = opcodes->Get(op->opcode_index()); + if (opcode->custom_code() && + opcode->custom_code()->str() == edgetpu::kCustomOp) { + return true; + } + } + } + return false; +} // Creates and returns an Edge TPU interpreter to run the given edgetpu model. std::unique_ptr BuildEdgeTpuInterpreter( @@ -94,14 +109,9 @@ std::unique_ptr BuildEdgeTpuInterpreter( edgetpu::EdgeTpuContext* edgetpu_context) { resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); std::unique_ptr interpreter; - if (tflite::InterpreterBuilder(model, *resolver)(&interpreter) != kTfLiteOk) { - std::cerr << "Failed to build edge TPU interpreter." << std::endl; - } + CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter), + kTfLiteOk); interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context); - interpreter->SetNumThreads(1); - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate edge TPU tensors." << std::endl; - } return interpreter; } #endif // MEDIAPIPE_EDGE_TPU @@ -279,8 +289,7 @@ class TfLiteInferenceCalculator : public CalculatorBase { #endif // MEDIAPIPE_TFLITE_GL_INFERENCE #if defined(MEDIAPIPE_EDGE_TPU) - std::shared_ptr edgetpu_context_ = - edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); + std::shared_ptr edgetpu_context_; #endif bool gpu_inference_ = false; @@ -292,6 +301,8 @@ class TfLiteInferenceCalculator : public CalculatorBase { bool allow_precision_loss_ = false; mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api tflite_gpu_runner_api_; + mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::InferenceUsage + tflite_gpu_runner_usage_; bool use_kernel_caching_ = false; std::string cached_kernel_filename_; @@ -301,6 +312,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Calculator Core Section namespace { + +constexpr char kCustomOpResolverTag[] = "CUSTOM_OP_RESOLVER"; +constexpr char kModelTag[] = "MODEL"; + template bool ShouldUseGpu(CC* cc) { #if MEDIAPIPE_TFLITE_GPU_SUPPORTED @@ -325,7 +340,7 @@ absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { const auto& options = cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); RET_CHECK(!options.model_path().empty() ^ - cc->InputSidePackets().HasTag("MODEL")) + cc->InputSidePackets().HasTag(kModelTag)) << "Either model as side packet or model path in options is required."; if (cc->Inputs().HasTag(kTensorsTag)) @@ -338,13 +353,13 @@ absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kTensorsGpuTag)) cc->Outputs().Tag(kTensorsGpuTag).Set>(); - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") + .Tag(kCustomOpResolverTag) .Set(); } - if (cc->InputSidePackets().HasTag("MODEL")) { - cc->InputSidePackets().Tag("MODEL").Set(); + if (cc->InputSidePackets().HasTag(kModelTag)) { + cc->InputSidePackets().Tag(kModelTag).Set(); } if (ShouldUseGpu(cc)) { @@ -377,6 +392,7 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { options.delegate().gpu().use_advanced_gpu_api(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); tflite_gpu_runner_api_ = options.delegate().gpu().api(); + tflite_gpu_runner_usage_ = options.delegate().gpu().usage(); use_kernel_caching_ = use_advanced_gpu_api_ && options.delegate().gpu().has_cached_kernel_path(); @@ -483,8 +499,8 @@ absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { MP_RETURN_IF_ERROR(WriteKernelsToFile()); return RunInContextIfNeeded([this]() -> absl::Status { + interpreter_ = nullptr; if (delegate_) { - interpreter_ = nullptr; delegate_ = nullptr; #if MEDIAPIPE_TFLITE_GPU_SUPPORTED if (gpu_inference_) { @@ -498,7 +514,7 @@ absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED } #if defined(MEDIAPIPE_EDGE_TPU) - edgetpu_context_.reset(); + edgetpu_context_ = nullptr; #endif return absl::OkStatus(); }); @@ -720,9 +736,9 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( auto op_resolver_ptr = static_cast( &default_op_resolver); - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { op_resolver_ptr = &(cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") + .Tag(kCustomOpResolverTag) .Get()); } @@ -733,7 +749,23 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( : tflite::gpu::InferencePriority::MAX_PRECISION; options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + switch (tflite_gpu_runner_usage_) { + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + FAST_SINGLE_ANSWER: { + options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER; + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + SUSTAINED_SPEED: { + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + break; + } + case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu:: + UNSPECIFIED: { + return absl::InternalError("inference usage need to be specified."); + } + } + tflite_gpu_runner_ = std::make_unique(options); switch (tflite_gpu_runner_api_) { case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: { @@ -806,21 +838,26 @@ absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates default_op_resolver; - auto op_resolver_ptr = - static_cast( - &default_op_resolver); - - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver_ptr = &(cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get()); - } - #if defined(MEDIAPIPE_EDGE_TPU) - interpreter_ = - BuildEdgeTpuInterpreter(model, op_resolver_ptr, edgetpu_context_.get()); -#else - tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_); + if (ContainsEdgeTpuCustomOp(model)) { + edgetpu_context_ = edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); + interpreter_ = BuildEdgeTpuInterpreter(model, &default_op_resolver, + edgetpu_context_.get()); + } else { +#endif // MEDIAPIPE_EDGE_TPU + auto op_resolver_ptr = + static_cast( + &default_op_resolver); + + if (cc->InputSidePackets().HasTag(kCustomOpResolverTag)) { + op_resolver_ptr = &(cc->InputSidePackets() + .Tag(kCustomOpResolverTag) + .Get()); + } + + tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_); +#if defined(MEDIAPIPE_EDGE_TPU) + } #endif // MEDIAPIPE_EDGE_TPU RET_CHECK(interpreter_); @@ -853,8 +890,8 @@ absl::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( if (!options.model_path().empty()) { return TfLiteModelLoader::LoadFromPath(options.model_path()); } - if (cc.InputSidePackets().HasTag("MODEL")) { - return cc.InputSidePackets().Tag("MODEL"); + if (cc.InputSidePackets().HasTag(kModelTag)) { + return cc.InputSidePackets().Tag(kModelTag); } return absl::Status(absl::StatusCode::kNotFound, "Must specify TFLite model as path or loaded model."); @@ -878,11 +915,15 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { // Attempt to use NNAPI. // If not supported, the default CPU delegate will be created and used. interpreter_->SetAllowFp16PrecisionForFp32(1); - delegate_ = - TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { - // No need to free according to tflite::NnApiDelegate() - // documentation. - }); + tflite::StatefulNnApiDelegate::Options options; + const auto& nnapi = calculator_opts.delegate().nnapi(); + // Set up cache_dir and model_token for NNAPI compilation cache. + if (nnapi.has_cache_dir() && nnapi.has_model_token()) { + options.cache_dir = nnapi.cache_dir().c_str(); + options.model_token = nnapi.model_token().c_str(); + } + delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), + [](TfLiteDelegate*) {}); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); return absl::OkStatus(); @@ -906,6 +947,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { kTfLiteOk); return absl::OkStatus(); } +#else + (void)use_xnnpack; #endif // !EDGETPU // Return and use default tflite infernece (on CPU). No need for GPU diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index 02dc20831..3b4d2896e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -67,9 +67,31 @@ message TfLiteInferenceCalculatorOptions { // Only available for OpenCL delegate on Android. // Kernel caching will only be enabled if this path is set. optional string cached_kernel_path = 2; + + // Encapsulated compilation/runtime tradeoffs. + enum InferenceUsage { + UNSPECIFIED = 0; + + // InferenceRunner will be used only once. Therefore, it is important to + // minimize bootstrap time as well. + FAST_SINGLE_ANSWER = 1; + + // Prefer maximizing the throughput. Same inference runner will be used + // repeatedly on different inputs. + SUSTAINED_SPEED = 2; + } + optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED]; } // Android only. - message Nnapi {} + message Nnapi { + // Directory to store compilation cache. If unspecified, NNAPI will not + // try caching the compilation. + optional string cache_dir = 1; + // Unique token identifying the model. It is the caller's responsibility + // to ensure there is no clash of the tokens. If unspecified, NNAPI will + // not try caching the compilation. + optional string model_token = 2; + } message Xnnpack { // Number of threads for XNNPACK delegate. (By default, calculator tries // to choose optimal number of threads depending on the device.) diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto index ef494c2cc..41ad903de 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto @@ -26,12 +26,12 @@ message TfLiteTensorsToDetectionsCalculatorOptions { } // The number of output classes predicted by the detection model. - required int32 num_classes = 1; + optional int32 num_classes = 1; // required // The number of output boxes predicted by the detection model. - required int32 num_boxes = 2; + optional int32 num_boxes = 2; // required // The number of output values per boxes predicted by the detection model. The // values contain bounding boxes, keypoints, etc. - required int32 num_coords = 3; + optional int32 num_coords = 3; // required // The offset of keypoint coordinates in the location tensor. optional int32 keypoint_coord_offset = 9; diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc index ef2946c32..94cfaece8 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc @@ -18,6 +18,10 @@ namespace mediapipe { +constexpr char kFloatsTag[] = "FLOATS"; +constexpr char kFloatTag[] = "FLOAT"; +constexpr char kTensorsTag[] = "TENSORS"; + // A calculator for converting TFLite tensors to to a float or a float vector. // // Input: @@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); absl::Status TfLiteTensorsToFloatsCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("TENSORS")); - RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); + RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); + RET_CHECK(cc->Outputs().HasTag(kFloatsTag) || + cc->Outputs().HasTag(kFloatTag)); - cc->Inputs().Tag("TENSORS").Set>(); - if (cc->Outputs().HasTag("FLOATS")) { - cc->Outputs().Tag("FLOATS").Set>(); + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Outputs().HasTag(kFloatsTag)) { + cc->Outputs().Tag(kFloatsTag).Set>(); } - if (cc->Outputs().HasTag("FLOAT")) { - cc->Outputs().Tag("FLOAT").Set(); + if (cc->Outputs().HasTag(kFloatTag)) { + cc->Outputs().Tag(kFloatTag).Set(); } return absl::OkStatus(); @@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) { } absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { - RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); + RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty()); const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); + cc->Inputs().Tag(kTensorsTag).Get>(); // TODO: Add option to specify which tensor to take from. const TfLiteTensor* raw_tensor = &input_tensors[0]; const float* raw_floats = raw_tensor->data.f; @@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { num_values *= raw_tensor->dims->data[i]; } - if (cc->Outputs().HasTag("FLOAT")) { + if (cc->Outputs().HasTag(kFloatTag)) { // TODO: Could add an index in the option to specifiy returning one // value of a float array. RET_CHECK_EQ(num_values, 1); - cc->Outputs().Tag("FLOAT").AddPacket( + cc->Outputs().Tag(kFloatTag).AddPacket( MakePacket(raw_floats[0]).At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("FLOATS")) { + if (cc->Outputs().HasTag(kFloatsTag)) { auto output_floats = absl::make_unique>( raw_floats, raw_floats + num_values); - cc->Outputs().Tag("FLOATS").Add(output_floats.release(), - cc->InputTimestamp()); + cc->Outputs() + .Tag(kFloatsTag) + .Add(output_floats.release(), cc->InputTimestamp()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto index 793639a53..ea3902473 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto @@ -31,7 +31,7 @@ message TfLiteTensorsToLandmarksCalculatorOptions { } // Number of landmarks from the output of the model. - required int32 num_landmarks = 1; + optional int32 num_landmarks = 1; // required // Size of the input image for the model. These options are used only when // normalized landmarks are needed. Z coordinate is scaled as X assuming diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto index d04aa562b..fa768efe7 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto @@ -24,9 +24,9 @@ message TfLiteTensorsToSegmentationCalculatorOptions { } // Dimensions of input segmentation tensor to process. - required int32 tensor_width = 1; - required int32 tensor_height = 2; - required int32 tensor_channels = 3; + optional int32 tensor_width = 1; // required + optional int32 tensor_height = 2; // required + optional int32 tensor_channels = 3; // required // How much to use previous mask when computing current one; range [0-1]. // This is a tradeoff between responsiveness (0.0) and accuracy (1.0). diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 869b4387e..dc90485a1 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -71,6 +71,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "filter_detections_calculator_proto", + srcs = ["filter_detections_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], @@ -172,6 +182,21 @@ cc_test( ], ) +cc_test( + name = "filter_detections_calculator_test", + size = "small", + srcs = ["filter_detections_calculator_test.cc"], + deps = [ + ":filter_detections_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], @@ -386,6 +411,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "filter_detections_calculator", + srcs = ["filter_detections_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":filter_detections_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], @@ -1144,6 +1183,7 @@ cc_library( "//mediapipe/framework:collection_item_id", "//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:status", + "//mediapipe/util:rectangle_util", "@com_google_absl//absl/memory", ], alwayslink = 1, @@ -1359,6 +1399,32 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "landmarks_refinement_calculator_proto", + srcs = ["landmarks_refinement_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "landmarks_refinement_calculator", + srcs = ["landmarks_refinement_calculator.cc"], + hdrs = ["landmarks_refinement_calculator.h"], + deps = [ + ":landmarks_refinement_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + cc_test( name = "refine_landmarks_from_heatmap_calculator_test", srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"], @@ -1367,3 +1433,34 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "inverse_matrix_calculator", + srcs = ["inverse_matrix_calculator.cc"], + hdrs = ["inverse_matrix_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "@com_google_absl//absl/status", + "@eigen_archive//:eigen3", + ], + alwayslink = True, +) + +cc_test( + name = "inverse_matrix_calculator_test", + srcs = ["inverse_matrix_calculator_test.cc"], + tags = ["desktop_only_test"], + deps = [ + ":inverse_matrix_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/calculators/util/association_calculator.h b/mediapipe/calculators/util/association_calculator.h index 6e5b480ce..037ea838c 100644 --- a/mediapipe/calculators/util/association_calculator.h +++ b/mediapipe/calculators/util/association_calculator.h @@ -26,20 +26,10 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/util/rectangle_util.h" namespace mediapipe { -// Computes the overlap similarity based on Intersection over Union (IoU) of -// two rectangles. -inline float OverlapSimilarity(const Rectangle_f& rect1, - const Rectangle_f& rect2) { - if (!rect1.Intersects(rect2)) return 0.0f; - // Compute IoU similarity score. - const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area(); - const float normalization = rect1.Area() + rect2.Area() - intersection_area; - return normalization > 0.0f ? intersection_area / normalization : 0.0f; -} - // AssocationCalculator accepts multiple inputs of vectors of type T that can // be converted to Rectangle_f. The output is a vector of type T that contains // elements from the input vectors that don't overlap with each other. When @@ -187,7 +177,7 @@ class AssociationCalculator : public CalculatorBase { for (auto uit = current->begin(); uit != current->end();) { ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit)); - if (OverlapSimilarity(cur_rect, prev_rect) > + if (CalculateIou(cur_rect, prev_rect) > options_.min_similarity_threshold()) { std::pair prev_id = GetId(*uit); // If prev_id.first is false when some element doesn't have an ID, @@ -232,7 +222,7 @@ class AssociationCalculator : public CalculatorBase { } const Rectangle_f& prev_rect = get_prev_rectangle.value(); - if (OverlapSimilarity(cur_rect, prev_rect) > + if (CalculateIou(cur_rect, prev_rect) > options_.min_similarity_threshold()) { std::pair prev_id = GetId(prev_input_vec[ui]); // If prev_id.first is false when some element doesn't have an ID, diff --git a/mediapipe/calculators/util/clock_timestamp_calculator.cc b/mediapipe/calculators/util/clock_timestamp_calculator.cc index 4ba56cfd0..324bc4ac7 100644 --- a/mediapipe/calculators/util/clock_timestamp_calculator.cc +++ b/mediapipe/calculators/util/clock_timestamp_calculator.cc @@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { // Initialize the clock. if (cc->InputSidePackets().HasTag(kClockTag)) { clock_ = cc->InputSidePackets() - .Tag("CLOCK") + .Tag(kClockTag) .Get>(); } else { clock_.reset( diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc index 805ad495d..62eb1d8ae 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc +++ b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc @@ -27,6 +27,8 @@ namespace mediapipe { +constexpr char kIterableTag[] = "ITERABLE"; + typedef CollectionHasMinSizeCalculator> TestIntCollectionHasMinSizeCalculator; REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); @@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); void AddInputVector(const std::vector& input, int64 timestamp, CalculatorRunner* runner) { runner->MutableInputs() - ->Tag("ITERABLE") + ->Tag(kIterableTag) .packets.push_back( MakePacket>(input).At(Timestamp(timestamp))); } diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc index 8f8025576..ca85a267e 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { } cc->Outputs() - .Tag("DETECTIONS") + .Tag(kDetectionsTag) .Add(output_detections.release(), cc->InputTimestamp()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc index 343ccea4f..c4f084363 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc @@ -25,6 +25,9 @@ namespace mediapipe { +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + LocationData CreateRelativeLocationData(double xmin, double ymin, double width, double height) { LocationData location_data; @@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) { detections->push_back( CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.2f, 0.f, 0.3f, 0.f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; ASSERT_EQ(1, output.size()); const auto& output_detections = output[0].Get>(); @@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) { detections->push_back( CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.f, 0.2f, 0.f, 0.3f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; ASSERT_EQ(1, output.size()); const auto& output_detections = output[0].Get>(); diff --git a/mediapipe/calculators/util/detection_projection_calculator_test.cc b/mediapipe/calculators/util/detection_projection_calculator_test.cc index 4cc85acee..176054e43 100644 --- a/mediapipe/calculators/util/detection_projection_calculator_test.cc +++ b/mediapipe/calculators/util/detection_projection_calculator_test.cc @@ -31,6 +31,9 @@ namespace mediapipe { namespace { +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + using ::testing::ElementsAre; using ::testing::FloatNear; @@ -74,19 +77,19 @@ absl::StatusOr RunProjectionCalculator( )pb")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back(MakePacket>( std::vector({std::move(detection)})) .At(Timestamp::PostStream())); runner.MutableInputs() - ->Tag("PROJECTION_MATRIX") + ->Tag(kProjectionMatrixTag) .packets.push_back( MakePacket>(std::move(project_mat)) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); const std::vector& output = - runner.Outputs().Tag("DETECTIONS").packets; + runner.Outputs().Tag(kDetectionsTag).packets; RET_CHECK_EQ(output.size(), 1); const auto& output_detections = output[0].Get>(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 3eae1af9d..a45048d40 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -32,6 +32,14 @@ namespace mediapipe { namespace { +constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kRectsTag[] = "RECTS"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kRectTag[] = "RECT"; +constexpr char kDetectionTag[] = "DETECTION"; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && @@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) { DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); @@ -120,16 +128,16 @@ absl::StatusOr RunDetectionKeyPointsToRectCalculation( )pb")); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back(MakePacket(std::move(detection)) .At(Timestamp::PostStream())); runner.MutableInputs() - ->Tag("IMAGE_SIZE") + ->Tag(kImageSizeTag) .packets.push_back(MakePacket>(image_size) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; RET_CHECK_EQ(output.size(), 1); return output[0].Get(); } @@ -176,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); @@ -201,12 +210,13 @@ absl::StatusOr RunDetectionKeyPointsToNormRectCalculation( )pb")); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back(MakePacket(std::move(detection)) .At(Timestamp::PostStream())); MP_RETURN_IF_ERROR(runner.Run()); - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; RET_CHECK_EQ(output.size(), 1); return output[0].Get(); } @@ -248,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECT").packets; + const std::vector& output = runner.Outputs().Tag(kRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); @@ -271,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) { detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + const std::vector& output = + runner.Outputs().Tag(kNormRectTag).packets; ASSERT_EQ(1, output.size()); const auto& rect = output[0].Get(); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); @@ -294,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) { detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECTS").packets; + const std::vector& output = runner.Outputs().Tag(kRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 2); @@ -319,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) { detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("NORM_RECTS").packets; + runner.Outputs().Tag(kNormRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 2); @@ -344,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) { DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("RECTS").packets; + const std::vector& output = runner.Outputs().Tag(kRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); EXPECT_EQ(rects.size(), 1); @@ -367,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) { DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTION") + ->Tag(kDetectionTag) .packets.push_back( Adopt(detection.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("NORM_RECTS").packets; + runner.Outputs().Tag(kNormRectsTag).packets; ASSERT_EQ(1, output.size()); const auto& rects = output[0].Get>(); ASSERT_EQ(rects.size(), 1); @@ -391,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) { detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); @@ -411,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) { detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc index ea4bfc484..04d8b5bcd 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -30,6 +30,10 @@ namespace mediapipe { +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; +constexpr char kDetectionListTag[] = "DETECTION_LIST"; + using ::testing::DoubleNear; // Error tolerance for pixels, distances, etc. @@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) { CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"); runner.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, output.size()); const auto& actual = output[0].Get(); EXPECT_EQ(actual.render_annotations_size(), 3); @@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) { CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, output.size()); const auto& actual = output[0].Get(); EXPECT_EQ(actual.render_annotations_size(), 3); @@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { *(detection_list->add_detection()) = CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1"); runner.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list.release()).At(Timestamp::PostStream())); @@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { detections->push_back( CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2")); runner.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& actual = - runner.Outputs().Tag("RENDER_DATA").packets; + runner.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, actual.size()); // Check the feature tag for item from detection list. EXPECT_EQ( @@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { auto detection_list1(absl::make_unique()); runner1.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list1.release()).At(Timestamp::PostStream())); auto detections1(absl::make_unique>()); runner1.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections1.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed."; const std::vector& exact1 = - runner1.Outputs().Tag("RENDER_DATA").packets; + runner1.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(0, exact1.size()); // Check when produce_empty_packet is true. @@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { auto detection_list2(absl::make_unique()); runner2.MutableInputs() - ->Tag("DETECTION_LIST") + ->Tag(kDetectionListTag) .packets.push_back( Adopt(detection_list2.release()).At(Timestamp::PostStream())); auto detections2(absl::make_unique>()); runner2.MutableInputs() - ->Tag("DETECTIONS") + ->Tag(kDetectionsTag) .packets.push_back( Adopt(detections2.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed."; const std::vector& exact2 = - runner2.Outputs().Tag("RENDER_DATA").packets; + runner2.Outputs().Tag(kRenderDataTag).packets; ASSERT_EQ(1, exact2.size()); EXPECT_EQ(exact2[0].Get().render_annotations_size(), 0); } diff --git a/mediapipe/calculators/util/filter_detections_calculator.cc b/mediapipe/calculators/util/filter_detections_calculator.cc new file mode 100644 index 000000000..a1f23ba83 --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator.cc @@ -0,0 +1,81 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/filter_detections_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +const char kInputDetectionsTag[] = "INPUT_DETECTIONS"; +const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS"; + +// +// Calculator to filter out detections that do not meet the criteria specified +// in options. +// +class FilterDetectionsCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kInputDetectionsTag)); + RET_CHECK(cc->Outputs().HasTag(kOutputDetectionsTag)); + + cc->Inputs().Tag(kInputDetectionsTag).Set>(); + cc->Outputs().Tag(kOutputDetectionsTag).Set>(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + const auto& input_detections = + cc->Inputs().Tag(kInputDetectionsTag).Get>(); + + auto output_detections = absl::make_unique>(); + + for (const Detection& detection : input_detections) { + RET_CHECK_GT(detection.score_size(), 0); + // Note: only score at index 0 supported. + if (detection.score(0) >= options_.min_score()) { + output_detections->push_back(detection); + } + } + + cc->Outputs() + .Tag(kOutputDetectionsTag) + .Add(output_detections.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + private: + mediapipe::FilterDetectionsCalculatorOptions options_; +}; + +REGISTER_CALCULATOR(FilterDetectionsCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/filter_detections_calculator.proto b/mediapipe/calculators/util/filter_detections_calculator.proto new file mode 100644 index 000000000..e16898c79 --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message FilterDetectionsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FilterDetectionsCalculatorOptions ext = 395478132; + } + + // Detections lower than this score get filtered out. + optional float min_score = 1; +} diff --git a/mediapipe/calculators/util/filter_detections_calculator_test.cc b/mediapipe/calculators/util/filter_detections_calculator_test.cc new file mode 100644 index 000000000..515a8b7df --- /dev/null +++ b/mediapipe/calculators/util/filter_detections_calculator_test.cc @@ -0,0 +1,100 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; + +absl::Status RunGraph(std::vector& input_detections, + std::vector* output_detections) { + CalculatorRunner runner(R"pb( + calculator: "FilterDetectionsCalculator" + input_stream: "INPUT_DETECTIONS:input_detections" + output_stream: "OUTPUT_DETECTIONS:output_detections" + options { + [mediapipe.FilterDetectionsCalculatorOptions.ext] { min_score: 0.5 } + } + )pb"); + + const Timestamp input_timestamp = Timestamp(0); + runner.MutableInputs() + ->Tag("INPUT_DETECTIONS") + .packets.push_back(MakePacket>(input_detections) + .At(input_timestamp)); + MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed."; + + const std::vector& output_packets = + runner.Outputs().Tag("OUTPUT_DETECTIONS").packets; + RET_CHECK_EQ(output_packets.size(), 1); + + *output_detections = output_packets[0].Get>(); + return absl::OkStatus(); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetections) { + std::vector input_detections; + Detection d1, d2; + d1.add_score(0.2); + d2.add_score(0.8); + input_detections.push_back(d1); + input_detections.push_back(d2); + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2))); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) { + std::vector input_detections; + Detection d1, d2, d3, d4; + d1.add_score(0.3); + d2.add_score(0.4); + d3.add_score(0.5); + d4.add_score(0.6); + input_detections.push_back(d1); + input_detections.push_back(d2); + input_detections.push_back(d3); + input_detections.push_back(d4); + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3), + mediapipe::EqualsProto(d4))); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) { + std::vector input_detections; + + std::vector output_detections; + MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + + EXPECT_EQ(output_detections.size(), 0); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc index 7484d9257..0ddb342eb 100644 --- a/mediapipe/calculators/util/from_image_calculator.cc +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -33,6 +33,7 @@ namespace { constexpr char kImageFrameTag[] = "IMAGE_CPU"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kSourceOnGpuTag[] = "SOURCE_ON_GPU"; } // namespace // A calculator for converting the unified image container into @@ -46,6 +47,8 @@ constexpr char kImageTag[] = "IMAGE"; // IMAGE_CPU: An ImageFrame containing output image. // IMAGE_GPU: A GpuBuffer containing output image. // +// SOURCE_ON_GPU: The source Image is stored on GPU or CPU. +// // Note: // Data is automatically transferred to/from the CPU or GPU // depending on output type. @@ -66,6 +69,7 @@ class FromImageCalculator : public CalculatorBase { absl::Status RenderGpu(CalculatorContext* cc); absl::Status RenderCpu(CalculatorContext* cc); + bool check_image_source_ = false; bool gpu_output_ = false; bool gpu_initialized_ = false; #if !MEDIAPIPE_DISABLE_GPU @@ -102,6 +106,9 @@ absl::Status FromImageCalculator::GetContract(CalculatorContract* cc) { #endif // !MEDIAPIPE_DISABLE_GPU } + if (cc->Outputs().HasTag(kSourceOnGpuTag)) { + cc->Outputs().Tag(kSourceOnGpuTag).Set(); + } return absl::OkStatus(); } @@ -111,7 +118,9 @@ absl::Status FromImageCalculator::Open(CalculatorContext* cc) { if (cc->Outputs().HasTag(kGpuBufferTag)) { gpu_output_ = true; } - + if (cc->Outputs().HasTag(kSourceOnGpuTag)) { + check_image_source_ = true; + } if (gpu_output_) { #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); @@ -122,6 +131,13 @@ absl::Status FromImageCalculator::Open(CalculatorContext* cc) { } absl::Status FromImageCalculator::Process(CalculatorContext* cc) { + if (check_image_source_) { + auto& input = cc->Inputs().Tag(kImageTag).Get(); + cc->Outputs() + .Tag(kSourceOnGpuTag) + .AddPacket(MakePacket(input.UsesGpu()).At(cc->InputTimestamp())); + } + if (gpu_output_) { #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { diff --git a/mediapipe/calculators/util/inverse_matrix_calculator.cc b/mediapipe/calculators/util/inverse_matrix_calculator.cc new file mode 100644 index 000000000..5809623c0 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator.cc @@ -0,0 +1,50 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/inverse_matrix_calculator.h" + +#include "Eigen/Core" +#include "Eigen/Geometry" +#include "Eigen/LU" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { +namespace api2 { + +class InverseMatrixCalculatorImpl : public NodeImpl { + absl::Status Process(mediapipe::CalculatorContext* cc) override { + if (kInputMatrix(cc).IsEmpty()) { + return absl::OkStatus(); + } + Eigen::Matrix matrix( + kInputMatrix(cc).Get().data()); + + Eigen::Matrix inverse_matrix; + bool inverse_check; + matrix.computeInverseWithCheck(inverse_matrix, inverse_check); + RET_CHECK(inverse_check) << "Inverse matrix cannot be calculated."; + + std::array output; + Eigen::Map>( + output.data(), 4, 4) = inverse_matrix.matrix(); + kOutputMatrix(cc).Send(std::move(output)); + return absl::OkStatus(); + } +}; +MEDIAPIPE_NODE_IMPLEMENTATION(InverseMatrixCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/inverse_matrix_calculator.h b/mediapipe/calculators/util/inverse_matrix_calculator.h new file mode 100644 index 000000000..ba1657348 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator.h @@ -0,0 +1,51 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" + +namespace mediapipe { + +// Runs affine transformation. +// +// Input: +// MATRIX - std::array +// Row major 4x4 matrix to inverse. +// +// Output: +// MATRIX - std::array +// Row major 4x4 inversed matrix. +// +// Usage example: +// node { +// calculator: "dishti.aimatter.InverseMatrixCalculator" +// input_stream: "MATRIX:input_matrix" +// output_stream: "MATRIX:output_matrix" +// } +class InverseMatrixCalculator : public mediapipe::api2::NodeIntf { + public: + static constexpr mediapipe::api2::Input> kInputMatrix{ + "MATRIX"}; + static constexpr mediapipe::api2::Output> kOutputMatrix{ + "MATRIX"}; + MEDIAPIPE_NODE_INTERFACE(InverseMatrixCalculator, kInputMatrix, + kOutputMatrix); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_INVERSE_MATRIX_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/inverse_matrix_calculator_test.cc b/mediapipe/calculators/util/inverse_matrix_calculator_test.cc new file mode 100644 index 000000000..d3b629c78 --- /dev/null +++ b/mediapipe/calculators/util/inverse_matrix_calculator_test.cc @@ -0,0 +1,126 @@ +#include "mediapipe/calculators/util/inverse_matrix_calculator.h" + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +void RunTest(const std::array& matrix, + const std::array& expected_inverse_matrix) { + auto graph_config = mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "matrix" + node { + calculator: "InverseMatrixCalculator" + input_stream: "MATRIX:matrix" + output_stream: "MATRIX:inverse_matrix" + } + )pb"); + + std::vector output_packets; + tool::AddVectorSink("inverse_matrix", &graph_config, &output_packets); + + // Run the graph. + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "matrix", + MakePacket>(std::move(matrix)).At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_packets, testing::SizeIs(1)); + + const auto& inverse_matrix = output_packets[0].Get>(); + + EXPECT_THAT(inverse_matrix, testing::Eq(expected_inverse_matrix)); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("matrix")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST(InverseMatrixCalculatorTest, Identity) { + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Translation) { + // clang-format off + std::array matrix = { + 1.0f, 0.0f, 0.0f, 2.0f, + 0.0f, 1.0f, 0.0f, -5.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 1.0f, 0.0f, 0.0f, -2.0f, + 0.0f, 1.0f, 0.0f, 5.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Scale) { + // clang-format off + std::array matrix = { + 5.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 0.2f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.5f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +TEST(InverseMatrixCalculatorTest, Rotation90) { + // clang-format off + std::array matrix = { + 0.0f, -1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + std::array expected_inverse_matrix = { + 0.0f, 1.0f, 0.0f, 0.0f, + -1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + }; + // clang-format on + RunTest(matrix, expected_inverse_matrix); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 099bdc7e6..4aab3b676 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -32,6 +32,12 @@ namespace mediapipe { +constexpr char kRenderDataTag[] = "RENDER_DATA"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kScoresTag[] = "SCORES"; +constexpr char kLabelsTag[] = "LABELS"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; + constexpr float kFontHeightScale = 1.25f; // A calculator takes in pairs of labels and scores or classifications, outputs @@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase { REGISTER_CALCULATOR(LabelsToRenderDataCalculator); absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("CLASSIFICATIONS")) { - cc->Inputs().Tag("CLASSIFICATIONS").Set(); + if (cc->Inputs().HasTag(kClassificationsTag)) { + cc->Inputs().Tag(kClassificationsTag).Set(); } else { - RET_CHECK(cc->Inputs().HasTag("LABELS")) + RET_CHECK(cc->Inputs().HasTag(kLabelsTag)) << "Must provide input stream \"LABELS\""; - cc->Inputs().Tag("LABELS").Set>(); - if (cc->Inputs().HasTag("SCORES")) { - cc->Inputs().Tag("SCORES").Set>(); + cc->Inputs().Tag(kLabelsTag).Set>(); + if (cc->Inputs().HasTag(kScoresTag)) { + cc->Inputs().Tag(kScoresTag).Set>(); } } - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } - cc->Outputs().Tag("RENDER_DATA").Set(); + cc->Outputs().Tag(kRenderDataTag).Set(); return absl::OkStatus(); } @@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { } absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().HasTag("VIDEO_PRESTREAM") && + if (cc->Inputs().HasTag(kVideoPrestreamTag) && cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = - cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + cc->Inputs().Tag(kVideoPrestreamTag).Get(); video_width_ = video_header.width; video_height_ = video_header.height; return absl::OkStatus(); @@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { std::vector labels; std::vector scores; - if (cc->Inputs().HasTag("CLASSIFICATIONS")) { + if (cc->Inputs().HasTag(kClassificationsTag)) { const ClassificationList& classifications = - cc->Inputs().Tag("CLASSIFICATIONS").Get(); + cc->Inputs().Tag(kClassificationsTag).Get(); labels.resize(classifications.classification_size()); scores.resize(classifications.classification_size()); for (int i = 0; i < classifications.classification_size(); ++i) { @@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { } } else { const std::vector& label_vector = - cc->Inputs().Tag("LABELS").Get>(); + cc->Inputs().Tag(kLabelsTag).Get>(); labels.resize(label_vector.size()); for (int i = 0; i < label_vector.size(); ++i) { labels[i] = label_vector[i]; } - if (cc->Inputs().HasTag("SCORES")) { + if (cc->Inputs().HasTag(kScoresTag)) { std::vector score_vector = - cc->Inputs().Tag("SCORES").Get>(); + cc->Inputs().Tag(kScoresTag).Get>(); CHECK_EQ(label_vector.size(), score_vector.size()); scores.resize(label_vector.size()); for (int i = 0; i < label_vector.size(); ++i) { @@ -169,7 +175,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { auto* text = label_annotation->mutable_text(); std::string display_text = labels[i]; - if (cc->Inputs().HasTag("SCORES")) { + if (cc->Inputs().HasTag(kScoresTag) || + options_.display_classification_score()) { absl::StrAppend(&display_text, ":", scores[i]); } text->set_display_text(display_text); @@ -179,7 +186,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { text->set_font_face(options_.font_face()); } cc->Outputs() - .Tag("RENDER_DATA") + .Tag(kRenderDataTag) .AddPacket(MakePacket(render_data).At(cc->InputTimestamp())); return absl::OkStatus(); diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index c5012ce85..cf0ada9c2 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -62,4 +62,7 @@ message LabelsToRenderDataCalculatorOptions { // Uses Classification.display_name field instead of Classification.label. optional bool use_display_name = 9 [default = false]; + + // Displays Classification score if enabled. + optional bool display_classification_score = 10 [default = false]; } diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc index 556d5673d..05827220e 100644 --- a/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc @@ -24,6 +24,9 @@ namespace mediapipe { +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLandmarksTag[] = "LANDMARKS"; + NormalizedLandmark CreateLandmark(float x, float y) { NormalizedLandmark landmark; landmark.set_x(x); @@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) { *landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f); *landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f); runner.MutableInputs() - ->Tag("LANDMARKS") + ->Tag(kLandmarksTag) .packets.push_back( Adopt(landmarks.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.2f, 0.f, 0.3f, 0.f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + const std::vector& output = + runner.Outputs().Tag(kLandmarksTag).packets; ASSERT_EQ(1, output.size()); const auto& output_landmarks = output[0].Get(); @@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) { landmark = landmarks->add_landmark(); *landmark = CreateLandmark(0.7f, 0.7f); runner.MutableInputs() - ->Tag("LANDMARKS") + ->Tag(kLandmarksTag) .packets.push_back( Adopt(landmarks.release()).At(Timestamp::PostStream())); auto padding = absl::make_unique>( std::array{0.0f, 0.2f, 0.0f, 0.3f}); runner.MutableInputs() - ->Tag("LETTERBOX_PADDING") + ->Tag(kLetterboxPaddingTag) .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + const std::vector& output = + runner.Outputs().Tag(kLandmarksTag).packets; ASSERT_EQ(1, output.size()); const auto& output_landmarks = output[0].Get(); diff --git a/mediapipe/calculators/util/landmark_projection_calculator_test.cc b/mediapipe/calculators/util/landmark_projection_calculator_test.cc index b15bb0f0c..2e919c30e 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator_test.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator_test.cc @@ -16,6 +16,10 @@ namespace mediapipe { namespace { +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; + absl::StatusOr RunCalculator( mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) { mediapipe::CalculatorRunner runner( @@ -26,17 +30,17 @@ absl::StatusOr RunCalculator( output_stream: "NORM_LANDMARKS:projected_landmarks" )pb")); runner.MutableInputs() - ->Tag("NORM_LANDMARKS") + ->Tag(kNormLandmarksTag) .packets.push_back( MakePacket(std::move(input)) .At(Timestamp(1))); runner.MutableInputs() - ->Tag("NORM_RECT") + ->Tag(kNormRectTag) .packets.push_back(MakePacket(std::move(rect)) .At(Timestamp(1))); MP_RETURN_IF_ERROR(runner.Run()); - const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets; RET_CHECK_EQ(output_packets.size(), 1); return output_packets[0].Get(); } @@ -104,17 +108,17 @@ absl::StatusOr RunCalculator( output_stream: "NORM_LANDMARKS:projected_landmarks" )pb")); runner.MutableInputs() - ->Tag("NORM_LANDMARKS") + ->Tag(kNormLandmarksTag) .packets.push_back( MakePacket(std::move(input)) .At(Timestamp(1))); runner.MutableInputs() - ->Tag("PROJECTION_MATRIX") + ->Tag(kProjectionMatrixTag) .packets.push_back(MakePacket>(std::move(matrix)) .At(Timestamp(1))); MP_RETURN_IF_ERROR(runner.Run()); - const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; + const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets; RET_CHECK_EQ(output_packets.size(), 1); return output_packets[0].Get(); } diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.cc b/mediapipe/calculators/util/landmarks_refinement_calculator.cc new file mode 100644 index 000000000..8f734ac88 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.cc @@ -0,0 +1,197 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/landmarks_refinement_calculator.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/landmarks_refinement_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace api2 { + +namespace { + +absl::StatusOr GetNumberOfRefinedLandmarks( + const proto_ns::RepeatedPtrField< + LandmarksRefinementCalculatorOptions::Refinement>& refinements) { + // Gather all used indexes. + std::set idxs; + for (int i = 0; i < refinements.size(); ++i) { + const auto& refinement = refinements.Get(i); + for (int i = 0; i < refinement.indexes_mapping_size(); ++i) { + idxs.insert(refinement.indexes_mapping(i)); + } + } + + // Check that indxes start with 0 and there is no gaps between min and max + // indexes. + RET_CHECK(!idxs.empty()) + << "There should be at least one landmark in indexes mapping"; + int idxs_min = *idxs.begin(); + int idxs_max = *idxs.rbegin(); + int n_idxs = idxs.size(); + RET_CHECK_EQ(idxs_min, 0) + << "Indexes are expected to start with 0 instead of " << idxs_min; + RET_CHECK_EQ(idxs_max, n_idxs - 1) + << "Indexes should have no gaps but " << idxs_max - n_idxs + 1 + << " indexes are missing"; + + return n_idxs; +} + +void RefineXY(const proto_ns::RepeatedField& indexes_mapping, + const NormalizedLandmarkList& landmarks, + NormalizedLandmarkList* refined_landmarks) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& landmark = landmarks.landmark(i); + auto* refined_landmark = + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)); + refined_landmark->set_x(landmark.x()); + refined_landmark->set_y(landmark.y()); + } +} + +float GetZAverage(const NormalizedLandmarkList& landmarks, + const proto_ns::RepeatedField& indexes) { + double z_sum = 0; + for (int i = 0; i < indexes.size(); ++i) { + z_sum += landmarks.landmark(indexes.Get(i)).z(); + } + return z_sum / indexes.size(); +} + +void RefineZ( + const proto_ns::RepeatedField& indexes_mapping, + const LandmarksRefinementCalculatorOptions::ZRefinement& z_refinement, + const NormalizedLandmarkList& landmarks, + NormalizedLandmarkList* refined_landmarks) { + if (z_refinement.has_none()) { + // Do nothing and keep Z that is already in refined landmarks. + } else if (z_refinement.has_copy()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)) + ->set_z(landmarks.landmark(i).z()); + } + } else if (z_refinement.has_assign_average()) { + const float z_average = + GetZAverage(*refined_landmarks, + z_refinement.assign_average().indexes_for_average()); + for (int i = 0; i < indexes_mapping.size(); ++i) { + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)) + ->set_z(z_average); + } + } else { + CHECK(false) << "Z refinement is either not specified or not supported"; + } +} + +} // namespace + +class LandmarksRefinementCalculatorImpl + : public NodeImpl { + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + // Validate refinements. + for (int i = 0; i < options_.refinement_size(); ++i) { + const auto& refinement = options_.refinement(i); + RET_CHECK_GT(refinement.indexes_mapping_size(), 0) + << "Refinement " << i << " has no indexes mapping"; + RET_CHECK(refinement.has_z_refinement()) + << "Refinement " << i << " has no Z refinement specified"; + RET_CHECK(refinement.z_refinement().has_none() ^ + refinement.z_refinement().has_copy() ^ + refinement.z_refinement().has_assign_average()) + << "Exactly one Z refinement should be specified"; + + const auto z_refinement = refinement.z_refinement(); + if (z_refinement.has_assign_average()) { + RET_CHECK_GT(z_refinement.assign_average().indexes_for_average_size(), + 0) + << "When using assign average Z refinement at least one index for " + "averagin should be specified"; + } + } + + // Validate indexes mapping and get total number of refined landmarks. + ASSIGN_OR_RETURN(n_refined_landmarks_, + GetNumberOfRefinedLandmarks(options_.refinement())); + + // Validate that number of refinements and landmark streams is the same. + RET_CHECK_EQ(kLandmarks(cc).Count(), options_.refinement_size()) + << "There are " << options_.refinement_size() << " refinements while " + << kLandmarks(cc).Count() << " landmark streams"; + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // If any of the refinement landmarks is missing - refinement won't happen. + for (const auto& landmarks_stream : kLandmarks(cc)) { + if (landmarks_stream.IsEmpty()) { + return absl::OkStatus(); + } + } + + // Initialize refined landmarks list. + auto refined_landmarks = absl::make_unique(); + for (int i = 0; i < n_refined_landmarks_; ++i) { + refined_landmarks->add_landmark(); + } + + // Apply input landmarks to outpu refined landmarks in provided order. + for (int i = 0; i < kLandmarks(cc).Count(); ++i) { + const auto& landmarks = kLandmarks(cc)[i].Get(); + const auto& refinement = options_.refinement(i); + + // Check number of landmarks in mapping and stream are the same. + RET_CHECK_EQ(landmarks.landmark_size(), refinement.indexes_mapping_size()) + << "There are " << landmarks.landmark_size() + << " refinement landmarks while mapping has " + << refinement.indexes_mapping_size(); + + // Refine X and Y. + RefineXY(refinement.indexes_mapping(), landmarks, + refined_landmarks.get()); + + // Refine Z. + RefineZ(refinement.indexes_mapping(), refinement.z_refinement(), + landmarks, refined_landmarks.get()); + + // Visibility and presence are not currently refined and are left as `0`. + } + + kRefinedLandmarks(cc).Send(std::move(refined_landmarks)); + return absl::OkStatus(); + } + + private: + LandmarksRefinementCalculatorOptions options_; + int n_refined_landmarks_ = 0; +}; + +MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksRefinementCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.h b/mediapipe/calculators/util/landmarks_refinement_calculator.h new file mode 100644 index 000000000..1edadcd5b --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.h @@ -0,0 +1,85 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { + +namespace api2 { + +// A calculator to refine one set of landmarks with another. +// +// Inputs: +// LANDMARKS: Multiple NormalizedLandmarkList to use for +// refinement. They will be applied to the resulting REFINED_LANDMARKS in +// the provided order. Each list should be non empty and contain the same +// amount of landmarks as indexes in mapping. Number of lists should be the +// same as number of refinements in options. +// +// Outputs: +// REFINED_LANDMARKS: A NormalizedLandmarkList with refined landmarks. Number +// of produced landmarks is equal to to the maximum index mapping number in +// calculator options (calculator verifies that there are no gaps in the +// mapping). +// +// Examples config: +// node { +// calculator: "LandmarksRefinementCalculator" +// input_stream: "LANDMARKS:0:mesh_landmarks" +// input_stream: "LANDMARKS:1:lips_landmarks" +// input_stream: "LANDMARKS:2:left_eye_landmarks" +// input_stream: "LANDMARKS:3:right_eye_landmarks" +// output_stream: "REFINED_LANDMARKS:landmarks" +// options: { +// [mediapipe.LandmarksRefinementCalculatorOptions.ext] { +// refinement: { +// indexes_mapping: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +// z_refinement: { copy {} } +// } +// refinement: { +// indexes_mapping: [0, 1, 2, 3] +// z_refinement: { none {} } +// } +// refinement: { +// indexes_mapping: [4, 5] +// z_refinement: { none {} } +// } +// refinement: { +// indexes_mapping: [6, 7] +// z_refinement: { none {} } +// } +// } +// } +// } +// +class LandmarksRefinementCalculator : public NodeIntf { + public: + static constexpr Input<::mediapipe::NormalizedLandmarkList>::Multiple + kLandmarks{"LANDMARKS"}; + static constexpr Output<::mediapipe::NormalizedLandmarkList> + kRefinedLandmarks{"REFINED_LANDMARKS"}; + + MEDIAPIPE_NODE_INTERFACE(LandmarksRefinementCalculator, kLandmarks, + kRefinedLandmarks); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.proto b/mediapipe/calculators/util/landmarks_refinement_calculator.proto new file mode 100644 index 000000000..e5234e713 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.proto @@ -0,0 +1,71 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarksRefinementCalculatorOptions { + extend CalculatorOptions { + optional LandmarksRefinementCalculatorOptions ext = 381914658; + } + + // Do nothing and keep those Z that are already present in the resulting set + // of landmarks. + message ZRefinementNone {} + + // Simply copy Z values from the given set of landmarks to the resulting set + // of landmarks. + message ZRefinementCopy {} + + // Calculate average of the specified set of landmarks in the resulting set + // and use it as Z for all given landmarks when assigning their values to the + // resulting set of landmarks. + message ZRefinementAssignAverage { + // Indexes of the resulting landmarks to use for average. Should be non + // empty. + repeated int32 indexes_for_average = 1; + } + + // Specifies the set of instructions on assigning z value from the given set + // of landmarks to the resulting set of landmarks. + message ZRefinement { + // Exactly one Z refinement option should be specified. + oneof z_refinement_options { + ZRefinementNone none = 1; + ZRefinementCopy copy = 2; + ZRefinementAssignAverage assign_average = 3; + } + } + + // Specifies the set of instructions of assigning values to the resulting set + // of landmarks. + message Refinement { + // Maps indexes of the given set of landmarks to indexes of the resulting + // set of landmarks. Should be non empty and contain the same amount of + // indexes as landmarks in the corresponding input stream. + repeated int32 indexes_mapping = 1; + + // Z refinement instructions. + optional ZRefinement z_refinement = 2; + } + + // Refinement instructions for every landmarks input stream. Applied in the + // same order as defined. Should be the same amount of refinements as landmark + // input streams in the calculator. Union of index mappings should start with + // 0 and cover a contineous range. + repeated Refinement refinement = 1; +} diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index f2cec3ae3..263ef85c6 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -86,11 +86,11 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min, } template -bool IsLandmarkVisibileAndPresent(const LandmarkType& landmark, - bool utilize_visibility, - float visibility_threshold, - bool utilize_presence, - float presence_threshold) { +bool IsLandmarkVisibleAndPresent(const LandmarkType& landmark, + bool utilize_visibility, + float visibility_threshold, + bool utilize_presence, + float presence_threshold) { if (utilize_visibility && landmark.has_visibility() && landmark.visibility() < visibility_threshold) { return false; @@ -153,12 +153,16 @@ void AddConnectionsWithDepth(const LandmarkListType& landmarks, const Color& max_depth_line_color, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { + if (landmark_connections[i] >= landmarks.landmark_size() || + landmark_connections[i + 1] >= landmarks.landmark_size()) { + continue; + } const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( ld0, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold) || - !IsLandmarkVisibileAndPresent( + !IsLandmarkVisibleAndPresent( ld1, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold)) { continue; @@ -196,12 +200,16 @@ void AddConnections(const LandmarkListType& landmarks, const Color& connection_color, float thickness, bool normalized, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { + if (landmark_connections[i] >= landmarks.landmark_size() || + landmark_connections[i + 1] >= landmarks.landmark_size()) { + continue; + } const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( ld0, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold) || - !IsLandmarkVisibileAndPresent( + !IsLandmarkVisibleAndPresent( ld1, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold)) { continue; @@ -317,7 +325,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const Landmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( landmark, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold())) { @@ -363,7 +371,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const NormalizedLandmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( landmark, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold())) { diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index fcba83a49..9cd460114 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -20,6 +20,11 @@ #include "mediapipe/framework/port/status.h" namespace mediapipe { + +constexpr char kContentsTag[] = "CONTENTS"; +constexpr char kFileSuffixTag[] = "FILE_SUFFIX"; +constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY"; + // The calculator takes the path to local directory and desired file suffix to // mach as input side packets, and outputs the contents of those files that // match the pattern. Those matched files will be sent sequentially through the @@ -35,16 +40,16 @@ namespace mediapipe { class LocalFilePatternContentsCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Tag("FILE_DIRECTORY").Set(); - cc->InputSidePackets().Tag("FILE_SUFFIX").Set(); - cc->Outputs().Tag("CONTENTS").Set(); + cc->InputSidePackets().Tag(kFileDirectoryTag).Set(); + cc->InputSidePackets().Tag(kFileSuffixTag).Set(); + cc->Outputs().Tag(kContentsTag).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory( - cc->InputSidePackets().Tag("FILE_DIRECTORY").Get(), - cc->InputSidePackets().Tag("FILE_SUFFIX").Get(), + cc->InputSidePackets().Tag(kFileDirectoryTag).Get(), + cc->InputSidePackets().Tag(kFileSuffixTag).Get(), &filenames_)); return absl::OkStatus(); } @@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { filenames_[current_output_], contents.get())); ++current_output_; cc->Outputs() - .Tag("CONTENTS") + .Tag(kContentsTag) .Add(contents.release(), Timestamp(current_output_)); } else { return tool::StatusStop(); diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 35e415505..0e5b2e885 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { // Initialize the clock. if (cc->InputSidePackets().HasTag(kClockTag)) { clock_ = cc->InputSidePackets() - .Tag("CLOCK") + .Tag(kClockTag) .Get>(); } else { clock_ = std::shared_ptr<::mediapipe::Clock>( diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index e0a759bdb..15bb26826 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -36,7 +36,7 @@ inline float NormalizeRadians(float angle) { } // namespace // Performs geometric transformation to the input Rect or NormalizedRect, -// correpsonding to input stream RECT or NORM_RECT respectively. When the input +// corresponding to input stream RECT or NORM_RECT respectively. When the input // is NORM_RECT, an addition input stream IMAGE_SIZE is required, which is a // std::pair representing the image width and height. // diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc index c86e6ca52..a89d8253f 100644 --- a/mediapipe/calculators/util/thresholding_calculator.cc +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -17,6 +17,12 @@ namespace mediapipe { +constexpr char kThresholdTag[] = "THRESHOLD"; +constexpr char kRejectTag[] = "REJECT"; +constexpr char kAcceptTag[] = "ACCEPT"; +constexpr char kFlagTag[] = "FLAG"; +constexpr char kFloatTag[] = "FLOAT"; + // Applies a threshold on a stream of numeric values and outputs a flag and/or // accept/reject stream. The threshold can be specified by one of the following: // 1) Input stream. @@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase { REGISTER_CALCULATOR(ThresholdingCalculator); absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("FLOAT")); - cc->Inputs().Tag("FLOAT").Set(); + RET_CHECK(cc->Inputs().HasTag(kFloatTag)); + cc->Inputs().Tag(kFloatTag).Set(); - if (cc->Outputs().HasTag("FLAG")) { - cc->Outputs().Tag("FLAG").Set(); + if (cc->Outputs().HasTag(kFlagTag)) { + cc->Outputs().Tag(kFlagTag).Set(); } - if (cc->Outputs().HasTag("ACCEPT")) { - cc->Outputs().Tag("ACCEPT").Set(); + if (cc->Outputs().HasTag(kAcceptTag)) { + cc->Outputs().Tag(kAcceptTag).Set(); } - if (cc->Outputs().HasTag("REJECT")) { - cc->Outputs().Tag("REJECT").Set(); + if (cc->Outputs().HasTag(kRejectTag)) { + cc->Outputs().Tag(kRejectTag).Set(); } - if (cc->Inputs().HasTag("THRESHOLD")) { - cc->Inputs().Tag("THRESHOLD").Set(); + if (cc->Inputs().HasTag(kThresholdTag)) { + cc->Inputs().Tag(kThresholdTag).Set(); } - if (cc->InputSidePackets().HasTag("THRESHOLD")) { - cc->InputSidePackets().Tag("THRESHOLD").Set(); - RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + if (cc->InputSidePackets().HasTag(kThresholdTag)) { + cc->InputSidePackets().Tag(kThresholdTag).Set(); + RET_CHECK(!cc->Inputs().HasTag(kThresholdTag)) << "Using both the threshold input side packet and input stream is not " "supported."; } @@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::ThresholdingCalculatorOptions>(); if (options.has_threshold()) { - RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + RET_CHECK(!cc->Inputs().HasTag(kThresholdTag)) << "Using both the threshold option and input stream is not supported."; - RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD")) + RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag)) << "Using both the threshold option and input side packet is not " "supported."; threshold_ = options.threshold(); } - if (cc->InputSidePackets().HasTag("THRESHOLD")) { - threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); + if (cc->InputSidePackets().HasTag(kThresholdTag)) { + threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get(); } return absl::OkStatus(); } absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().HasTag("THRESHOLD") && - !cc->Inputs().Tag("THRESHOLD").IsEmpty()) { - threshold_ = cc->Inputs().Tag("THRESHOLD").Get(); + if (cc->Inputs().HasTag(kThresholdTag) && + !cc->Inputs().Tag(kThresholdTag).IsEmpty()) { + threshold_ = cc->Inputs().Tag(kThresholdTag).Get(); } bool accept = false; - RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty()); - accept = - static_cast(cc->Inputs().Tag("FLOAT").Get()) > threshold_; + RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty()); + accept = static_cast(cc->Inputs().Tag(kFloatTag).Get()) > + threshold_; - if (cc->Outputs().HasTag("FLAG")) { - cc->Outputs().Tag("FLAG").AddPacket( + if (cc->Outputs().HasTag(kFlagTag)) { + cc->Outputs().Tag(kFlagTag).AddPacket( MakePacket(accept).At(cc->InputTimestamp())); } - if (accept && cc->Outputs().HasTag("ACCEPT")) { - cc->Outputs().Tag("ACCEPT").AddPacket( - MakePacket(true).At(cc->InputTimestamp())); + if (accept && cc->Outputs().HasTag(kAcceptTag)) { + cc->Outputs() + .Tag(kAcceptTag) + .AddPacket(MakePacket(true).At(cc->InputTimestamp())); } - if (!accept && cc->Outputs().HasTag("REJECT")) { - cc->Outputs().Tag("REJECT").AddPacket( - MakePacket(false).At(cc->InputTimestamp())); + if (!accept && cc->Outputs().HasTag(kRejectTag)) { + cc->Outputs() + .Tag(kRejectTag) + .AddPacket(MakePacket(false).At(cc->InputTimestamp())); } return absl::OkStatus(); diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index 37d1b2ab2..42ec5715e 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -39,6 +39,14 @@ namespace mediapipe { +constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION"; +constexpr char kSummaryTag[] = "SUMMARY"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kTopKLabelsTag[] = "TOP_K_LABELS"; +constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; +constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; +constexpr char kScoresTag[] = "SCORES"; + // A calculator that takes a vector of scores and returns the indexes, scores, // labels of the top k elements, classification protos, and summary std::string // (in csv format). @@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase { REGISTER_CALCULATOR(TopKScoresCalculator); absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("SCORES")); - cc->Inputs().Tag("SCORES").Set>(); - if (cc->Outputs().HasTag("TOP_K_INDEXES")) { - cc->Outputs().Tag("TOP_K_INDEXES").Set>(); + RET_CHECK(cc->Inputs().HasTag(kScoresTag)); + cc->Inputs().Tag(kScoresTag).Set>(); + if (cc->Outputs().HasTag(kTopKIndexesTag)) { + cc->Outputs().Tag(kTopKIndexesTag).Set>(); } - if (cc->Outputs().HasTag("TOP_K_SCORES")) { - cc->Outputs().Tag("TOP_K_SCORES").Set>(); + if (cc->Outputs().HasTag(kTopKScoresTag)) { + cc->Outputs().Tag(kTopKScoresTag).Set>(); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { - cc->Outputs().Tag("TOP_K_LABELS").Set>(); + if (cc->Outputs().HasTag(kTopKLabelsTag)) { + cc->Outputs().Tag(kTopKLabelsTag).Set>(); } - if (cc->Outputs().HasTag("CLASSIFICATIONS")) { - cc->Outputs().Tag("CLASSIFICATIONS").Set(); + if (cc->Outputs().HasTag(kClassificationsTag)) { + cc->Outputs().Tag(kClassificationsTag).Set(); } - if (cc->Outputs().HasTag("SUMMARY")) { - cc->Outputs().Tag("SUMMARY").Set(); + if (cc->Outputs().HasTag(kSummaryTag)) { + cc->Outputs().Tag(kSummaryTag).Set(); } return absl::OkStatus(); } @@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { if (options.has_label_map_path()) { MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path())); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { + if (cc->Outputs().HasTag(kTopKLabelsTag)) { RET_CHECK(!label_map_.empty()); } return absl::OkStatus(); @@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { const std::vector& input_vector = - cc->Inputs().Tag("SCORES").Get>(); + cc->Inputs().Tag(kScoresTag).Get>(); std::vector top_k_indexes; std::vector top_k_scores; @@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { top_k_labels.push_back(label_map_[index]); } } - if (cc->Outputs().HasTag("TOP_K_INDEXES")) { + if (cc->Outputs().HasTag(kTopKIndexesTag)) { cc->Outputs() - .Tag("TOP_K_INDEXES") + .Tag(kTopKIndexesTag) .AddPacket(MakePacket>(top_k_indexes) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_SCORES")) { + if (cc->Outputs().HasTag(kTopKScoresTag)) { cc->Outputs() - .Tag("TOP_K_SCORES") + .Tag(kTopKScoresTag) .AddPacket(MakePacket>(top_k_scores) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_LABELS")) { + if (cc->Outputs().HasTag(kTopKLabelsTag)) { cc->Outputs() - .Tag("TOP_K_LABELS") + .Tag(kTopKLabelsTag) .AddPacket(MakePacket>(top_k_labels) .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("SUMMARY")) { + if (cc->Outputs().HasTag(kSummaryTag)) { std::vector results; for (int index = 0; index < top_k_indexes.size(); ++index) { if (label_map_loaded_) { @@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { absl::StrCat(top_k_indexes[index], ":", top_k_scores[index])); } } - cc->Outputs().Tag("SUMMARY").AddPacket( - MakePacket(absl::StrJoin(results, ",")) - .At(cc->InputTimestamp())); + cc->Outputs() + .Tag(kSummaryTag) + .AddPacket(MakePacket(absl::StrJoin(results, ",")) + .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) { + if (cc->Outputs().HasTag(kTopKClassificationTag)) { auto classification_list = absl::make_unique(); for (int index = 0; index < top_k_indexes.size(); ++index) { Classification* classification = diff --git a/mediapipe/calculators/util/top_k_scores_calculator_test.cc b/mediapipe/calculators/util/top_k_scores_calculator_test.cc index 6e6a2ebad..e5a17af28 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator_test.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator_test.cc @@ -23,6 +23,10 @@ namespace mediapipe { +constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; +constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; +constexpr char kScoresTag[] = "SCORES"; + TEST(TopKScoresCalculatorTest, TestNodeConfig) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TopKScoresCalculator" @@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(2, indexes.size()); EXPECT_EQ(3, indexes[0]); EXPECT_EQ(0, indexes[1]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(2, scores.size()); @@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(4, indexes.size()); @@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) { EXPECT_EQ(2, indexes[2]); EXPECT_EQ(1, indexes[3]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(4, scores.size()); @@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; - runner.MutableInputs()->Tag("SCORES").packets.push_back( - MakePacket>(score_vector).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()); const std::vector& indexes_outputs = - runner.Outputs().Tag("TOP_K_INDEXES").packets; + runner.Outputs().Tag(kTopKIndexesTag).packets; ASSERT_EQ(1, indexes_outputs.size()); const auto& indexes = indexes_outputs[0].Get>(); EXPECT_EQ(3, indexes.size()); @@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { EXPECT_EQ(0, indexes[1]); EXPECT_EQ(2, indexes[2]); const std::vector& scores_outputs = - runner.Outputs().Tag("TOP_K_SCORES").packets; + runner.Outputs().Tag(kTopKScoresTag).packets; ASSERT_EQ(1, scores_outputs.size()); const auto& scores = scores_outputs[0].Get>(); EXPECT_EQ(3, scores.size()); diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index b7b91d253..55b5c458b 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -47,6 +47,21 @@ namespace mediapipe { +constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT"; +constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME"; +constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kBoxesTag[] = "BOXES"; +constexpr char kReacqSwitchTag[] = "REACQ_SWITCH"; +constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID"; +constexpr char kAddIndexTag[] = "ADD_INDEX"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kDescriptorsTag[] = "DESCRIPTORS"; +constexpr char kFeaturesTag[] = "FEATURES"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES"; +constexpr char kTrackingTag[] = "TRACKING"; + // A calculator to detect reappeared box positions from single frame. // // Input stream: @@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(BoxDetectorCalculator); absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("TRACKING")) { - cc->Inputs().Tag("TRACKING").Set(); + if (cc->Inputs().HasTag(kTrackingTag)) { + cc->Inputs().Tag(kTrackingTag).Set(); } - if (cc->Inputs().HasTag("TRACKED_BOXES")) { - cc->Inputs().Tag("TRACKED_BOXES").Set(); + if (cc->Inputs().HasTag(kTrackedBoxesTag)) { + cc->Inputs().Tag(kTrackedBoxesTag).Set(); } - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } - if (cc->Inputs().HasTag("FEATURES")) { - RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS")) + if (cc->Inputs().HasTag(kFeaturesTag)) { + RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag)) << "FEATURES and DESCRIPTORS need to be specified together."; - cc->Inputs().Tag("FEATURES").Set>(); + cc->Inputs().Tag(kFeaturesTag).Set>(); } - if (cc->Inputs().HasTag("DESCRIPTORS")) { - RET_CHECK(cc->Inputs().HasTag("FEATURES")) + if (cc->Inputs().HasTag(kDescriptorsTag)) { + RET_CHECK(cc->Inputs().HasTag(kFeaturesTag)) << "FEATURES and DESCRIPTORS need to be specified together."; - cc->Inputs().Tag("DESCRIPTORS").Set>(); + cc->Inputs().Tag(kDescriptorsTag).Set>(); } - if (cc->Inputs().HasTag("IMAGE_SIZE")) { - cc->Inputs().Tag("IMAGE_SIZE").Set>(); + if (cc->Inputs().HasTag(kImageSizeTag)) { + cc->Inputs().Tag(kImageSizeTag).Set>(); } - if (cc->Inputs().HasTag("ADD_INDEX")) { - cc->Inputs().Tag("ADD_INDEX").Set(); + if (cc->Inputs().HasTag(kAddIndexTag)) { + cc->Inputs().Tag(kAddIndexTag).Set(); } - if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { - cc->Inputs().Tag("CANCEL_OBJECT_ID").Set(); + if (cc->Inputs().HasTag(kCancelObjectIdTag)) { + cc->Inputs().Tag(kCancelObjectIdTag).Set(); } - if (cc->Inputs().HasTag("REACQ_SWITCH")) { - cc->Inputs().Tag("REACQ_SWITCH").Set(); + if (cc->Inputs().HasTag(kReacqSwitchTag)) { + cc->Inputs().Tag(kReacqSwitchTag).Set(); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Set(); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")) + if (cc->Outputs().HasTag(kVizTag)) { + RET_CHECK(cc->Inputs().HasTag(kVideoTag)) << "Output stream VIZ requires VIDEO to be present."; - cc->Outputs().Tag("VIZ").Set(); + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { - cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set(); + if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) { + cc->InputSidePackets().Tag(kIndexProtoStringTag).Set(); } - if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { - cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set(); + if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) { + cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set(); } - if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { - cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set(); + if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) { + cc->InputSidePackets().Tag(kFrameAlignmentTag).Set(); } return absl::OkStatus(); @@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); box_detector_ = BoxDetectorInterface::Create(options_.detector_options()); - if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { + if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) { BoxDetectorIndex predefined_index; if (!predefined_index.ParseFromString(cc->InputSidePackets() - .Tag("INDEX_PROTO_STRING") + .Tag(kIndexProtoStringTag) .Get())) { LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING"; } @@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { box_detector_->AddBoxDetectorIndex(predefined_index); } - if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { + if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) { write_index_ = true; } - if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { - frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get(); + if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) { + frame_alignment_ = + cc->InputSidePackets().Tag(kFrameAlignmentTag).Get(); } return absl::OkStatus(); @@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const int64 timestamp_msec = timestamp.Value() / 1000; InputStream* cancel_object_id_stream = - cc->Inputs().HasTag("CANCEL_OBJECT_ID") - ? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) + cc->Inputs().HasTag(kCancelObjectIdTag) + ? &(cc->Inputs().Tag(kCancelObjectIdTag)) : nullptr; if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { const int cancel_object_id = cancel_object_id_stream->Get(); box_detector_->CancelBoxDetection(cancel_object_id); } - InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX") - ? &(cc->Inputs().Tag("ADD_INDEX")) + InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag) + ? &(cc->Inputs().Tag(kAddIndexTag)) : nullptr; if (add_index_stream && !add_index_stream->IsEmpty()) { BoxDetectorIndex predefined_index; @@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { box_detector_->AddBoxDetectorIndex(predefined_index); } - InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH") - ? &(cc->Inputs().Tag("REACQ_SWITCH")) + InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag) + ? &(cc->Inputs().Tag(kReacqSwitchTag)) : nullptr; if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) { detector_switch_ = reacq_switch_stream->Get(); @@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } - InputStream* track_stream = cc->Inputs().HasTag("TRACKING") - ? &(cc->Inputs().Tag("TRACKING")) + InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag) + ? &(cc->Inputs().Tag(kTrackingTag)) : nullptr; InputStream* video_stream = - cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr; - InputStream* feature_stream = cc->Inputs().HasTag("FEATURES") - ? &(cc->Inputs().Tag("FEATURES")) + cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr; + InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag) + ? &(cc->Inputs().Tag(kFeaturesTag)) : nullptr; - InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS") - ? &(cc->Inputs().Tag("DESCRIPTORS")) + InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag) + ? &(cc->Inputs().Tag(kDescriptorsTag)) : nullptr; CHECK(track_stream != nullptr || video_stream != nullptr || @@ -266,9 +282,10 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { << "One and only one of {tracking_data, input image frame, " "feature/descriptor} need to be valid."; - InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES") - ? &(cc->Inputs().Tag("TRACKED_BOXES")) - : nullptr; + InputStream* tracked_boxes_stream = + cc->Inputs().HasTag(kTrackedBoxesTag) + ? &(cc->Inputs().Tag(kTrackedBoxesTag)) + : nullptr; std::unique_ptr detected_boxes(new TimedBoxProtoList()); if (track_stream != nullptr) { @@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { } const auto& image_size = - cc->Inputs().Tag("IMAGE_SIZE").Get>(); + cc->Inputs().Tag(kImageSizeTag).Get>(); float inv_scale = 1.0f / std::max(image_size.first, image_size.second); TimedBoxProtoList tracked_boxes; @@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { detected_boxes.get()); } - if (cc->Outputs().HasTag("VIZ")) { + if (cc->Outputs().HasTag(kVizTag)) { cv::Mat viz_view; std::unique_ptr viz_frame; if (video_stream != nullptr && !video_stream->IsEmpty()) { @@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { for (const auto& box : detected_boxes->box()) { RenderBox(box, &viz_view); } - cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp); } return absl::OkStatus(); @@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { if (write_index_) { BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex(); MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents( - cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get(), + cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get(), index.SerializeAsString())); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index 7d04d9765..d3acc322a 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2; namespace { +constexpr char kCacheDirTag[] = "CACHE_DIR"; +constexpr char kInitialPosTag[] = "INITIAL_POS"; +constexpr char kRaBoxesTag[] = "RA_BOXES"; +constexpr char kBoxesTag[] = "BOXES"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING"; +constexpr char kRaTrackTag[] = "RA_TRACK"; +constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID"; +constexpr char kRestartPosTag[] = "RESTART_POS"; +constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING"; +constexpr char kStartPosTag[] = "START_POS"; +constexpr char kStartTag[] = "START"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kTrackTimeTag[] = "TRACK_TIME"; +constexpr char kTrackingTag[] = "TRACKING"; + // Convert box position according to rotation angle in degrees. void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, float in_right, int rotation, float* out_top, @@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, } // namespace. absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("TRACKING")) { - cc->Inputs().Tag("TRACKING").Set(); + if (cc->Inputs().HasTag(kTrackingTag)) { + cc->Inputs().Tag(kTrackingTag).Set(); } - if (cc->Inputs().HasTag("TRACK_TIME")) { - RET_CHECK(cc->Inputs().HasTag("TRACKING")) + if (cc->Inputs().HasTag(kTrackTimeTag)) { + RET_CHECK(cc->Inputs().HasTag(kTrackingTag)) << "TRACK_TIME needs TRACKING input"; - cc->Inputs().Tag("TRACK_TIME").SetAny(); + cc->Inputs().Tag(kTrackTimeTag).SetAny(); } - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } - if (cc->Inputs().HasTag("START")) { + if (cc->Inputs().HasTag(kStartTag)) { // Actual packet content does not matter. - cc->Inputs().Tag("START").SetAny(); + cc->Inputs().Tag(kStartTag).SetAny(); } - if (cc->Inputs().HasTag("START_POS")) { - cc->Inputs().Tag("START_POS").Set(); + if (cc->Inputs().HasTag(kStartPosTag)) { + cc->Inputs().Tag(kStartPosTag).Set(); } - if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) { - cc->Inputs().Tag("START_POS_PROTO_STRING").Set(); + if (cc->Inputs().HasTag(kStartPosProtoStringTag)) { + cc->Inputs().Tag(kStartPosProtoStringTag).Set(); } - if (cc->Inputs().HasTag("RESTART_POS")) { - cc->Inputs().Tag("RESTART_POS").Set(); + if (cc->Inputs().HasTag(kRestartPosTag)) { + cc->Inputs().Tag(kRestartPosTag).Set(); } - if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { - cc->Inputs().Tag("CANCEL_OBJECT_ID").Set(); + if (cc->Inputs().HasTag(kCancelObjectIdTag)) { + cc->Inputs().Tag(kCancelObjectIdTag).Set(); } - if (cc->Inputs().HasTag("RA_TRACK")) { - cc->Inputs().Tag("RA_TRACK").Set(); + if (cc->Inputs().HasTag(kRaTrackTag)) { + cc->Inputs().Tag(kRaTrackTag).Set(); } - if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) { - cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set(); + if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) { + cc->Inputs().Tag(kRaTrackProtoStringTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")) + if (cc->Outputs().HasTag(kVizTag)) { + RET_CHECK(cc->Inputs().HasTag(kVideoTag)) << "Output stream VIZ requires VIDEO to be present."; - cc->Outputs().Tag("VIZ").Set(); + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->Outputs().HasTag("BOXES")) { - cc->Outputs().Tag("BOXES").Set(); + if (cc->Outputs().HasTag(kBoxesTag)) { + cc->Outputs().Tag(kBoxesTag).Set(); } - if (cc->Outputs().HasTag("RA_BOXES")) { - cc->Outputs().Tag("RA_BOXES").Set(); + if (cc->Outputs().HasTag(kRaBoxesTag)) { + cc->Outputs().Tag(kRaBoxesTag).Set(); } #if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) - RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS")) + RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag)) << "Unsupported on mobile"; #else - if (cc->InputSidePackets().HasTag("INITIAL_POS")) { - cc->InputSidePackets().Tag("INITIAL_POS").Set(); + if (cc->InputSidePackets().HasTag(kInitialPosTag)) { + cc->InputSidePackets().Tag(kInitialPosTag).Set(); } #endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cc->InputSidePackets().Tag("CACHE_DIR").Set(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cc->InputSidePackets().Tag(kCacheDirTag).Set(); } - RET_CHECK(cc->Inputs().HasTag("TRACKING") != - cc->InputSidePackets().HasTag("CACHE_DIR")) + RET_CHECK(cc->Inputs().HasTag(kTrackingTag) != + cc->InputSidePackets().HasTag(kCacheDirTag)) << "Either TRACKING or CACHE_DIR needs to be specified."; if (cc->InputSidePackets().HasTag(kOptionsTag)) { @@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); - RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") || + RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) || !options_.has_initial_position()) << "Can not specify initial position as side packet and via options"; @@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } #if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__) - if (cc->InputSidePackets().HasTag("INITIAL_POS")) { + if (cc->InputSidePackets().HasTag(kInitialPosTag)) { LOG(INFO) << "Parsing: " - << cc->InputSidePackets().Tag("INITIAL_POS").Get(); + << cc->InputSidePackets().Tag(kInitialPosTag).Get(); initial_pos_ = ParseTextProtoOrDie( - cc->InputSidePackets().Tag("INITIAL_POS").Get()); + cc->InputSidePackets().Tag(kInitialPosTag).Get()); } #endif // !defined(__ANDROID__) && !defined(__APPLE__) && // !defined(__EMSCRIPTEN__) @@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } visualize_tracking_data_ = - options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ"); - visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ"); + options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag); + visualize_state_ = + options_.visualize_state() && cc->Outputs().HasTag(kVizTag); visualize_internal_state_ = - options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ"); + options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag); // Force recording of internal state for rendering. if (visualize_internal_state_) { @@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_.mutable_tracker_options()->set_record_path_states(true); } - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get(); RET_CHECK(!cache_dir_.empty()); box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options())); } else { @@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { } if (options_.streaming_track_data_cache_size() > 0) { - RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR")) + RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag)) << "Streaming mode not compatible with cache dir."; } @@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } - InputStream* track_stream = cc->Inputs().HasTag("TRACKING") - ? &(cc->Inputs().Tag("TRACKING")) + InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag) + ? &(cc->Inputs().Tag(kTrackingTag)) : nullptr; - InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME") - ? &(cc->Inputs().Tag("TRACK_TIME")) + InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag) + ? &(cc->Inputs().Tag(kTrackTimeTag)) : nullptr; // Cache tracking data if possible. @@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } } - InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS") - ? &(cc->Inputs().Tag("START_POS")) + InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag) + ? &(cc->Inputs().Tag(kStartPosTag)) : nullptr; MotionBoxMap fast_forward_boxes; @@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* start_pos_proto_string_stream = - cc->Inputs().HasTag("START_POS_PROTO_STRING") - ? &(cc->Inputs().Tag("START_POS_PROTO_STRING")) + cc->Inputs().HasTag(kStartPosProtoStringTag) + ? &(cc->Inputs().Tag(kStartPosProtoStringTag)) : nullptr; if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) { if (start_pos_proto_string_stream && @@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } } - InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS") - ? &(cc->Inputs().Tag("RESTART_POS")) + InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag) + ? &(cc->Inputs().Tag(kRestartPosTag)) : nullptr; if (restart_pos_stream && !restart_pos_stream->IsEmpty()) { @@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* cancel_object_id_stream = - cc->Inputs().HasTag("CANCEL_OBJECT_ID") - ? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) + cc->Inputs().HasTag(kCancelObjectIdTag) + ? &(cc->Inputs().Tag(kCancelObjectIdTag)) : nullptr; if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { const int cancel_object_id = cancel_object_id_stream->Get(); @@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { TrackingData track_data_to_render; - if (cc->Outputs().HasTag("VIZ")) { - InputStream* video_stream = &(cc->Inputs().Tag("VIDEO")); + if (cc->Outputs().HasTag(kVizTag)) { + InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag)); if (!video_stream->IsEmpty()) { input_view = formats::MatView(&video_stream->Get()); @@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { ++frame_num_since_reset_; // Generate results for queued up request. - if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) { + if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) { for (int j = 0; j < queued_track_requests_.size(); ++j) { const Timestamp& past_time = queued_track_requests_[j]; RET_CHECK(past_time.Value() < timestamp.Value()) @@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } // Output for every time. - cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time); + cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time); } queued_track_requests_.clear(); @@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } // Handle random access track requests. - InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK") - ? &(cc->Inputs().Tag("RA_TRACK")) + InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag) + ? &(cc->Inputs().Tag(kRaTrackTag)) : nullptr; if (ra_track_stream && !ra_track_stream->IsEmpty()) { @@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { } InputStream* ra_track_proto_string_stream = - cc->Inputs().HasTag("RA_TRACK_PROTO_STRING") - ? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING")) + cc->Inputs().HasTag(kRaTrackProtoStringTag) + ? &(cc->Inputs().Tag(kRaTrackProtoStringTag)) : nullptr; if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) { if (ra_track_proto_string_stream && @@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { // Always output in batch, only output in streaming if tracking data // is present (might be in fast forward mode instead). - if (cc->Outputs().HasTag("BOXES") && + if (cc->Outputs().HasTag(kBoxesTag) && (box_tracker_ || !track_stream->IsEmpty())) { std::unique_ptr boxes(new TimedBoxProtoList()); *boxes = std::move(box_track_list); - cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp); + cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp); } if (viz_frame) { - cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp); } return absl::OkStatus(); @@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( } cc->Outputs() - .Tag("RA_BOXES") + .Tag(kRaBoxesTag) .Add(result_list.release(), cc->InputTimestamp()); } diff --git a/mediapipe/calculators/video/flow_packager_calculator.cc b/mediapipe/calculators/video/flow_packager_calculator.cc index a57105928..2965cd8e6 100644 --- a/mediapipe/calculators/video/flow_packager_calculator.cc +++ b/mediapipe/calculators/video/flow_packager_calculator.cc @@ -29,6 +29,13 @@ namespace mediapipe { +constexpr char kCacheDirTag[] = "CACHE_DIR"; +constexpr char kCompleteTag[] = "COMPLETE"; +constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK"; +constexpr char kTrackingTag[] = "TRACKING"; +constexpr char kCameraTag[] = "CAMERA"; +constexpr char kFlowTag[] = "FLOW"; + using mediapipe::CameraMotion; using mediapipe::FlowPackager; using mediapipe::RegionFlowFeatureList; @@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase { REGISTER_CALCULATOR(FlowPackagerCalculator); absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { - if (!cc->Inputs().HasTag("FLOW")) { + if (!cc->Inputs().HasTag(kFlowTag)) { return tool::StatusFail("No input flow was specified."); } - cc->Inputs().Tag("FLOW").Set(); + cc->Inputs().Tag(kFlowTag).Set(); - if (cc->Inputs().HasTag("CAMERA")) { - cc->Inputs().Tag("CAMERA").Set(); + if (cc->Inputs().HasTag(kCameraTag)) { + cc->Inputs().Tag(kCameraTag).Set(); } - if (cc->Outputs().HasTag("TRACKING")) { - cc->Outputs().Tag("TRACKING").Set(); + if (cc->Outputs().HasTag(kTrackingTag)) { + cc->Outputs().Tag(kTrackingTag).Set(); } - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { - cc->Outputs().Tag("TRACKING_CHUNK").Set(); + if (cc->Outputs().HasTag(kTrackingChunkTag)) { + cc->Outputs().Tag(kTrackingChunkTag).Set(); } - if (cc->Outputs().HasTag("COMPLETE")) { - cc->Outputs().Tag("COMPLETE").Set(); + if (cc->Outputs().HasTag(kCompleteTag)) { + cc->Outputs().Tag(kCompleteTag).Set(); } - if (cc->InputSidePackets().HasTag("CACHE_DIR")) { - cc->InputSidePackets().Tag("CACHE_DIR").Set(); + if (cc->InputSidePackets().HasTag(kCacheDirTag)) { + cc->InputSidePackets().Tag(kCacheDirTag).Set(); } return absl::OkStatus(); @@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { flow_packager_.reset(new FlowPackager(options_.flow_packager_options())); - use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR"); - build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK"); + use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag); + build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag); if (use_caching_) { - cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); + cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get(); } return absl::OkStatus(); } absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { - InputStream* flow_stream = &(cc->Inputs().Tag("FLOW")); + InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag)); const RegionFlowFeatureList& flow = flow_stream->Get(); const Timestamp timestamp = flow_stream->Value().Timestamp(); const CameraMotion* camera_motion = nullptr; - if (cc->Inputs().HasTag("CAMERA")) { - InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA")); + if (cc->Inputs().HasTag(kCameraTag)) { + InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag)); camera_motion = &camera_stream->Get(); } @@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { if (frame_idx_ > 0) { item->set_prev_timestamp_usec(prev_timestamp_.Value()); } - if (cc->Outputs().HasTag("TRACKING")) { + if (cc->Outputs().HasTag(kTrackingTag)) { // Need to copy as output is requested. *item->mutable_tracking_data() = *tracking_data; } else { @@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { options_.caching_chunk_size_msec() * (chunk_idx_ + 1); if (timestamp.Value() / 1000 >= next_chunk_msec) { - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { + if (cc->Outputs().HasTag(kTrackingChunkTag)) { cc->Outputs() - .Tag("TRACKING_CHUNK") + .Tag(kTrackingChunkTag) .Add(new TrackingDataChunk(tracking_chunk_), Timestamp(tracking_chunk_.item(0).timestamp_usec())); } @@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { } } - if (cc->Outputs().HasTag("TRACKING")) { + if (cc->Outputs().HasTag(kTrackingTag)) { cc->Outputs() - .Tag("TRACKING") + .Tag(kTrackingTag) .Add(tracking_data.release(), flow_stream->Value().Timestamp()); } @@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { if (frame_idx_ > 0) { tracking_chunk_.set_last_chunk(true); - if (cc->Outputs().HasTag("TRACKING_CHUNK")) { + if (cc->Outputs().HasTag(kTrackingChunkTag)) { cc->Outputs() - .Tag("TRACKING_CHUNK") + .Tag(kTrackingChunkTag) .Add(new TrackingDataChunk(tracking_chunk_), Timestamp(tracking_chunk_.item(0).timestamp_usec())); } @@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { } } - if (cc->Outputs().HasTag("COMPLETE")) { - cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream()); + if (cc->Outputs().HasTag(kCompleteTag)) { + cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index 4e8ddac41..6217d3be9 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -38,6 +38,18 @@ namespace mediapipe { +constexpr char kDownsampleTag[] = "DOWNSAMPLE"; +constexpr char kCsvFileTag[] = "CSV_FILE"; +constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT"; +constexpr char kVideoOutTag[] = "VIDEO_OUT"; +constexpr char kDenseFgTag[] = "DENSE_FG"; +constexpr char kVizTag[] = "VIZ"; +constexpr char kSaliencyTag[] = "SALIENCY"; +constexpr char kCameraTag[] = "CAMERA"; +constexpr char kFlowTag[] = "FLOW"; +constexpr char kSelectionTag[] = "SELECTION"; +constexpr char kVideoTag[] = "VIDEO"; + using mediapipe::AffineAdapter; using mediapipe::CameraMotion; using mediapipe::FrameSelectionResult; @@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase { REGISTER_CALCULATOR(MotionAnalysisCalculator); absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag("VIDEO")) { - cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag(kVideoTag)) { + cc->Inputs().Tag(kVideoTag).Set(); } // Optional input stream from frame selection calculator. - if (cc->Inputs().HasTag("SELECTION")) { - cc->Inputs().Tag("SELECTION").Set(); + if (cc->Inputs().HasTag(kSelectionTag)) { + cc->Inputs().Tag(kSelectionTag).Set(); } - RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION")) + RET_CHECK(cc->Inputs().HasTag(kVideoTag) || + cc->Inputs().HasTag(kSelectionTag)) << "Either VIDEO, SELECTION must be specified."; - if (cc->Outputs().HasTag("FLOW")) { - cc->Outputs().Tag("FLOW").Set(); + if (cc->Outputs().HasTag(kFlowTag)) { + cc->Outputs().Tag(kFlowTag).Set(); } - if (cc->Outputs().HasTag("CAMERA")) { - cc->Outputs().Tag("CAMERA").Set(); + if (cc->Outputs().HasTag(kCameraTag)) { + cc->Outputs().Tag(kCameraTag).Set(); } - if (cc->Outputs().HasTag("SALIENCY")) { - cc->Outputs().Tag("SALIENCY").Set(); + if (cc->Outputs().HasTag(kSaliencyTag)) { + cc->Outputs().Tag(kSaliencyTag).Set(); } - if (cc->Outputs().HasTag("VIZ")) { - cc->Outputs().Tag("VIZ").Set(); + if (cc->Outputs().HasTag(kVizTag)) { + cc->Outputs().Tag(kVizTag).Set(); } - if (cc->Outputs().HasTag("DENSE_FG")) { - cc->Outputs().Tag("DENSE_FG").Set(); + if (cc->Outputs().HasTag(kDenseFgTag)) { + cc->Outputs().Tag(kDenseFgTag).Set(); } - if (cc->Outputs().HasTag("VIDEO_OUT")) { - cc->Outputs().Tag("VIDEO_OUT").Set(); + if (cc->Outputs().HasTag(kVideoOutTag)) { + cc->Outputs().Tag(kVideoOutTag).Set(); } - if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) { + if (cc->Outputs().HasTag(kGrayVideoOutTag)) { // We only output grayscale video if we're actually performing full region- // flow analysis on the video. - RET_CHECK(cc->Inputs().HasTag("VIDEO") && - !cc->Inputs().HasTag("SELECTION")); - cc->Outputs().Tag("GRAY_VIDEO_OUT").Set(); + RET_CHECK(cc->Inputs().HasTag(kVideoTag) && + !cc->Inputs().HasTag(kSelectionTag)); + cc->Outputs().Tag(kGrayVideoOutTag).Set(); } - if (cc->InputSidePackets().HasTag("CSV_FILE")) { - cc->InputSidePackets().Tag("CSV_FILE").Set(); + if (cc->InputSidePackets().HasTag(kCsvFileTag)) { + cc->InputSidePackets().Tag(kCsvFileTag).Set(); } - if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { - cc->InputSidePackets().Tag("DOWNSAMPLE").Set(); + if (cc->InputSidePackets().HasTag(kDownsampleTag)) { + cc->InputSidePackets().Tag(kDownsampleTag).Set(); } if (cc->InputSidePackets().HasTag(kOptionsTag)) { @@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); - video_input_ = cc->Inputs().HasTag("VIDEO"); - selection_input_ = cc->Inputs().HasTag("SELECTION"); - region_flow_feature_output_ = cc->Outputs().HasTag("FLOW"); - camera_motion_output_ = cc->Outputs().HasTag("CAMERA"); - saliency_output_ = cc->Outputs().HasTag("SALIENCY"); - visualize_output_ = cc->Outputs().HasTag("VIZ"); - dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG"); - video_output_ = cc->Outputs().HasTag("VIDEO_OUT"); - grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT"); - csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE"); + video_input_ = cc->Inputs().HasTag(kVideoTag); + selection_input_ = cc->Inputs().HasTag(kSelectionTag); + region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag); + camera_motion_output_ = cc->Outputs().HasTag(kCameraTag); + saliency_output_ = cc->Outputs().HasTag(kSaliencyTag); + visualize_output_ = cc->Outputs().HasTag(kVizTag); + dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag); + video_output_ = cc->Outputs().HasTag(kVideoOutTag); + grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag); + csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag); hybrid_meta_analysis_ = options_.meta_analysis() == MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID; @@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { if (csv_file_input_) { // Read from file and parse. const std::string filename = - cc->InputSidePackets().Tag("CSV_FILE").Get(); + cc->InputSidePackets().Tag(kCsvFileTag).Get(); std::string file_contents; std::ifstream input_file(filename, std::ios::in); @@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { // Get video header from video or selection input if present. const VideoHeader* video_header = nullptr; - if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) { - video_header = &(cc->Inputs().Tag("VIDEO").Header().Get()); + if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) { + video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get()); } else if (selection_input_ && - !cc->Inputs().Tag("SELECTION").Header().IsEmpty()) { - video_header = &(cc->Inputs().Tag("SELECTION").Header().Get()); + !cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) { + video_header = + &(cc->Inputs().Tag(kSelectionTag).Header().Get()); } else { LOG(WARNING) << "No input video header found. Downstream calculators " "expecting video headers are likely to fail."; @@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { with_saliency_ = options_.analysis_options().compute_motion_saliency(); // Force computation of saliency if requested as output. - if (cc->Outputs().HasTag("SALIENCY")) { + if (cc->Outputs().HasTag(kSaliencyTag)) { with_saliency_ = true; if (!options_.analysis_options().compute_motion_saliency()) { LOG(WARNING) << "Enable saliency computation. Set " @@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); } - if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { + if (cc->InputSidePackets().HasTag(kDownsampleTag)) { options_.mutable_analysis_options() ->mutable_flow_options() ->set_downsample_factor( - cc->InputSidePackets().Tag("DOWNSAMPLE").Get()); + cc->InputSidePackets().Tag(kDownsampleTag).Get()); } // If no video header is provided, just return and initialize on the first @@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { ////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE /////////////// if (visualize_output_) { - cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header))); + cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header))); } if (video_output_) { cc->Outputs() - .Tag("VIDEO_OUT") + .Tag(kVideoOutTag) .SetHeader(Adopt(new VideoHeader(*video_header))); } - if (cc->Outputs().HasTag("DENSE_FG")) { + if (cc->Outputs().HasTag(kDenseFgTag)) { std::unique_ptr foreground_header( new VideoHeader(*video_header)); foreground_header->format = ImageFormat::GRAY8; - cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release())); - } - - if (cc->Outputs().HasTag("CAMERA")) { - cc->Outputs().Tag("CAMERA").SetHeader( - Adopt(new VideoHeader(*video_header))); - } - - if (cc->Outputs().HasTag("SALIENCY")) { cc->Outputs() - .Tag("SALIENCY") + .Tag(kDenseFgTag) + .SetHeader(Adopt(foreground_header.release())); + } + + if (cc->Outputs().HasTag(kCameraTag)) { + cc->Outputs() + .Tag(kCameraTag) + .SetHeader(Adopt(new VideoHeader(*video_header))); + } + + if (cc->Outputs().HasTag(kSaliencyTag)) { + cc->Outputs() + .Tag(kSaliencyTag) .SetHeader(Adopt(new VideoHeader(*video_header))); } @@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { } InputStream* video_stream = - video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr; + video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr; InputStream* selection_stream = - selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr; + selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr; // Checked on Open. CHECK(video_stream || selection_stream); @@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { CameraMotion output_motion = meta_motions_.front(); meta_motions_.pop_front(); output_motion.set_timestamp_usec(timestamp.Value()); - cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion), - timestamp); + cc->Outputs() + .Tag(kCameraTag) + .Add(new CameraMotion(output_motion), timestamp); } if (region_flow_feature_output_) { @@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { meta_features_.pop_front(); output_features.set_timestamp_usec(timestamp.Value()); - cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features), - timestamp); + cc->Outputs().Tag(kFlowTag).Add( + new RegionFlowFeatureList(output_features), timestamp); } ++frame_idx_; @@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) { // Output concatenated results, nothing to compute here. if (camera_motion_output_) { - cc->Outputs().Tag("CAMERA").Add( - frame_selection_result->release_camera_motion(), timestamp); + cc->Outputs() + .Tag(kCameraTag) + .Add(frame_selection_result->release_camera_motion(), timestamp); } if (region_flow_feature_output_) { - cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(), - timestamp); + cc->Outputs().Tag(kFlowTag).Add( + frame_selection_result->release_features(), timestamp); } if (video_output_) { - cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value()); + cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value()); } return absl::OkStatus(); @@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { grayscale_mat.copyTo(image_frame_mat); cc->Outputs() - .Tag("GRAY_VIDEO_OUT") + .Tag(kGrayVideoOutTag) .Add(grayscale_image.release(), timestamp); } @@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( *feature_list, *camera_motion, with_saliency_ ? saliency[k].get() : nullptr, &visualization); - cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp); + cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp); } // Output dense foreground mask. @@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( cv::Mat foreground = formats::MatView(foreground_frame.get()); motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion, &foreground); - cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp); + cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp); } // Output flow features if requested. if (region_flow_feature_output_) { - cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp); + cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp); } // Output camera motion. if (camera_motion_output_) { - cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp); + cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp); } if (video_output_) { - cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]); + cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]); } // Output saliency. if (saliency_output_) { - cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp); + cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp); } } diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index bf7ed3e8a..94ddbb836 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -27,6 +27,12 @@ namespace mediapipe { namespace { + +constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; + // cv::VideoCapture set data type to unsigned char by default. Therefore, the // image format is only related to the number of channles the cv::Mat has. ImageFormat::Format GetImageFormat(int num_channels) { @@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) { class OpenCvVideoDecoderCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); - cc->Outputs().Tag("VIDEO").Set(); - if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { - cc->Outputs().Tag("VIDEO_PRESTREAM").Set(); + cc->InputSidePackets().Tag(kInputFilePathTag).Set(); + cc->Outputs().Tag(kVideoTag).Set(); + if (cc->Outputs().HasTag(kVideoPrestreamTag)) { + cc->Outputs().Tag(kVideoPrestreamTag).Set(); } - if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { - cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set(); + if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) { + cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set(); } return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { const std::string& input_file_path = - cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kInputFilePathTag).Get(); cap_ = absl::make_unique(input_file_path); if (!cap_->isOpened()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { header->frame_rate = fps; header->duration = frame_count_ / fps; - if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { + if (cc->Outputs().HasTag(kVideoPrestreamTag)) { cc->Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .Add(header.release(), Timestamp::PreStream()); - cc->Outputs().Tag("VIDEO_PRESTREAM").Close(); + cc->Outputs().Tag(kVideoPrestreamTag).Close(); } // Rewind to the very first frame. cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0); - if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { + if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) { #ifdef HAVE_FFMPEG std::string saved_audio_path = std::tmpnam(nullptr); std::string ffmpeg_command = @@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str()); if (status_code == 0) { cc->OutputSidePackets() - .Tag("SAVED_AUDIO_PATH") + .Tag(kSavedAudioPathTag) .Set(MakePacket(saved_audio_path)); } else { LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path << " by executing the following command: " << ffmpeg_command; cc->OutputSidePackets() - .Tag("SAVED_AUDIO_PATH") + .Tag(kSavedAudioPathTag) .Set(MakePacket(std::string())); } #else @@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { // If the timestamp of the current frame is not greater than the one of the // previous frame, the new frame will be discarded. if (prev_timestamp_ < timestamp) { - cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp); + cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp); prev_timestamp_ = timestamp; decoded_frames_++; } diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc index 03d27b6fe..035e5a8c9 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc @@ -29,6 +29,10 @@ namespace mediapipe { namespace { +constexpr char kVideoTag[] = "VIDEO"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; + TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { CalculatorGraphConfig::Node node_config = ParseTextProtoOrDie(R"pb( @@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_MP4_AVC720P_AAC.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(1280, header.width); EXPECT_EQ(640, header.height); @@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { // The number of the output packets should be 180. // Some OpenCV version returns the first two frames with the same timestamp on // macos and we might miss one frame here. - int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); + int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size(); EXPECT_GE(num_of_packets, 179); for (int i = 0; i < num_of_packets; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(1280, output_mat.size().width); @@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_FLV_H264_AAC.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(640, header.width); EXPECT_EQ(320, header.height); @@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { // can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4). // EXPECT_FLOAT_EQ(6.0f, header.duration); // EXPECT_FLOAT_EQ(30.0f, header.frame_rate); - EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); + EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size()); for (int i = 0; i < 180; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(640, output_mat.size().width); @@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { output_stream: "VIDEO:video" output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); CalculatorRunner runner(node_config); - runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket( file::JoinPath("./", "/mediapipe/calculators/video/" "testdata/format_MKV_VP8_VORBIS.video")); MP_EXPECT_OK(runner.Run()); - EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); MP_EXPECT_OK(runner.Outputs() - .Tag("VIDEO_PRESTREAM") + .Tag(kVideoPrestreamTag) .packets[0] .ValidateAsType()); const mediapipe::VideoHeader& header = - runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get(); EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(640, header.width); EXPECT_EQ(320, header.height); @@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { // The number of the output packets should be 180. // Some OpenCV version returns the first two frames with the same timestamp on // macos and we might miss one frame here. - int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); + int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size(); EXPECT_GE(num_of_packets, 179); for (int i = 0; i < num_of_packets; ++i) { - Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i]; cv::Mat output_mat = formats::MatView(&(image_frame_packet.Get())); EXPECT_EQ(640, output_mat.size().width); diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index 9a74fb710..4af8c5955 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -36,6 +36,11 @@ namespace mediapipe { +constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH"; +constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kVideoTag[] = "VIDEO"; + // Encodes the input video stream and produces a media file. // The media file can be output to the output_file_path specified as a side // packet. Currently, the calculator only supports one video stream (in @@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { }; absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("VIDEO")); - cc->Inputs().Tag("VIDEO").Set(); - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + RET_CHECK(cc->Inputs().HasTag(kVideoTag)); + cc->Inputs().Tag(kVideoTag).Set(); + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } - RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH")); - cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set(); - if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { - cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set(); + RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag)); + cc->InputSidePackets().Tag(kOutputFilePathTag).Set(); + if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) { + cc->InputSidePackets().Tag(kAudioFilePathTag).Set(); } return absl::OkStatus(); } @@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { << "Video format must be specified in " "OpenCvVideoEncoderCalculatorOptions"; output_file_path_ = - cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kOutputFilePathTag).Get(); std::vector splited_file_path = absl::StrSplit(output_file_path_, '.'); RET_CHECK(splited_file_path.size() >= 2 && @@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { // If the video header will be available, the video metadata will be fetched // from the video header directly. The calculator will receive the video // header packet at timestamp prestream. - if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { return absl::OkStatus(); } return SetUpVideoWriter(options.fps(), options.width(), options.height()); @@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = - cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + cc->Inputs().Tag(kVideoPrestreamTag).Get(); return SetUpVideoWriter(video_header.frame_rate, video_header.width, video_header.height); } const ImageFrame& image_frame = - cc->Inputs().Tag("VIDEO").Value().Get(); + cc->Inputs().Tag(kVideoTag).Value().Get(); ImageFormat::Format format = image_frame.Format(); cv::Mat frame; if (format == ImageFormat::GRAY8) { @@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (frame.empty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " - << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << cc->Inputs().Tag(kVideoTag).Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; } } else { @@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (tmp_frame.empty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Receive empty frame at timestamp " - << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << cc->Inputs().Tag(kVideoTag).Value().Timestamp() << " in OpenCvVideoEncoderCalculator::Process()"; } if (format == ImageFormat::SRGB) { @@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { if (writer_ && writer_->isOpened()) { writer_->release(); } - if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { + if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) { #ifdef HAVE_FFMPEG const std::string& audio_file_path = - cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get(); + cc->InputSidePackets().Tag(kAudioFilePathTag).Get(); if (audio_file_path.empty()) { LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the " "audio tracks to the generated video because the audio " diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc index cf00da1f7..56f3253e2 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -23,6 +23,11 @@ namespace mediapipe { namespace { +constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW"; +constexpr char kForwardFlowTag[] = "FORWARD_FLOW"; +constexpr char kSecondFrameTag[] = "SECOND_FRAME"; +constexpr char kFirstFrameTag[] = "FIRST_FRAME"; + // Checks that img1 and img2 have the same dimensions. bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) { return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height()); @@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { }; absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) { - if (!cc->Inputs().HasTag("FIRST_FRAME") || - !cc->Inputs().HasTag("SECOND_FRAME")) { + if (!cc->Inputs().HasTag(kFirstFrameTag) || + !cc->Inputs().HasTag(kSecondFrameTag)) { return absl::InvalidArgumentError( "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " "must be specified."); } - cc->Inputs().Tag("FIRST_FRAME").Set(); - cc->Inputs().Tag("SECOND_FRAME").Set(); - if (cc->Outputs().HasTag("FORWARD_FLOW")) { - cc->Outputs().Tag("FORWARD_FLOW").Set(); + cc->Inputs().Tag(kFirstFrameTag).Set(); + cc->Inputs().Tag(kSecondFrameTag).Set(); + if (cc->Outputs().HasTag(kForwardFlowTag)) { + cc->Outputs().Tag(kForwardFlowTag).Set(); } - if (cc->Outputs().HasTag("BACKWARD_FLOW")) { - cc->Outputs().Tag("BACKWARD_FLOW").Set(); + if (cc->Outputs().HasTag(kBackwardFlowTag)) { + cc->Outputs().Tag(kBackwardFlowTag).Set(); } return absl::OkStatus(); } @@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { absl::MutexLock lock(&mutex_); tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); } - if (cc->Outputs().HasTag("FORWARD_FLOW")) { + if (cc->Outputs().HasTag(kForwardFlowTag)) { forward_requested_ = true; } - if (cc->Outputs().HasTag("BACKWARD_FLOW")) { + if (cc->Outputs().HasTag(kBackwardFlowTag)) { backward_requested_ = true; } @@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { const ImageFrame& first_frame = - cc->Inputs().Tag("FIRST_FRAME").Value().Get(); + cc->Inputs().Tag(kFirstFrameTag).Value().Get(); const ImageFrame& second_frame = - cc->Inputs().Tag("SECOND_FRAME").Value().Get(); + cc->Inputs().Tag(kSecondFrameTag).Value().Get(); if (forward_requested_) { auto forward_optical_flow_field = absl::make_unique(); MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame, forward_optical_flow_field.get())); cc->Outputs() - .Tag("FORWARD_FLOW") + .Tag(kForwardFlowTag) .Add(forward_optical_flow_field.release(), cc->InputTimestamp()); } if (backward_requested_) { @@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame, backward_optical_flow_field.get())); cc->Outputs() - .Tag("BACKWARD_FLOW") + .Tag(kBackwardFlowTag) .Add(backward_optical_flow_field.release(), cc->InputTimestamp()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/video/video_pre_stream_calculator.cc b/mediapipe/calculators/video/video_pre_stream_calculator.cc index ab9cd22a4..317d4baad 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator.cc @@ -19,6 +19,9 @@ namespace mediapipe { +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; +constexpr char kFrameTag[] = "FRAME"; + // Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp // PreStream. Note that this calculator only fills in format, width, and height, // i.e. frame_rate and duration will not be filled, unless: @@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().UsesTags()) { cc->Inputs().Index(0).Set(); } else { - cc->Inputs().Tag("FRAME").Set(); - cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + cc->Inputs().Tag(kFrameTag).Set(); + cc->Inputs().Tag(kVideoPrestreamTag).Set(); } cc->Outputs().Index(0).Set(); return absl::OkStatus(); @@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { frame_rate_in_prestream_ = cc->Inputs().UsesTags() && - cc->Inputs().HasTag("FRAME") && - cc->Inputs().HasTag("VIDEO_PRESTREAM"); + cc->Inputs().HasTag(kFrameTag) && + cc->Inputs().HasTag(kVideoPrestreamTag); header_ = absl::make_unique(); return absl::OkStatus(); } @@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment(); if (cc->InputTimestamp() == Timestamp::PreStream()) { - RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty()); - RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()); - *header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty()); + RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty()); + *header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get(); RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero"; } else { - RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()) + RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty()) << "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream()."; - RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty()); - const auto& frame = cc->Inputs().Tag("FRAME").Get(); + RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty()); + const auto& frame = cc->Inputs().Tag(kFrameTag).Get(); header_->format = frame.Format(); header_->width = frame.Width(); header_->height = frame.Height(); diff --git a/mediapipe/examples/android/solutions/BUILD b/mediapipe/examples/android/solutions/BUILD new file mode 100644 index 000000000..1ba23afe6 --- /dev/null +++ b/mediapipe/examples/android/solutions/BUILD @@ -0,0 +1,21 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), + visibility = ["//mediapipe/examples/android/solutions:__subpackages__"], +) diff --git a/mediapipe/examples/android/solutions/build.gradle b/mediapipe/examples/android/solutions/build.gradle new file mode 100644 index 000000000..691e41013 --- /dev/null +++ b/mediapipe/examples/android/solutions/build.gradle @@ -0,0 +1,24 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +buildscript { + repositories { + google() + mavenCentral() + } + dependencies { + classpath "com.android.tools.build:gradle:4.2.0" + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + mavenCentral() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/mediapipe/examples/android/solutions/create_win_symlinks.bat b/mediapipe/examples/android/solutions/create_win_symlinks.bat new file mode 100644 index 000000000..57bafeb2b --- /dev/null +++ b/mediapipe/examples/android/solutions/create_win_symlinks.bat @@ -0,0 +1,23 @@ +@rem Remove the current res dir symlinks that are for Linux and macOS and recreate res dir symlinks for Windows. +@rem This script needs administrator permission. Must run this script as administrator. + +@rem for hands example app. +cd /d %~dp0 +cd hands\src\main +rm res +mklink /d res ..\..\..\res + +@rem for facemesh example app. +cd /d %~dp0 +cd facemesh\src\main +rm res +mklink /d res ..\..\..\res + +@rem for face detection example app. +cd /d %~dp0 +cd facedetection\src\main +rm res +mklink /d res ..\..\..\res + +dir +pause diff --git a/mediapipe/examples/android/solutions/facedetection/build.gradle b/mediapipe/examples/android/solutions/facedetection/build.gradle new file mode 100644 index 000000000..c3ebd94ac --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/build.gradle @@ -0,0 +1,41 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.facedetection" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe Face Detection Solution. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:facedetection:latest.release' +} diff --git a/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro b/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml new file mode 100644 index 000000000..f150bd41a --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/BUILD b/mediapipe/examples/android/solutions/facedetection/src/main/BUILD new file mode 100644 index 000000000..5044b55ed --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/BUILD @@ -0,0 +1,46 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "facedetection", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.facedetection", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.facedetection", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/facedetection", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java new file mode 100644 index 000000000..df1847178 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java @@ -0,0 +1,146 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import android.opengl.GLES20; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +/** A custom implementation of {@link ResultGlRenderer} to render {@link FaceDetectionResult}. */ +public class FaceDetectionResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "FaceDetectionResultGlRenderer"; + + private static final float[] KEYPOINT_COLOR = new float[] {1f, 0f, 0f, 1f}; + private static final float KEYPOINT_SIZE = 16f; + private static final float[] BBOX_COLOR = new float[] {0f, 1f, 0f, 1f}; + private static final int BBOX_THICKNESS = 8; + private static final String VERTEX_SHADER = + "uniform mat4 uProjectionMatrix;\n" + + "uniform float uPointSize;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uProjectionMatrix * vPosition;\n" + + " gl_PointSize = uPointSize;" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "uniform vec4 uColor;\n" + + "void main() {\n" + + " gl_FragColor = uColor;\n" + + "}"; + private int program; + private int positionHandle; + private int pointSizeHandle; + private int projectionMatrixHandle; + private int colorHandle; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + pointSizeHandle = GLES20.glGetUniformLocation(program, "uPointSize"); + projectionMatrixHandle = GLES20.glGetUniformLocation(program, "uProjectionMatrix"); + colorHandle = GLES20.glGetUniformLocation(program, "uColor"); + } + + @Override + public void renderResult(FaceDetectionResult result, float[] projectionMatrix) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + GLES20.glUniformMatrix4fv(projectionMatrixHandle, 1, false, projectionMatrix, 0); + GLES20.glUniform1f(pointSizeHandle, KEYPOINT_SIZE); + int numDetectedFaces = result.multiFaceDetections().size(); + for (int i = 0; i < numDetectedFaces; ++i) { + drawDetection(result.multiFaceDetections().get(i)); + } + } + + /** + * Deletes the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + private void drawDetection(Detection detection) { + if (!detection.hasLocationData()) { + return; + } + // Draw keypoints. + float[] points = new float[FaceKeypoint.NUM_KEY_POINTS * 2]; + for (int i = 0; i < FaceKeypoint.NUM_KEY_POINTS; ++i) { + points[2 * i] = detection.getLocationData().getRelativeKeypoints(i).getX(); + points[2 * i + 1] = detection.getLocationData().getRelativeKeypoints(i).getY(); + } + GLES20.glUniform4fv(colorHandle, 1, KEYPOINT_COLOR, 0); + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(points.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(points); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_POINTS, 0, FaceKeypoint.NUM_KEY_POINTS); + if (!detection.getLocationData().hasRelativeBoundingBox()) { + return; + } + // Draw bounding box. + float left = detection.getLocationData().getRelativeBoundingBox().getXmin(); + float top = detection.getLocationData().getRelativeBoundingBox().getYmin(); + float right = left + detection.getLocationData().getRelativeBoundingBox().getWidth(); + float bottom = top + detection.getLocationData().getRelativeBoundingBox().getHeight(); + drawLine(top, left, top, right); + drawLine(bottom, left, bottom, right); + drawLine(top, left, bottom, left); + drawLine(top, right, bottom, right); + } + + private void drawLine(float y1, float x1, float y2, float x2) { + GLES20.glUniform4fv(colorHandle, 1, BBOX_COLOR, 0); + GLES20.glLineWidth(BBOX_THICKNESS); + float[] vertex = {x1, y1, x2, y2}; + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java new file mode 100644 index 000000000..3da3a467a --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java @@ -0,0 +1,108 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import static java.lang.Math.min; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; + +/** An ImageView implementation for displaying {@link FaceDetectionResult}. */ +public class FaceDetectionResultImageView extends AppCompatImageView { + private static final String TAG = "FaceDetectionResultImageView"; + + private static final int KEYPOINT_COLOR = Color.RED; + private static final int KEYPOINT_RADIUS = 8; // Pixels + private static final int BBOX_COLOR = Color.GREEN; + private static final int BBOX_THICKNESS = 5; // Pixels + private Bitmap latest; + + public FaceDetectionResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link FaceDetectionResult} to render. + * + * @param result a {@link FaceDetectionResult} object that contains the solution outputs and the + * input {@link Bitmap}. + */ + public void setFaceDetectionResult(FaceDetectionResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + + canvas.drawBitmap(bmInput, new Matrix(), null); + int numDetectedFaces = result.multiFaceDetections().size(); + for (int i = 0; i < numDetectedFaces; ++i) { + drawDetectionOnCanvas(result.multiFaceDetections().get(i), canvas, width, height); + } + } + + /** Updates the image view with the latest {@link FaceDetectionResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawDetectionOnCanvas(Detection detection, Canvas canvas, int width, int height) { + if (!detection.hasLocationData()) { + return; + } + // Draw keypoints. + Paint keypointPaint = new Paint(); + keypointPaint.setColor(KEYPOINT_COLOR); + for (int i = 0; i < FaceKeypoint.NUM_KEY_POINTS; ++i) { + int xPixel = + min( + (int) (detection.getLocationData().getRelativeKeypoints(i).getX() * width), + width - 1); + int yPixel = + min( + (int) (detection.getLocationData().getRelativeKeypoints(i).getY() * height), + height - 1); + canvas.drawCircle(xPixel, yPixel, KEYPOINT_RADIUS, keypointPaint); + } + if (!detection.getLocationData().hasRelativeBoundingBox()) { + return; + } + // Draw bounding box. + Paint bboxPaint = new Paint(); + bboxPaint.setColor(BBOX_COLOR); + bboxPaint.setStyle(Paint.Style.STROKE); + bboxPaint.setStrokeWidth(BBOX_THICKNESS); + float left = detection.getLocationData().getRelativeBoundingBox().getXmin() * width; + float top = detection.getLocationData().getRelativeBoundingBox().getYmin() * height; + float right = left + detection.getLocationData().getRelativeBoundingBox().getWidth() * width; + float bottom = top + detection.getLocationData().getRelativeBoundingBox().getHeight() * height; + canvas.drawRect(left, top, right, bottom, bboxPaint); + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java new file mode 100644 index 000000000..b274ce289 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java @@ -0,0 +1,364 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.facedetection.FaceDetection; +import com.google.mediapipe.solutions.facedetection.FaceDetectionOptions; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Face Detection app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private FaceDetection faceDetection; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private FaceDetectionResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation == ExifInterface.ORIENTATION_NORMAL) { + return inputBitmap; + } + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + return Bitmap.createBitmap( + inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + bitmap = rotateBitmap(bitmap, imageData); + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + faceDetection.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + imageView = new FaceDetectionResultImageView(this); + } + + /** Sets up core workflow for static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe Face Detection solution instance in the static image mode. + faceDetection = + new FaceDetection( + this, + FaceDetectionOptions.builder() + .setStaticImageMode(true) + .setModelSelection(0) + .setMinDetectionConfidence(0.5f) + .build()); + + // Connects MediaPipe Face Detection solution to the user-defined FaceDetectionResultImageView. + faceDetection.setResultListener( + faceDetectionResult -> { + logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ true); + imageView.setFaceDetectionResult(faceDetectionResult); + runOnUiThread(() -> imageView.update()); + }); + faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + faceDetection.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** Sets up core workflow for streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe Face Detection solution instance in the streaming mode. + faceDetection = + new FaceDetection( + this, + FaceDetectionOptions.builder().setStaticImageMode(false).setModelSelection(0).build()); + faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + + if (inputSource == InputSource.CAMERA) { + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined FaceDetectionResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + faceDetection.setResultListener( + faceDetectionResult -> { + logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + faceDetection.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (faceDetection != null) { + faceDetection.close(); + } + } + + private void logNoseTipKeypoint( + FaceDetectionResult result, int faceIndex, boolean showPixelValues) { + if (result.multiFaceDetections().isEmpty()) { + return; + } + RelativeKeypoint noseTip = + result + .multiFaceDetections() + .get(faceIndex) + .getLocationData() + .getRelativeKeypoints(FaceKeypoint.NOSE_TIP); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip coordinates (pixel values): x=%f, y=%f", + noseTip.getX() * width, noseTip.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]):" + + " x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + } + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/res b/mediapipe/examples/android/solutions/facedetection/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/facemesh/build.gradle b/mediapipe/examples/android/solutions/facemesh/build.gradle new file mode 100644 index 000000000..b8cf3b288 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/build.gradle @@ -0,0 +1,41 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.facemesh" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe Face Mesh Solution. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:facemesh:latest.release' +} diff --git a/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro b/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml new file mode 100644 index 000000000..f3d4be207 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/AndroidManifest.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/BUILD b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD new file mode 100644 index 000000000..515f03b6b --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD @@ -0,0 +1,45 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "facemesh", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.facemesh", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.facemesh", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/facemesh", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java new file mode 100644 index 000000000..1b7eca9d6 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java @@ -0,0 +1,176 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.opengl.GLES20; +import com.google.common.collect.ImmutableSet; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.facemesh.FaceMesh; +import com.google.mediapipe.solutions.facemesh.FaceMeshConnections; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.List; + +/** A custom implementation of {@link ResultGlRenderer} to render {@link FaceMeshResult}. */ +public class FaceMeshResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "FaceMeshResultGlRenderer"; + + private static final float[] TESSELATION_COLOR = new float[] {0.75f, 0.75f, 0.75f, 0.5f}; + private static final int TESSELATION_THICKNESS = 5; + private static final float[] RIGHT_EYE_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final int RIGHT_EYE_THICKNESS = 8; + private static final float[] RIGHT_EYEBROW_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final int RIGHT_EYEBROW_THICKNESS = 8; + private static final float[] LEFT_EYE_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final int LEFT_EYE_THICKNESS = 8; + private static final float[] LEFT_EYEBROW_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final int LEFT_EYEBROW_THICKNESS = 8; + private static final float[] FACE_OVAL_COLOR = new float[] {0.9f, 0.9f, 0.9f, 1f}; + private static final int FACE_OVAL_THICKNESS = 8; + private static final float[] LIPS_COLOR = new float[] {0.9f, 0.9f, 0.9f, 1f}; + private static final int LIPS_THICKNESS = 8; + private static final String VERTEX_SHADER = + "uniform mat4 uProjectionMatrix;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uProjectionMatrix * vPosition;\n" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "uniform vec4 uColor;\n" + + "void main() {\n" + + " gl_FragColor = uColor;\n" + + "}"; + private int program; + private int positionHandle; + private int projectionMatrixHandle; + private int colorHandle; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + projectionMatrixHandle = GLES20.glGetUniformLocation(program, "uProjectionMatrix"); + colorHandle = GLES20.glGetUniformLocation(program, "uColor"); + } + + @Override + public void renderResult(FaceMeshResult result, float[] projectionMatrix) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + GLES20.glUniformMatrix4fv(projectionMatrixHandle, 1, false, projectionMatrix, 0); + + int numFaces = result.multiFaceLandmarks().size(); + for (int i = 0; i < numFaces; ++i) { + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_TESSELATION, + TESSELATION_COLOR, + TESSELATION_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYE, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYEBROW, + RIGHT_EYEBROW_COLOR, + RIGHT_EYEBROW_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYE, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYEBROW, + LEFT_EYEBROW_COLOR, + LEFT_EYEBROW_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_FACE_OVAL, + FACE_OVAL_COLOR, + FACE_OVAL_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LIPS, + LIPS_COLOR, + LIPS_THICKNESS); + if (result.multiFaceLandmarks().get(i).getLandmarkCount() + == FaceMesh.FACEMESH_NUM_LANDMARKS_WITH_IRISES) { + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_IRIS, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarks( + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_IRIS, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + } + } + } + + /** + * Deletes the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + private void drawLandmarks( + List faceLandmarkList, + ImmutableSet connections, + float[] colorArray, + int thickness) { + GLES20.glUniform4fv(colorHandle, 1, colorArray, 0); + GLES20.glLineWidth(thickness); + for (FaceMeshConnections.Connection c : connections) { + NormalizedLandmark start = faceLandmarkList.get(c.start()); + NormalizedLandmark end = faceLandmarkList.get(c.end()); + float[] vertex = {start.getX(), start.getY(), end.getX(), end.getY()}; + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java new file mode 100644 index 000000000..3b2a1b7be --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java @@ -0,0 +1,175 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import android.util.Size; +import com.google.common.collect.ImmutableSet; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutions.facemesh.FaceMesh; +import com.google.mediapipe.solutions.facemesh.FaceMeshConnections; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.util.List; + +/** An ImageView implementation for displaying {@link FaceMeshResult}. */ +public class FaceMeshResultImageView extends AppCompatImageView { + private static final String TAG = "FaceMeshResultImageView"; + + private static final int TESSELATION_COLOR = Color.parseColor("#70C0C0C0"); + private static final int TESSELATION_THICKNESS = 3; // Pixels + private static final int RIGHT_EYE_COLOR = Color.parseColor("#FF3030"); + private static final int RIGHT_EYE_THICKNESS = 5; // Pixels + private static final int RIGHT_EYEBROW_COLOR = Color.parseColor("#FF3030"); + private static final int RIGHT_EYEBROW_THICKNESS = 5; // Pixels + private static final int LEFT_EYE_COLOR = Color.parseColor("#30FF30"); + private static final int LEFT_EYE_THICKNESS = 5; // Pixels + private static final int LEFT_EYEBROW_COLOR = Color.parseColor("#30FF30"); + private static final int LEFT_EYEBROW_THICKNESS = 5; // Pixels + private static final int FACE_OVAL_COLOR = Color.parseColor("#E0E0E0"); + private static final int FACE_OVAL_THICKNESS = 5; // Pixels + private static final int LIPS_COLOR = Color.parseColor("#E0E0E0"); + private static final int LIPS_THICKNESS = 5; // Pixels + private Bitmap latest; + + public FaceMeshResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link FaceMeshResult} to render. + * + * @param result a {@link FaceMeshResult} object that contains the solution outputs and the input + * {@link Bitmap}. + */ + public void setFaceMeshResult(FaceMeshResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + Size imageSize = new Size(width, height); + canvas.drawBitmap(bmInput, new Matrix(), null); + int numFaces = result.multiFaceLandmarks().size(); + for (int i = 0; i < numFaces; ++i) { + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_TESSELATION, + imageSize, + TESSELATION_COLOR, + TESSELATION_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYE, + imageSize, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_EYEBROW, + imageSize, + RIGHT_EYEBROW_COLOR, + RIGHT_EYEBROW_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYE, + imageSize, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_EYEBROW, + imageSize, + LEFT_EYEBROW_COLOR, + LEFT_EYEBROW_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_FACE_OVAL, + imageSize, + FACE_OVAL_COLOR, + FACE_OVAL_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LIPS, + imageSize, + LIPS_COLOR, + LIPS_THICKNESS); + if (result.multiFaceLandmarks().get(i).getLandmarkCount() + == FaceMesh.FACEMESH_NUM_LANDMARKS_WITH_IRISES) { + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_RIGHT_IRIS, + imageSize, + RIGHT_EYE_COLOR, + RIGHT_EYE_THICKNESS); + drawLandmarksOnCanvas( + canvas, + result.multiFaceLandmarks().get(i).getLandmarkList(), + FaceMeshConnections.FACEMESH_LEFT_IRIS, + imageSize, + LEFT_EYE_COLOR, + LEFT_EYE_THICKNESS); + } + } + } + + /** Updates the image view with the latest {@link FaceMeshResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawLandmarksOnCanvas( + Canvas canvas, + List faceLandmarkList, + ImmutableSet connections, + Size imageSize, + int color, + int thickness) { + // Draw connections. + for (FaceMeshConnections.Connection c : connections) { + Paint connectionPaint = new Paint(); + connectionPaint.setColor(color); + connectionPaint.setStrokeWidth(thickness); + NormalizedLandmark start = faceLandmarkList.get(c.start()); + NormalizedLandmark end = faceLandmarkList.get(c.end()); + canvas.drawLine( + start.getX() * imageSize.getWidth(), + start.getY() * imageSize.getHeight(), + end.getX() * imageSize.getWidth(), + end.getY() * imageSize.getHeight(), + connectionPaint); + } + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java new file mode 100644 index 000000000..93039870c --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java @@ -0,0 +1,359 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facemesh; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.facemesh.FaceMesh; +import com.google.mediapipe.solutions.facemesh.FaceMeshOptions; +import com.google.mediapipe.solutions.facemesh.FaceMeshResult; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Face Mesh app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private FaceMesh facemesh; + // Run the pipeline and the model inference on GPU or CPU. + private static final boolean RUN_ON_GPU = true; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private FaceMeshResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + // TODO: Add a toggle to switch between the original face mesh and attention mesh. + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation == ExifInterface.ORIENTATION_NORMAL) { + return inputBitmap; + } + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + return Bitmap.createBitmap( + inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + bitmap = rotateBitmap(bitmap, imageData); + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + facemesh.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + imageView = new FaceMeshResultImageView(this); + } + + /** Sets up core workflow for static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe Face Mesh solution instance in the static image mode. + facemesh = + new FaceMesh( + this, + FaceMeshOptions.builder() + .setStaticImageMode(true) + .setRefineLandmarks(true) + .setRunOnGpu(RUN_ON_GPU) + .build()); + + // Connects MediaPipe Face Mesh solution to the user-defined FaceMeshResultImageView. + facemesh.setResultListener( + faceMeshResult -> { + logNoseLandmark(faceMeshResult, /*showPixelValues=*/ true); + imageView.setFaceMeshResult(faceMeshResult); + runOnUiThread(() -> imageView.update()); + }); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + facemesh.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** Sets up core workflow for streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe Face Mesh solution instance in the streaming mode. + facemesh = + new FaceMesh( + this, + FaceMeshOptions.builder() + .setStaticImageMode(false) + .setRefineLandmarks(true) + .setRunOnGpu(RUN_ON_GPU) + .build()); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); + + if (inputSource == InputSource.CAMERA) { + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined FaceMeshResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>(this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + facemesh.setResultListener( + faceMeshResult -> { + logNoseLandmark(faceMeshResult, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(faceMeshResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + facemesh.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (facemesh != null) { + facemesh.close(); + } + } + + private void logNoseLandmark(FaceMeshResult result, boolean showPixelValues) { + if (result == null || result.multiFaceLandmarks().isEmpty()) { + return; + } + NormalizedLandmark noseLandmark = result.multiFaceLandmarks().get(0).getLandmarkList().get(1); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe Face Mesh nose coordinates (pixel values): x=%f, y=%f", + noseLandmark.getX() * width, noseLandmark.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseLandmark.getX(), noseLandmark.getY())); + } + } +} diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/res b/mediapipe/examples/android/solutions/facemesh/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/facemesh/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/gradle.properties b/mediapipe/examples/android/solutions/gradle.properties new file mode 100644 index 000000000..c09e1e3b0 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app"s APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..e708b1c02 Binary files /dev/null and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..442d9132e --- /dev/null +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/solutions/gradlew b/mediapipe/examples/android/solutions/gradlew new file mode 100755 index 000000000..4f906e0c8 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradlew @@ -0,0 +1,185 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/mediapipe/examples/android/solutions/gradlew.bat b/mediapipe/examples/android/solutions/gradlew.bat new file mode 100755 index 000000000..ac1b06f93 --- /dev/null +++ b/mediapipe/examples/android/solutions/gradlew.bat @@ -0,0 +1,89 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/mediapipe/examples/android/solutions/hands/build.gradle b/mediapipe/examples/android/solutions/hands/build.gradle new file mode 100644 index 000000000..6c21109d5 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/build.gradle @@ -0,0 +1,41 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.hands" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe Hands Solution. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:hands:latest.release' +} diff --git a/mediapipe/examples/android/solutions/hands/proguard-rules.pro b/mediapipe/examples/android/solutions/hands/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml new file mode 100644 index 000000000..2344ff9a8 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/AndroidManifest.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/hands/src/main/BUILD b/mediapipe/examples/android/solutions/hands/src/main/BUILD new file mode 100644 index 000000000..d3c304b57 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/BUILD @@ -0,0 +1,45 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "hands", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.hands", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.hands", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/hands", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java new file mode 100644 index 000000000..d136cc22b --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java @@ -0,0 +1,181 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.opengl.GLES20; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.List; + +/** A custom implementation of {@link ResultGlRenderer} to render {@link HandsResult}. */ +public class HandsResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "HandsResultGlRenderer"; + + private static final float[] LEFT_HAND_CONNECTION_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final float[] RIGHT_HAND_CONNECTION_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final float CONNECTION_THICKNESS = 25.0f; + private static final float[] LEFT_HAND_HOLLOW_CIRCLE_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final float[] RIGHT_HAND_HOLLOW_CIRCLE_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final float HOLLOW_CIRCLE_RADIUS = 0.01f; + private static final float[] LEFT_HAND_LANDMARK_COLOR = new float[] {1f, 0.2f, 0.2f, 1f}; + private static final float[] RIGHT_HAND_LANDMARK_COLOR = new float[] {0.2f, 1f, 0.2f, 1f}; + private static final float LANDMARK_RADIUS = 0.008f; + private static final int NUM_SEGMENTS = 120; + private static final String VERTEX_SHADER = + "uniform mat4 uProjectionMatrix;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uProjectionMatrix * vPosition;\n" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "uniform vec4 uColor;\n" + + "void main() {\n" + + " gl_FragColor = uColor;\n" + + "}"; + private int program; + private int positionHandle; + private int projectionMatrixHandle; + private int colorHandle; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + projectionMatrixHandle = GLES20.glGetUniformLocation(program, "uProjectionMatrix"); + colorHandle = GLES20.glGetUniformLocation(program, "uColor"); + } + + @Override + public void renderResult(HandsResult result, float[] projectionMatrix) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + GLES20.glUniformMatrix4fv(projectionMatrixHandle, 1, false, projectionMatrix, 0); + GLES20.glLineWidth(CONNECTION_THICKNESS); + + int numHands = result.multiHandLandmarks().size(); + for (int i = 0; i < numHands; ++i) { + boolean isLeftHand = result.multiHandedness().get(i).getLabel().equals("Left"); + drawConnections( + result.multiHandLandmarks().get(i).getLandmarkList(), + isLeftHand ? LEFT_HAND_CONNECTION_COLOR : RIGHT_HAND_CONNECTION_COLOR); + for (NormalizedLandmark landmark : result.multiHandLandmarks().get(i).getLandmarkList()) { + // Draws the landmark. + drawCircle( + landmark.getX(), + landmark.getY(), + isLeftHand ? LEFT_HAND_LANDMARK_COLOR : RIGHT_HAND_LANDMARK_COLOR); + // Draws a hollow circle around the landmark. + drawHollowCircle( + landmark.getX(), + landmark.getY(), + isLeftHand ? LEFT_HAND_HOLLOW_CIRCLE_COLOR : RIGHT_HAND_HOLLOW_CIRCLE_COLOR); + } + } + } + + /** + * Deletes the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + private void drawConnections(List handLandmarkList, float[] colorArray) { + GLES20.glUniform4fv(colorHandle, 1, colorArray, 0); + for (Hands.Connection c : Hands.HAND_CONNECTIONS) { + NormalizedLandmark start = handLandmarkList.get(c.start()); + NormalizedLandmark end = handLandmarkList.get(c.end()); + float[] vertex = {start.getX(), start.getY(), end.getX(), end.getY()}; + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } + } + + private void drawCircle(float x, float y, float[] colorArray) { + GLES20.glUniform4fv(colorHandle, 1, colorArray, 0); + int vertexCount = NUM_SEGMENTS + 2; + float[] vertices = new float[vertexCount * 3]; + vertices[0] = x; + vertices[1] = y; + vertices[2] = 0; + for (int i = 1; i < vertexCount; i++) { + float angle = 2.0f * i * (float) Math.PI / NUM_SEGMENTS; + int currentIndex = 3 * i; + vertices[currentIndex] = x + (float) (LANDMARK_RADIUS * Math.cos(angle)); + vertices[currentIndex + 1] = y + (float) (LANDMARK_RADIUS * Math.sin(angle)); + vertices[currentIndex + 2] = 0; + } + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertices.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertices); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 3, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_TRIANGLE_FAN, 0, vertexCount); + } + + private void drawHollowCircle(float x, float y, float[] colorArray) { + GLES20.glUniform4fv(colorHandle, 1, colorArray, 0); + int vertexCount = NUM_SEGMENTS + 1; + float[] vertices = new float[vertexCount * 3]; + for (int i = 0; i < vertexCount; i++) { + float angle = 2.0f * i * (float) Math.PI / NUM_SEGMENTS; + int currentIndex = 3 * i; + vertices[currentIndex] = x + (float) (HOLLOW_CIRCLE_RADIUS * Math.cos(angle)); + vertices[currentIndex + 1] = y + (float) (HOLLOW_CIRCLE_RADIUS * Math.sin(angle)); + vertices[currentIndex + 2] = 0; + } + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertices.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertices); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 3, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINE_STRIP, 0, vertexCount); + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java new file mode 100644 index 000000000..91b508ee6 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java @@ -0,0 +1,127 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.formats.proto.LandmarkProto; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.util.List; + +/** An ImageView implementation for displaying {@link HandsResult}. */ +public class HandsResultImageView extends AppCompatImageView { + private static final String TAG = "HandsResultImageView"; + + private static final int LEFT_HAND_CONNECTION_COLOR = Color.parseColor("#30FF30"); + private static final int RIGHT_HAND_CONNECTION_COLOR = Color.parseColor("#FF3030"); + private static final int CONNECTION_THICKNESS = 8; // Pixels + private static final int LEFT_HAND_HOLLOW_CIRCLE_COLOR = Color.parseColor("#30FF30"); + private static final int RIGHT_HAND_HOLLOW_CIRCLE_COLOR = Color.parseColor("#FF3030"); + private static final int HOLLOW_CIRCLE_WIDTH = 5; // Pixels + private static final int LEFT_HAND_LANDMARK_COLOR = Color.parseColor("#FF3030"); + private static final int RIGHT_HAND_LANDMARK_COLOR = Color.parseColor("#30FF30"); + private static final int LANDMARK_RADIUS = 10; // Pixels + private Bitmap latest; + + public HandsResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link HandsResult} to render. + * + * @param result a {@link HandsResult} object that contains the solution outputs and the input + * {@link Bitmap}. + */ + public void setHandsResult(HandsResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + + canvas.drawBitmap(bmInput, new Matrix(), null); + int numHands = result.multiHandLandmarks().size(); + for (int i = 0; i < numHands; ++i) { + drawLandmarksOnCanvas( + result.multiHandLandmarks().get(i).getLandmarkList(), + result.multiHandedness().get(i).getLabel().equals("Left"), + canvas, + width, + height); + } + } + + /** Updates the image view with the latest {@link HandsResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawLandmarksOnCanvas( + List handLandmarkList, + boolean isLeftHand, + Canvas canvas, + int width, + int height) { + // Draw connections. + for (Hands.Connection c : Hands.HAND_CONNECTIONS) { + Paint connectionPaint = new Paint(); + connectionPaint.setColor( + isLeftHand ? LEFT_HAND_CONNECTION_COLOR : RIGHT_HAND_CONNECTION_COLOR); + connectionPaint.setStrokeWidth(CONNECTION_THICKNESS); + NormalizedLandmark start = handLandmarkList.get(c.start()); + NormalizedLandmark end = handLandmarkList.get(c.end()); + canvas.drawLine( + start.getX() * width, + start.getY() * height, + end.getX() * width, + end.getY() * height, + connectionPaint); + } + Paint landmarkPaint = new Paint(); + landmarkPaint.setColor(isLeftHand ? LEFT_HAND_LANDMARK_COLOR : RIGHT_HAND_LANDMARK_COLOR); + // Draws landmarks. + for (LandmarkProto.NormalizedLandmark landmark : handLandmarkList) { + canvas.drawCircle( + landmark.getX() * width, landmark.getY() * height, LANDMARK_RADIUS, landmarkPaint); + } + // Draws hollow circles around landmarks. + landmarkPaint.setColor( + isLeftHand ? LEFT_HAND_HOLLOW_CIRCLE_COLOR : RIGHT_HAND_HOLLOW_CIRCLE_COLOR); + landmarkPaint.setStrokeWidth(HOLLOW_CIRCLE_WIDTH); + landmarkPaint.setStyle(Paint.Style.STROKE); + for (LandmarkProto.NormalizedLandmark landmark : handLandmarkList) { + canvas.drawCircle( + landmark.getX() * width, + landmark.getY() * height, + LANDMARK_RADIUS + HOLLOW_CIRCLE_WIDTH, + landmarkPaint); + } + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java new file mode 100644 index 000000000..d93f9b1e3 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java @@ -0,0 +1,373 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.hands; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.hands.HandLandmark; +import com.google.mediapipe.solutions.hands.Hands; +import com.google.mediapipe.solutions.hands.HandsOptions; +import com.google.mediapipe.solutions.hands.HandsResult; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Hands app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private Hands hands; + // Run the pipeline and the model inference on GPU or CPU. + private static final boolean RUN_ON_GPU = true; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private HandsResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation == ExifInterface.ORIENTATION_NORMAL) { + return inputBitmap; + } + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + return Bitmap.createBitmap( + inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + bitmap = rotateBitmap(bitmap, imageData); + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + hands.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + imageView = new HandsResultImageView(this); + } + + /** Sets up core workflow for static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe Hands solution instance in the static image mode. + hands = + new Hands( + this, + HandsOptions.builder() + .setStaticImageMode(true) + .setMaxNumHands(2) + .setRunOnGpu(RUN_ON_GPU) + .build()); + + // Connects MediaPipe Hands solution to the user-defined HandsResultImageView. + hands.setResultListener( + handsResult -> { + logWristLandmark(handsResult, /*showPixelValues=*/ true); + imageView.setHandsResult(handsResult); + runOnUiThread(() -> imageView.update()); + }); + hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + hands.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** Sets up core workflow for streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe Hands solution instance in the streaming mode. + hands = + new Hands( + this, + HandsOptions.builder() + .setStaticImageMode(false) + .setMaxNumHands(2) + .setRunOnGpu(RUN_ON_GPU) + .build()); + hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); + + if (inputSource == InputSource.CAMERA) { + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined HandsResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>(this, hands.getGlContext(), hands.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + hands.setResultListener( + handsResult -> { + logWristLandmark(handsResult, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(handsResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + hands.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (hands != null) { + hands.close(); + } + } + + private void logWristLandmark(HandsResult result, boolean showPixelValues) { + if (result.multiHandLandmarks().isEmpty()) { + return; + } + NormalizedLandmark wristLandmark = + result.multiHandLandmarks().get(0).getLandmarkList().get(HandLandmark.WRIST); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist coordinates (pixel values): x=%f, y=%f", + wristLandmark.getX() * width, wristLandmark.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x=%f, y=%f", + wristLandmark.getX(), wristLandmark.getY())); + } + if (result.multiHandWorldLandmarks().isEmpty()) { + return; + } + Landmark wristWorldLandmark = + result.multiHandWorldLandmarks().get(0).getLandmarkList().get(HandLandmark.WRIST); + Log.i( + TAG, + String.format( + "MediaPipe Hand wrist world coordinates (in meters with the origin at the hand's" + + " approximate geometric center): x=%f m, y=%f m, z=%f m", + wristWorldLandmark.getX(), wristWorldLandmark.getY(), wristWorldLandmark.getZ())); + } +} diff --git a/mediapipe/examples/android/solutions/hands/src/main/res b/mediapipe/examples/android/solutions/hands/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/hands/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..c7bd21dbd --- /dev/null +++ b/mediapipe/examples/android/solutions/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml b/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..01f0af0ad --- /dev/null +++ b/mediapipe/examples/android/solutions/res/drawable/ic_launcher_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/res/layout/activity_main.xml b/mediapipe/examples/android/solutions/res/layout/activity_main.xml new file mode 100644 index 000000000..834e9a3e6 --- /dev/null +++ b/mediapipe/examples/android/solutions/res/layout/activity_main.xml @@ -0,0 +1,40 @@ + + + +