diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3703a7014..d7a1e1877 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,15 +1,17 @@ # Contributing guidelines -## Pull Request Checklist +## What type of pull request do we accept into MediaPipe repository? + +* Bug fixes +* Documentation fixes + +For new feature additions (e.g., new graphs and calculators), we are currently not planning to accept new feature pull requests into the MediaPipe repository. Instead, we like to get contributors to create their own repositories of the new feature and list it at [Awesome MediaPipe](https://mediapipe.org). This will allow contributors to more quickly get their code out to the community. Before sending your pull requests, make sure you followed this list. - Read [contributing guidelines](CONTRIBUTING.md). - Read [Code of Conduct](CODE_OF_CONDUCT.md). - Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- Check if my changes are consistent with the [guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- Changes are consistent with the [Coding Style](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#c-coding-style). -- Run [Unit Tests](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#running-unit-tests). ## How to become a contributor and submit your own code @@ -28,100 +30,7 @@ Follow either of the two links above to access the appropriate CLA and instructi ### Contributing code -If you have improvements to MediaPipe, send us your pull requests! For those +If you have bug fixes and documentation fixes to MediaPipe, send us your pull requests! For those just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/). -MediaPipe team members will be assigned to review your pull requests. Once the -pull requests are approved and pass continuous integration checks, a MediaPipe -team member will apply `ready to pull` label to your change. This means we are -working on getting your pull request submitted to our internal repository. After -the change has been submitted internally, your pull request will be merged -automatically on GitHub. - -If you want to contribute but you're not sure where to start, take a look at the -[issues with the "contributions welcome" label](https://github.com/google/mediapipe/labels/stat%3Acontributions%20welcome). -These are issues that we believe are particularly well suited for outside -contributions, often because we probably won't get to them right now. If you -decide to start on an issue, leave a comment so that other people know that -you're working on it. If you want to help out, but not alone, use the issue -comment thread to coordinate. - -### Contribution guidelines and standards - -Before sending your pull request for -[review](https://github.com/google/mediapipe/pulls), -make sure your changes are consistent with the guidelines and follow the -MediaPipe coding style. - -#### General guidelines and philosophy for contribution - -* Include unit tests when you contribute new features, as they help to a) - prove that your code works correctly, and b) guard against future breaking - changes to lower the maintenance cost. -* Bug fixes also generally require unit tests, because the presence of bugs - usually indicates insufficient test coverage. -* Keep API compatibility in mind when you change code in MediaPipe framework - e.g., code in - [mediapipe/framework](https://github.com/google/mediapipe/tree/master/mediapipe/framework) - and - [mediapipe/calculators](https://github.com/google/mediapipe/tree/master/mediapipe/calculators). - Once MediaPipe has reached version 1 and we will not make - non-backward-compatible API changes without a major release. Reviewers of - your pull request will comment on any API compatibility issues. -* When you contribute a new feature to MediaPipe, the maintenance burden is - (by default) transferred to the MediaPipe team. This means that benefit of - the contribution must be compared against the cost of maintaining the - feature. -* Full new features (e.g., a new op implementing a cutting-edge algorithm) - typically will live in - [mediapipe/addons](https://github.com/google/mediapipe/addons) to get some - airtime before decision is made regarding whether they are to be migrated to - the core. - -#### License - -Include a license at the top of new files. - -* [C/C++ license example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/calculator_base.cc#L1) -* [Java license example](https://github.com/google/mediapipe/blob/master/mediapipe/java/com/google/mediapipe/components/CameraHelper.java) - -Bazel BUILD files also need to include a license section, e.g., -[BUILD example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/BUILD#L61). - -#### C++ coding style - -Changes to MediaPipe C++ code should conform to -[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). - -Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: - -```bash -apt-get install -y clang-tidy -``` - -You can check a C/C++ file by doing: - - -```bash -clang-format --style=google > /tmp/my_cc_file.cc -diff /tmp/my_cc_file.cc -``` - -#### Coding style for other languages - -* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) -* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) -* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) -* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html) - -#### Running sanity check - -If you have Docker installed on your system, you can perform a sanity check on -your changes by running the command: - -```bash -mediapipe/tools/ci_build/ci_build.sh CPU mediapipe/tools/ci_build/ci_sanity.sh -``` - -This will catch most license, Python coding style and BUILD file issues that -may exist in your changes. +MediaPipe team members will be assigned to review your pull requests. Once the bug/documentation fixes are verified, a MediaPipe team member will acknowledge your contribution in the pull request comments, manually merge the fixes into our internal codebase upstream, and apply the `to be closed` label to the pull request. These fixes will later be pushed to GitHub in the next release, and a MediaPipe team member will then close the pull request. diff --git a/README.md b/README.md index 66323f988..0a96c42c8 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ run code search using * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users +* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/docs/data/visualizer/sample_trace.binarypb b/docs/data/visualizer/sample_trace.binarypb index fce47d8e5..fc2562934 100644 Binary files a/docs/data/visualizer/sample_trace.binarypb and b/docs/data/visualizer/sample_trace.binarypb differ diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 66aefb7b1..116f17d3b 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -207,8 +207,8 @@ class SomeAudioVideoCalculator : public CalculatorBase { // particular type. SetAny() has the same effect as explicitly // setting the type to be the stream's type. cc->Outputs().Tag("VIDEO").Set(); - cc->Outputs().Get("AUDIO", 0).Set; - cc->Outputs().Get("AUDIO", 1).Set; + cc->Outputs().Get("AUDIO", 0).Set(); + cc->Outputs().Get("AUDIO", 1).Set(); return ::mediapipe::OkStatus(); } ``` @@ -400,13 +400,13 @@ node { ``` The diagram below shows how the `PacketClonerCalculator` defines its output -packets based on its series of input packets. +packets (bottom) based on its series of input packets (top). | ![Graph using | : PacketClonerCalculator](../images/packet_cloner_calculator.png) : | :--------------------------------------------------------------------------: | | *Each time it receives a packet on its TICK input stream, the | : PacketClonerCalculator outputs the most recent packet from each of its input : -: streams. The sequence of output packets is determined by the sequene of : -: input packets and their timestamps. The timestamps are shows along the right : -: side of the diagram.* : +: streams. The sequence of output packets (bottom) is determined by the : +: sequence of input packets (top) and their timestamps. The timestamps are : +: shown along the right side of the diagram.* : diff --git a/docs/getting_started/building_examples.md b/docs/getting_started/building_examples.md index 2c3b6e77c..089a1fefe 100644 --- a/docs/getting_started/building_examples.md +++ b/docs/getting_started/building_examples.md @@ -184,8 +184,8 @@ app: ### Prerequisite -1. Install [Xcode](https://developer.apple.com/xcode/), and additionally - install the Command Line Tools by: +1. Install [Xcode](https://developer.apple.com/xcode/), then install the + Command Line Tools using: ```bash xcode-select --install @@ -196,74 +196,38 @@ app: We recommend using [Homebrew](https://brew.sh/) to get the latest version. 3. Set Python 3.7 as the default Python version and install the Python "six" - library. - - To make Mediapipe work with TensorFlow, please set Python 3.7 as the default - Python version and install the Python "six" library. + library. This is needed for TensorFlow. ```bash pip3 install --user six ``` -4. Follow - [Apple's instructions](https://developer.apple.com/support/certificates/) to - obtain the required development certificates and provisioning profiles for - your iOS device. - - Tip: You can the following command to see the provisioning profiles you have - previously downloaded using Xcode: `open - ~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate - and download a profile on - [Apple's developer site](https://developer.apple.com/account/resources/). - -5. Clone the MediaPipe repository. +4. Clone the MediaPipe repository. ```bash git clone https://github.com/google/mediapipe.git ``` -6. In the cloned MediaPipe repository, symlink or copy your provisioning profile - to `mediapipe/provisioning_profile.mobileprovision`, e.g., +### Set up a bundle ID prefix - ```bash - cd mediapipe - ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision - ``` +All iOS apps must have a bundle ID, and you must have a provisioning profile +that lets you install an app with that ID onto your phone. To avoid clashes +between different MediaPipe users, you need to configure a unique prefix for the +bundle IDs of our iOS demo apps. -### Option 1: Build with Bazel in Command Line +If you have a custom provisioning profile, see +[Custom provisioning](#custom-provisioning) below. -1. Modify the `bundle_id` field of the app's `ios_application` build target to - use your own identifier. For instance, for - [MediaPipe Hands](../solutions/hands.md), the `bundle_id` is in the - `HandTrackingGpuApp` target in the - [BUILD](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handtrackinggpu/BUILD) - file. +Otherwise, run this command to generate a unique prefix: -2. Again using [MediaPipe Hands](../solutions/hands.md) for example, run: +```bash +python3 mediapipe/examples/ios/link_local_profiles.py +``` - ```bash - bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp - ``` +### Create an Xcode project - You may see a permission request from `codesign` in order to sign the app. - - Tip: You can run this - [script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh) - to build all MediaPipe iOS example apps. - -3. In Xcode, open the `Devices and Simulators` window (command-shift-2). - -4. Make sure your device is connected. You will see a list of installed apps. - Press the "+" button under the list, and select the `.ipa` file built by - Bazel. - -5. You can now run the app on your device. - -### Option 2: Build in Xcode - -Note: This workflow requires a separate tool in addition to Bazel. If it fails -to work for some reason, please resort to the command-line build instructions in -the previous section. +This allows you to edit and debug one of the example apps in Xcode. It also +allows you to make use of automatic provisioning (see later section). 1. We will use a tool called [Tulsi](https://tulsi.bazel.build/) for generating Xcode projects from Bazel build configurations. @@ -283,25 +247,138 @@ the previous section. 2. Open `mediapipe/Mediapipe.tulsiproj` using the Tulsi app. - Important: If Tulsi displays an error saying "Bazel could not be found", - press the "Bazel..." button in the Packages tab and select the `bazel` - executable in your homebrew `/bin/` directory. + Tip: If Tulsi displays an error saying "Bazel could not be found", press the + "Bazel..." button in the Packages tab and select the `bazel` executable in + your homebrew `/bin/` directory. 3. Select the MediaPipe config in the Configs tab, then press the Generate button below. You will be asked for a location to save the Xcode project. Once the project is generated, it will be opened in Xcode. -4. You can now select any of the MediaPipe demos in the target menu, and build + If you get an error about bundle IDs, see the + [previous section](#set-up-a-bundle-id-prefix). + +### Set up provisioning + +To install applications on an iOS device, you need a provisioning profile. There +are two options: + +1. Automatic provisioning. This allows you to build and install an app to your + personal device. The provisining profile is managed by Xcode, and has to be + updated often (it is valid for about a week). + +2. Custom provisioning. This uses a provisioning profile associated with an + Apple developer account. These profiles have a longer validity period and + can target multiple devices, but you need a paid developer account with + Apple to obtain one. + +#### Automatic provisioning + +1. Create an Xcode project for MediaPipe, as discussed + [earlier](#create-an-xcode-project). + +2. In the project navigator in the left sidebar, select the "Mediapipe" + project. + +3. Select the "Signing & Capabilities" tab. + +4. Select one of the application targets, e.g. HandTrackingGpuApp. + +5. Check "Automatically manage signing", and confirm the dialog box. + +6. Select "_Your Name_ (Personal Team)" in the Team pop-up menu. + +7. This set-up needs to be done once for each application you want to install. + Repeat steps 4-6 as needed. + +This generates provisioning profiles for each app you have selected. Now we need +to tell Bazel to use them. We have provided a script to make this easier. + +1. In the terminal, to the `mediapipe` directory where you cloned the + repository. + +2. Run this command: + + ```bash + python3 mediapipe/examples/ios/link_local_profiles.py + ``` + +This will find and link the provisioning profile for all applications for which +you have enabled automatic provisioning in Xcode. + +Note: once a profile expires, Xcode will generate a new one; you must then run +this script again to link the updated profiles. + +#### Custom provisioning + +1. Obtain a provisioning profile from Apple. + +Tip: You can use this command to see the provisioning profiles you have +previously downloaded using Xcode: `open ~/Library/MobileDevice/"Provisioning +Profiles"`. If there are none, generate and download a profile on +[Apple's developer site](https://developer.apple.com/account/resources/). + +1. Symlink or copy your provisioning profile to + `mediapipe/mediapipe/provisioning_profile.mobileprovision`. + + ```bash + cd mediapipe + ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision + ``` + +Note: if you had previously set up automatic provisioning, you should remove the +`provisioning_profile.mobileprovision` symlink in each example's directory, +since it will take precedence over the common one. You can also overwrite it +with you own profile if you need a different profile for different apps. + +1. Open `mediapipe/examples/ios/bundle_id.bzl`, and change the + `BUNDLE_ID_PREFIX` to a prefix associated with your provisioning profile. + +### Build and run an app using Xcode + +1. Create the Xcode project, and make sure you have set up either automatic or + custom provisioning. + +2. You can now select any of the MediaPipe demos in the target menu, and build and run them as normal. - Note: When you ask Xcode to run an app, by default it will use the Debug - configuration. Some of our demos are computationally heavy; you may want to - use the Release configuration for better performance. +Note: When you ask Xcode to run an app, by default it will use the Debug +configuration. Some of our demos are computationally heavy; you may want to use +the Release configuration for better performance. - Tip: To switch build configuration in Xcode, click on the target menu, - choose "Edit Scheme...", select the Run action, and switch the Build - Configuration from Debug to Release. Note that this is set independently for - each target. +Tip: To switch build configuration in Xcode, click on the target menu, choose +"Edit Scheme...", select the Run action, and switch the Build Configuration from +Debug to Release. Note that this is set independently for each target. + +Tip: On the device, in Settings > General > Device Management, make sure the +developer (yourself) is trusted. + +### Build an app using the command line + +1. Make sure you have set up either automatic or custom provisioning. + +2. Using [MediaPipe Hands](../solutions/hands.md) for example, run: + + ```bash + bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp + ``` + + You may see a permission request from `codesign` in order to sign the app. + + Tip: If you are using custom provisioning, you can run this + [script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh) + to build all MediaPipe iOS example apps. + +3. In Xcode, open the `Devices and Simulators` window (command-shift-2). + +4. Make sure your device is connected. You will see a list of installed apps. + Press the "+" button under the list, and select the `.ipa` file built by + Bazel. + +5. You can now run the app on your device. + +Tip: On the device, in Settings > General > Device Management, make sure the +developer (yourself) is trusted. ## Desktop diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index 2794ea4f8..e4e8286f7 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -43,8 +43,8 @@ We will be using the following graph, [`edge_detection_mobile_gpu.pbtxt`]: ``` # MediaPipe graph that performs GPU Sobel edge detection on a live video stream. -# Used in the examples -# mediapipe/examples/android/src/java/com/mediapipe/apps/basic. +# Used in the examples in +# mediapipe/examples/android/src/java/com/mediapipe/apps/basic and # mediapipe/examples/ios/edgedetectiongpu. # Images coming into and out of the graph. @@ -764,7 +764,7 @@ If you ran into any issues, please see the full code of the tutorial [CameraX]:https://developer.android.com/training/camerax [`CameraXPreviewHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java [developer options]:https://developer.android.com/studio/debug/dev-options -[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt +[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt [`EglManager`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/glutil/EglManager.java [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java [`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout diff --git a/docs/getting_started/hello_world_desktop.md b/docs/getting_started/hello_world_desktop.md index 28a9aea8f..61e9b6471 100644 --- a/docs/getting_started/hello_world_desktop.md +++ b/docs/getting_started/hello_world_desktop.md @@ -18,7 +18,7 @@ nav_order: 5 2. To run the [`hello world`] example: ```bash - $ git clone https://github.com/google/mediapipe/mediapipe.git + $ git clone https://github.com/google/mediapipe.git $ cd mediapipe $ export GLOG_logtostderr=1 @@ -92,10 +92,10 @@ nav_order: 5 ```c++ CalculatorGraph graph; - RETURN_IF_ERROR(graph.Initialize(config)); - ASSIGN_OR_RETURN(OutputStreamPoller poller, - graph.AddOutputStreamPoller("out")); - RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.Initialize(config)); + MP_ASSIGN_OR_RETURN(OutputStreamPoller poller, + graph.AddOutputStreamPoller("out")); + MP_RETURN_IF_ERROR(graph.StartRun({})); ``` 5. The example then creates 10 packets (each packet contains a string "Hello @@ -105,9 +105,10 @@ nav_order: 5 ```c++ for (int i = 0; i < 10; ++i) { - RETURN_IF_ERROR(graph.AddPacketToInputStream("in", MakePacket("Hello World!").At(Timestamp(i)))); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", + MakePacket("Hello World!").At(Timestamp(i)))); } - RETURN_IF_ERROR(graph.CloseInputStream("in")); + MP_RETURN_IF_ERROR(graph.CloseInputStream("in")); ``` 6. Through the `OutputStreamPoller` object the example then retrieves all 10 diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 1c6c44961..2fdb028ce 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -56,7 +56,7 @@ node: { output_stream: "luma_video" } -# Applies the Sobel filter to luminance images sotred in RGB format. +# Applies the Sobel filter to luminance images stored in RGB format. node: { calculator: "SobelEdgesCalculator" input_stream: "luma_video" diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 7374e244b..b9be6e498 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -70,9 +70,15 @@ apps, see these [instructions](./building_examples.md#ios). libopencv-imgproc-dev libopencv-video-dev ``` - [`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia - Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be - modified. + Debian 9 and Ubuntu 18.04 install the packages in + `/usr/lib/x86_64-linux-gnu`. MediaPipe's [`opencv_linux.BUILD`] and + [`ffmpeg_linux.BUILD`] are configured for this library path. Ubuntu 20.04 + may install the OpenCV and FFmpeg packages in `/usr/local`, Please follow + the option 3 below to modify the [`WORKSPACE`], [`opencv_linux.BUILD`] and + [`ffmpeg_linux.BUILD`] files accordingly. + + Moreover, for Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, the + library path needs to be modified like the following: ```bash sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD @@ -85,11 +91,13 @@ apps, see these [instructions](./building_examples.md#ios). [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) to manually build OpenCV from source code. - Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to - point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed - in "/usr/local/", you need to update the "linux_opencv" new_local_repository - rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] - like the following: + Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and + [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg + libraries. For example if OpenCV and FFmpeg are both manually installed in + "/usr/local/", you will need to update: (1) the "linux_opencv" and + "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv" + cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg" + cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below: ```bash new_local_repository( @@ -98,6 +106,12 @@ apps, see these [instructions](./building_examples.md#ios). path = "/usr/local", ) + new_local_repository( + name = "linux_ffmpeg", + build_file = "@//third_party:ffmpeg_linux.BUILD", + path = "/usr/local", + ) + cc_library( name = "opencv", srcs = glob( @@ -110,8 +124,36 @@ apps, see these [instructions](./building_examples.md#ios). "lib/libopencv_videoio.so", ], ), - hdrs = glob(["include/opencv4/**/*.h*"]), - includes = ["include/opencv4/"], + hdrs = glob([ + # For OpenCV 3.x + "include/opencv2/**/*.h*", + # For OpenCV 4.x + # "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + # For OpenCV 3.x + "include/", + # For OpenCV 4.x + # "include/opencv4/", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + ) + + cc_library( + name = "libffmpeg", + srcs = glob( + [ + "lib/libav*.so", + ], + ), + hdrs = glob(["include/libav*/*.h"]), + includes = ["include"], + linkopts = [ + "-lavcodec", + "-lavformat", + "-lavutil", + ], linkstatic = 1, visibility = ["//visibility:public"], ) @@ -158,6 +200,10 @@ apps, see these [instructions](./building_examples.md#ios). # Hello World! ``` +If you run into a build error, please read +[Troubleshooting](./troubleshooting.md) to find the solutions of several common +build issues. + ## Installing on CentOS **Disclaimer**: Running MediaPipe on CentOS is experimental. @@ -190,11 +236,13 @@ apps, see these [instructions](./building_examples.md#ios). Option 2. Build OpenCV from source code. - Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to - point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed - in "/usr/local/", you need to update the "linux_opencv" new_local_repository - rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] - like the following: + Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and + [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg + libraries. For example if OpenCV and FFmpeg are both manually installed in + "/usr/local/", you will need to update: (1) the "linux_opencv" and + "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv" + cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg" + cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below: ```bash new_local_repository( @@ -203,6 +251,12 @@ apps, see these [instructions](./building_examples.md#ios). path = "/usr/local", ) + new_local_repository( + name = "linux_ffmpeg", + build_file = "@//third_party:ffmpeg_linux.BUILD", + path = "/usr/local", + ) + cc_library( name = "opencv", srcs = glob( @@ -215,8 +269,36 @@ apps, see these [instructions](./building_examples.md#ios). "lib/libopencv_videoio.so", ], ), - hdrs = glob(["include/opencv4/**/*.h*"]), - includes = ["include/opencv4/"], + hdrs = glob([ + # For OpenCV 3.x + "include/opencv2/**/*.h*", + # For OpenCV 4.x + # "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + # For OpenCV 3.x + "include/", + # For OpenCV 4.x + # "include/opencv4/", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + ) + + cc_library( + name = "libffmpeg", + srcs = glob( + [ + "lib/libav*.so", + ], + ), + hdrs = glob(["include/libav*/*.h"]), + includes = ["include"], + linkopts = [ + "-lavcodec", + "-lavformat", + "-lavutil", + ], linkstatic = 1, visibility = ["//visibility:public"], ) @@ -243,6 +325,10 @@ apps, see these [instructions](./building_examples.md#ios). # Hello World! ``` +If you run into a build error, please read +[Troubleshooting](./troubleshooting.md) to find the solutions of several common +build issues. + ## Installing on macOS 1. Prework: @@ -375,6 +461,10 @@ apps, see these [instructions](./building_examples.md#ios). # Hello World! ``` +If you run into a build error, please read +[Troubleshooting](./troubleshooting.md) to find the solutions of several common +build issues. + ## Installing on Windows **Disclaimer**: Running MediaPipe on Windows is experimental. @@ -454,13 +544,13 @@ next section. 9. Run the [Hello World desktop example](./hello_world_desktop.md). Note: For building MediaPipe on Windows, please add `--action_env - PYTHON_BIN_PATH="C:/path/to/python.exe"` to the build command. + PYTHON_BIN_PATH="C://path//to//python.exe"` to the build command. Alternatively, you can follow [issue 724](https://github.com/google/mediapipe/issues/724) to fix the python configuration manually. ``` - C:\Users\Username\mediapipe_repo>bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 --action_env PYTHON_BIN_PATH="C:/python_36/python.exe" mediapipe/examples/desktop/hello_world + C:\Users\Username\mediapipe_repo>bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 --action_env PYTHON_BIN_PATH="C://python_36//python.exe" mediapipe/examples/desktop/hello_world C:\Users\Username\mediapipe_repo>set GLOG_logtostderr=1 @@ -480,6 +570,10 @@ next section. ``` +If you run into a build error, please read +[Troubleshooting](./troubleshooting.md) to find the solutions of several common +build issues. + ## Installing on Windows Subsystem for Linux (WSL) Note: The pre-built OpenCV packages don't support cameras in WSL. Unless you @@ -603,6 +697,10 @@ cameras. Alternatively, you use a video file as input. # Hello World! ``` +If you run into a build error, please +read [Troubleshooting](./troubleshooting.md) to find the solutions of several +common build issues. + ## Installing using Docker This will use a Docker image that will isolate mediapipe's installation from the rest of the system. @@ -653,6 +751,10 @@ This will use a Docker image that will isolate mediapipe's installation from the # Hello World! ``` +If you run into a build error, please +read [Troubleshooting](./troubleshooting.md) to find the solutions of several +common build issues. + 4. Build a MediaPipe Android example. ```bash @@ -692,6 +794,7 @@ This will use a Docker image that will isolate mediapipe's installation from the [`WORKSPACE`]: https://github.com/google/mediapipe/blob/master/WORKSPACE [`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD +[`ffmpeg_linux.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_linux.BUILD [`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD [`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD [`setup_opencv.sh`]: https://github.com/google/mediapipe/blob/master/setup_opencv.sh diff --git a/docs/getting_started/troubleshooting.md b/docs/getting_started/troubleshooting.md index 9d1bedac4..6d1a0e96e 100644 --- a/docs/getting_started/troubleshooting.md +++ b/docs/getting_started/troubleshooting.md @@ -12,6 +12,90 @@ nav_order: 10 {:toc} --- +## Missing Python binary path + +The error message: + +``` +ERROR: An error occurred during the fetch of repository 'local_execution_config_python': + Traceback (most recent call last): + File "/sandbox_path/external/org_tensorflow/third_party/py/python_configure.bzl", line 208 + get_python_bin(repository_ctx) + ... +Repository command failed +``` + +usually indicates that Bazel fails to find the local Python binary. To solve +this issue, please first find where the python binary is and then add +`--action_env PYTHON_BIN_PATH=` to the Bazel command like +the following: + +``` +bazel build -c opt \ + --define MEDIAPIPE_DISABLE_GPU=1 \ + --action_env PYTHON_BIN_PATH="/path/to/python" \ + mediapipe/examples/desktop/hello_world +``` + +## Missing necessary Python packages + +The error message: + +``` +ImportError: No module named numpy +Is numpy installed? +``` + +usually indicates that certain Python packages are not installed. Please run +`pip install` or `pip3 install` depending on your Python binary version to +install those packages. + +## Fail to fetch remote dependency repositories + +The error message: + +``` +ERROR: An error occurred during the fetch of repository 'org_tensorflow': + java.io.IOException: Error downloading [https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz, https://github.com/tensorflow/tensorflow/archive/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz] to /sandbox_path/external/org_tensorflow/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz: Tried to reconnect at offset 9,944,151 but server didn't support it + +or + +WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz failed: class java.net.ConnectException Connection timed out (Connection timed out) +``` + +usually indicates that Bazel fails to download necessary dependency repositories +that MediaPipe needs. MedaiPipe has several dependency repositories that are +hosted by Google sites. In some regions, you may need to set up a network proxy +or use a VPN to access those resources. You may also need to append +`--host_jvm_args "-DsocksProxyHost= -DsocksProxyPort="` +to the Bazel command. See +[this GitHub issue](https://github.com/google/mediapipe/issues/581#issuecomment-610356857) +for more details. + +If you believe that it's not a network issue, another possibility is that some +resources could be temporarily unavailable, please run `bazel clean --expunge` +and retry it later. If it's still not working, please file a GitHub issue with +the detailed error message. + +## Incorrect MediaPipe OpenCV config + +The error message: + +``` +error: undefined reference to 'cv::String::deallocate()' +error: undefined reference to 'cv::String::allocate(unsigned long)' +error: undefined reference to 'cv::VideoCapture::VideoCapture(cv::String const&)' +... +error: undefined reference to 'cv::putText(cv::InputOutputArray const&, cv::String const&, cv::Point, int, double, cv::Scalar, int, int, bool)' +``` + +usually indicates that OpenCV is not properly configured for MediaPipe. Please +take a look at the "Install OpenCV and FFmpeg" sections in +[Installation](./install.md) to see how to modify MediaPipe's WORKSPACE and +linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv +libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666) +may also help. + ## Native method not found The error message: diff --git a/docs/index.md b/docs/index.md index 39ea05b42..ea6c2feb3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -123,7 +123,7 @@ run code search using * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users +* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/docs/solutions/box_tracking.md b/docs/solutions/box_tracking.md index 84da8565d..007376168 100644 --- a/docs/solutions/box_tracking.md +++ b/docs/solutions/box_tracking.md @@ -112,7 +112,7 @@ examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index a8e844df4..4b9534b22 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -43,7 +43,7 @@ examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index e81ac0f08..9026eac8d 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -65,7 +65,7 @@ from the Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ## Models @@ -109,7 +109,7 @@ Please first see general instructions for Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/hair_segmentation.md b/docs/solutions/hair_segmentation.md index 87361040a..94cabce24 100644 --- a/docs/solutions/hair_segmentation.md +++ b/docs/solutions/hair_segmentation.md @@ -24,7 +24,7 @@ examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 4ba33f861..f4d93d840 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -66,7 +66,7 @@ and a Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ## Models @@ -132,7 +132,7 @@ examples. Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/knift.md b/docs/solutions/knift.md index ec2eec154..15d6f5d30 100644 --- a/docs/solutions/knift.md +++ b/docs/solutions/knift.md @@ -72,7 +72,7 @@ Please first see general instructions for Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). * Graph: [`mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt) diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index a92e57e7d..fb0bff2b1 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -19,7 +19,7 @@ nav_order: 5 Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Mobile diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index c142bfdf9..f4db179e2 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -156,7 +156,7 @@ Please first see general instructions for Note: To visualize a graph, copy the graph and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how to visualize its associated subgraphs, please see -[visualizer documentation](../visualizer.md). +[visualizer documentation](../tools/visualizer.md). ### Objectron for Shoes diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index 472e52a7d..a0188836b 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -33,13 +33,12 @@ command line option: `--define MEDIAPIPE_PROFILING=0`. To enable tracing and profiling, the `CalculatorGraphConfig` (in [calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto)) representing the graph must have a `profiler_config` message at its root. Here -is a simple setup that turns on a few extra options: +is a simple setup that turns on tracing and keeps 100 seconds of timing events: ``` profiler_config { - enable_profiler: true trace_enabled: true - trace_log_count: 5 + trace_log_interval_count: 200 } ``` diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 6196bed5b..1d053081c 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -450,6 +450,21 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "mux_calculator_test", + srcs = ["mux_calculator_test.cc"], + deps = [ + ":mux_calculator", + ":round_robin_demux_calculator", + ":split_vector_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], @@ -947,7 +962,6 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index aedd01b64..ea0f7b81b 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -56,12 +56,19 @@ std::string ToString(GateState state) { // disallowing the corresponding packets in other input streams. The behavior // can be inverted with a calculator option. // +// ALLOW or DISALLOW can also be specified as an input side packet. The rules +// for evaluation remain the same as above. +// +// ALLOW/DISALLOW inputs must be specified either using input stream or +// via input side packet but not both. +// // Intended to be used with the default input stream handler, which synchronizes // all data input streams with the ALLOW/DISALLOW control input stream. // // Example config: // node { // calculator: "GateCalculator" +// input_side_packet: "ALLOW:allow" or "DISALLOW:disallow" // input_stream: "input_stream0" // input_stream: "input_stream1" // input_stream: "input_streamN" @@ -75,10 +82,40 @@ class GateCalculator : public CalculatorBase { public: GateCalculator() {} + static ::mediapipe::Status CheckAndInitAllowDisallowInputs( + CalculatorContract* cc) { + bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || + cc->InputSidePackets().HasTag("DISALLOW"); + bool input_via_stream = + cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW"); + // Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW + // input. + RET_CHECK(input_via_side_packet ^ input_via_stream); + + if (input_via_side_packet) { + RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^ + cc->InputSidePackets().HasTag("DISALLOW")); + + if (cc->InputSidePackets().HasTag("ALLOW")) { + cc->InputSidePackets().Tag("ALLOW").Set(); + } else { + cc->InputSidePackets().Tag("DISALLOW").Set(); + } + } else { + RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); + + if (cc->Inputs().HasTag("ALLOW")) { + cc->Inputs().Tag("ALLOW").Set(); + } else { + cc->Inputs().Tag("DISALLOW").Set(); + } + } + return ::mediapipe::OkStatus(); + } + static ::mediapipe::Status GetContract(CalculatorContract* cc) { - // Assume that input streams do not have a tag and that gating signal is - // tagged either ALLOW or DISALLOW. - RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); + RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); + const int num_data_streams = cc->Inputs().NumEntries(""); RET_CHECK_GE(num_data_streams, 1); RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) @@ -88,11 +125,6 @@ class GateCalculator : public CalculatorBase { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); } - if (cc->Inputs().HasTag("ALLOW")) { - cc->Inputs().Tag("ALLOW").Set(); - } else { - cc->Inputs().Tag("DISALLOW").Set(); - } if (cc->Outputs().HasTag("STATE_CHANGE")) { cc->Outputs().Tag("STATE_CHANGE").Set(); @@ -102,6 +134,17 @@ class GateCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) final { + use_side_packet_for_allow_disallow_ = false; + if (cc->InputSidePackets().HasTag("ALLOW")) { + use_side_packet_for_allow_disallow_ = true; + allow_by_side_packet_decision_ = + cc->InputSidePackets().Tag("ALLOW").Get(); + } else if (cc->InputSidePackets().HasTag("DISALLOW")) { + use_side_packet_for_allow_disallow_ = true; + allow_by_side_packet_decision_ = + !cc->InputSidePackets().Tag("DISALLOW").Get(); + } + cc->SetOffset(TimestampDiff(0)); num_data_streams_ = cc->Inputs().NumEntries(""); last_gate_state_ = GATE_UNINITIALIZED; @@ -115,14 +158,18 @@ class GateCalculator : public CalculatorBase { ::mediapipe::Status Process(CalculatorContext* cc) final { bool allow = empty_packets_as_allow_; - if (cc->Inputs().HasTag("ALLOW") && !cc->Inputs().Tag("ALLOW").IsEmpty()) { - allow = cc->Inputs().Tag("ALLOW").Get(); + if (use_side_packet_for_allow_disallow_) { + allow = allow_by_side_packet_decision_; + } else { + if (cc->Inputs().HasTag("ALLOW") && + !cc->Inputs().Tag("ALLOW").IsEmpty()) { + allow = cc->Inputs().Tag("ALLOW").Get(); + } + if (cc->Inputs().HasTag("DISALLOW") && + !cc->Inputs().Tag("DISALLOW").IsEmpty()) { + allow = !cc->Inputs().Tag("DISALLOW").Get(); + } } - if (cc->Inputs().HasTag("DISALLOW") && - !cc->Inputs().Tag("DISALLOW").IsEmpty()) { - allow = !cc->Inputs().Tag("DISALLOW").Get(); - } - const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; if (cc->Outputs().HasTag("STATE_CHANGE")) { @@ -157,6 +204,8 @@ class GateCalculator : public CalculatorBase { GateState last_gate_state_ = GATE_UNINITIALIZED; int num_data_streams_; bool empty_packets_as_allow_; + bool use_side_packet_for_allow_disallow_; + bool allow_by_side_packet_decision_; }; REGISTER_CALCULATOR(GateCalculator); diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index 8a7272416..fc34f6e97 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -24,6 +24,21 @@ namespace { class GateCalculatorTest : public ::testing::Test { protected: + // Helper to run a graph and return status. + static ::mediapipe::Status RunGraph(const std::string& proto) { + auto runner = absl::make_unique( + ParseTextProtoOrDie(proto)); + return runner->Run(); + } + + // Use this when ALLOW/DISALLOW input is provided as a side packet. + void RunTimeStep(int64 timestamp, bool stream_payload) { + runner_->MutableInputs()->Get("", 0).packets.push_back( + MakePacket(stream_payload).At(Timestamp(timestamp))); + MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; + } + + // Use this when ALLOW/DISALLOW input is provided as an input stream. void RunTimeStep(int64 timestamp, const std::string& control_tag, bool control) { runner_->MutableInputs()->Get("", 0).packets.push_back( @@ -31,7 +46,6 @@ class GateCalculatorTest : public ::testing::Test { runner_->MutableInputs() ->Tag(control_tag) .packets.push_back(MakePacket(control).At(Timestamp(timestamp))); - MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; } @@ -46,6 +60,136 @@ class GateCalculatorTest : public ::testing::Test { std::unique_ptr runner_; }; +TEST_F(GateCalculatorTest, InvalidInputs) { + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + input_stream: "DISALLOW:gating_stream" + output_stream: "test_output" + )"))); + + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_side_packet: "ALLOW:gating_stream" + input_side_packet: "DISALLOW:gating_stream" + output_stream: "test_output" + )"))); + + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + input_side_packet: "ALLOW:gating_stream" + output_stream: "test_output" + )"))); + + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "DISALLOW:gating_stream" + input_side_packet: "DISALLOW:gating_stream" + output_stream: "test_output" + )"))); + + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + input_side_packet: "DISALLOW:gating_stream" + output_stream: "test_output" + )"))); + + EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "DISALLOW:gating_stream" + input_side_packet: "ALLOW:gating_stream" + output_stream: "test_output" + )"))); +} + +TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { + SetRunner(R"( + calculator: "GateCalculator" + input_side_packet: "ALLOW:gating_stream" + input_stream: "test_input" + output_stream: "test_output" + )"); + runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true)); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(false, output[1].Get()); +} + +TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { + SetRunner(R"( + calculator: "GateCalculator" + input_side_packet: "DISALLOW:gating_stream" + input_stream: "test_input" + output_stream: "test_output" + )"); + runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false)); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(false, output[1].Get()); +} + +TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { + SetRunner(R"( + calculator: "GateCalculator" + input_side_packet: "ALLOW:gating_stream" + input_stream: "test_input" + output_stream: "test_output" + )"); + runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false)); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + +TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { + SetRunner(R"( + calculator: "GateCalculator" + input_side_packet: "DISALLOW:gating_stream" + input_stream: "test_input" + output_stream: "test_output" + )"); + runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true)); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(0, output.size()); +} + TEST_F(GateCalculatorTest, Allow) { SetRunner(R"( calculator: "GateCalculator" diff --git a/mediapipe/calculators/core/immediate_mux_calculator.cc b/mediapipe/calculators/core/immediate_mux_calculator.cc index cb930bed7..007fbf73e 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator.cc @@ -37,6 +37,10 @@ namespace mediapipe { // the RoundRobinDemuxCalculator. Therefore, packets from different // input streams are normally not expected to have the same timestamp. // +// NOTE: this calculator can drop packets non-deterministically, depending on +// how fast the input streams are fed. In most cases, MuxCalculator should be +// preferred. In particular, dropping packets can interfere with rate limiting +// mechanisms. class ImmediateMuxCalculator : public CalculatorBase { public: // This calculator combines any set of input streams into a single @@ -76,6 +80,9 @@ REGISTER_CALCULATOR(ImmediateMuxCalculator); if (!packet.IsEmpty()) { if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) { cc->Outputs().Index(0).AddPacket(packet); + } else { + LOG_FIRST_N(WARNING, 5) + << "Dropping a packet with timestamp " << packet.Timestamp(); } if (cc->Outputs().NumEntries() >= 2) { Timestamp output_timestamp = std::max( diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index 1d1ae1904..8ca25bdd0 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -17,28 +17,49 @@ namespace mediapipe { +namespace { +constexpr char kSelectTag[] = "SELECT"; +constexpr char kInputTag[] = "INPUT"; +} // namespace + // A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., -// using the integer value (0, 1, ...) in the packet on the "SELECT" input +// using the integer value (0, 1, ...) in the packet on the kSelectTag input // stream, and passes the packet on the selected input stream to the "OUTPUT" // output stream. +// The kSelectTag input can also be passed in as an input side packet, instead +// of as an input stream. Either of input stream or input side packet must be +// specified but not both. // // Note that this calculator defaults to use MuxInputStreamHandler, which is -// required for this calculator. +// required for this calculator. However, it can be overridden to work with +// other InputStreamHandlers. Check out the unit tests on for an example usage +// with DefaultInputStreamHandler. class MuxCalculator : public CalculatorBase { public: + static ::mediapipe::Status CheckAndInitAllowDisallowInputs( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^ + cc->InputSidePackets().HasTag(kSelectTag)); + if (cc->Inputs().HasTag(kSelectTag)) { + cc->Inputs().Tag(kSelectTag).Set(); + } else { + cc->InputSidePackets().Tag(kSelectTag).Set(); + } + return ::mediapipe::OkStatus(); + } + static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Tag("SELECT").Set(); - CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); + RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); + CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag); PacketType* data_input0 = &cc->Inputs().Get(data_input_id); data_input0->SetAny(); ++data_input_id; - for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) { + for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) { cc->Inputs().Get(data_input_id).SetSameAs(data_input0); } RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); - // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("MuxInputStreamHandler"); MediaPipeOptions options; cc->SetInputStreamHandlerOptions(options); @@ -47,16 +68,24 @@ class MuxCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) final { - select_input_ = cc->Inputs().GetId("SELECT", 0); - data_input_base_ = cc->Inputs().GetId("INPUT", 0); - num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); + use_side_packet_select_ = false; + if (cc->InputSidePackets().HasTag(kSelectTag)) { + use_side_packet_select_ = true; + selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get(); + } else { + select_input_ = cc->Inputs().GetId(kSelectTag, 0); + } + data_input_base_ = cc->Inputs().GetId(kInputTag, 0); + num_data_inputs_ = cc->Inputs().NumEntries(kInputTag); output_ = cc->Outputs().GetId("OUTPUT", 0); cc->SetOffset(TimestampDiff(0)); return ::mediapipe::OkStatus(); } ::mediapipe::Status Process(CalculatorContext* cc) final { - int select = cc->Inputs().Get(select_input_).Get(); + int select = use_side_packet_select_ + ? selected_index_ + : cc->Inputs().Get(select_input_).Get(); RET_CHECK(0 <= select && select < num_data_inputs_); if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) { cc->Outputs().Get(output_).AddPacket( @@ -70,6 +99,8 @@ class MuxCalculator : public CalculatorBase { CollectionItemId data_input_base_; int num_data_inputs_ = 0; CollectionItemId output_; + bool use_side_packet_select_; + int selected_index_; }; REGISTER_CALCULATOR(MuxCalculator); diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc new file mode 100644 index 000000000..ac6f7d6ee --- /dev/null +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -0,0 +1,237 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/split_vector_calculator.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +typedef SplitVectorCalculator SplitIntVectorCalculator; +REGISTER_CALCULATOR(SplitIntVectorCalculator); + +namespace { + +// Graph with default input stream handler, and the input selection is driven +// by an input stream. All MuxCalculator inputs are present at each timestamp. +constexpr char kTestGraphConfig1[] = R"proto( + input_stream: "input" + output_stream: "test_output" + node { + calculator: "SplitIntVectorCalculator" + input_stream: "input" + output_stream: "stream0" + output_stream: "stream1" + output_stream: "stream2" + output_stream: "input_select" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + ranges: { begin: 2 end: 3 } + ranges: { begin: 3 end: 4 } + element_only: true + } + } + } + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:stream0" + input_stream: "INPUT:1:stream1" + input_stream: "INPUT:2:stream2" + input_stream: "SELECT:input_select" + output_stream: "OUTPUT:test_output" + input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" } + } +)proto"; + +// Graph with default input stream handler, and the input selection is driven +// by an input side packet. All MuxCalculator inputs are present at each +// timestamp. +constexpr char kTestGraphConfig2[] = R"proto( + input_side_packet: "input_selector" + input_stream: "input" + output_stream: "test_output" + node { + calculator: "SplitIntVectorCalculator" + input_stream: "input" + output_stream: "stream0" + output_stream: "stream1" + output_stream: "stream2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + ranges: { begin: 2 end: 3 } + element_only: true + } + } + } + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:stream0" + input_stream: "INPUT:1:stream1" + input_stream: "INPUT:2:stream2" + input_side_packet: "SELECT:input_selector" + output_stream: "OUTPUT:test_output" + input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" } + } +)proto"; + +// Graph with mux input stream handler, and the input selection is driven +// by an input stream. Only one MuxCalculator input is present at each +// timestamp. +constexpr char kTestGraphConfig3[] = R"proto( + input_stream: "input" + output_stream: "test_output" + node { + calculator: "RoundRobinDemuxCalculator" + input_stream: "input" + output_stream: "OUTPUT:0:stream0" + output_stream: "OUTPUT:1:stream1" + output_stream: "OUTPUT:2:stream2" + output_stream: "SELECT:input_select" + } + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:stream0" + input_stream: "INPUT:1:stream1" + input_stream: "INPUT:2:stream2" + input_stream: "SELECT:input_select" + output_stream: "OUTPUT:test_output" + } +)proto"; + +constexpr char kOutputName[] = "test_output"; +constexpr char kInputName[] = "input"; +constexpr char kInputSelector[] = "input_selector"; + +// Helper to run a graph with the given inputs and generate outputs, asserting +// each step along the way. +// Inputs: +// graph_config_proto - graph config protobuf +// extra_side_packets - input side packets name to value map +// input_stream_name - name of the input +void RunGraph(const std::string& graph_config_proto, + const std::map& extra_side_packets, + const std::string& input_stream_name, int num_input_packets, + std::function input_fn, + const std::string& output_stream_name, + std::function<::mediapipe::Status(const Packet&)> output_fn) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie( + graph_config_proto); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream(output_stream_name, output_fn)); + MP_ASSERT_OK(graph.StartRun(extra_side_packets)); + for (int i = 0; i < num_input_packets; ++i) { + MP_ASSERT_OK(graph.AddPacketToInputStream(input_stream_name, input_fn(i))); + } + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) { + // Input and handling. + std::vector> input_packets = { + {1, 1, 2, 1}, {3, 5, 8, 2}, {13, 21, 34, 0}, + {55, 89, 144, 2}, {233, 377, 610, 0}, {987, 1597, 2584, 1}, + {4181, 6765, 10946, 2}, + }; + int packet_time_stamp = 22; + // This function will return the i-th input packet. + auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet { + return MakePacket>(input_packets[i]) + .At(Timestamp(packet_time_stamp++)); + }; + + // Output and handling. + std::vector output; + // This function collects the output from the packet. + auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + output.push_back(p.Get()); + return ::mediapipe::OkStatus(); + }; + + RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn, + kOutputName, output_fn); + EXPECT_THAT(output, testing::ElementsAre(1, 8, 13, 144, 233, 1597, 10946)); +} + +TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) { + // Input and handling. + std::vector> input_packets = { + {1, 1, 2}, {3, 5, 8}, {13, 21, 34}, {55, 89, 144}, + {233, 377, 610}, {987, 1597, 2584}, {4181, 6765, 10946}, + }; + int packet_time_stamp = 22; + // This function will return the i-th input packet. + auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet { + return MakePacket>(input_packets[i]) + .At(Timestamp(packet_time_stamp++)); + }; + + // Output and handling. + std::vector output; + // This function collects the output from the packet. + auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + output.push_back(p.Get()); + return ::mediapipe::OkStatus(); + }; + + RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket(0)}}, + kInputName, input_packets.size(), input_fn, kOutputName, output_fn); + EXPECT_THAT(output, testing::ElementsAre(1, 3, 13, 55, 233, 987, 4181)); + + output.clear(); + RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket(1)}}, + kInputName, input_packets.size(), input_fn, kOutputName, output_fn); + EXPECT_THAT(output, testing::ElementsAre(1, 5, 21, 89, 377, 1597, 6765)); + + output.clear(); + RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket(2)}}, + kInputName, input_packets.size(), input_fn, kOutputName, output_fn); + EXPECT_THAT(output, testing::ElementsAre(2, 8, 34, 144, 610, 2584, 10946)); +} + +TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) { + // Input and handling. + std::vector input_packets = {1, 1, 2, 3, 5, 8, 13, + 21, 34, 55, 89, 144, 233, 377, + 610, 987, 1597, 2584, 4181, 6765, 10946}; + int packet_time_stamp = 22; + // This function will return the i-th input packet. + auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet { + return MakePacket(input_packets[i]).At(Timestamp(packet_time_stamp++)); + }; + + // Output and handling. + std::vector output; + // This function collects the output from the packet. + auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status { + output.push_back(p.Get()); + return ::mediapipe::OkStatus(); + }; + + RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn, + kOutputName, output_fn); + EXPECT_EQ(output, input_packets); +} +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index 9d14ec956..8cbf04410 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -128,11 +128,17 @@ class PreviousLoopbackCalculator : public CalculatorBase { loop_packets_.pop_front(); main_packet_specs_.pop_front(); } + + // We can close PREV_LOOP output stream as soon as we processed last + // possible MAIN packet. That can happen in two cases: + // a) Non-empty MAIN packet has been received with Timestamp::Max() + // b) Empty MAIN packet has been received with Timestamp::Max() indicating + // MAIN is done. + if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) { + prev_loop.Close(); + } } - if (main_packet_specs_.empty() && cc->Inputs().Get(main_id_).IsDone()) { - prev_loop.Close(); - } return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index 0fabacd57..ef469b43a 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -228,6 +228,104 @@ TEST(PreviousLoopbackCalculator, ClosesCorrectly) { MP_EXPECT_OK(graph_.WaitUntilDone()); } +TEST(PreviousLoopbackCalculator, ProcessesMaxTimestamp) { + std::vector out_and_previous_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PreviousLoopbackCalculator' + input_stream: 'MAIN:in' + input_stream: 'LOOP:out' + input_stream_info: { tag_index: 'LOOP' back_edge: true } + output_stream: 'PREV_LOOP:previous' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + input_stream: 'previous' + output_stream: 'out' + output_stream: 'previous2' + } + node { + calculator: 'MakePairCalculator' + input_stream: 'out' + input_stream: 'previous' + output_stream: 'out_and_previous' + } + )"); + tool::AddVectorSink("out_and_previous", &graph_config, + &out_and_previous_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(1).At(Timestamp::Max()))); + + MP_EXPECT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(out_and_previous_packets, + ElementsAre(PairPacket(Timestamp::Max(), + Pair(IntPacket(1), EmptyPacket())))); + + MP_EXPECT_OK(graph.CloseAllInputStreams()); + MP_EXPECT_OK(graph.WaitUntilIdle()); + MP_EXPECT_OK(graph.WaitUntilDone()); +} + +TEST(PreviousLoopbackCalculator, ProcessesMaxTimestampNonEmptyPrevious) { + std::vector out_and_previous_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PreviousLoopbackCalculator' + input_stream: 'MAIN:in' + input_stream: 'LOOP:out' + input_stream_info: { tag_index: 'LOOP' back_edge: true } + output_stream: 'PREV_LOOP:previous' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + input_stream: 'previous' + output_stream: 'out' + output_stream: 'previous2' + } + node { + calculator: 'MakePairCalculator' + input_stream: 'out' + input_stream: 'previous' + output_stream: 'out_and_previous' + } + )"); + tool::AddVectorSink("out_and_previous", &graph_config, + &out_and_previous_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(1).At(Timestamp::Min()))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(2).At(Timestamp::Max()))); + + MP_EXPECT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT( + out_and_previous_packets, + ElementsAre( + PairPacket(Timestamp::Min(), Pair(IntPacket(1), EmptyPacket())), + PairPacket(Timestamp::Max(), Pair(IntPacket(2), IntPacket(1))))); + + MP_EXPECT_OK(graph.CloseAllInputStreams()); + MP_EXPECT_OK(graph.WaitUntilIdle()); + MP_EXPECT_OK(graph.WaitUntilDone()); +} + // Demonstrates that downstream calculators won't be blocked by // always-empty-LOOP-stream. TEST(PreviousLoopbackCalculator, EmptyLoopForever) { diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index d7df7530b..47c3f624b 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -34,6 +34,8 @@ constexpr char kTagAtPostStream[] = "AT_POSTSTREAM"; constexpr char kTagAtZero[] = "AT_ZERO"; constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagTick[] = "TICK"; +constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP"; +constexpr char kTagSideInputTimestamp[] = "TIMESTAMP"; static std::map* kTimestampMap = []() { auto* res = new std::map(); @@ -41,6 +43,7 @@ static std::map* kTimestampMap = []() { res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtTick, Timestamp::Unset()); + res->emplace(kTagAtTimestamp, Timestamp::Unset()); return res; }(); @@ -56,9 +59,10 @@ std::string GetOutputTag(const CC& cc) { // timestamp, depending on the tag used to define output stream(s). (One tag can // be used only.) // -// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK and -// corresponding timestamps are Timestamp::PreStream(), Timestamp::PostStream(), -// Timestamp(0) and timestamp of a packet received in TICK input. +// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP +// and corresponding timestamps are Timestamp::PreStream(), +// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK +// input, and timestamp received from a side input. // // Examples: // node { @@ -73,6 +77,13 @@ std::string GetOutputTag(const CC& cc) { // input_side_packet: "side_packet" // output_stream: "AT_TICK:packet" // } +// +// node { +// calculator: "SidePacketToStreamCalculator" +// input_side_packet: "TIMESTAMP:timestamp" +// input_side_packet: "side_packet" +// output_stream: "AT_TIMESTAMP:packet" +// } class SidePacketToStreamCalculator : public CalculatorBase { public: SidePacketToStreamCalculator() = default; @@ -93,16 +104,29 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); CalculatorContract* cc) { const auto& tags = cc->Outputs().GetTags(); RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) - << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK tags is " - "allowed and required to specify output stream(s)."; + << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " + "AT_TIMESTAMP tags is allowed and required to specify output " + "stream(s)."; RET_CHECK( (cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || (!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) << "Either both of TICK and AT_TICK should be used or none of them."; + RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) && + cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) || + (!cc->Outputs().HasTag(kTagAtTimestamp) && + !cc->InputSidePackets().HasTag(kTagSideInputTimestamp))) + << "Either both TIMESTAMP and AT_TIMESTAMP should be used or none of " + "them."; const std::string output_tag = GetOutputTag(*cc); const int num_entries = cc->Outputs().NumEntries(output_tag); - RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) - << "Same number of input side packets and output streams is required."; + if (cc->Outputs().HasTag(kTagAtTimestamp)) { + RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries()) + << "For AT_TIMESTAMP tag, 2 input side packets are required."; + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set(); + } else { + RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) + << "Same number of input side packets and output streams is required."; + } for (int i = 0; i < num_entries; ++i) { cc->InputSidePackets().Index(i).SetAny(); cc->Outputs() @@ -147,13 +171,22 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); } ::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { - if (!cc->Outputs().HasTag(kTagAtTick)) { + if (!cc->Outputs().HasTag(kTagAtTick) && + !cc->Outputs().HasTag(kTagAtTimestamp)) { const auto& timestamp = kTimestampMap->at(output_tag_); for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { cc->Outputs() .Get(output_tag_, i) .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); } + } else if (cc->Outputs().HasTag(kTagAtTimestamp)) { + int64 timestamp = + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get(); + for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { + cc->Outputs() + .Get(output_tag_, i) + .AddPacket(cc->InputSidePackets().Index(i).At(Timestamp(timestamp))); + } } return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc index 078055f07..e7195e03b 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc @@ -51,6 +51,27 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { "Either both of TICK and AT_TICK should be used or none of them."); } +TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"( + input_stream: "timestamp" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_side_packet: "side_packet" + output_stream: "AT_TIMESTAMP:packet" + } + )"); + CalculatorGraph graph; + auto status = graph.Initialize(graph_config); + EXPECT_FALSE(status.ok()); + EXPECT_PRED2( + absl::StrContains, status.message(), + "Either both TIMESTAMP and AT_TIMESTAMP should be used or none of them."); +} + TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( @@ -68,8 +89,9 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); EXPECT_PRED2(absl::StrContains, status.message(), - "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK " - "tags is allowed and required to specify output stream(s)."); + "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " + "AT_TIMESTAMP tags is allowed and required to specify output " + "stream(s)."); } TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { @@ -91,8 +113,9 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); EXPECT_PRED2(absl::StrContains, status.message(), - "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK " - "tags is allowed and required to specify output stream(s)."); + "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " + "AT_TIMESTAMP tags is allowed and required to specify output " + "stream(s)."); } TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { @@ -271,5 +294,79 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { tick_and_verify(/*at_timestamp=*/1025); } +TEST(SidePacketToStreamCalculator, AtTimestamp) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"( + input_side_packet: "timestamp" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_side_packet: "TIMESTAMP:timestamp" + input_side_packet: "side_packet" + output_stream: "AT_TIMESTAMP:packet" + } + )"); + std::vector output_packets; + tool::AddVectorSink("packet", &graph_config, &output_packets); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value = 20; + const int64 expected_timestamp = 5; + MP_ASSERT_OK( + graph.StartRun({{"side_packet", MakePacket(expected_value)}, + {"timestamp", MakePacket(expected_timestamp)}})); + + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_FALSE(output_packets.empty()); + EXPECT_EQ(Timestamp(expected_timestamp), output_packets.back().Timestamp()); + EXPECT_EQ(expected_value, output_packets.back().Get()); +} + +TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"( + input_side_packet: "timestamp" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_side_packet: "TIMESTAMP:timestamp" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "AT_TIMESTAMP:0:packet0" + output_stream: "AT_TIMESTAMP:1:packet1" + } + )"); + std::vector output_packets0; + tool::AddVectorSink("packet0", &graph_config, &output_packets0); + std::vector output_packets1; + tool::AddVectorSink("packet1", &graph_config, &output_packets1); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value0 = 20; + const int expected_value1 = 15; + const int64 expected_timestamp = 5; + MP_ASSERT_OK( + graph.StartRun({{"side_packet0", MakePacket(expected_value0)}, + {"side_packet1", MakePacket(expected_value1)}, + {"timestamp", MakePacket(expected_timestamp)}})); + + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_FALSE(output_packets0.empty()); + EXPECT_EQ(Timestamp(expected_timestamp), output_packets0.back().Timestamp()); + EXPECT_EQ(expected_value0, output_packets0.back().Get()); + ASSERT_FALSE(output_packets1.empty()); + EXPECT_EQ(Timestamp(expected_timestamp), output_packets1.back().Timestamp()); + EXPECT_EQ(expected_value1, output_packets1.back().Get()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index cb5f6419e..37539d814 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -449,19 +449,15 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); switch (rotation_) { case mediapipe::RotationMode_Mode_UNKNOWN: case mediapipe::RotationMode_Mode_ROTATION_0: - LOG(ERROR) << "Not rotating image."; rotated_mat = input_mat; break; case mediapipe::RotationMode_Mode_ROTATION_90: - LOG(ERROR) << "Rotating image by 90 degrees ccw."; cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE); break; case mediapipe::RotationMode_Mode_ROTATION_180: - LOG(ERROR) << "Rotating image by 180 degrees."; cv::rotate(input_mat, rotated_mat, cv::ROTATE_180); break; case mediapipe::RotationMode_Mode_ROTATION_270: - LOG(ERROR) << "Rotating image by 90 degrees cw."; cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE); break; } diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index ea3fcc715..f934bd5a4 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -57,22 +57,6 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) -proto_library( - name = "tensorflow_session_from_saved_model_generator_proto", - srcs = ["tensorflow_session_from_saved_model_generator.proto"], - visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:packet_generator_proto"], -) - -proto_library( - name = "tensorflow_session_from_saved_model_calculator_proto", - srcs = ["tensorflow_session_from_saved_model_calculator.proto"], - visibility = ["//visibility:public"], - deps = [ - "//mediapipe/framework:calculator_proto", - ], -) - proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], @@ -212,7 +196,10 @@ mediapipe_cc_proto_library( mediapipe_cc_proto_library( name = "tensorflow_session_from_saved_model_generator_cc_proto", srcs = ["tensorflow_session_from_saved_model_generator.proto"], - cc_deps = ["//mediapipe/framework:packet_generator_cc_proto"], + cc_deps = [ + "//mediapipe/framework:packet_generator_cc_proto", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -220,7 +207,10 @@ mediapipe_cc_proto_library( mediapipe_cc_proto_library( name = "tensorflow_session_from_saved_model_calculator_cc_proto", srcs = ["tensorflow_session_from_saved_model_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -488,6 +478,8 @@ cc_library( "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", ] + select({ @@ -518,6 +510,8 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/tool:status_util", "//mediapipe/framework/port:status", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [ @@ -929,6 +923,7 @@ cc_test( "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) @@ -954,6 +949,7 @@ cc_test( "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 53f6b70f8..7975d4c9d 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -26,9 +26,14 @@ #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/monotonic_clock.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/public/session_options.h" #if defined(MEDIAPIPE_MOBILE) @@ -41,6 +46,17 @@ namespace mediapipe { namespace tf = ::tensorflow; +namespace { +// Updates the graph nodes to use the device as specified by device_id. +void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { + for (auto& node : *graph_def->mutable_node()) { + if (node.device().empty()) { + node.set_device(device_id); + } + } +} +} // namespace + class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { @@ -77,6 +93,9 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) override { + auto clock = std::unique_ptr( + mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); + const uint64 start_time = absl::ToUnixMicros(clock->TimeNow()); const auto& options = cc->Options(); // Output bundle packet. @@ -108,6 +127,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { tensorflow::GraphDef graph_def; RET_CHECK(graph_def.ParseFromString(graph_def_serialized)); + + // Update the graph nodes to use the preferred device, if set. + if (!options.preferred_device_id().empty()) { + SetPreferredDevice(&graph_def, options.preferred_device_id()); + } + const tf::Status tf_status = session->session->Create(graph_def); RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString(); @@ -123,6 +148,9 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { } cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); + LOG(INFO) << "Loaded frozen model in: " << end_time - start_time + << " microseconds."; return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.proto index 3921d2016..87b2304ad 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.proto @@ -69,4 +69,12 @@ message TensorFlowSessionFromFrozenGraphCalculatorOptions { // Graph nodes to run to initialize the model. Any output of these ops is // ignored. repeated string initialization_op_names = 4; + + // The id of the device you would prefer to execute the graph nodes on. + // If set, all graph nodes without a previously specified device, will be set + // to run on preferred_device_id. Example values include: + // ["/device:GPU:0","/device:CPU:0", ...] + // NOTE: If config.allow_soft_placement = false, and the device is not found, + // an error will be thrown. + optional string preferred_device_id = 5; } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index c2b774278..5277eb348 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -66,6 +66,7 @@ class TensorFlowSessionFromFrozenGraphCalculatorTest : public ::testing::Test { (*calculator_options_->mutable_tag_to_tensor_names())["B"] = "b:0"; calculator_options_->mutable_config()->set_intra_op_parallelism_threads(1); calculator_options_->mutable_config()->set_inter_op_parallelism_threads(2); + calculator_options_->set_preferred_device_id("/device:CPU:0"); } void VerifySignatureMap(const TensorFlowSession& session) { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index 4f71336f2..0cb4a70b5 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -27,16 +27,32 @@ #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/monotonic_clock.h" #include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/public/session_options.h" namespace mediapipe { namespace tf = ::tensorflow; +namespace { +// Updates the graph nodes to use the device as specified by device_id. +void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { + for (auto& node : *graph_def->mutable_node()) { + if (node.device().empty()) { + node.set_device(device_id); + } + } +} +} // namespace + class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { public: static ::mediapipe::Status FillExpectations( @@ -77,6 +93,9 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { static ::mediapipe::Status Generate( const PacketGeneratorOptions& packet_generator_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { + auto clock = std::unique_ptr( + mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); + const uint64 start_time = absl::ToUnixMicros(clock->TimeNow()); const TensorFlowSessionFromFrozenGraphGeneratorOptions& options = packet_generator_options.GetExtension( TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); @@ -108,6 +127,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { tensorflow::GraphDef graph_def; RET_CHECK(graph_def.ParseFromString(graph_def_serialized)); + + // Update the graph nodes to use the preferred device, if set. + if (!options.preferred_device_id().empty()) { + SetPreferredDevice(&graph_def, options.preferred_device_id()); + } + const tf::Status tf_status = session->session->Create(graph_def); RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString(); @@ -123,6 +148,9 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { } output_side_packets->Tag("SESSION") = Adopt(session.release()); + const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); + LOG(INFO) << "Loaded frozen model in: " << end_time - start_time + << " microseconds."; return ::mediapipe::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto index 183b5a5a5..4643b4d60 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto @@ -69,4 +69,12 @@ message TensorFlowSessionFromFrozenGraphGeneratorOptions { // Graph nodes to run to initialize the model. Any output of these ops is // ignored. repeated string initialization_op_names = 4; + + // The id of the device you would prefer to execute the graph nodes on. + // If set, all graph nodes without a previously specified device, will be set + // to run on preferred_device_id. Example values include: + // ["/device:GPU:0","/device:CPU:0", ...] + // NOTE: If config.allow_soft_placement = false, and the device is not found, + // an error will be thrown. + optional string preferred_device_id = 5; } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index d11007299..e2b968217 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -66,6 +66,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { (*generator_options_->mutable_tag_to_tensor_names())["B"] = "b:0"; generator_options_->mutable_config()->set_intra_op_parallelism_threads(1); generator_options_->mutable_config()->set_inter_op_parallelism_threads(2); + generator_options_->set_preferred_device_id("/device:CPU:0"); } void VerifySignatureMap(PacketSet* output_side_packets) { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index b54976478..55709bcd9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -134,8 +134,8 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { } tensorflow::RunOptions run_options; - // In the future, could construct session options from the options proto. tensorflow::SessionOptions session_options; + session_options.config = options.session_config(); auto saved_model = absl::make_unique(); ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 66e03d893..a8839ef52 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "tensorflow/core/protobuf/config.proto"; message TensorFlowSessionFromSavedModelCalculatorOptions { extend mediapipe.CalculatorOptions { @@ -55,4 +56,7 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // If no tag is specified, then use "serve" as the default. Note that in order // to use TPU accelerator hardware, the tag "tpu" needs to be specified. repeated string saved_model_tag = 6; + + // Tensorflow session config options. + optional tensorflow.ConfigProto session_config = 7; } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index fee0da0fb..d6064d862 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -26,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/core/framework/device_attributes.pb.h" namespace mediapipe { @@ -204,5 +205,31 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, ASSERT_NE(session.session, nullptr); } +TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, + ConfiguresSessionGivenConfig) { + options_->set_saved_model_path( + std::string(file::SplitPath(GetSavedModelDir()).first)); + options_->set_load_latest_model(true); + options_->mutable_session_config()->mutable_device_count()->insert( + {"CPU", 10}); + CalculatorRunner runner(absl::Substitute(R"( + calculator: "TensorFlowSessionFromSavedModelCalculator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { + $0 + } + })", + options_->DebugString())); + MP_ASSERT_OK(runner.Run()); + const TensorFlowSession& session = + runner.OutputSidePackets().Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); + std::vector devices; + ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); + EXPECT_THAT(devices.size(), 10); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 6e1a29e59..73ffc6497 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -129,8 +129,8 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { } tensorflow::RunOptions run_options; - // In the future, could construct session options from the options proto. tensorflow::SessionOptions session_options; + session_options.config = options.session_config(); auto saved_model = absl::make_unique(); ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index 2dab09242..88ce93435 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/packet_generator.proto"; +import "tensorflow/core/protobuf/config.proto"; message TensorFlowSessionFromSavedModelGeneratorOptions { extend mediapipe.PacketGeneratorOptions { @@ -55,4 +56,7 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // If no tag is specified, then use "serve" as the default. Note that in order // to use TPU accelerator hardware, the tag "tpu" needs to be specified. repeated string saved_model_tag = 6; + + // Tensorflow session config options. + optional tensorflow.ConfigProto session_config = 9; } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index d12fee12a..792c3841b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -25,6 +25,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/core/framework/device_attributes.pb.h" namespace mediapipe { @@ -196,5 +197,29 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, ASSERT_NE(session.session, nullptr); } +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + ConfiguresSessionGivenConfig) { + generator_options_->set_saved_model_path( + std::string(file::SplitPath(GetSavedModelDir()).first)); + generator_options_->set_load_latest_model(true); + generator_options_->mutable_session_config()->mutable_device_count()->insert( + {"CPU", 10}); + + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MP_EXPECT_OK(run_status) << run_status.message(); + const TensorFlowSession& session = + output_side_packets.Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); + std::vector devices; + ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); + EXPECT_THAT(devices.size(), 10); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc index 068be5714..f7c041788 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc @@ -91,11 +91,11 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator); cc->Inputs().Index(0).Value().Get>>(); const int32 rows = input.size(); - CHECK_GE(rows, 1); + RET_CHECK_GE(rows, 1); const int32 cols = input[0].size(); - CHECK_GE(cols, 1); + RET_CHECK_GE(cols, 1); for (int i = 1; i < rows; ++i) { - CHECK_EQ(input[i].size(), cols); + RET_CHECK_EQ(input[i].size(), cols); } if (options_.transpose()) { tensor_shape = tf::TensorShape({cols, rows}); @@ -116,7 +116,7 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator); } else if (options_.input_size() == INPUT_1D) { const std::vector& input = cc->Inputs().Index(0).Value().Get>(); - CHECK_GE(input.size(), 1); + RET_CHECK_GE(input.size(), 1); const int32 length = input.size(); tensor_shape = tf::TensorShape({length}); auto output = ::absl::make_unique(tf::DT_FLOAT, tensor_shape); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index f1101a009..2c4bb637b 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -196,13 +196,6 @@ cc_test( ], ) -cc_library( - name = "util", - hdrs = ["util.h"], - visibility = ["//visibility:public"], - alwayslink = 1, -) - selects.config_setting_group( name = "gpu_inference_disabled", match_any = [ @@ -229,7 +222,6 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - ":util", ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", "//mediapipe/framework:calculator_framework", @@ -295,7 +287,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/util/tflite:config", - ":util", ":tflite_converter_calculator_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", @@ -334,7 +325,6 @@ cc_library( srcs = ["tflite_model_calculator.cc"], visibility = ["//visibility:public"], deps = [ - ":util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework/port:ret_check", @@ -348,7 +338,6 @@ cc_library( srcs = ["tflite_tensors_to_segmentation_calculator.cc"], visibility = ["//visibility:public"], deps = [ - ":util", ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -418,7 +407,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/util/tflite:config", - ":util", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -551,6 +539,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/lite:framework", ], ) diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index 6a3011141..e81354242 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -16,7 +16,6 @@ #include #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" -#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" @@ -146,8 +145,7 @@ class TfLiteConverterCalculator : public CalculatorBase { ::mediapipe::Status LoadOptions(CalculatorContext* cc); template ::mediapipe::Status NormalizeImage(const ImageFrame& image_frame, - bool zero_center, bool flip_vertically, - float* tensor_ptr); + bool flip_vertically, float* tensor_ptr); ::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); ::mediapipe::Status ProcessCPU(CalculatorContext* cc); @@ -165,10 +163,7 @@ class TfLiteConverterCalculator : public CalculatorBase { bool initialized_ = false; bool use_gpu_ = false; - bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1] - bool use_custom_normalization_ = false; - float custom_div_ = -1.0f; - float custom_sub_ = -1.0f; + absl::optional> output_range_; bool flip_vertically_ = false; bool row_major_matrix_ = false; bool use_quantized_tensors_ = false; @@ -362,11 +357,11 @@ bool ShouldUseGpu(CC* cc) { float* tensor_buffer = tensor->data.f; RET_CHECK(tensor_buffer); if (image_frame.ByteDepth() == 1) { - MP_RETURN_IF_ERROR(NormalizeImage( - image_frame, zero_center_, flip_vertically_, tensor_buffer)); + MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, + tensor_buffer)); } else if (image_frame.ByteDepth() == 4) { - MP_RETURN_IF_ERROR(NormalizeImage( - image_frame, zero_center_, flip_vertically_, tensor_buffer)); + MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, + tensor_buffer)); } else { return ::mediapipe::InternalError( "Only byte-based (8 bit) and float (32 bit) images supported."); @@ -427,11 +422,11 @@ bool ShouldUseGpu(CC* cc) { auto src = gpu_helper_.CreateSourceTexture(input); glActiveTexture(GL_TEXTURE0 + 0); glBindTexture(GL_TEXTURE_2D, src.name()); - RET_CHECK_CALL(gpu_data_out_->buffer.BindToIndex(1)); + MP_RETURN_IF_ERROR(gpu_data_out_->buffer.BindToIndex(1)); const tflite::gpu::uint3 workgroups = { NumGroups(input.width(), kWorkgroupSize), NumGroups(input.height(), kWorkgroupSize), 1}; - RET_CHECK_CALL(gpu_data_out_->program.Dispatch(workgroups)); + MP_RETURN_IF_ERROR(gpu_data_out_->program.Dispatch(workgroups)); glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindTexture(GL_TEXTURE_2D, 0); src.Release(); @@ -445,9 +440,9 @@ bool ShouldUseGpu(CC* cc) { output_tensors->resize(1); { GpuTensor& tensor = output_tensors->at(0); - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( gpu_data_out_->elements, &tensor)); - RET_CHECK_CALL(CopyBuffer(gpu_data_out_->buffer, tensor)); + MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_->buffer, tensor)); } return ::mediapipe::OkStatus(); })); @@ -521,7 +516,7 @@ bool ShouldUseGpu(CC* cc) { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { // Device memory. - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( gpu_data_out_->elements, &gpu_data_out_->buffer)); @@ -544,7 +539,13 @@ bool ShouldUseGpu(CC* cc) { $6 // alpha channel })", /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), - /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", + /*$3=*/ + output_range_.has_value() + ? absl::Substitute( + "pixel = pixel * float($0) + float($1);", + (output_range_->second - output_range_->first), + output_range_->first) + : "", /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", /*$5=*/ single_channel @@ -555,10 +556,10 @@ bool ShouldUseGpu(CC* cc) { include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" : "", /*$7=*/max_num_channels_); - RET_CHECK_CALL(GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, - &gpu_data_out_->shader)); - RET_CHECK_CALL(GlProgram::CreateWithShader(gpu_data_out_->shader, - &gpu_data_out_->program)); + MP_RETURN_IF_ERROR(GlShader::CompileShader( + GL_COMPUTE_SHADER, shader_source, &gpu_data_out_->shader)); + MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( + gpu_data_out_->shader, &gpu_data_out_->program)); return ::mediapipe::OkStatus(); })); @@ -599,7 +600,12 @@ bool ShouldUseGpu(CC* cc) { )", /*$0=*/include_alpha ? "float4" : "float3", /*$1=*/include_alpha ? "rgba" : "rgb", - /*$2=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", + /*$2=*/ + output_range_.has_value() + ? absl::Substitute("pixel = pixel * float($0) + float($1);", + (output_range_->second - output_range_->first), + output_range_->first) + : "", /*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y", /*$4=*/include_alpha ? 4 : 3, /*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : ""); @@ -630,13 +636,27 @@ bool ShouldUseGpu(CC* cc) { const auto& options = cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>(); - // Get data normalization mode. - zero_center_ = options.zero_center(); + // if zero_center, set output float range to match [-1, 1] as specified in + // calculator proto. + if (options.zero_center()) { + output_range_.emplace(std::pair(-1.0, 1.0)); + } + + // Custom output_tensor_float_range values. + // If the float range is specified in pb text, use the specified values + // instead. + if (options.has_output_tensor_float_range()) { + output_range_.emplace(options.output_tensor_float_range().min(), + options.output_tensor_float_range().max()); + CHECK_GT(output_range_->second, output_range_->first); + } // Custom div and sub values. - use_custom_normalization_ = options.use_custom_normalization(); - custom_div_ = options.custom_div(); - custom_sub_ = options.custom_sub(); + if (options.use_custom_normalization()) { + output_range_.emplace(std::pair( + -options.custom_sub(), + -options.custom_sub() + 255.0 / options.custom_div())); + } // Get y-flip mode. flip_vertically_ = options.flip_vertically(); @@ -664,40 +684,46 @@ bool ShouldUseGpu(CC* cc) { template ::mediapipe::Status TfLiteConverterCalculator::NormalizeImage( - const ImageFrame& image_frame, bool zero_center, bool flip_vertically, - float* tensor_ptr) { + const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) { const int height = image_frame.Height(); const int width = image_frame.Width(); const int channels = image_frame.NumberOfChannels(); const int channels_preserved = std::min(channels, max_num_channels_); const int channels_ignored = channels - channels_preserved; - float div, sub; + if (output_range_.has_value()) { + // If the output float range is set and we are not using custom + // normalization, normalize the pixel values from [0, 255] to the specified + // output range. + RET_CHECK_NE(output_range_->first, output_range_->second); + const float scale = (output_range_->second - output_range_->first) / 255.0f; + const float bias = output_range_->first; - if (use_custom_normalization_) { - RET_CHECK_GT(custom_div_, 0.0f); - RET_CHECK_GE(custom_sub_, 0.0f); - div = custom_div_; - sub = custom_sub_; - } else if (zero_center) { - // [-1,1] - div = 127.5f; - sub = 1.0f; - } else { - // [0,1] - div = 255.0f; - sub = 0.0f; - } - - for (int i = 0; i < height; ++i) { - const T* image_ptr = reinterpret_cast( - image_frame.PixelData() + - (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); - for (int j = 0; j < width; ++j) { - for (int c = 0; c < channels_preserved; ++c) { - *tensor_ptr++ = *image_ptr++ / div - sub; + for (int i = 0; i < height; ++i) { + const T* image_ptr = reinterpret_cast( + image_frame.PixelData() + + (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); + for (int j = 0; j < width; ++j) { + for (int c = 0; c < channels_preserved; ++c) { + *tensor_ptr++ = *image_ptr++ * scale + bias; + } + image_ptr += channels_ignored; + } + } + } else { + // [0,1], scale only (bias == 0) + // Verified that there are no precision issues with 1.0f / 255.0f expression + const float scale = 1.0f / 255.0f; + for (int i = 0; i < height; ++i) { + const T* image_ptr = reinterpret_cast( + image_frame.PixelData() + + (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); + for (int j = 0; j < width; ++j) { + for (int c = 0; c < channels_preserved; ++c) { + *tensor_ptr++ = *image_ptr++ * scale; + } + image_ptr += channels_ignored; } - image_ptr += channels_ignored; } } diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.proto b/mediapipe/calculators/tflite/tflite_converter_calculator.proto index 4d468c851..5ed70879d 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.proto @@ -56,4 +56,14 @@ message TfLiteConverterCalculatorOptions { // Quantization option (CPU only). // When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32. optional bool use_quantized_tensors = 5 [default = false]; + + // Normalization option. + // Setting normalization_range results in the values normalized to + // the range [output_tensor_float_range.min, output_tensor_float_range.max]. + optional TensorFloatRange output_tensor_float_range = 9; + + message TensorFloatRange { + optional float min = 1; + optional float max = 2; + } } diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc index cecf84e6f..c8762b09b 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc @@ -16,6 +16,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/substitute.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -40,6 +41,7 @@ constexpr char kTransposeOptionsString[] = } // namespace using RandomEngine = std::mt19937_64; +using testing::Eq; const uint32 kSeed = 1234; const int kNumSizes = 8; const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, @@ -232,7 +234,6 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) { // Wait until the calculator done processing. MP_ASSERT_OK(graph.WaitUntilIdle()); - EXPECT_EQ(1, output_packets.size()); // Get and process results. const std::vector& tensor_vec = @@ -249,4 +250,70 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST_F(TfLiteConverterCalculatorTest, SetOutputRange) { + std::vector> range_values = { + std::make_pair(0.0, 1.0), std::make_pair(-1.0, 1.0), + std::make_pair(-0.5, 0.5)}; + for (std::pair range : range_values) { + CalculatorGraph graph; + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "input_image" + node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TfLiteConverterCalculatorOptions.ext] { + output_tensor_float_range { + min: $0 + max: $1 + } + } + } + } + )", + /*$0=*/range.first, + /*$1=*/range.second)); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); + cv::Mat mat = ::mediapipe::formats::MatView(input_image.get()); + mat.at(0, 0) = 200; + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator finishes processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets.size(), Eq(1)); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_THAT(tensor_vec.size(), Eq(1)); + + const TfLiteTensor* tensor = &tensor_vec[0]; + + // Calculate the expected normalized value: + float normalized_value = + range.first + (200 * (range.second - range.first)) / 255.0; + + EXPECT_THAT(tensor->type, Eq(kTfLiteFloat32)); + EXPECT_THAT(normalized_value, + testing::FloatNear(*tensor->data.f, + 2.0f * std::abs(*tensor->data.f) * + std::numeric_limits::epsilon())); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); + } +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index cd881102d..8ed8a7ae8 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -19,7 +19,6 @@ #include "absl/memory/memory.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" -#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/tflite/config.h" @@ -496,7 +495,7 @@ bool ShouldUseGpu(CC* cc) { output_tensors_gpu->resize(gpu_data_out_.size()); for (int i = 0; i < gpu_data_out_.size(); ++i) { GpuTensor& tensor = output_tensors_gpu->at(i); - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( gpu_data_out_[i]->elements, &tensor)); MP_RETURN_IF_ERROR( tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i)); @@ -518,7 +517,7 @@ bool ShouldUseGpu(CC* cc) { // Explicit copy input. gpu_data_in_.resize(input_tensors.size()); for (int i = 0; i < input_tensors.size(); ++i) { - RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer)); + MP_RETURN_IF_ERROR(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer)); } #elif MEDIAPIPE_TFLITE_METAL_INFERENCE const auto& input_tensors = @@ -582,7 +581,7 @@ bool ShouldUseGpu(CC* cc) { for (int i = 0; i < tensor_indexes.size(); ++i) { TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); std::vector gpu_data(tensor->bytes / sizeof(float)); - RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read( + MP_RETURN_IF_ERROR(gpu_data_out_[i]->buffer.Read( absl::MakeSpan(tensor->data.f, tensor->bytes))); output_tensors_cpu->emplace_back(*tensor); } @@ -599,9 +598,9 @@ bool ShouldUseGpu(CC* cc) { for (int i = 0; i < gpu_data_out_.size(); ++i) { GpuTensor& tensor = output_tensors_gpu->at(i); // Allocate output tensor. - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( gpu_data_out_[i]->elements, &tensor)); - RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); + MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); } cc->Outputs() .Tag(kTensorsGpuTag) @@ -655,7 +654,8 @@ bool ShouldUseGpu(CC* cc) { options.priority3 = tflite::gpu::InferencePriority::AUTO; options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; tflite_gpu_runner_ = std::make_unique(options); - RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + MP_RETURN_IF_ERROR( + tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); // Allocate interpreter memory for cpu output. if (!gpu_output_) { @@ -688,10 +688,11 @@ bool ShouldUseGpu(CC* cc) { ASSIGN_OR_RETURN(gpu_data_out_[i]->elements, tflite_gpu_runner_->GetOutputElements(i)); // Create and bind input buffer. - RET_CHECK_CALL(::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); + MP_RETURN_IF_ERROR( + ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); } - RET_CHECK_CALL(tflite_gpu_runner_->Build()); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); @@ -841,7 +842,7 @@ bool ShouldUseGpu(CC* cc) { gpu_data_in_[i]->elements *= tensor->dims->data[d]; } // Create and bind input buffer. - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( gpu_data_in_[i]->elements, &gpu_data_in_[i]->buffer)); RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( @@ -866,7 +867,7 @@ bool ShouldUseGpu(CC* cc) { // Create and bind output buffers. interpreter_->SetAllowBufferHandleOutput(true); for (int i = 0; i < gpu_data_out_.size(); ++i) { - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( delegate_.get(), gpu_data_out_[i]->buffer.id(), diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 412c07125..ec07aab98 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -18,7 +18,6 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" -#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/detection.pb.h" @@ -404,8 +403,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); &output_detections]() -> ::mediapipe::Status { // Copy inputs. - RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); - RET_CHECK_CALL(CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); + MP_RETURN_IF_ERROR( + CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); + MP_RETURN_IF_ERROR( + CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); if (!anchors_init_) { if (side_packet_anchors_) { CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); @@ -413,11 +414,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); cc->InputSidePackets().Tag("ANCHORS").Get>(); std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); - RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.Write( + MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write( absl::MakeSpan(raw_anchors))); } else { CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer)); } anchors_init_ = true; @@ -425,23 +426,24 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); // Run shaders. // Decode boxes. - RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.BindToIndex(0)); - RET_CHECK_CALL(gpu_data_->raw_boxes_buffer.BindToIndex(1)); - RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.BindToIndex(2)); + MP_RETURN_IF_ERROR(gpu_data_->decoded_boxes_buffer.BindToIndex(0)); + MP_RETURN_IF_ERROR(gpu_data_->raw_boxes_buffer.BindToIndex(1)); + MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.BindToIndex(2)); const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; - RET_CHECK_CALL(gpu_data_->decode_program.Dispatch(decode_workgroups)); + MP_RETURN_IF_ERROR(gpu_data_->decode_program.Dispatch(decode_workgroups)); // Score boxes. - RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.BindToIndex(0)); - RET_CHECK_CALL(gpu_data_->raw_scores_buffer.BindToIndex(1)); + MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.BindToIndex(0)); + MP_RETURN_IF_ERROR(gpu_data_->raw_scores_buffer.BindToIndex(1)); const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; - RET_CHECK_CALL(gpu_data_->score_program.Dispatch(score_workgroups)); + MP_RETURN_IF_ERROR(gpu_data_->score_program.Dispatch(score_workgroups)); // Copy decoded boxes from GPU to CPU. std::vector boxes(num_boxes_ * num_coords_); - RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes))); + MP_RETURN_IF_ERROR( + gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes))); std::vector score_class_id_pairs(num_boxes_ * 2); - RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.Read( + MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.Read( absl::MakeSpan(score_class_id_pairs))); // TODO: b/138851969. Is it possible to output a float vector @@ -802,20 +804,20 @@ void main() { // Shader program GlShader decode_shader; - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader)); - RET_CHECK_CALL(GpuProgram::CreateWithShader(decode_shader, - &gpu_data_->decode_program)); + MP_RETURN_IF_ERROR(GpuProgram::CreateWithShader( + decode_shader, &gpu_data_->decode_program)); // Outputs size_t decoded_boxes_length = num_boxes_ * num_coords_; - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( decoded_boxes_length, &gpu_data_->decoded_boxes_buffer)); // Inputs size_t raw_boxes_length = num_boxes_ * num_coords_; - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( raw_boxes_length, &gpu_data_->raw_boxes_buffer)); size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox; - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( raw_anchors_length, &gpu_data_->raw_anchors_buffer)); // Parameters glUseProgram(gpu_data_->decode_program.id()); @@ -896,17 +898,17 @@ void main() { // Shader program GlShader score_shader; - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader)); - RET_CHECK_CALL( + MP_RETURN_IF_ERROR( GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program)); // Outputs size_t scored_boxes_length = num_boxes_ * 2; // score, class - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( scored_boxes_length, &gpu_data_->scored_boxes_buffer)); // Inputs size_t raw_scores_length = num_boxes_ * num_classes_; - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( raw_scores_length, &gpu_data_->raw_scores_buffer)); return ::mediapipe::OkStatus(); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 23a85276d..3369840e4 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -17,7 +17,6 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -400,7 +399,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Create initial working mask texture. ::tflite::gpu::gl::GlTexture small_mask_texture; - RET_CHECK_CALL(CreateReadWriteRgbaImageTexture( + MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture( tflite::gpu::DataType::UINT8, // GL_RGBA8 {tensor_width_, tensor_height_}, &small_mask_texture)); @@ -410,7 +409,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); : mediapipe::GlTexture(); // Copy input tensor. - RET_CHECK_CALL(CopyBuffer(input_tensors[0], *tensor_buffer_)); + MP_RETURN_IF_ERROR(CopyBuffer(input_tensors[0], *tensor_buffer_)); // Run shader, process mask tensor. // Run softmax over tensor output and blend with previous mask. @@ -418,18 +417,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); const int output_index = 0; glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA8); - RET_CHECK_CALL(tensor_buffer_->BindToIndex(2)); + MP_RETURN_IF_ERROR(tensor_buffer_->BindToIndex(2)); const tflite::gpu::uint3 workgroups = { NumGroups(tensor_width_, kWorkgroupSize), NumGroups(tensor_height_, kWorkgroupSize), 1}; if (!has_prev_mask) { - RET_CHECK_CALL(mask_program_no_prev_->Dispatch(workgroups)); + MP_RETURN_IF_ERROR(mask_program_no_prev_->Dispatch(workgroups)); } else { glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, input_mask_texture.name()); - RET_CHECK_CALL(mask_program_with_prev_->Dispatch(workgroups)); + MP_RETURN_IF_ERROR(mask_program_with_prev_->Dispatch(workgroups)); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, 0); } @@ -622,22 +621,22 @@ void main() { // Shader programs. GlShader shader_without_previous; - RET_CHECK_CALL(GlShader::CompileShader( + MP_RETURN_IF_ERROR(GlShader::CompileShader( GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous)); mask_program_no_prev_ = absl::make_unique(); - RET_CHECK_CALL(GlProgram::CreateWithShader(shader_without_previous, - mask_program_no_prev_.get())); + MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( + shader_without_previous, mask_program_no_prev_.get())); GlShader shader_with_previous; - RET_CHECK_CALL(GlShader::CompileShader( + MP_RETURN_IF_ERROR(GlShader::CompileShader( GL_COMPUTE_SHADER, shader_src_with_previous, &shader_with_previous)); mask_program_with_prev_ = absl::make_unique(); - RET_CHECK_CALL(GlProgram::CreateWithShader(shader_with_previous, - mask_program_with_prev_.get())); + MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( + shader_with_previous, mask_program_with_prev_.get())); // Buffer storage for input tensor. size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; tensor_buffer_ = absl::make_unique(); - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( tensor_length, tensor_buffer_.get())); // Parameters. diff --git a/mediapipe/calculators/tflite/util.h b/mediapipe/calculators/tflite/util.h deleted file mode 100644 index 3a60441cb..000000000 --- a/mediapipe/calculators/tflite/util.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ -#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ - -#define RET_CHECK_CALL(call) \ - do { \ - const auto status = (call); \ - if (ABSL_PREDICT_FALSE(!status.ok())) \ - return ::mediapipe::InternalError(status.message()); \ - } while (0); - -#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 7223ad44d..376b608b0 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -700,6 +700,8 @@ mediapipe_cc_proto_library( deps = [":rect_to_render_data_calculator_proto"], ) +# TODO: What is that one for? + mediapipe_cc_proto_library( name = "detections_to_render_data_calculator_cc_proto", srcs = ["detections_to_render_data_calculator.proto"], diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 812522f7a..8e6fe977e 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -160,6 +160,8 @@ class AnnotationOverlayCalculator : public CalculatorBase { GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. int width_ = 0; int height_ = 0; + int width_gpu_ = 0; // Size of overlay drawing texture. + int height_gpu_ = 0; #endif // MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(AnnotationOverlayCalculator); @@ -389,7 +391,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); glBindTexture(GL_TEXTURE_2D, image_mat_tex_); - glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_, height_, GL_RGB, + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_gpu_, height_gpu_, GL_RGB, GL_UNSIGNED_BYTE, overlay_image); glBindTexture(GL_TEXTURE_2D, 0); } @@ -492,12 +494,12 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); if (format != mediapipe::ImageFormat::SRGBA && format != mediapipe::ImageFormat::SRGB) RET_CHECK_FAIL() << "Unsupported GPU input format: " << format; - image_mat = absl::make_unique(height_, width_, CV_8UC3); + image_mat = absl::make_unique(height_gpu_, width_gpu_, CV_8UC3); memset(image_mat->data, kAnnotationBackgroundColor, - height_ * width_ * image_mat->elemSize()); + height_gpu_ * width_gpu_ * image_mat->elemSize()); } else { image_mat = absl::make_unique( - options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, + height_gpu_, width_gpu_, CV_8UC3, cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); } @@ -632,18 +634,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); kAnnotationBackgroundColor / 255.0, kAnnotationBackgroundColor / 255.0); - // Init texture for opencv rendered frame. - const auto& input_frame = - cc->Inputs().Tag(kInputFrameTagGpu).Get(); // Ensure GPU texture is divisible by 4. See b/138751944 for more info. - width_ = - RoundUp(input_frame.width(), ImageFrame::kGlDefaultAlignmentBoundary); - height_ = - RoundUp(input_frame.height(), ImageFrame::kGlDefaultAlignmentBoundary); + const float alignment = ImageFrame::kGlDefaultAlignmentBoundary; + const float scale_factor = options_.gpu_scale_factor(); + if (image_frame_available_) { + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + width_ = RoundUp(input_frame.width(), alignment); + height_ = RoundUp(input_frame.height(), alignment); + } else { + width_ = RoundUp(options_.canvas_width_px(), alignment); + height_ = RoundUp(options_.canvas_height_px(), alignment); + } + width_gpu_ = RoundUp(width_ * scale_factor, alignment); + height_gpu_ = RoundUp(height_ * scale_factor, alignment); + + // Init texture for opencv rendered frame. { glGenTextures(1, &image_mat_tex_); glBindTexture(GL_TEXTURE_2D, image_mat_tex_); - glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width_, height_, 0, GL_RGB, + // TODO + // OpenCV only renders to RGB images, not RGBA. Ideally this should be RGBA. + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width_gpu_, height_gpu_, 0, GL_RGB, GL_UNSIGNED_BYTE, nullptr); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.proto b/mediapipe/calculators/util/annotation_overlay_calculator.proto index 4391a1f2a..b34f2c1ae 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.proto +++ b/mediapipe/calculators/util/annotation_overlay_calculator.proto @@ -45,4 +45,12 @@ message AnnotationOverlayCalculatorOptions { // origin. (Historically, OpenGL uses bottom left origin, but most MediaPipe // examples expect textures to have top-left origin.) optional bool gpu_uses_top_left_origin = 6 [default = true]; + + // Scale factor for intermediate image for GPU rendering. + // This can be used to speed up annotation by drawing the annotation on an + // intermediate image with a reduced scale, e.g. 0.5 (of the input image width + // and height), before resizing and overlaying it on top of the input image. + // Should only be used if *all* render data uses normalized coordinates + // (or absolute coordinates are updated to scale accordingly). + optional float gpu_scale_factor = 7 [default = 1.0]; } diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc index 731994d4f..5082cd363 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -235,9 +235,10 @@ void DetectionsToRenderDataCalculator::AddLabels( const Detection& detection, const DetectionsToRenderDataCalculatorOptions& options, float text_line_height, RenderData* render_data) { - CHECK(detection.label().empty() || detection.label_id().empty()) - << "Either std::string or integer labels must be used for detection " - "but not both at the same time."; + CHECK(detection.label().empty() || detection.label_id().empty() || + detection.label_size() == detection.label_id_size()) + << "String or integer labels should be of same size. Or only one of them " + "is present."; const auto num_labels = std::max(detection.label_size(), detection.label_id_size()); CHECK_EQ(detection.score_size(), num_labels) diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index d171f8af1..da76a1536 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -316,6 +316,7 @@ cc_library( "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_tracker", "//mediapipe/util/tracking:tracking_visualization_utilities", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index 9963d864a..4fef7cc8e 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/video/box_tracker_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -193,12 +194,12 @@ class BoxTrackerCalculator : public CalculatorBase { TimedBoxProtoList initial_pos_; // Keeps tracks boxes that have already been initialized. - std::unordered_set initialized_ids_; + absl::node_hash_set initialized_ids_; // Non empty for batch mode tracking. std::string cache_dir_; // Ids to be tracked in batch_mode. - std::unordered_set batch_track_ids_; + absl::node_hash_set batch_track_ids_; int frame_num_ = 0; diff --git a/mediapipe/examples/android/README.md b/mediapipe/examples/android/README.md index 8ce927727..bc32c24da 100644 --- a/mediapipe/examples/android/README.md +++ b/mediapipe/examples/android/README.md @@ -1 +1 @@ -This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details. +This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev) for details. diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/MainActivity.java index 065e88f07..82c1f4478 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/MainActivity.java @@ -43,19 +43,23 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { inputSidePackets.put(INPUT_NUM_FACES_SIDE_PACKET_NAME, packetCreator.createInt32(NUM_FACES)); processor.setInputSidePackets(inputSidePackets); - processor.addPacketCallback( + // To show verbose logging, run: + // adb shell setprop log.tag.MainActivity VERBOSE + if (Log.isLoggable(TAG, Log.VERBOSE)) { + processor.addPacketCallback( OUTPUT_LANDMARKS_STREAM_NAME, (packet) -> { - Log.d(TAG, "Received multi face landmarks packet."); + Log.v(TAG, "Received multi face landmarks packet."); List multiFaceLandmarks = PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); - Log.d( + Log.v( TAG, "[TS:" + packet.getTimestamp() + "] " + getMultiFaceLandmarksDebugString(multiFaceLandmarks)); }); + } } private static String getMultiFaceLandmarksDebugString( diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java index 7305c9ef5..e45510c1c 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/MainActivity.java @@ -43,29 +43,33 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { } }); - processor.addPacketCallback( + // To show verbose logging, run: + // adb shell setprop log.tag.MainActivity VERBOSE + if (Log.isLoggable(TAG, Log.VERBOSE)) { + processor.addPacketCallback( OUTPUT_LANDMARKS_STREAM_NAME, (packet) -> { byte[] landmarksRaw = PacketGetter.getProtoBytes(packet); try { NormalizedLandmarkList landmarks = NormalizedLandmarkList.parseFrom(landmarksRaw); if (landmarks == null) { - Log.d(TAG, "[TS:" + packet.getTimestamp() + "] No hand landmarks."); + Log.v(TAG, "[TS:" + packet.getTimestamp() + "] No hand landmarks."); return; } // Note: If hand_presence is false, these landmarks are useless. - Log.d( + Log.v( TAG, "[TS:" + packet.getTimestamp() + "] #Landmarks for hand: " + landmarks.getLandmarkCount()); - Log.d(TAG, getLandmarksDebugString(landmarks)); + Log.v(TAG, getLandmarksDebugString(landmarks)); } catch (InvalidProtocolBufferException e) { Log.e(TAG, "Couldn't Exception received - " + e); return; } }); + } } private static String getLandmarksDebugString(NormalizedLandmarkList landmarks) { diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java index 4aee88768..0d4dfde7f 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/MainActivity.java @@ -31,19 +31,23 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); - processor.addPacketCallback( + // To show verbose logging, run: + // adb shell setprop log.tag.MainActivity VERBOSE + if (Log.isLoggable(TAG, Log.VERBOSE)) { + processor.addPacketCallback( OUTPUT_LANDMARKS_STREAM_NAME, (packet) -> { - Log.d(TAG, "Received multi-hand landmarks packet."); + Log.v(TAG, "Received multi-hand landmarks packet."); List multiHandLandmarks = PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); - Log.d( + Log.v( TAG, "[TS:" + packet.getTimestamp() + "] " + getMultiHandLandmarksDebugString(multiHandLandmarks)); }); + } } private String getMultiHandLandmarksDebugString(List multiHandLandmarks) { diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index ee403a5d0..bb922d92a 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -324,6 +324,19 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, int path_offset_y; MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y)); + // Prevent box from extending beyond the image after camera smoothing. + if (path_offset_y - ceil(path_height / 2.0) < 0) { + path_offset_y = ceil(path_height / 2.0); + } else if (path_offset_y + ceil(path_height / 2.0) > frame_height_) { + path_offset_y = frame_height_ - ceil(path_height / 2.0); + } + int path_width = path_height * target_aspect_; + if (path_offset_x - ceil(path_width / 2.0) < 0) { + path_offset_x = ceil(path_width / 2.0); + } else if (path_offset_x + ceil(path_width / 2.0) > frame_width_) { + path_offset_x = frame_width_ - ceil(path_width / 2.0); + } + // Convert to top/bottom borders to remove. int path_top = path_offset_y - path_height / 2; int path_bottom = frame_height_ - (path_offset_y + path_height / 2); diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index a37e09c57..ed3a10c9e 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -344,6 +344,28 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) { CheckBorder(static_features, 1000, 1000, 495, 395); } +TEST(ContentZoomingCalculatorTest, ZoomTestNearOutsideBorder) { + auto runner = ::absl::make_unique( + ParseTextProtoOrDie(kConfigD)); + AddDetection(cv::Rect_(.95, .95, .05, .05), 0, runner.get()); + AddDetection(cv::Rect_(.9, .9, .1, .1), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(972, 972, 55, 55, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(958, 958, 83, 83, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ZoomTestNearInsideBorder) { + auto runner = ::absl::make_unique( + ParseTextProtoOrDie(kConfigD)); + AddDetection(cv::Rect_(0, 0, .05, .05), 0, runner.get()); + AddDetection(cv::Rect_(0, 0, .1, .1), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(28, 28, 55, 55, 0, runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(42, 42, 83, 83, 1, runner->Outputs().Tag("CROP_RECT").packets); +} + } // namespace } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc index 340c4b253..3d37541cf 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc @@ -10,6 +10,10 @@ namespace autoflip { current_time_ = time_us; initialized_ = true; current_velocity_deg_per_s_ = 0; + RET_CHECK_GT(pixels_per_degree_, 0) + << "pixels_per_degree must be larger than 0."; + RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; return ::mediapipe::OkStatus(); } @@ -22,6 +26,14 @@ namespace autoflip { if (abs(delta_degs) < options_.min_motion_to_reframe()) { position = current_position_px_; delta_degs = 0; + } else if (delta_degs > 0) { + // Apply new position, less the reframe window size. + position = position - pixels_per_degree_ * options_.reframe_window(); + delta_degs = (position - current_position_px_) / pixels_per_degree_; + } else { + // Apply new position, plus the reframe window size. + position = position + pixels_per_degree_ * options_.reframe_window(); + delta_degs = (position - current_position_px_) / pixels_per_degree_; } // Time and position updates. diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto index eda04c4b1..552ead0d9 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto @@ -10,4 +10,9 @@ message KinematicOptions { optional double max_velocity = 2 [default = 18]; // Min motion (in degrees) to react in pixels. optional float min_motion_to_reframe = 3 [default = 1.8]; + // When motion exceeds min_motion_to_reframe, move within this distance of the + // camera from the starting direction. Setting this value non-zero reduces + // total reframe distance on average. Value cannot exceed + // min_motion_to_reframe value. + optional float reframe_window = 4 [default = 0]; } diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc index 5d5717589..d751bd1e3 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc @@ -27,6 +27,12 @@ namespace mediapipe { namespace autoflip { namespace { +TEST(KinematicPathSolverTest, FailZeroPixelsPerDegree) { + KinematicOptions options; + KinematicPathSolver solver(options, 0, 1000, 0); + EXPECT_FALSE(solver.AddObservation(500, kMicroSecInSec * 0).ok()); +} + TEST(KinematicPathSolverTest, FailNotInitializedState) { KinematicOptions options; KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); @@ -109,6 +115,38 @@ TEST(KinematicPathSolverTest, PassEnoughMotionSmallImg) { EXPECT_EQ(state, 410); } +TEST(KinematicPathSolverTest, FailReframeWindowSetting) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate(1); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(1.1); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + ASSERT_FALSE(solver.AddObservation(500, kMicroSecInSec * 0).ok()); +} + +TEST(KinematicPathSolverTest, PassReframeWindow) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate(1); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(0.75); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to move 1.2-.75 deg, * 16.6 = 7.47px + 500 = + EXPECT_EQ(state, 507); +} + TEST(KinematicPathSolverTest, PassUpdateRate) { KinematicOptions options; options.set_min_motion_to_reframe(1.0); diff --git a/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py b/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py index a639e1056..205834cc8 100644 --- a/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py +++ b/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py @@ -30,7 +30,7 @@ SECONDS_TO_MICROSECONDS = 1000000 def bytes23(string): - """Creates a bytes string in either Python 2 or 3.""" + """Creates a bytes string in either Python 2 or 3.""" if sys.version_info >= (3, 0): return bytes(string, 'utf8') else: diff --git a/mediapipe/examples/ios/bundle_id.bzl b/mediapipe/examples/ios/bundle_id.bzl new file mode 100644 index 000000000..4866b07c6 --- /dev/null +++ b/mediapipe/examples/ios/bundle_id.bzl @@ -0,0 +1,26 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration helper for iOS app bundle ids and provisioning profiles. +""" + +BUNDLE_ID_PREFIX = "*SEE_IOS_INSTRUCTIONS*.mediapipe.examples" + +# Look for a provisioning profile in the example's directory first, +# otherwise look for a common one. +def example_provisioning(): + local_profile = native.glob(["provisioning_profile.mobileprovision"]) + if local_profile: + return local_profile[0] + return "//mediapipe/examples/ios:provisioning_profile" diff --git a/mediapipe/examples/ios/edgedetectiongpu/BUILD b/mediapipe/examples/ios/edgedetectiongpu/BUILD index 66ea1b066..46fb32a94 100644 --- a/mediapipe/examples/ios/edgedetectiongpu/BUILD +++ b/mediapipe/examples/ios/edgedetectiongpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "edgedetectiongpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "EdgeDetectionGpuApp", - bundle_id = "com.google.mediapipe.EdgeDetectionGpu", + bundle_id = BUNDLE_ID_PREFIX + ".EdgeDetectionGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [":EdgeDetectionGpuAppLibrary"], ) diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index 1e8488b34..0387ae8a4 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "facedetectioncpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "FaceDetectionCpuApp", - bundle_id = "com.google.mediapipe.FaceDetectionCpu", + bundle_id = BUNDLE_ID_PREFIX + ".FaceDetectionCpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":FaceDetectionCpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index b6fce8791..87f0d7894 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "facedetectiongpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "FaceDetectionGpuApp", - bundle_id = "com.google.mediapipe.FaceDetectionGpu", + bundle_id = BUNDLE_ID_PREFIX + ".FaceDetectionGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":FaceDetectionGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index a892510ff..b1e169bf7 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -16,6 +16,11 @@ load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) licenses(["notice"]) # Apache 2.0 @@ -28,14 +33,14 @@ alias( ios_application( name = "FaceMeshGpuApp", - bundle_id = "com.google.mediapipe.FaceMeshGpu", + bundle_id = BUNDLE_ID_PREFIX + ".FaceMeshGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":FaceMeshGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index 162166a42..9507e81cc 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "handdetectiongpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "HandDetectionGpuApp", - bundle_id = "com.google.mediapipe.HandDetectionGpu", + bundle_id = BUNDLE_ID_PREFIX + ".HandDetectionGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":HandDetectionGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 72965cef3..bfccddd04 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -16,6 +16,11 @@ load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) licenses(["notice"]) # Apache 2.0 @@ -28,14 +33,14 @@ alias( ios_application( name = "HandTrackingGpuApp", - bundle_id = "com.google.mediapipe.HandTrackingGpu", + bundle_id = BUNDLE_ID_PREFIX + ".HandTrackingGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":HandTrackingGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/link_local_profiles.py b/mediapipe/examples/ios/link_local_profiles.py new file mode 100755 index 000000000..9814009e1 --- /dev/null +++ b/mediapipe/examples/ios/link_local_profiles.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""This script is used to set up automatic provisioning for iOS examples. + +It scans the provisioning profiles used by Xcode, looking for one matching the +application identifier for each example app. If found, it symlinks the profile +in the appropriate location for Bazel to find it. + +It also checks whether the bundle_id.bzl file has a placeholder bundle ID, and +replaces it with a unique ID if so. +""" + +import os +import plistlib +import re +import subprocess +import uuid + +# This script is meant to be located in the MediaPipe iOS examples directory +# root. The logic below will have to be changed if the directory structure is +# reorganized. +examples_ios = os.path.dirname(os.path.realpath(__file__)) +example_names = { + f for f in os.listdir(examples_ios) + if os.path.isdir(os.path.join(examples_ios, f)) +} + + +def configure_bundle_id_prefix( + bundle_id_bzl=os.path.join(examples_ios, "bundle_id.bzl")) -> str: + """Configures the bundle id prefix to use. + + Gets the bundle id prefix in use from bundle_id.bzl; sets up a unique + prefix if not already set. + + Args: + bundle_id_bzl: Path to the bzl file where the bundle id is stored. + + Returns: + The bundle id prefix to use. + + Raises: + Exception: If parsing of bundle_id.bzl fails. + """ + bundle_id_re = re.compile( + r'^BUNDLE_ID_PREFIX\s*=\s*"(.*)"', flags=re.MULTILINE) + + with open(bundle_id_bzl, "r") as f: + contents = f.read() + match = bundle_id_re.search(contents) + if not match: + raise Exception("could not find BUNDLE_ID_PREFIX") + + bundle_id_prefix = match.group(1) + # The default value contains a *, which is an invalid character in bundle IDs. + if "*" in bundle_id_prefix: + bundle_id_prefix = str(uuid.uuid4()) + ".mediapipe.examples" + contents = contents[:match.start(1)] + bundle_id_prefix + contents[match + .end(1):] + with open(bundle_id_bzl, "w") as f: + f.write(contents) + print("Set up a unique bundle ID prefix: " + bundle_id_prefix) + + return bundle_id_prefix + + +def get_app_id(profile_path) -> str: + try: + plist = subprocess.check_output( + ["security", "cms", "-D", "-i", profile_path]) + profile = plistlib.loads(plist) + return profile["Entitlements"]["application-identifier"] + except Exception: # pylint: disable=broad-except + return None + + +def update_symlink(target_path, link_path): + if os.path.islink(link_path): + print(f" Removing existing symlink at {link_path}") + os.remove(link_path) + elif os.path.exists(link_path): + print(f" Unexpected existing file at {link_path}; skipping") + return + os.symlink(target_path, link_path) + print(f" Created symlink to {target_path} at {link_path}") + + +def process_profile(profile_path, our_app_id_re): + """Processes one mobileprovision file. + + Checks if its app ID matches one of our example apps, and symlinks it in the + appropriate location if so. + + Args: + profile_path: Path to the mobileprovision file. + our_app_id_re: Regular expression to extract the example name from one of + out app ids. + """ + app_id = get_app_id(profile_path) + if not app_id: + print(f"Could not parse '{profile_path}', skipping") + return + match = our_app_id_re.match(app_id) + if not match: + return + app_name = match.group(1) + app_dir_name = app_name.lower() + if app_dir_name not in example_names: + print(f"The app id '{app_id}' has our prefix, but does not seem to match" + + "any of our examples. Skipping.") + return + + print(f"Found profile for {app_name}") + + link_path = os.path.join(examples_ios, app_dir_name, + "provisioning_profile.mobileprovision") + update_symlink(profile_path, link_path) + + +def main(): + bundle_id_prefix = configure_bundle_id_prefix() + our_app_id_re = re.compile(r"[0-9A-Z]+\." + re.escape(bundle_id_prefix) + + r"\.(.*)") + + profile_dir = os.path.expanduser( + "~/Library/MobileDevice/Provisioning Profiles") + if not os.path.isdir(profile_dir): + print(f"Could not find provisioning profiles directory at {profile_dir}") + return 2 + + print( + f"Looking for profiles for app ids with prefix '{bundle_id_prefix}' in '{profile_dir}'" + ) + + for name in os.listdir(profile_dir): + if not name.endswith(".mobileprovision"): + continue + profile_path = os.path.join(profile_dir, name) + process_profile(profile_path, our_app_id_re) + + +if __name__ == "__main__": + main() diff --git a/mediapipe/examples/ios/multihandtrackinggpu/BUILD b/mediapipe/examples/ios/multihandtrackinggpu/BUILD index be718d3e9..cadc390c9 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/BUILD +++ b/mediapipe/examples/ios/multihandtrackinggpu/BUILD @@ -16,6 +16,11 @@ load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) licenses(["notice"]) # Apache 2.0 @@ -28,14 +33,14 @@ alias( ios_application( name = "MultiHandTrackingGpuApp", - bundle_id = "com.google.mediapipe.MultiHandTrackingGpu", + bundle_id = BUNDLE_ID_PREFIX + ".MultiHandTrackingGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":MultiHandTrackingGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD index 0efb96316..91a5a9331 100644 --- a/mediapipe/examples/ios/objectdetectioncpu/BUILD +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "objectdetectioncpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "ObjectDetectionCpuApp", - bundle_id = "com.google.mediapipe.ObjectDetectionCpu", + bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionCpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":ObjectDetectionCpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD index 288273ac0..19715532e 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -MIN_IOS_VERSION = "10.0" - load( "@build_bazel_rules_apple//apple:ios.bzl", "ios_application", ) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" alias( name = "objectdetectiongpu", @@ -28,14 +33,14 @@ alias( ios_application( name = "ObjectDetectionGpuApp", - bundle_id = "com.google.mediapipe.ObjectDetectionGpu", + bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionGpu", families = [ "iphone", "ipad", ], infoplists = ["Info.plist"], minimum_os_version = MIN_IOS_VERSION, - provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + provisioning_profile = example_provisioning(), deps = [ ":ObjectDetectionGpuAppLibrary", "@ios_opencv//:OpencvFramework", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index f5b170ab9..2140144cd 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -112,7 +112,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [":mediapipe_internal"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -130,7 +130,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [":mediapipe_internal"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -923,6 +923,15 @@ cc_library( ], ) +# When --copt=-fno-rtti is set, MEDIAPIPE_HAS_RTTI is cleared in port.h. +# To explicitly clear MEDIAPIPE_HAS_RTTI, compile with: +# bazel build --define=disable_rtti_and_exceptions=true +config_setting( + name = "disable_rtti_and_exceptions", + define_values = {"disable_rtti_and_exceptions": "true"}, + visibility = ["//visibility:public"], +) + cc_library( name = "port", hdrs = ["port.h"], @@ -931,6 +940,11 @@ cc_library( }) + select({ "//conditions:default": [], "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU"], + }) + select({ + "//conditions:default": [], + "//mediapipe/framework:disable_rtti_and_exceptions": [ + "MEDIAPIPE_HAS_RTTI=0", + ], }), visibility = [ "//mediapipe/framework:__subpackages__", @@ -1167,6 +1181,7 @@ cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:tag_map_helper", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/mediapipe/framework/calculator_base_test.cc b/mediapipe/framework/calculator_base_test.cc index 48dce7074..fcb4ebf37 100644 --- a/mediapipe/framework/calculator_base_test.cc +++ b/mediapipe/framework/calculator_base_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/framework/calculator_base.h" // TODO: Move protos in another CL after the C++ code migration. +#include "absl/container/flat_hash_set.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context_manager.h" @@ -192,8 +193,8 @@ TEST(CalculatorTest, CreateByName) { // Tests registration of a calculator within a whitelisted namespace. TEST(CalculatorTest, CreateByNameWhitelisted) { // Reset the registration namespace whitelist. - *const_cast*>( - &NamespaceWhitelist::TopNamespaces()) = std::unordered_set{ + *const_cast*>( + &NamespaceWhitelist::TopNamespaces()) = absl::flat_hash_set{ "mediapipe::test_ns::whitelisted_ns", "mediapipe", }; diff --git a/mediapipe/framework/calculator_graph_event_loop_test.cc b/mediapipe/framework/calculator_graph_event_loop_test.cc index c8e018f7b..eb7b1d866 100644 --- a/mediapipe/framework/calculator_graph_event_loop_test.cc +++ b/mediapipe/framework/calculator_graph_event_loop_test.cc @@ -375,7 +375,7 @@ TEST_F(CalculatorGraphEventLoopTest, TryToAddPacketToInputStream) { this, std::placeholders::_1))}, {"blocking_mutex", mutex_side_packet}})); - constexpr int kNumInputPackets = 2; + constexpr int kNumInputPackets = 20; constexpr int kMaxQueueSize = 1; // Lock the mutex so that the BlockingPassThroughCalculator cannot read any of diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 1adf2c0ac..a70ee02e1 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -1828,14 +1828,14 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { status = graph->Initialize(config); EXPECT_THAT(status.message(), - testing::AllOf(testing::HasSubstr("StringStatusHandler"), - // The problematic input side packet. - testing::HasSubstr("generated_by_generator"), - // Actual type. - testing::HasSubstr("string"), - // Expected type. - testing::HasSubstr( - MediaPipeTypeStringOrDemangled()))); + testing::AllOf( + testing::HasSubstr("StringStatusHandler"), + // The problematic input side packet. + testing::HasSubstr("generated_by_generator"), + // Actual type. + testing::HasSubstr(MediaPipeTypeStringOrDemangled()), + // Expected type. + testing::HasSubstr("string"))); } TEST(CalculatorGraph, GenerateInInitialize) { diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc index d9edb4b69..4605082bf 100644 --- a/mediapipe/framework/calculator_node.cc +++ b/mediapipe/framework/calculator_node.cc @@ -405,7 +405,7 @@ namespace { // Returns the Packet sent to an OutputSidePacket, or an empty packet // if none available. const Packet GetPacket(const OutputSidePacket& out) { - auto impl = dynamic_cast(&out); + auto impl = static_cast(&out); return (impl == nullptr) ? Packet() : impl->GetPacket(); } diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 1573fa1f7..1cf51ddea 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -209,6 +209,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/framework/deps/registration.cc b/mediapipe/framework/deps/registration.cc index c467b700b..c6a3ffa0d 100644 --- a/mediapipe/framework/deps/registration.cc +++ b/mediapipe/framework/deps/registration.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/deps/registration.h" +#include "absl/container/flat_hash_set.h" + namespace mediapipe { namespace { @@ -34,9 +36,9 @@ inline size_t array_size(T (&arr)[SIZE]) { } // namespace /*static*/ -const std::unordered_set& NamespaceWhitelist::TopNamespaces() { - static std::unordered_set* result = - new std::unordered_set( +const absl::flat_hash_set& NamespaceWhitelist::TopNamespaces() { + static absl::flat_hash_set* result = + new absl::flat_hash_set( kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces)); return *result; } diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index e5634bc45..66845ad06 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -26,6 +26,7 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -145,7 +146,7 @@ struct WrapStatusOr<::mediapipe::StatusOr> { class NamespaceWhitelist { public: - static const std::unordered_set& TopNamespaces(); + static const absl::flat_hash_set& TopNamespaces(); }; template diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 3d4b71cd7..ba22168b9 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -95,18 +95,23 @@ cc_library( hdrs = ["image_frame.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:port", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/port:aligned_malloc_and_free", - "//mediapipe/framework/port:core_proto", - "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", - "//mediapipe/framework/port:source_location", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - ], + "//mediapipe/framework:port", + "//mediapipe/framework/port:aligned_malloc_and_free", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:type_util", + ] + select({ + "//conditions:default": [ + ], + "//mediapipe/framework:disable_rtti_and_exceptions": [], + }), ) cc_library( diff --git a/mediapipe/framework/formats/image_frame.h b/mediapipe/framework/formats/image_frame.h index ff877cd16..6fcefbd38 100644 --- a/mediapipe/framework/formats/image_frame.h +++ b/mediapipe/framework/formats/image_frame.h @@ -42,6 +42,9 @@ #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/tool/type_util.h" + +#define IMAGE_FRAME_RAW_IMAGE MEDIAPIPE_HAS_RTTI namespace mediapipe { diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index f6b177454..ee9a85aeb 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -407,7 +407,7 @@ StatusOr> ConvertToVectorOfProtoMessageLitePtrs(const T* data, /*is_proto_vector=*/std::false_type) { return ::mediapipe::InvalidArgumentError(absl::StrCat( - "The Packet stores \"", typeid(T).name(), "\"", + "The Packet stores \"", tool::TypeId().name(), "\"", "which is not convertible to vector.")); } diff --git a/mediapipe/framework/packet_test.cc b/mediapipe/framework/packet_test.cc index 039ccedf7..d63578668 100644 --- a/mediapipe/framework/packet_test.cc +++ b/mediapipe/framework/packet_test.cc @@ -147,6 +147,9 @@ struct UnregisteredPairStruct { }; MEDIAPIPE_REGISTER_TYPE(::mediapipe::RegisteredPairStruct, "::mediapipe::RegisteredPairStruct", nullptr, nullptr); +MEDIAPIPE_REGISTER_TYPE(int, "int", nullptr, nullptr); +MEDIAPIPE_REGISTER_TYPE(float, "float", nullptr, nullptr); +constexpr bool kHaveUnregisteredTypeNames = MEDIAPIPE_HAS_RTTI; TEST(PacketTest, TypeRegistrationDebugString) { // Test registered type. @@ -159,9 +162,13 @@ TEST(PacketTest, TypeRegistrationDebugString) { // Unregistered type. UnregisteredPairStruct u{"s", true}; Packet packet2 = MakePacket(u); + std::string expected_type_name = + (kHaveUnregisteredTypeNames) + ? "mediapipe::(anonymous namespace)::UnregisteredPairStruct" + : ""; EXPECT_EQ(packet2.DebugString(), - "mediapipe::Packet with timestamp: Timestamp::Unset() and type: " - "mediapipe::(anonymous namespace)::UnregisteredPairStruct"); + "mediapipe::Packet with timestamp: Timestamp::Unset() and type: " + + expected_type_name); } TEST(PacketTest, ReturnGenericProtobufMessage) { diff --git a/mediapipe/framework/port.h b/mediapipe/framework/port.h index bd5639599..e1a96bea4 100644 --- a/mediapipe/framework/port.h +++ b/mediapipe/framework/port.h @@ -80,4 +80,17 @@ #endif #endif +#ifndef MEDIAPIPE_HAS_RTTI +// Detect if RTTI is disabled in the compiler. +#if defined(__clang__) && defined(__has_feature) +#define MEDIAPIPE_HAS_RTTI __has_feature(cxx_rtti) +#elif defined(__GNUC__) && !defined(__GXX_RTTI) +#define MEDIAPIPE_HAS_RTTI 0 +#elif defined(_MSC_VER) && !defined(_CPPRTTI) +#define MEDIAPIPE_HAS_RTTI 0 +#else +#define MEDIAPIPE_HAS_RTTI 1 +#endif +#endif // MEDIAPIPE_HAS_RTTI + #endif // MEDIAPIPE_FRAMEWORK_PORT_H_ diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 68e158efd..2fc6be528 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -307,11 +307,10 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":core_proto", ":logging", "//mediapipe/framework:port", - ] + select({ - "//conditions:default": ["@com_google_protobuf//:protobuf"], - }), + ], ) cc_library( diff --git a/mediapipe/framework/port/parse_text_proto.h b/mediapipe/framework/port/parse_text_proto.h index 86343b840..c352d4f01 100644 --- a/mediapipe/framework/port/parse_text_proto.h +++ b/mediapipe/framework/port/parse_text_proto.h @@ -15,16 +15,21 @@ #ifndef MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ #define MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ -#include "google/protobuf/text_format.h" +#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" namespace mediapipe { +template +bool ParseTextProto(const std::string& input, T* proto) { + return proto_ns::TextFormat::ParseFromString(input, proto); +} + template T ParseTextProtoOrDie(const std::string& input) { T result; - CHECK(google::protobuf::TextFormat::ParseFromString(input, &result)); + CHECK(ParseTextProto(input, &result)); return result; } diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index aa770829a..86007b016 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -261,6 +261,7 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:threadpool", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], diff --git a/mediapipe/framework/profiler/sharded_map_test.cc b/mediapipe/framework/profiler/sharded_map_test.cc index be01ed870..a589ac42a 100644 --- a/mediapipe/framework/profiler/sharded_map_test.cc +++ b/mediapipe/framework/profiler/sharded_map_test.cc @@ -16,6 +16,7 @@ #include +#include "absl/container/node_hash_map.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" @@ -122,7 +123,7 @@ absl::Duration time(const std::function& f) { // With bazel build -c opt, the ShardedMap reduces CPU time by 60%. TEST(ShardedMapTest, TestParallelAccess) { absl::Duration simple_time = time([] { - std::unordered_map simple_map; + absl::node_hash_map simple_map; TestParallelAccess(simple_map, 1); }); absl::Duration safe_time = time([] { diff --git a/mediapipe/framework/stream_handler.proto b/mediapipe/framework/stream_handler.proto index b443034ec..e0731d9e6 100644 --- a/mediapipe/framework/stream_handler.proto +++ b/mediapipe/framework/stream_handler.proto @@ -22,6 +22,10 @@ package mediapipe; import "mediapipe/framework/mediapipe_options.proto"; +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "StreamHandlerProto"; +option objc_class_prefix = "MediaPipe"; + // Settings specifying an input stream handler. message InputStreamHandlerConfig { // Name of the registered input stream handler class. diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 5412e24ee..694e2e3a1 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -167,6 +167,7 @@ cc_library( "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/tool:type_util", ], ) @@ -371,6 +372,9 @@ cc_library( name = "type_util", hdrs = ["type_util.h"], visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + "//mediapipe/framework:port", + ], ) cc_library( diff --git a/mediapipe/framework/tool/options_util.h b/mediapipe/framework/tool/options_util.h index 171aec681..3a97e6bd1 100644 --- a/mediapipe/framework/tool/options_util.h +++ b/mediapipe/framework/tool/options_util.h @@ -15,13 +15,12 @@ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_ -#include - #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/tool/type_util.h" namespace mediapipe { @@ -54,18 +53,18 @@ class TypeMap { public: template bool Has() const { - return content_.count(typeid(T)) > 0; + return content_.count(TypeId()) > 0; } template T* Get() const { if (!Has()) { - content_[typeid(T)] = std::make_shared(); + content_[TypeId()] = std::make_shared(); } - return static_cast(content_[typeid(T)].get()); + return static_cast(content_[TypeId()].get()); } private: - mutable std::map> content_; + mutable std::map> content_; }; template +#include +#include #include +#include "mediapipe/framework/port.h" + namespace mediapipe { namespace tool { + +#if !MEDIAPIPE_HAS_RTTI +// A unique identifier for type T. +class TypeInfo { + public: + size_t hash_code() const { return reinterpret_cast(this); } + bool operator==(const TypeInfo& other) const { return &other == this; } + bool operator<(const TypeInfo& other) const { return &other < this; } + const char* name() const { return ""; } + template + static const TypeInfo& Get() { + static TypeInfo* static_type_info = new TypeInfo; + return *static_type_info; + } + + private: + TypeInfo() {} + TypeInfo(const TypeInfo&) = delete; +}; + +#else // MEDIAPIPE_HAS_RTTI +// The std unique identifier for type T. +class TypeInfo { + public: + size_t hash_code() const { return info_.hash_code(); } + bool operator==(const TypeInfo& o) const { return info_ == o.info_; } + bool operator<(const TypeInfo& o) const { return info_.before(o.info_); } + const char* name() const { return info_.name(); } + template + static const TypeInfo& Get() { + static TypeInfo* static_type_info = new TypeInfo(typeid(T)); + return *static_type_info; + } + + private: + TypeInfo(const std::type_info& info) : info_(info) {} + TypeInfo(const TypeInfo&) = delete; + + private: + const std::type_info& info_; + friend class TypeIndex; +}; +#endif + +// An associative key for TypeInfo. +class TypeIndex { + public: + TypeIndex(const TypeInfo& info) : info_(info) {} + size_t hash_code() const { return info_.hash_code(); } + bool operator==(const TypeIndex& other) const { return info_ == other.info_; } + bool operator<(const TypeIndex& other) const { return info_ < other.info_; } + + private: + const TypeInfo& info_; +}; + +// Returns a unique identifier for type T. +template +const TypeInfo& TypeId() { + return TypeInfo::Get(); +} + // Helper method that returns a hash code of the given type. This allows for // typeid testing across multiple binaries, unlike FastTypeId which used a // memory location that only works within the same binary. Moreover, we use this @@ -30,7 +96,7 @@ namespace tool { // as much as possible. template size_t GetTypeHash() { - return typeid(T).hash_code(); + return TypeId().hash_code(); } } // namespace tool diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index c94cc28ae..366da8d54 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -383,7 +383,7 @@ const std::string MediaPipeTypeStringOrDemangled() { if (type_string) { return *type_string; } else { - return ::mediapipe::Demangle(typeid(T).name()); + return ::mediapipe::Demangle(tool::TypeId().name()); } } diff --git a/mediapipe/graphs/hand_tracking/BUILD b/mediapipe/graphs/hand_tracking/BUILD index 3c419d8be..a84f5d941 100644 --- a/mediapipe/graphs/hand_tracking/BUILD +++ b/mediapipe/graphs/hand_tracking/BUILD @@ -49,6 +49,13 @@ cc_library( ], ) +mediapipe_binary_graph( + name = "hand_tracking_desktop_live_binary_graph", + graph = "hand_tracking_desktop_live.pbtxt", + output_name = "hand_tracking_desktop_live.binarypb", + deps = [":desktop_tflite_calculators"], +) + cc_library( name = "mobile_calculators", deps = [ diff --git a/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt b/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt index 2843a6f11..3106e9041 100644 --- a/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt +++ b/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt @@ -6,6 +6,9 @@ # Images coming into and out of the graph. input_stream: "input_video" output_stream: "output_video" +# Hand landmarks and palm detection info. +output_stream: "palm_detections" +output_stream: "hand_landmarks" # Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon # the arrival of the next input image sends out the cached decision with the diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index 3e2c7bf6d..893b6ea8d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -79,12 +79,20 @@ android_library( "AssetCache.java", "AssetCacheDbHelper.java", "MediaPipeRunner.java", + "PacketUtil.java", + "TypeNameRegistry.java", + "TypeNameRegistryLite.java", + "TypeNameRegistryFull.java", + "MediaPipeException.java", ], ), + exports = [ + ":framework_proto_lite", + ":mediapipe_exception_android", + ], deps = [ - "//mediapipe/framework:calculator_java_proto_lite", - "//mediapipe/framework:calculator_profile_java_proto_lite", - "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", + ":framework_proto_lite", + ":mediapipe_exception_android", "@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_flogger_flogger", "@maven//:com_google_flogger_flogger_system_backend", @@ -92,6 +100,39 @@ android_library( ], ) +android_library( + name = "framework_proto_lite", + srcs = [ + "PacketUtil.java", + "TypeNameRegistry.java", + "TypeNameRegistryLite.java", + ], + exports = [ + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", + ], + deps = [ + ":mediapipe_exception_android", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "mediapipe_exception_android", + srcs = [ + "MediaPipeException.java", + ], + deps = [ + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe AAR. filegroup( name = "java_src", diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d1e8089ce..d5cfb1c69 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -14,6 +14,7 @@ package com.google.mediapipe.framework; +import com.google.mediapipe.framework.PacketUtil.SerializedMessage; import com.google.protobuf.MessageLite; import java.nio.ByteBuffer; import java.nio.FloatBuffer; @@ -263,6 +264,13 @@ public class PacketCreator { nativeCreateCalculatorOptions(mediapipeGraph.getNativeHandle(), message.toByteArray())); } + /** Creates a {@link Packet} containing a protobuf MessageLite. */ + public Packet createProto(MessageLite message) { + SerializedMessage serialized = PacketUtil.pack(message); + return Packet.create( + nativeCreateProto(mediapipeGraph.getNativeHandle(), serialized)); + } + /** Creates a {@link Packet} containing the given camera intrinsics. */ public Packet createCameraIntrinsics( float fx, float fy, float cx, float cy, float width, float height) { @@ -359,6 +367,7 @@ public class PacketCreator { private native long nativeCreateInt32Array(long context, int[] data); private native long nativeCreateFloat32Array(long context, float[] data); private native long nativeCreateStringFromByteArray(long context, byte[] data); + private native long nativeCreateProto(long context, SerializedMessage data); private native long nativeCreateCalculatorOptions(long context, byte[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index aae0adc6d..67fa955b4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -16,7 +16,9 @@ package com.google.mediapipe.framework; import com.google.common.base.Preconditions; import com.google.common.flogger.FluentLogger; +import com.google.mediapipe.framework.PacketUtil.SerializedMessage; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageLite; import com.google.protobuf.Parser; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -117,6 +119,13 @@ public final class PacketGetter { return nativeGetProtoBytes(packet.getNativeHandle()); } + public static T getProto(final Packet packet, Class clazz) + throws InvalidProtocolBufferException { + SerializedMessage result = new SerializedMessage(); + nativeGetProto(packet.getNativeHandle(), result); + return PacketUtil.unpack(result, clazz); + } + public static short[] getInt16Vector(final Packet packet) { return nativeGetInt16Vector(packet.getNativeHandle()); } @@ -295,6 +304,7 @@ public final class PacketGetter { private static native String nativeGetString(long nativePacketHandle); private static native byte[] nativeGetBytes(long nativePacketHandle); private static native byte[] nativeGetProtoBytes(long nativePacketHandle); + private static native void nativeGetProto(long nativePacketHandle, SerializedMessage result); private static native short[] nativeGetInt16Vector(long nativePacketHandle); private static native int[] nativeGetInt32Vector(long nativePacketHandle); private static native long[] nativeGetInt64Vector(long nativePacketHandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketUtil.java b/mediapipe/java/com/google/mediapipe/framework/PacketUtil.java new file mode 100644 index 000000000..25a013e48 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/PacketUtil.java @@ -0,0 +1,85 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.framework; + +import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.Internal; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageLite; +import java.util.NoSuchElementException; + +/** Utility functions for translating MediaPipe packet data between languages. */ +final class PacketUtil { + /** Records the protobuf type name for a Java Class. */ + public static void registerTypeName(Class clazz, String typeName) { + typeNameRegistry.registerTypeName(clazz, typeName); + } + + /** Returns the protobuf type name for a Java Class. */ + public static String getTypeName(Class clazz) { + return typeNameRegistry.getTypeName(clazz); + } + + /** Returns the best available ExtensionRegistry */ + public static ExtensionRegistryLite getExtensionRegistry() { + return ExtensionRegistryLite.getEmptyRegistry(); + } + + /** Serializes a MessageLite into a SerializedMessage object. */ + public static SerializedMessage pack(T message) { + SerializedMessage result = new SerializedMessage(); + result.typeName = getTypeName(message.getClass()); + if (result.typeName == null) { + throw new NoSuchElementException( + "Cannot determine the protobuf package name for class: " + message.getClass()); + } + result.value = message.toByteArray(); + return result; + } + + /** Deserializes a MessageLite from a SerializedMessage object. */ + public static T unpack( + SerializedMessage serialized, java.lang.Class clazz) + throws InvalidProtocolBufferException { + T defaultInstance = Internal.getDefaultInstance(clazz); + String expectedType = PacketUtil.getTypeName(defaultInstance.getClass()); + if (!serialized.typeName.equals(expectedType)) { + throw new InvalidProtocolBufferException( + "Message type does not match the expected type. Expected: " + + expectedType + + " Got: " + + serialized.typeName); + } + // Specifying the ExtensionRegistry is recommended. The ExtensionRegistry is + // needed to deserialize any nested proto2 extension Messages. + @SuppressWarnings("unchecked") // The type_url indicates type T. + T result = + (T) + defaultInstance + .getParserForType() + .parseFrom(serialized.value, PacketUtil.getExtensionRegistry()); + return result; + } + + /** A singleton to find protobuf full type names. */ + static TypeNameRegistry typeNameRegistry = new TypeNameRegistryConcrete(); + + private PacketUtil() {} + + static class SerializedMessage { + public String typeName; + public byte[] value; + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistry.java b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistry.java new file mode 100644 index 000000000..e5f64b3a6 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistry.java @@ -0,0 +1,15 @@ +package com.google.mediapipe.framework; + +import com.google.protobuf.MessageLite; + +/** + * Utility interface for retrieving the protobuf type name for a MessageLite class. + */ +interface TypeNameRegistry { + + /** Returns the protobuf type name for a Java Class. */ + public String getTypeName(Class clazz); + + /** Records the protobuf type name for a Java Class. */ + public void registerTypeName(Class clazz, String typeName); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryFull.java b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryFull.java new file mode 100644 index 000000000..b4a48d22f --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryFull.java @@ -0,0 +1,27 @@ +package com.google.mediapipe.framework; + +import com.google.protobuf.MessageLite; +import com.google.protobuf.contrib.MessageUtils; + +/** + * Utility class for retrieving the protobuf type name for a MessageLite class. This implementation + * uses the full-protobuf Message and Descriptor library. + * + *

This class is defined in separate source files for "full" or for "lite" dependencies. + */ +final class TypeNameRegistryConcrete implements TypeNameRegistry { + + /** Returns the protobuf type name for a Java Class. */ + @Override + public String getTypeName(Class clazz) { + return MessageUtils.getProtoTypeName(clazz); + } + + /** Records the protobuf type name for a Java Class. */ + @Override + public void registerTypeName(Class clazz, String typeName) {} +} + +/** Satisfies Java file name convention. */ +@SuppressWarnings("TopLevel") +final class TypeNameRegistryFull {} diff --git a/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryLite.java b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryLite.java new file mode 100644 index 000000000..d62957352 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/TypeNameRegistryLite.java @@ -0,0 +1,40 @@ +package com.google.mediapipe.framework; + +import com.google.protobuf.MessageLite; +import java.util.HashMap; +import java.util.Map; + +/** + * Utility class for retrieving the protobuf type name for a MessageLite class. This implementation + * uses the mediapipe protobuf type names registry. + * + *

This class is defined in separate source files for "full" or for "lite" dependencies. + */ +final class TypeNameRegistryConcrete implements TypeNameRegistry { + + TypeNameRegistryConcrete() {} + + /** Returns the protobuf type name for a Java Class. */ + @Override + public String getTypeName(Class javaClass) { + return typeNames.get(javaClass); + } + + /** Records the protobuf type name for a Java Class. */ + @Override + public void registerTypeName(Class clazz, String typeName) { + if (typeNames.containsKey(clazz) && !typeNames.get(clazz).equals(typeName)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.ALREADY_EXISTS.ordinal(), + "Protobuf type name: " + typeName + " conflicts with: " + typeNames.get(clazz)); + } + typeNames.put(clazz, typeName); + } + + /** A mapping from java package names to proto package names. */ + private final Map, String> typeNames = new HashMap<>(); +} + +/** Satisfies Java file name convention. */ +@SuppressWarnings("TopLevel") +final class TypeNameRegistryLite {} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index 6060e4ea1..a90f56f03 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -24,6 +24,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" using mediapipe::android::JStringToStdString; +using mediapipe::android::ThrowIfError; namespace { mediapipe::Status AddSidePacketsIntoGraph( @@ -70,15 +71,6 @@ mediapipe::Status AddStreamHeadersIntoGraph( return mediapipe::OkStatus(); } -// Throws a MediaPipeException for any non-ok mediapipe::Status. -// Note that the exception is thrown after execution returns to Java. -bool ThrowIfError(JNIEnv* env, mediapipe::Status status) { - if (!status.ok()) { - env->Throw(mediapipe::android::CreateMediaPipeException(env, status)); - return true; - } - return false; -} } // namespace JNIEXPORT jlong JNICALL GRAPH_METHOD(nativeCreateGraph)(JNIEnv* env, diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc index 079767512..1bddc0166 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc @@ -131,6 +131,21 @@ jthrowable CreateMediaPipeException(JNIEnv* env, mediapipe::Status status) { env->NewObject(status_cls, status_ctr, status.code(), message_bytes)); } +bool ThrowIfError(JNIEnv* env, mediapipe::Status status) { + if (!status.ok()) { + env->Throw(mediapipe::android::CreateMediaPipeException(env, status)); + return true; + } + return false; +} + +SerializedMessageIds::SerializedMessageIds(JNIEnv* env, jobject data) { + jclass j_class = reinterpret_cast(env->NewGlobalRef(env->FindClass( + "com/google/mediapipe/framework/PacketUtil$SerializedMessage"))); + type_name_id = env->GetFieldID(j_class, "typeName", "Ljava/lang/String;"); + value_id = env->GetFieldID(j_class, "value", "[B"); +} + } // namespace android namespace java { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h index 9efa28304..f52e142ee 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h @@ -30,6 +30,19 @@ std::string JStringToStdString(JNIEnv* env, jstring jstr); // Creates a java MediaPipeException object for a mediapipe::Status. jthrowable CreateMediaPipeException(JNIEnv* env, mediapipe::Status status); +// Throws a MediaPipeException for any non-ok mediapipe::Status. +// Note that the exception is thrown after execution returns to Java. +bool ThrowIfError(JNIEnv* env, mediapipe::Status status); + +// The Jni ids for Java class SerializedMessage. +class SerializedMessageIds { + public: + SerializedMessageIds(JNIEnv* env, jobject data); + jclass j_class; + jfieldID type_name_id; + jfieldID value_id; +}; + } // namespace android namespace java { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 45a4955b3..cb8acf536 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,12 +17,14 @@ #include #include +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" @@ -32,6 +34,8 @@ #endif // !defined(MEDIAPIPE_DISABLE_GPU) namespace { +using mediapipe::android::SerializedMessageIds; +using mediapipe::android::ThrowIfError; template int64_t CreatePacketScalar(jlong context, const T& value) { @@ -49,7 +53,6 @@ int64_t CreatePacketWithContext(jlong context, reinterpret_cast(context); return mediapipe_graph->WrapPacketIntoContext(packet); } - } // namespace JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( @@ -412,6 +415,30 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( return CreatePacketWithContext(context, packet); } +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateProto)(JNIEnv* env, + jobject thiz, + jlong context, + jobject data) { + // Convert type_name and value from Java data. + static SerializedMessageIds ids(env, data); + jstring j_type_name = (jstring)env->GetObjectField(data, ids.type_name_id); + std::string type_name = + mediapipe::android::JStringToStdString(env, j_type_name); + jbyteArray value_array = (jbyteArray)env->GetObjectField(data, ids.value_id); + jsize value_len = env->GetArrayLength(value_array); + jbyte* value_ref = env->GetByteArrayElements(value_array, nullptr); + + // Create the C++ MessageLite and Packet. + mediapipe::Packet packet; + auto packet_or = mediapipe::packet_internal::PacketFromDynamicProto( + type_name, std::string((char*)value_ref, value_len)); + if (!ThrowIfError(env, packet_or.status())) { + packet = packet_or.ValueOrDie(); + } + env->ReleaseByteArrayElements(value_array, value_ref, JNI_ABORT); + return CreatePacketWithContext(context, packet); +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCameraIntrinsics)( JNIEnv* env, jobject thiz, jlong context, jfloat fx, jfloat fy, jfloat cx, jfloat cy, jfloat width, jfloat height) { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index 9d0e73165..e7866382a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -113,6 +113,11 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data); +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateProto)(JNIEnv* env, + jobject thiz, + jlong context, + jobject data); + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCameraIntrinsics)( JNIEnv* env, jobject thiz, jlong context, jfloat fx, jfloat fy, jfloat cx, jfloat cy, jfloat width, jfloat height); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index e9de00e2b..0b1cd788d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" @@ -28,6 +29,8 @@ #endif // !defined(MEDIAPIPE_DISABLE_GPU) namespace { +using mediapipe::android::SerializedMessageIds; +using mediapipe::android::ThrowIfError; template const T& GetFromNativeHandle(int64_t packet_handle) { @@ -143,6 +146,32 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetProtoBytes)( return data; } +JNIEXPORT void JNICALL PACKET_GETTER_METHOD(nativeGetProto)(JNIEnv* env, + jobject thiz, + jlong packet, + jobject result) { + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + mediapipe::Status status = mediapipe_packet.ValidateAsProtoMessageLite(); + if (!ThrowIfError(env, status)) { + // Convert type_name and value to Java data. + const auto& proto_message = mediapipe_packet.GetProtoMessageLite(); + std::string type_name = proto_message.GetTypeName(); + jstring j_type_name = env->NewStringUTF(type_name.c_str()); + std::string proto_bytes; + proto_message.SerializeToString(&proto_bytes); + jbyteArray j_proto_bytes = env->NewByteArray(proto_bytes.length()); + env->SetByteArrayRegion( + j_proto_bytes, 0, proto_bytes.length(), + reinterpret_cast(proto_bytes.c_str())); + + // Set type_name and value in the result Java object. + static SerializedMessageIds ids(env, result); + env->SetObjectField(result, ids.type_name_id, j_type_name); + env->SetObjectField(result, ids.value_id, j_proto_bytes); + } +} + JNIEXPORT jobjectArray JNICALL PACKET_GETTER_METHOD(nativeGetProtoVector)( JNIEnv* env, jobject thiz, jlong packet) { mediapipe::Packet mediapipe_packet = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h index 72c55935d..14b287158 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h @@ -69,6 +69,11 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetBytes)(JNIEnv* env, JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetProtoBytes)( JNIEnv* env, jobject thiz, jlong packet); +JNIEXPORT void JNICALL PACKET_GETTER_METHOD(nativeGetProto)(JNIEnv* env, + jobject thiz, + jlong packet, + jobject result); + JNIEXPORT jobjectArray JNICALL PACKET_GETTER_METHOD(nativeGetProtoVector)( JNIEnv* env, jobject thiz, jlong packet); diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 702771ae9..5a428460d 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -26,12 +26,14 @@ Finally, import the AAR into Android Studio. load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") -def mediapipe_aar(name, calculators = []): +def mediapipe_aar(name, calculators = [], assets = [], assets_dir = ""): """Generate MediaPipe AAR. Args: name: the name of the AAR. calculators: the calculator libraries to be compiled into the .so. + assets: additional assets to be included into the archive. + assets_dir: path where the assets will the packaged. """ native.cc_binary( name = "libmediapipe_jni.so", @@ -136,6 +138,8 @@ cat > $(OUTS) <ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return true; + } + return false; +} + +} // namespace namespace mediapipe { @@ -49,10 +64,7 @@ bool AssetManager::InitializeFromAssetManager( bool AssetManager::InitializeFromContext(JNIEnv* env, jobject context, const std::string& cache_dir_path) { - jni_common::JniHelper jni_helper(env, __LINE__, true); - - int status = env->GetJavaVM(&jvm_); - if (status != 0) { + if (!mediapipe::java::SetJavaVM(env)) { return false; } @@ -70,9 +82,8 @@ bool AssetManager::InitializeFromContext(JNIEnv* env, jobject context, jobject local_asset_manager = env->CallObjectMethod(context_, context_class_get_assets); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); + // TODO: Don't swallow the exception + if (ExceptionPrintClear(env)) { return false; } @@ -112,9 +123,8 @@ bool AssetManager::FileExists(const std::string& filename) { return false; } -bool AssetManager::ReadFile(const std::string& filename, - std::vector* raw_bytes) { - CHECK(raw_bytes); +bool AssetManager::ReadFile(const std::string& filename, std::string* output) { + CHECK(output); if (!asset_manager_) { LOG(ERROR) << "Asset manager was not initialized from JNI"; return false; @@ -126,9 +136,8 @@ bool AssetManager::ReadFile(const std::string& filename, return false; } else { size_t size = AAsset_getLength(asset); - raw_bytes->resize(size); - memcpy(static_cast(&raw_bytes->at(0)), AAsset_getBuffer(asset), - size); + output->resize(size); + memcpy(static_cast(&output->at(0)), AAsset_getBuffer(asset), size); AAsset_close(asset); } return true; @@ -145,7 +154,7 @@ bool AssetManager::ReadFile(const std::string& filename, // For now, since we don't know the app version, we overwrite the cache file // unconditionally. - std::vector asset_data; + std::string asset_data; RET_CHECK(ReadFile(asset_path, &asset_data)) << "could not read asset: " << asset_path; @@ -155,20 +164,19 @@ bool AssetManager::ReadFile(const std::string& filename, std::ofstream output_file(file_path); RET_CHECK(output_file.good()) << "could not open cache file: " << file_path; - output_file.write(reinterpret_cast(asset_data.data()), - asset_data.size()); + output_file << asset_data; RET_CHECK(output_file.good()) << "could not write cache file: " << file_path; return file_path; } -::mediapipe::StatusOr AssetManager::OpenContentUri( - const std::string& content_uri) { - jni_common::JniHelper jni_helper(jvm_, JNI_VERSION_1_6, __LINE__); - JNIEnv* env = jni_helper.GetEnv(); - if (env == nullptr) { - return ::mediapipe::UnavailableError("Couldn't get JNI env."); - } +mediapipe::Status AssetManager::ReadContentUri(const std::string& content_uri, + std::string* output) { + RET_CHECK(mediapipe::java::HasJavaVM()) << "JVM instance not set"; + JNIEnv* env = mediapipe::java::GetJNIEnv(); + RET_CHECK(env != nullptr) << "Unable to retrieve JNIEnv"; + + RET_CHECK(context_ != nullptr) << "Android context not initialized"; // ContentResolver contentResolver = context.getContentResolver(); jclass context_class = env->FindClass("android/content/Context"); @@ -187,29 +195,54 @@ bool AssetManager::ReadFile(const std::string& filename, jobject uri = env->CallStaticObjectMethod( uri_class, uri_parse, env->NewStringUTF(content_uri.c_str())); - // ParcelFileDescriptor descriptor = + // AssetFileDescriptor descriptor = // contentResolver.openAssetFileDescriptor(uri, "r"); - jmethodID content_resolver_open_file_descriptor = env->GetMethodID( - content_resolver_class, "openFileDescriptor", - "(Landroid/net/Uri;Ljava/lang/String;)Landroid/os/ParcelFileDescriptor;"); - jobject parcel_file_descriptor = env->CallObjectMethod( + jmethodID content_resolver_open_file_descriptor = + env->GetMethodID(content_resolver_class, "openAssetFileDescriptor", + "(Landroid/net/Uri;Ljava/lang/String;)" + "Landroid/content/res/AssetFileDescriptor;"); + jobject descriptor = env->CallObjectMethod( content_resolver, content_resolver_open_file_descriptor, uri, env->NewStringUTF("r")); - // int fd = parcelDescriptor.detachFd(); - jclass parcel_descriptor_class = - env->FindClass("android/os/ParcelFileDescriptor"); - jmethodID parcel_class_detach_fd = - env->GetMethodID(parcel_descriptor_class, "detachFd", "()I"); - jint fd = env->CallIntMethod(parcel_file_descriptor, parcel_class_detach_fd); + RET_CHECK(!ExceptionPrintClear(env)) << "unable to open content URI"; - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return ::mediapipe::NotFoundError("Content URI not found"); - } + // long size = descriptor.getLength(); + jclass asset_file_descriptor_class = + env->FindClass("android/content/res/AssetFileDescriptor"); + jmethodID get_length_method = + env->GetMethodID(asset_file_descriptor_class, "getLength", "()J"); + jlong size = env->CallLongMethod(descriptor, get_length_method); - return static_cast(fd); + // byte[] data = new byte[size]; + jbyteArray data = env->NewByteArray(size); + + // FileInputStream stream = descriptor.createInputStream(); + jmethodID create_input_stream_method = + env->GetMethodID(asset_file_descriptor_class, "createInputStream", + "()Ljava/io/FileInputStream;"); + jobject stream = + env->CallObjectMethod(descriptor, create_input_stream_method); + + RET_CHECK(!ExceptionPrintClear(env)) << "failed to create input stream"; + + // stream.read(data); + jclass input_stream_class = env->FindClass("java/io/InputStream"); + jmethodID read_method = env->GetMethodID(input_stream_class, "read", "([B)I"); + env->CallIntMethod(stream, read_method, data); + + RET_CHECK(!ExceptionPrintClear(env)) << "failed to read input stream"; + + // stream.close(); + jmethodID close_method = env->GetMethodID(input_stream_class, "close", "()V"); + env->CallVoidMethod(stream, close_method); + + output->resize(size); + env->GetByteArrayRegion(data, 0, size, + reinterpret_cast(&output->at(0))); + RET_CHECK(!ExceptionPrintClear(env)) << "failed to copy array data"; + + return mediapipe::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/android/asset_manager_util.h b/mediapipe/util/android/asset_manager_util.h index 4644c4825..677d7c42c 100644 --- a/mediapipe/util/android/asset_manager_util.h +++ b/mediapipe/util/android/asset_manager_util.h @@ -15,6 +15,9 @@ #ifndef MEDIAPIPE_ANDROID_UTIL_ASSET_MANAGER_UTIL_H_ #define MEDIAPIPE_ANDROID_UTIL_ASSET_MANAGER_UTIL_H_ +#include +#include + #ifdef __ANDROID__ #include #include @@ -65,12 +68,13 @@ class AssetManager { // Checks if a file exists. Returns true on success, false otherwise. bool FileExists(const std::string& filename); - // Reads a file into raw_bytes. Returns true on success, false otherwise. - bool ReadFile(const std::string& filename, std::vector* raw_bytes); + // Reads a file into output. Returns true on success, false otherwise. + bool ReadFile(const std::string& filename, std::string* output); - // Returns the open file descriptor from an Android content URI, the caller - // is responsible to close the file descriptor. - ::mediapipe::StatusOr OpenContentUri(const std::string& content_uri); + // Reads the raw bytes referred to by the supplied content URI. Returns true + // on success, false otherwise. + mediapipe::Status ReadContentUri(const std::string& content_uri, + std::string* output); // Returns the path to the Android cache directory. Will be empty if // InitializeFromActivity has not been called. @@ -92,9 +96,6 @@ class AssetManager { // The context from which assets should be loaded. jobject context_; - // Pointer to the JVM, used to get the JNIEnv on background threads. - JavaVM* jvm_; - // Path to the Android cache directory for our context. std::string cache_dir_path_; diff --git a/mediapipe/util/android/jni_helper.cc b/mediapipe/util/android/jni_helper.cc deleted file mode 100644 index fb171ff59..000000000 --- a/mediapipe/util/android/jni_helper.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2020 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mediapipe/util/android/jni_helper.h" - -#include "mediapipe/util/android/logging.h" - -namespace mediapipe { -namespace jni_common { - -JniHelper::JniHelper(JavaVM* vm, jint version, int caller_line, - bool enable_logging) - : vm_(vm), - env_(nullptr), - need_to_detach_(false), - caller_line_(caller_line), - enable_logging_(enable_logging) { - JNI_COMMON_CHECK(vm_); - const int code = vm_->GetEnv(reinterpret_cast(&env_), version); - if (code == JNI_OK) { - if (0 != env_->PushLocalFrame(0)) { // Create new stack frame. - ExceptionPrintClear(env_); - if (enable_logging_) { - JNI_COMMON_LOG(VERBOSE, "JniHelper: failed to push local frame."); - } - env_ = nullptr; - } - } else if (code == JNI_EDETACHED) { - if (vm_->AttachCurrentThread(&env_, nullptr) == JNI_OK) { - if (enable_logging_) { - JNI_COMMON_LOG(VERBOSE, - "JniHelper: attached thread (Called from line %d).", - caller_line_); - } - need_to_detach_ = true; - } else { - if (enable_logging_) { - JNI_COMMON_LOG( - ERROR, - "JniHelper: couldn't attach current thread (Called from line %d).", - caller_line_); - } - env_ = nullptr; - } - } else { - if (enable_logging_) { - JNI_COMMON_LOG(ERROR, - "JniHelper: couldn't get env (Called from line %d).", - caller_line_); - } - env_ = nullptr; - } -} - -JniHelper::JniHelper(JNIEnv* env, int caller_line, bool enable_logging) - : vm_(nullptr), - env_(env), - need_to_detach_(false), - caller_line_(caller_line), - enable_logging_(enable_logging) { - JNI_COMMON_CHECK(env_); - if (0 != env_->PushLocalFrame(0)) { // Create new stack frame. - ExceptionPrintClear(env_); - if (enable_logging_) { - JNI_COMMON_LOG( - VERBOSE, - "JniHelper: failed to push local frame (Called from line %d).", - caller_line_); - } - env_ = nullptr; - } -} - -JniHelper::~JniHelper() { - if (need_to_detach_) { - if (enable_logging_) { - JNI_COMMON_LOG( - VERBOSE, "~JniHelper: about to detach thread (Called from line %d).", - caller_line_); - } - if (vm_->DetachCurrentThread() == JNI_OK) { - if (enable_logging_) { - JNI_COMMON_LOG(VERBOSE, - "~JniHelper: detached thread (Called from line %d).", - caller_line_); - } - } else { - if (enable_logging_) { - JNI_COMMON_LOG( - ERROR, "~JniHelper: couldn't detach thread (Called from line %d).", - caller_line_); - } - } - } else { - if (env_ != nullptr) { - env_->PopLocalFrame(nullptr); // Clean up local references. - } - } -} - -JNIEnv* JniHelper::GetEnv() const { return env_; } - -} // namespace jni_common -} // namespace mediapipe diff --git a/mediapipe/util/android/jni_helper.h b/mediapipe/util/android/jni_helper.h deleted file mode 100644 index 0e603e931..000000000 --- a/mediapipe/util/android/jni_helper.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_UTIL_ANDROID_JNI_HELPER_H_ -#define MEDIAPIPE_UTIL_ANDROID_JNI_HELPER_H_ - -#include - -namespace mediapipe { -namespace jni_common { - -inline bool ExceptionPrintClear(JNIEnv* env) { - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return true; - } - return false; -} - -class JniHelper { - public: - // This constructor should be used when a JavaVM pointer is available, and the - // JNIEnv needs to be obtained using AttachCurrentThread. This will also push - // a local stack frame, and pop it when this object is destroyed. If - // enable_logging is true, it will log verbosely in the constructor and - // destructor. - JniHelper(JavaVM* vm, jint version, int caller_line, - bool enable_logging = true); - - // This constructor should be used then the JNIEnv pointer itself is - // available, and the only thing that needs to be taken care of is pushing and - // popping the stack frames. If enable_logging is true, it will log verbosely - // in the constructor and destructor. - JniHelper(JNIEnv* env, int caller_line, bool enable_logging = true); - - // Detaches the current thread, if necessary, and pops the local stack frame - // that was pushed during construction. - ~JniHelper(); - - // Copy and assignment are disallowed because it could cause double-detaching. - JniHelper(const JniHelper& other) = delete; - JniHelper& operator=(const JniHelper& other) = delete; - - JNIEnv* GetEnv() const; - - private: - JavaVM* vm_; - JNIEnv* env_; - bool need_to_detach_; - const int caller_line_; - const bool enable_logging_; -}; - -} // namespace jni_common -} // namespace mediapipe - -#endif // MEDIAPIPE_UTIL_ANDROID_JNI_HELPER_H_ diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 5fa53d57b..298217dbb 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -14,14 +14,23 @@ #include "mediapipe/util/resource_util.h" +#include "absl/flags/flag.h" +#include "absl/strings/str_split.h" +#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/ret_check.h" + +ABSL_FLAG( + std::string, resource_root_dir, "", + "The absolute path to the resource directory." + "If specified, resource_root_dir will be prepended to the original path."); namespace mediapipe { -// Trivial implementation for Linux. For now just returns the path. ::mediapipe::StatusOr PathToResourceAsFile( const std::string& path) { - return path; + return ::mediapipe::file::JoinPath(FLAGS_resource_root_dir.CurrentValue(), + path); } ::mediapipe::Status GetResourceContents(const std::string& path, diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index 449fbf376..04dde6e25 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/match.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/singleton.h" @@ -62,22 +64,13 @@ namespace { } if (absl::StartsWith(path, "content://")) { - auto fd_status = Singleton::get()->OpenContentUri(path); - if (!fd_status.ok()) { - return ::mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to open file: " + std::string(path)); - } - int fd = fd_status.ValueOrDie(); - auto status = file::GetContents(fd, output); - - close(fd); - return status; + MP_RETURN_IF_ERROR( + Singleton::get()->ReadContentUri(path, output)); + return ::mediapipe::OkStatus(); } - std::vector data; - RET_CHECK(Singleton::get()->ReadFile(path, &data)) + RET_CHECK(Singleton::get()->ReadFile(path, output)) << "could not read asset: " << path; - output->assign(reinterpret_cast(data.data()), data.size()); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 9d8a6f3db..85a8a6e69 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -22,6 +22,7 @@ package(default_visibility = [ cc_library( name = "config", hdrs = ["config.h"], + features = ["-parse_headers"], deps = [ "//mediapipe/framework:calculator_framework", ], diff --git a/third_party/opencv_linux.BUILD b/third_party/opencv_linux.BUILD index 2f21cbe92..26978d532 100644 --- a/third_party/opencv_linux.BUILD +++ b/third_party/opencv_linux.BUILD @@ -25,8 +25,18 @@ cc_library( "lib/x86_64-linux-gnu/libopencv_videoio.so", ], ), - hdrs = glob(["include/opencv2/**/*.h*"]), - includes = ["include/"], + hdrs = glob([ + # For OpenCV 3.x + "include/opencv2/**/*.h*", + # For OpenCV 4.x + # "include/opencv4/opencv2/**/*.h*", + ]), + includes = [ + # For OpenCV 3.x + "include/", + # For OpenCV 4.x + # "include/opencv4/", + ], linkstatic = 1, visibility = ["//visibility:public"], )