Project import generated by Copybara.

GitOrigin-RevId: b2062656e5b3d33264e28ed0cbca31c4b93fe1bf
This commit is contained in:
MediaPipe Team 2020-07-29 20:33:39 -04:00 committed by chuoling
parent e9fbe868e5
commit bdfdaef305
132 changed files with 2568 additions and 820 deletions

View File

@ -1,15 +1,17 @@
# Contributing guidelines # Contributing guidelines
## Pull Request Checklist ## What type of pull request do we accept into MediaPipe repository?
* Bug fixes
* Documentation fixes
For new feature additions (e.g., new graphs and calculators), we are currently not planning to accept new feature pull requests into the MediaPipe repository. Instead, we like to get contributors to create their own repositories of the new feature and list it at [Awesome MediaPipe](https://mediapipe.org). This will allow contributors to more quickly get their code out to the community.
Before sending your pull requests, make sure you followed this list. Before sending your pull requests, make sure you followed this list.
- Read [contributing guidelines](CONTRIBUTING.md). - Read [contributing guidelines](CONTRIBUTING.md).
- Read [Code of Conduct](CODE_OF_CONDUCT.md). - Read [Code of Conduct](CODE_OF_CONDUCT.md).
- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). - Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
- Check if my changes are consistent with the [guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
- Changes are consistent with the [Coding Style](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#c-coding-style).
- Run [Unit Tests](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#running-unit-tests).
## How to become a contributor and submit your own code ## How to become a contributor and submit your own code
@ -28,100 +30,7 @@ Follow either of the two links above to access the appropriate CLA and instructi
### Contributing code ### Contributing code
If you have improvements to MediaPipe, send us your pull requests! For those If you have bug fixes and documentation fixes to MediaPipe, send us your pull requests! For those
just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/). just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/).
MediaPipe team members will be assigned to review your pull requests. Once the MediaPipe team members will be assigned to review your pull requests. Once the bug/documentation fixes are verified, a MediaPipe team member will acknowledge your contribution in the pull request comments, manually merge the fixes into our internal codebase upstream, and apply the `to be closed` label to the pull request. These fixes will later be pushed to GitHub in the next release, and a MediaPipe team member will then close the pull request.
pull requests are approved and pass continuous integration checks, a MediaPipe
team member will apply `ready to pull` label to your change. This means we are
working on getting your pull request submitted to our internal repository. After
the change has been submitted internally, your pull request will be merged
automatically on GitHub.
If you want to contribute but you're not sure where to start, take a look at the
[issues with the "contributions welcome" label](https://github.com/google/mediapipe/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you
decide to start on an issue, leave a comment so that other people know that
you're working on it. If you want to help out, but not alone, use the issue
comment thread to coordinate.
### Contribution guidelines and standards
Before sending your pull request for
[review](https://github.com/google/mediapipe/pulls),
make sure your changes are consistent with the guidelines and follow the
MediaPipe coding style.
#### General guidelines and philosophy for contribution
* Include unit tests when you contribute new features, as they help to a)
prove that your code works correctly, and b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require unit tests, because the presence of bugs
usually indicates insufficient test coverage.
* Keep API compatibility in mind when you change code in MediaPipe framework
e.g., code in
[mediapipe/framework](https://github.com/google/mediapipe/tree/master/mediapipe/framework)
and
[mediapipe/calculators](https://github.com/google/mediapipe/tree/master/mediapipe/calculators).
Once MediaPipe has reached version 1 and we will not make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to MediaPipe, the maintenance burden is
(by default) transferred to the MediaPipe team. This means that benefit of
the contribution must be compared against the cost of maintaining the
feature.
* Full new features (e.g., a new op implementing a cutting-edge algorithm)
typically will live in
[mediapipe/addons](https://github.com/google/mediapipe/addons) to get some
airtime before decision is made regarding whether they are to be migrated to
the core.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/calculator_base.cc#L1)
* [Java license example](https://github.com/google/mediapipe/blob/master/mediapipe/java/com/google/mediapipe/components/CameraHelper.java)
Bazel BUILD files also need to include a license section, e.g.,
[BUILD example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/BUILD#L61).
#### C++ coding style
Changes to MediaPipe C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
```
You can check a C/C++ file by doing:
```bash
clang-format <my_cc_file> --style=google > /tmp/my_cc_file.cc
diff <my_cc_file> /tmp/my_cc_file.cc
```
#### Coding style for other languages
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html)
#### Running sanity check
If you have Docker installed on your system, you can perform a sanity check on
your changes by running the command:
```bash
mediapipe/tools/ci_build/ci_build.sh CPU mediapipe/tools/ci_build/ci_sanity.sh
```
This will catch most license, Python coding style and BUILD file issues that
may exist in your changes.

View File

@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software MediaPipe related frameworks, libraries and software
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users * [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe community discussion around MediaPipe

View File

@ -207,8 +207,8 @@ class SomeAudioVideoCalculator : public CalculatorBase {
// particular type. SetAny() has the same effect as explicitly // particular type. SetAny() has the same effect as explicitly
// setting the type to be the stream's type. // setting the type to be the stream's type.
cc->Outputs().Tag("VIDEO").Set<ImageFrame>(); cc->Outputs().Tag("VIDEO").Set<ImageFrame>();
cc->Outputs().Get("AUDIO", 0).Set<Matrix>; cc->Outputs().Get("AUDIO", 0).Set<Matrix>();
cc->Outputs().Get("AUDIO", 1).Set<Matrix>; cc->Outputs().Get("AUDIO", 1).Set<Matrix>();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
``` ```
@ -400,13 +400,13 @@ node {
``` ```
The diagram below shows how the `PacketClonerCalculator` defines its output The diagram below shows how the `PacketClonerCalculator` defines its output
packets based on its series of input packets. packets (bottom) based on its series of input packets (top).
| ![Graph using | | ![Graph using |
: PacketClonerCalculator](../images/packet_cloner_calculator.png) : : PacketClonerCalculator](../images/packet_cloner_calculator.png) :
| :--------------------------------------------------------------------------: | | :--------------------------------------------------------------------------: |
| *Each time it receives a packet on its TICK input stream, the | | *Each time it receives a packet on its TICK input stream, the |
: PacketClonerCalculator outputs the most recent packet from each of its input : : PacketClonerCalculator outputs the most recent packet from each of its input :
: streams. The sequence of output packets is determined by the sequene of : : streams. The sequence of output packets (bottom) is determined by the :
: input packets and their timestamps. The timestamps are shows along the right : : sequence of input packets (top) and their timestamps. The timestamps are :
: side of the diagram.* : : shown along the right side of the diagram.* :

View File

@ -184,8 +184,8 @@ app:
### Prerequisite ### Prerequisite
1. Install [Xcode](https://developer.apple.com/xcode/), and additionally 1. Install [Xcode](https://developer.apple.com/xcode/), then install the
install the Command Line Tools by: Command Line Tools using:
```bash ```bash
xcode-select --install xcode-select --install
@ -196,74 +196,38 @@ app:
We recommend using [Homebrew](https://brew.sh/) to get the latest version. We recommend using [Homebrew](https://brew.sh/) to get the latest version.
3. Set Python 3.7 as the default Python version and install the Python "six" 3. Set Python 3.7 as the default Python version and install the Python "six"
library. library. This is needed for TensorFlow.
To make Mediapipe work with TensorFlow, please set Python 3.7 as the default
Python version and install the Python "six" library.
```bash ```bash
pip3 install --user six pip3 install --user six
``` ```
4. Follow 4. Clone the MediaPipe repository.
[Apple's instructions](https://developer.apple.com/support/certificates/) to
obtain the required development certificates and provisioning profiles for
your iOS device.
Tip: You can the following command to see the provisioning profiles you have
previously downloaded using Xcode: `open
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
5. Clone the MediaPipe repository.
```bash ```bash
git clone https://github.com/google/mediapipe.git git clone https://github.com/google/mediapipe.git
``` ```
6. In the cloned MediaPipe repository, symlink or copy your provisioning profile ### Set up a bundle ID prefix
to `mediapipe/provisioning_profile.mobileprovision`, e.g.,
```bash All iOS apps must have a bundle ID, and you must have a provisioning profile
cd mediapipe that lets you install an app with that ID onto your phone. To avoid clashes
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision between different MediaPipe users, you need to configure a unique prefix for the
``` bundle IDs of our iOS demo apps.
### Option 1: Build with Bazel in Command Line If you have a custom provisioning profile, see
[Custom provisioning](#custom-provisioning) below.
1. Modify the `bundle_id` field of the app's `ios_application` build target to Otherwise, run this command to generate a unique prefix:
use your own identifier. For instance, for
[MediaPipe Hands](../solutions/hands.md), the `bundle_id` is in the
`HandTrackingGpuApp` target in the
[BUILD](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handtrackinggpu/BUILD)
file.
2. Again using [MediaPipe Hands](../solutions/hands.md) for example, run: ```bash
python3 mediapipe/examples/ios/link_local_profiles.py
```
```bash ### Create an Xcode project
bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp
```
You may see a permission request from `codesign` in order to sign the app. This allows you to edit and debug one of the example apps in Xcode. It also
allows you to make use of automatic provisioning (see later section).
Tip: You can run this
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
to build all MediaPipe iOS example apps.
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
4. Make sure your device is connected. You will see a list of installed apps.
Press the "+" button under the list, and select the `.ipa` file built by
Bazel.
5. You can now run the app on your device.
### Option 2: Build in Xcode
Note: This workflow requires a separate tool in addition to Bazel. If it fails
to work for some reason, please resort to the command-line build instructions in
the previous section.
1. We will use a tool called [Tulsi](https://tulsi.bazel.build/) for generating 1. We will use a tool called [Tulsi](https://tulsi.bazel.build/) for generating
Xcode projects from Bazel build configurations. Xcode projects from Bazel build configurations.
@ -283,25 +247,138 @@ the previous section.
2. Open `mediapipe/Mediapipe.tulsiproj` using the Tulsi app. 2. Open `mediapipe/Mediapipe.tulsiproj` using the Tulsi app.
Important: If Tulsi displays an error saying "Bazel could not be found", Tip: If Tulsi displays an error saying "Bazel could not be found", press the
press the "Bazel..." button in the Packages tab and select the `bazel` "Bazel..." button in the Packages tab and select the `bazel` executable in
executable in your homebrew `/bin/` directory. your homebrew `/bin/` directory.
3. Select the MediaPipe config in the Configs tab, then press the Generate 3. Select the MediaPipe config in the Configs tab, then press the Generate
button below. You will be asked for a location to save the Xcode project. button below. You will be asked for a location to save the Xcode project.
Once the project is generated, it will be opened in Xcode. Once the project is generated, it will be opened in Xcode.
4. You can now select any of the MediaPipe demos in the target menu, and build If you get an error about bundle IDs, see the
[previous section](#set-up-a-bundle-id-prefix).
### Set up provisioning
To install applications on an iOS device, you need a provisioning profile. There
are two options:
1. Automatic provisioning. This allows you to build and install an app to your
personal device. The provisining profile is managed by Xcode, and has to be
updated often (it is valid for about a week).
2. Custom provisioning. This uses a provisioning profile associated with an
Apple developer account. These profiles have a longer validity period and
can target multiple devices, but you need a paid developer account with
Apple to obtain one.
#### Automatic provisioning
1. Create an Xcode project for MediaPipe, as discussed
[earlier](#create-an-xcode-project).
2. In the project navigator in the left sidebar, select the "Mediapipe"
project.
3. Select the "Signing & Capabilities" tab.
4. Select one of the application targets, e.g. HandTrackingGpuApp.
5. Check "Automatically manage signing", and confirm the dialog box.
6. Select "_Your Name_ (Personal Team)" in the Team pop-up menu.
7. This set-up needs to be done once for each application you want to install.
Repeat steps 4-6 as needed.
This generates provisioning profiles for each app you have selected. Now we need
to tell Bazel to use them. We have provided a script to make this easier.
1. In the terminal, to the `mediapipe` directory where you cloned the
repository.
2. Run this command:
```bash
python3 mediapipe/examples/ios/link_local_profiles.py
```
This will find and link the provisioning profile for all applications for which
you have enabled automatic provisioning in Xcode.
Note: once a profile expires, Xcode will generate a new one; you must then run
this script again to link the updated profiles.
#### Custom provisioning
1. Obtain a provisioning profile from Apple.
Tip: You can use this command to see the provisioning profiles you have
previously downloaded using Xcode: `open ~/Library/MobileDevice/"Provisioning
Profiles"`. If there are none, generate and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
1. Symlink or copy your provisioning profile to
`mediapipe/mediapipe/provisioning_profile.mobileprovision`.
```bash
cd mediapipe
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
```
Note: if you had previously set up automatic provisioning, you should remove the
`provisioning_profile.mobileprovision` symlink in each example's directory,
since it will take precedence over the common one. You can also overwrite it
with you own profile if you need a different profile for different apps.
1. Open `mediapipe/examples/ios/bundle_id.bzl`, and change the
`BUNDLE_ID_PREFIX` to a prefix associated with your provisioning profile.
### Build and run an app using Xcode
1. Create the Xcode project, and make sure you have set up either automatic or
custom provisioning.
2. You can now select any of the MediaPipe demos in the target menu, and build
and run them as normal. and run them as normal.
Note: When you ask Xcode to run an app, by default it will use the Debug Note: When you ask Xcode to run an app, by default it will use the Debug
configuration. Some of our demos are computationally heavy; you may want to configuration. Some of our demos are computationally heavy; you may want to use
use the Release configuration for better performance. the Release configuration for better performance.
Tip: To switch build configuration in Xcode, click on the target menu, Tip: To switch build configuration in Xcode, click on the target menu, choose
choose "Edit Scheme...", select the Run action, and switch the Build "Edit Scheme...", select the Run action, and switch the Build Configuration from
Configuration from Debug to Release. Note that this is set independently for Debug to Release. Note that this is set independently for each target.
each target.
Tip: On the device, in Settings > General > Device Management, make sure the
developer (yourself) is trusted.
### Build an app using the command line
1. Make sure you have set up either automatic or custom provisioning.
2. Using [MediaPipe Hands](../solutions/hands.md) for example, run:
```bash
bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp
```
You may see a permission request from `codesign` in order to sign the app.
Tip: If you are using custom provisioning, you can run this
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
to build all MediaPipe iOS example apps.
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
4. Make sure your device is connected. You will see a list of installed apps.
Press the "+" button under the list, and select the `.ipa` file built by
Bazel.
5. You can now run the app on your device.
Tip: On the device, in Settings > General > Device Management, make sure the
developer (yourself) is trusted.
## Desktop ## Desktop

View File

@ -43,8 +43,8 @@ We will be using the following graph, [`edge_detection_mobile_gpu.pbtxt`]:
``` ```
# MediaPipe graph that performs GPU Sobel edge detection on a live video stream. # MediaPipe graph that performs GPU Sobel edge detection on a live video stream.
# Used in the examples # Used in the examples in
# mediapipe/examples/android/src/java/com/mediapipe/apps/basic. # mediapipe/examples/android/src/java/com/mediapipe/apps/basic and
# mediapipe/examples/ios/edgedetectiongpu. # mediapipe/examples/ios/edgedetectiongpu.
# Images coming into and out of the graph. # Images coming into and out of the graph.
@ -764,7 +764,7 @@ If you ran into any issues, please see the full code of the tutorial
[CameraX]:https://developer.android.com/training/camerax [CameraX]:https://developer.android.com/training/camerax
[`CameraXPreviewHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java [`CameraXPreviewHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java
[developer options]:https://developer.android.com/studio/debug/dev-options [developer options]:https://developer.android.com/studio/debug/dev-options
[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt
[`EglManager`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/glutil/EglManager.java [`EglManager`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/glutil/EglManager.java
[`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java
[`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout [`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout

View File

@ -18,7 +18,7 @@ nav_order: 5
2. To run the [`hello world`] example: 2. To run the [`hello world`] example:
```bash ```bash
$ git clone https://github.com/google/mediapipe/mediapipe.git $ git clone https://github.com/google/mediapipe.git
$ cd mediapipe $ cd mediapipe
$ export GLOG_logtostderr=1 $ export GLOG_logtostderr=1
@ -92,10 +92,10 @@ nav_order: 5
```c++ ```c++
CalculatorGraph graph; CalculatorGraph graph;
RETURN_IF_ERROR(graph.Initialize(config)); MP_RETURN_IF_ERROR(graph.Initialize(config));
ASSIGN_OR_RETURN(OutputStreamPoller poller, MP_ASSIGN_OR_RETURN(OutputStreamPoller poller,
graph.AddOutputStreamPoller("out")); graph.AddOutputStreamPoller("out"));
RETURN_IF_ERROR(graph.StartRun({})); MP_RETURN_IF_ERROR(graph.StartRun({}));
``` ```
5. The example then creates 10 packets (each packet contains a string "Hello 5. The example then creates 10 packets (each packet contains a string "Hello
@ -105,9 +105,10 @@ nav_order: 5
```c++ ```c++
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
RETURN_IF_ERROR(graph.AddPacketToInputStream("in", MakePacket<std::string>("Hello World!").At(Timestamp(i)))); MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in",
MakePacket<std::string>("Hello World!").At(Timestamp(i))));
} }
RETURN_IF_ERROR(graph.CloseInputStream("in")); MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
``` ```
6. Through the `OutputStreamPoller` object the example then retrieves all 10 6. Through the `OutputStreamPoller` object the example then retrieves all 10

View File

@ -56,7 +56,7 @@ node: {
output_stream: "luma_video" output_stream: "luma_video"
} }
# Applies the Sobel filter to luminance images sotred in RGB format. # Applies the Sobel filter to luminance images stored in RGB format.
node: { node: {
calculator: "SobelEdgesCalculator" calculator: "SobelEdgesCalculator"
input_stream: "luma_video" input_stream: "luma_video"

View File

@ -70,9 +70,15 @@ apps, see these [instructions](./building_examples.md#ios).
libopencv-imgproc-dev libopencv-video-dev libopencv-imgproc-dev libopencv-video-dev
``` ```
[`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia Debian 9 and Ubuntu 18.04 install the packages in
Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be `/usr/lib/x86_64-linux-gnu`. MediaPipe's [`opencv_linux.BUILD`] and
modified. [`ffmpeg_linux.BUILD`] are configured for this library path. Ubuntu 20.04
may install the OpenCV and FFmpeg packages in `/usr/local`, Please follow
the option 3 below to modify the [`WORKSPACE`], [`opencv_linux.BUILD`] and
[`ffmpeg_linux.BUILD`] files accordingly.
Moreover, for Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, the
library path needs to be modified like the following:
```bash ```bash
sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD
@ -85,11 +91,13 @@ apps, see these [instructions](./building_examples.md#ios).
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
to manually build OpenCV from source code. to manually build OpenCV from source code.
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg
in "/usr/local/", you need to update the "linux_opencv" new_local_repository libraries. For example if OpenCV and FFmpeg are both manually installed in
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] "/usr/local/", you will need to update: (1) the "linux_opencv" and
like the following: "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv"
cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg"
cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below:
```bash ```bash
new_local_repository( new_local_repository(
@ -98,6 +106,12 @@ apps, see these [instructions](./building_examples.md#ios).
path = "/usr/local", path = "/usr/local",
) )
new_local_repository(
name = "linux_ffmpeg",
build_file = "@//third_party:ffmpeg_linux.BUILD",
path = "/usr/local",
)
cc_library( cc_library(
name = "opencv", name = "opencv",
srcs = glob( srcs = glob(
@ -110,8 +124,36 @@ apps, see these [instructions](./building_examples.md#ios).
"lib/libopencv_videoio.so", "lib/libopencv_videoio.so",
], ],
), ),
hdrs = glob(["include/opencv4/**/*.h*"]), hdrs = glob([
includes = ["include/opencv4/"], # For OpenCV 3.x
"include/opencv2/**/*.h*",
# For OpenCV 4.x
# "include/opencv4/opencv2/**/*.h*",
]),
includes = [
# For OpenCV 3.x
"include/",
# For OpenCV 4.x
# "include/opencv4/",
],
linkstatic = 1,
visibility = ["//visibility:public"],
)
cc_library(
name = "libffmpeg",
srcs = glob(
[
"lib/libav*.so",
],
),
hdrs = glob(["include/libav*/*.h"]),
includes = ["include"],
linkopts = [
"-lavcodec",
"-lavformat",
"-lavutil",
],
linkstatic = 1, linkstatic = 1,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -158,6 +200,10 @@ apps, see these [instructions](./building_examples.md#ios).
# Hello World! # Hello World!
``` ```
If you run into a build error, please read
[Troubleshooting](./troubleshooting.md) to find the solutions of several common
build issues.
## Installing on CentOS ## Installing on CentOS
**Disclaimer**: Running MediaPipe on CentOS is experimental. **Disclaimer**: Running MediaPipe on CentOS is experimental.
@ -190,11 +236,13 @@ apps, see these [instructions](./building_examples.md#ios).
Option 2. Build OpenCV from source code. Option 2. Build OpenCV from source code.
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to Note: You may need to modify [`WORKSPACE`], [`opencv_linux.BUILD`] and
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed [`ffmpeg_linux.BUILD`] to point MediaPipe to your own OpenCV and FFmpeg
in "/usr/local/", you need to update the "linux_opencv" new_local_repository libraries. For example if OpenCV and FFmpeg are both manually installed in
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] "/usr/local/", you will need to update: (1) the "linux_opencv" and
like the following: "linux_ffmpeg" new_local_repository rules in [`WORKSPACE`], (2) the "opencv"
cc_library rule in [`opencv_linux.BUILD`], and (3) the "libffmpeg"
cc_library rule in [`ffmpeg_linux.BUILD`]. These 3 changes are shown below:
```bash ```bash
new_local_repository( new_local_repository(
@ -203,6 +251,12 @@ apps, see these [instructions](./building_examples.md#ios).
path = "/usr/local", path = "/usr/local",
) )
new_local_repository(
name = "linux_ffmpeg",
build_file = "@//third_party:ffmpeg_linux.BUILD",
path = "/usr/local",
)
cc_library( cc_library(
name = "opencv", name = "opencv",
srcs = glob( srcs = glob(
@ -215,8 +269,36 @@ apps, see these [instructions](./building_examples.md#ios).
"lib/libopencv_videoio.so", "lib/libopencv_videoio.so",
], ],
), ),
hdrs = glob(["include/opencv4/**/*.h*"]), hdrs = glob([
includes = ["include/opencv4/"], # For OpenCV 3.x
"include/opencv2/**/*.h*",
# For OpenCV 4.x
# "include/opencv4/opencv2/**/*.h*",
]),
includes = [
# For OpenCV 3.x
"include/",
# For OpenCV 4.x
# "include/opencv4/",
],
linkstatic = 1,
visibility = ["//visibility:public"],
)
cc_library(
name = "libffmpeg",
srcs = glob(
[
"lib/libav*.so",
],
),
hdrs = glob(["include/libav*/*.h"]),
includes = ["include"],
linkopts = [
"-lavcodec",
"-lavformat",
"-lavutil",
],
linkstatic = 1, linkstatic = 1,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -243,6 +325,10 @@ apps, see these [instructions](./building_examples.md#ios).
# Hello World! # Hello World!
``` ```
If you run into a build error, please read
[Troubleshooting](./troubleshooting.md) to find the solutions of several common
build issues.
## Installing on macOS ## Installing on macOS
1. Prework: 1. Prework:
@ -375,6 +461,10 @@ apps, see these [instructions](./building_examples.md#ios).
# Hello World! # Hello World!
``` ```
If you run into a build error, please read
[Troubleshooting](./troubleshooting.md) to find the solutions of several common
build issues.
## Installing on Windows ## Installing on Windows
**Disclaimer**: Running MediaPipe on Windows is experimental. **Disclaimer**: Running MediaPipe on Windows is experimental.
@ -454,13 +544,13 @@ next section.
9. Run the [Hello World desktop example](./hello_world_desktop.md). 9. Run the [Hello World desktop example](./hello_world_desktop.md).
Note: For building MediaPipe on Windows, please add `--action_env Note: For building MediaPipe on Windows, please add `--action_env
PYTHON_BIN_PATH="C:/path/to/python.exe"` to the build command. PYTHON_BIN_PATH="C://path//to//python.exe"` to the build command.
Alternatively, you can follow Alternatively, you can follow
[issue 724](https://github.com/google/mediapipe/issues/724) to fix the [issue 724](https://github.com/google/mediapipe/issues/724) to fix the
python configuration manually. python configuration manually.
``` ```
C:\Users\Username\mediapipe_repo>bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 --action_env PYTHON_BIN_PATH="C:/python_36/python.exe" mediapipe/examples/desktop/hello_world C:\Users\Username\mediapipe_repo>bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 --action_env PYTHON_BIN_PATH="C://python_36//python.exe" mediapipe/examples/desktop/hello_world
C:\Users\Username\mediapipe_repo>set GLOG_logtostderr=1 C:\Users\Username\mediapipe_repo>set GLOG_logtostderr=1
@ -480,6 +570,10 @@ next section.
``` ```
If you run into a build error, please read
[Troubleshooting](./troubleshooting.md) to find the solutions of several common
build issues.
## Installing on Windows Subsystem for Linux (WSL) ## Installing on Windows Subsystem for Linux (WSL)
Note: The pre-built OpenCV packages don't support cameras in WSL. Unless you Note: The pre-built OpenCV packages don't support cameras in WSL. Unless you
@ -603,6 +697,10 @@ cameras. Alternatively, you use a video file as input.
# Hello World! # Hello World!
``` ```
If you run into a build error, please
read [Troubleshooting](./troubleshooting.md) to find the solutions of several
common build issues.
## Installing using Docker ## Installing using Docker
This will use a Docker image that will isolate mediapipe's installation from the rest of the system. This will use a Docker image that will isolate mediapipe's installation from the rest of the system.
@ -653,6 +751,10 @@ This will use a Docker image that will isolate mediapipe's installation from the
# Hello World! # Hello World!
``` ```
If you run into a build error, please
read [Troubleshooting](./troubleshooting.md) to find the solutions of several
common build issues.
4. Build a MediaPipe Android example. 4. Build a MediaPipe Android example.
```bash ```bash
@ -692,6 +794,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
[`WORKSPACE`]: https://github.com/google/mediapipe/blob/master/WORKSPACE [`WORKSPACE`]: https://github.com/google/mediapipe/blob/master/WORKSPACE
[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD [`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD
[`ffmpeg_linux.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_linux.BUILD
[`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD [`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD
[`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD [`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD
[`setup_opencv.sh`]: https://github.com/google/mediapipe/blob/master/setup_opencv.sh [`setup_opencv.sh`]: https://github.com/google/mediapipe/blob/master/setup_opencv.sh

View File

@ -12,6 +12,90 @@ nav_order: 10
{:toc} {:toc}
--- ---
## Missing Python binary path
The error message:
```
ERROR: An error occurred during the fetch of repository 'local_execution_config_python':
Traceback (most recent call last):
File "/sandbox_path/external/org_tensorflow/third_party/py/python_configure.bzl", line 208
get_python_bin(repository_ctx)
...
Repository command failed
```
usually indicates that Bazel fails to find the local Python binary. To solve
this issue, please first find where the python binary is and then add
`--action_env PYTHON_BIN_PATH=<path to python binary>` to the Bazel command like
the following:
```
bazel build -c opt \
--define MEDIAPIPE_DISABLE_GPU=1 \
--action_env PYTHON_BIN_PATH="/path/to/python" \
mediapipe/examples/desktop/hello_world
```
## Missing necessary Python packages
The error message:
```
ImportError: No module named numpy
Is numpy installed?
```
usually indicates that certain Python packages are not installed. Please run
`pip install` or `pip3 install` depending on your Python binary version to
install those packages.
## Fail to fetch remote dependency repositories
The error message:
```
ERROR: An error occurred during the fetch of repository 'org_tensorflow':
java.io.IOException: Error downloading [https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz, https://github.com/tensorflow/tensorflow/archive/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz] to /sandbox_path/external/org_tensorflow/77e9ffb9b2bfb1a4f7056e62d84039626923e328.tar.gz: Tried to reconnect at offset 9,944,151 but server didn't support it
or
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz failed: class java.net.ConnectException Connection timed out (Connection timed out)
```
usually indicates that Bazel fails to download necessary dependency repositories
that MediaPipe needs. MedaiPipe has several dependency repositories that are
hosted by Google sites. In some regions, you may need to set up a network proxy
or use a VPN to access those resources. You may also need to append
`--host_jvm_args "-DsocksProxyHost=<ip address> -DsocksProxyPort=<port number>"`
to the Bazel command. See
[this GitHub issue](https://github.com/google/mediapipe/issues/581#issuecomment-610356857)
for more details.
If you believe that it's not a network issue, another possibility is that some
resources could be temporarily unavailable, please run `bazel clean --expunge`
and retry it later. If it's still not working, please file a GitHub issue with
the detailed error message.
## Incorrect MediaPipe OpenCV config
The error message:
```
error: undefined reference to 'cv::String::deallocate()'
error: undefined reference to 'cv::String::allocate(unsigned long)'
error: undefined reference to 'cv::VideoCapture::VideoCapture(cv::String const&)'
...
error: undefined reference to 'cv::putText(cv::InputOutputArray const&, cv::String const&, cv::Point, int, double, cv::Scalar, int, int, bool)'
```
usually indicates that OpenCV is not properly configured for MediaPipe. Please
take a look at the "Install OpenCV and FFmpeg" sections in
[Installation](./install.md) to see how to modify MediaPipe's WORKSPACE and
linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv
libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666)
may also help.
## Native method not found ## Native method not found
The error message: The error message:

View File

@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software MediaPipe related frameworks, libraries and software
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users * [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe community discussion around MediaPipe

View File

@ -112,7 +112,7 @@ examples.
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -43,7 +43,7 @@ examples.
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -65,7 +65,7 @@ from the
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
## Models ## Models
@ -109,7 +109,7 @@ Please first see general instructions for
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -24,7 +24,7 @@ examples.
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -66,7 +66,7 @@ and a
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
## Models ## Models
@ -132,7 +132,7 @@ examples.
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -72,7 +72,7 @@ Please first see general instructions for
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
* Graph: * Graph:
[`mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt) [`mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/template_matching/template_matching_mobile_cpu.pbtxt)

View File

@ -19,7 +19,7 @@ nav_order: 5
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Mobile ### Mobile

View File

@ -156,7 +156,7 @@ Please first see general instructions for
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Objectron for Shoes ### Objectron for Shoes

View File

@ -33,13 +33,12 @@ command line option: `--define MEDIAPIPE_PROFILING=0`.
To enable tracing and profiling, the `CalculatorGraphConfig` (in To enable tracing and profiling, the `CalculatorGraphConfig` (in
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto)) [calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
representing the graph must have a `profiler_config` message at its root. Here representing the graph must have a `profiler_config` message at its root. Here
is a simple setup that turns on a few extra options: is a simple setup that turns on tracing and keeps 100 seconds of timing events:
``` ```
profiler_config { profiler_config {
enable_profiler: true
trace_enabled: true trace_enabled: true
trace_log_count: 5 trace_log_interval_count: 200
} }
``` ```

View File

@ -450,6 +450,21 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "mux_calculator_test",
srcs = ["mux_calculator_test.cc"],
deps = [
":mux_calculator",
":round_robin_demux_calculator",
":split_vector_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library( cc_library(
name = "packet_cloner_calculator", name = "packet_cloner_calculator",
srcs = ["packet_cloner_calculator.cc"], srcs = ["packet_cloner_calculator.cc"],
@ -947,7 +962,6 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
], ],
) )

View File

@ -56,12 +56,19 @@ std::string ToString(GateState state) {
// disallowing the corresponding packets in other input streams. The behavior // disallowing the corresponding packets in other input streams. The behavior
// can be inverted with a calculator option. // can be inverted with a calculator option.
// //
// ALLOW or DISALLOW can also be specified as an input side packet. The rules
// for evaluation remain the same as above.
//
// ALLOW/DISALLOW inputs must be specified either using input stream or
// via input side packet but not both.
//
// Intended to be used with the default input stream handler, which synchronizes // Intended to be used with the default input stream handler, which synchronizes
// all data input streams with the ALLOW/DISALLOW control input stream. // all data input streams with the ALLOW/DISALLOW control input stream.
// //
// Example config: // Example config:
// node { // node {
// calculator: "GateCalculator" // calculator: "GateCalculator"
// input_side_packet: "ALLOW:allow" or "DISALLOW:disallow"
// input_stream: "input_stream0" // input_stream: "input_stream0"
// input_stream: "input_stream1" // input_stream: "input_stream1"
// input_stream: "input_streamN" // input_stream: "input_streamN"
@ -75,10 +82,40 @@ class GateCalculator : public CalculatorBase {
public: public:
GateCalculator() {} GateCalculator() {}
static ::mediapipe::Status CheckAndInitAllowDisallowInputs(
CalculatorContract* cc) {
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
cc->InputSidePackets().HasTag("DISALLOW");
bool input_via_stream =
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
// input.
RET_CHECK(input_via_side_packet ^ input_via_stream);
if (input_via_side_packet) {
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
cc->InputSidePackets().HasTag("DISALLOW"));
if (cc->InputSidePackets().HasTag("ALLOW")) {
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
} else {
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
}
} else {
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
if (cc->Inputs().HasTag("ALLOW")) {
cc->Inputs().Tag("ALLOW").Set<bool>();
} else {
cc->Inputs().Tag("DISALLOW").Set<bool>();
}
}
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
// Assume that input streams do not have a tag and that gating signal is RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
// tagged either ALLOW or DISALLOW.
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
const int num_data_streams = cc->Inputs().NumEntries(""); const int num_data_streams = cc->Inputs().NumEntries("");
RET_CHECK_GE(num_data_streams, 1); RET_CHECK_GE(num_data_streams, 1);
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
@ -88,11 +125,6 @@ class GateCalculator : public CalculatorBase {
cc->Inputs().Get("", i).SetAny(); cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
} }
if (cc->Inputs().HasTag("ALLOW")) {
cc->Inputs().Tag("ALLOW").Set<bool>();
} else {
cc->Inputs().Tag("DISALLOW").Set<bool>();
}
if (cc->Outputs().HasTag("STATE_CHANGE")) { if (cc->Outputs().HasTag("STATE_CHANGE")) {
cc->Outputs().Tag("STATE_CHANGE").Set<bool>(); cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
@ -102,6 +134,17 @@ class GateCalculator : public CalculatorBase {
} }
::mediapipe::Status Open(CalculatorContext* cc) final { ::mediapipe::Status Open(CalculatorContext* cc) final {
use_side_packet_for_allow_disallow_ = false;
if (cc->InputSidePackets().HasTag("ALLOW")) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
}
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
num_data_streams_ = cc->Inputs().NumEntries(""); num_data_streams_ = cc->Inputs().NumEntries("");
last_gate_state_ = GATE_UNINITIALIZED; last_gate_state_ = GATE_UNINITIALIZED;
@ -115,14 +158,18 @@ class GateCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) final { ::mediapipe::Status Process(CalculatorContext* cc) final {
bool allow = empty_packets_as_allow_; bool allow = empty_packets_as_allow_;
if (cc->Inputs().HasTag("ALLOW") && !cc->Inputs().Tag("ALLOW").IsEmpty()) { if (use_side_packet_for_allow_disallow_) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>(); allow = allow_by_side_packet_decision_;
} else {
if (cc->Inputs().HasTag("ALLOW") &&
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
}
if (cc->Inputs().HasTag("DISALLOW") &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
}
} }
if (cc->Inputs().HasTag("DISALLOW") &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
}
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
if (cc->Outputs().HasTag("STATE_CHANGE")) { if (cc->Outputs().HasTag("STATE_CHANGE")) {
@ -157,6 +204,8 @@ class GateCalculator : public CalculatorBase {
GateState last_gate_state_ = GATE_UNINITIALIZED; GateState last_gate_state_ = GATE_UNINITIALIZED;
int num_data_streams_; int num_data_streams_;
bool empty_packets_as_allow_; bool empty_packets_as_allow_;
bool use_side_packet_for_allow_disallow_;
bool allow_by_side_packet_decision_;
}; };
REGISTER_CALCULATOR(GateCalculator); REGISTER_CALCULATOR(GateCalculator);

View File

@ -24,6 +24,21 @@ namespace {
class GateCalculatorTest : public ::testing::Test { class GateCalculatorTest : public ::testing::Test {
protected: protected:
// Helper to run a graph and return status.
static ::mediapipe::Status RunGraph(const std::string& proto) {
auto runner = absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
return runner->Run();
}
// Use this when ALLOW/DISALLOW input is provided as a side packet.
void RunTimeStep(int64 timestamp, bool stream_payload) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
// Use this when ALLOW/DISALLOW input is provided as an input stream.
void RunTimeStep(int64 timestamp, const std::string& control_tag, void RunTimeStep(int64 timestamp, const std::string& control_tag,
bool control) { bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back( runner_->MutableInputs()->Get("", 0).packets.push_back(
@ -31,7 +46,6 @@ class GateCalculatorTest : public ::testing::Test {
runner_->MutableInputs() runner_->MutableInputs()
->Tag(control_tag) ->Tag(control_tag)
.packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp))); .packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
} }
@ -46,6 +60,136 @@ class GateCalculatorTest : public ::testing::Test {
std::unique_ptr<CalculatorRunner> runner_; std::unique_ptr<CalculatorRunner> runner_;
}; };
TEST_F(GateCalculatorTest, InvalidInputs) {
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_side_packet: "ALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_side_packet: "ALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
input_side_packet: "ALLOW:gating_stream"
output_stream: "test_output"
)")));
}
TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "ALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(false, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "DISALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(false, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "ALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(0, output.size());
}
TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "DISALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(0, output.size());
}
TEST_F(GateCalculatorTest, Allow) { TEST_F(GateCalculatorTest, Allow) {
SetRunner(R"( SetRunner(R"(
calculator: "GateCalculator" calculator: "GateCalculator"

View File

@ -37,6 +37,10 @@ namespace mediapipe {
// the RoundRobinDemuxCalculator. Therefore, packets from different // the RoundRobinDemuxCalculator. Therefore, packets from different
// input streams are normally not expected to have the same timestamp. // input streams are normally not expected to have the same timestamp.
// //
// NOTE: this calculator can drop packets non-deterministically, depending on
// how fast the input streams are fed. In most cases, MuxCalculator should be
// preferred. In particular, dropping packets can interfere with rate limiting
// mechanisms.
class ImmediateMuxCalculator : public CalculatorBase { class ImmediateMuxCalculator : public CalculatorBase {
public: public:
// This calculator combines any set of input streams into a single // This calculator combines any set of input streams into a single
@ -76,6 +80,9 @@ REGISTER_CALCULATOR(ImmediateMuxCalculator);
if (!packet.IsEmpty()) { if (!packet.IsEmpty()) {
if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) { if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) {
cc->Outputs().Index(0).AddPacket(packet); cc->Outputs().Index(0).AddPacket(packet);
} else {
LOG_FIRST_N(WARNING, 5)
<< "Dropping a packet with timestamp " << packet.Timestamp();
} }
if (cc->Outputs().NumEntries() >= 2) { if (cc->Outputs().NumEntries() >= 2) {
Timestamp output_timestamp = std::max( Timestamp output_timestamp = std::max(

View File

@ -17,28 +17,49 @@
namespace mediapipe { namespace mediapipe {
namespace {
constexpr char kSelectTag[] = "SELECT";
constexpr char kInputTag[] = "INPUT";
} // namespace
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., // A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the "SELECT" input // using the integer value (0, 1, ...) in the packet on the kSelectTag input
// stream, and passes the packet on the selected input stream to the "OUTPUT" // stream, and passes the packet on the selected input stream to the "OUTPUT"
// output stream. // output stream.
// The kSelectTag input can also be passed in as an input side packet, instead
// of as an input stream. Either of input stream or input side packet must be
// specified but not both.
// //
// Note that this calculator defaults to use MuxInputStreamHandler, which is // Note that this calculator defaults to use MuxInputStreamHandler, which is
// required for this calculator. // required for this calculator. However, it can be overridden to work with
// other InputStreamHandlers. Check out the unit tests on for an example usage
// with DefaultInputStreamHandler.
class MuxCalculator : public CalculatorBase { class MuxCalculator : public CalculatorBase {
public: public:
static ::mediapipe::Status CheckAndInitAllowDisallowInputs(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^
cc->InputSidePackets().HasTag(kSelectTag));
if (cc->Inputs().HasTag(kSelectTag)) {
cc->Inputs().Tag(kSelectTag).Set<int>();
} else {
cc->InputSidePackets().Tag(kSelectTag).Set<int>();
}
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("SELECT").Set<int>(); RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag);
PacketType* data_input0 = &cc->Inputs().Get(data_input_id); PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny(); data_input0->SetAny();
++data_input_id; ++data_input_id;
for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) { for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0); cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
} }
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
// Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("MuxInputStreamHandler"); cc->SetInputStreamHandler("MuxInputStreamHandler");
MediaPipeOptions options; MediaPipeOptions options;
cc->SetInputStreamHandlerOptions(options); cc->SetInputStreamHandlerOptions(options);
@ -47,16 +68,24 @@ class MuxCalculator : public CalculatorBase {
} }
::mediapipe::Status Open(CalculatorContext* cc) final { ::mediapipe::Status Open(CalculatorContext* cc) final {
select_input_ = cc->Inputs().GetId("SELECT", 0); use_side_packet_select_ = false;
data_input_base_ = cc->Inputs().GetId("INPUT", 0); if (cc->InputSidePackets().HasTag(kSelectTag)) {
num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); use_side_packet_select_ = true;
selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get<int>();
} else {
select_input_ = cc->Inputs().GetId(kSelectTag, 0);
}
data_input_base_ = cc->Inputs().GetId(kInputTag, 0);
num_data_inputs_ = cc->Inputs().NumEntries(kInputTag);
output_ = cc->Outputs().GetId("OUTPUT", 0); output_ = cc->Outputs().GetId("OUTPUT", 0);
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status Process(CalculatorContext* cc) final { ::mediapipe::Status Process(CalculatorContext* cc) final {
int select = cc->Inputs().Get(select_input_).Get<int>(); int select = use_side_packet_select_
? selected_index_
: cc->Inputs().Get(select_input_).Get<int>();
RET_CHECK(0 <= select && select < num_data_inputs_); RET_CHECK(0 <= select && select < num_data_inputs_);
if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) { if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) {
cc->Outputs().Get(output_).AddPacket( cc->Outputs().Get(output_).AddPacket(
@ -70,6 +99,8 @@ class MuxCalculator : public CalculatorBase {
CollectionItemId data_input_base_; CollectionItemId data_input_base_;
int num_data_inputs_ = 0; int num_data_inputs_ = 0;
CollectionItemId output_; CollectionItemId output_;
bool use_side_packet_select_;
int selected_index_;
}; };
REGISTER_CALCULATOR(MuxCalculator); REGISTER_CALCULATOR(MuxCalculator);

View File

@ -0,0 +1,237 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/split_vector_calculator.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
typedef SplitVectorCalculator<int, false> SplitIntVectorCalculator;
REGISTER_CALCULATOR(SplitIntVectorCalculator);
namespace {
// Graph with default input stream handler, and the input selection is driven
// by an input stream. All MuxCalculator inputs are present at each timestamp.
constexpr char kTestGraphConfig1[] = R"proto(
input_stream: "input"
output_stream: "test_output"
node {
calculator: "SplitIntVectorCalculator"
input_stream: "input"
output_stream: "stream0"
output_stream: "stream1"
output_stream: "stream2"
output_stream: "input_select"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 2 }
ranges: { begin: 2 end: 3 }
ranges: { begin: 3 end: 4 }
element_only: true
}
}
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_stream: "SELECT:input_select"
output_stream: "OUTPUT:test_output"
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
}
)proto";
// Graph with default input stream handler, and the input selection is driven
// by an input side packet. All MuxCalculator inputs are present at each
// timestamp.
constexpr char kTestGraphConfig2[] = R"proto(
input_side_packet: "input_selector"
input_stream: "input"
output_stream: "test_output"
node {
calculator: "SplitIntVectorCalculator"
input_stream: "input"
output_stream: "stream0"
output_stream: "stream1"
output_stream: "stream2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 2 }
ranges: { begin: 2 end: 3 }
element_only: true
}
}
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_side_packet: "SELECT:input_selector"
output_stream: "OUTPUT:test_output"
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
}
)proto";
// Graph with mux input stream handler, and the input selection is driven
// by an input stream. Only one MuxCalculator input is present at each
// timestamp.
constexpr char kTestGraphConfig3[] = R"proto(
input_stream: "input"
output_stream: "test_output"
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input"
output_stream: "OUTPUT:0:stream0"
output_stream: "OUTPUT:1:stream1"
output_stream: "OUTPUT:2:stream2"
output_stream: "SELECT:input_select"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_stream: "SELECT:input_select"
output_stream: "OUTPUT:test_output"
}
)proto";
constexpr char kOutputName[] = "test_output";
constexpr char kInputName[] = "input";
constexpr char kInputSelector[] = "input_selector";
// Helper to run a graph with the given inputs and generate outputs, asserting
// each step along the way.
// Inputs:
// graph_config_proto - graph config protobuf
// extra_side_packets - input side packets name to value map
// input_stream_name - name of the input
void RunGraph(const std::string& graph_config_proto,
const std::map<std::string, Packet>& extra_side_packets,
const std::string& input_stream_name, int num_input_packets,
std::function<Packet(int)> input_fn,
const std::string& output_stream_name,
std::function<::mediapipe::Status(const Packet&)> output_fn) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
graph_config_proto);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream(output_stream_name, output_fn));
MP_ASSERT_OK(graph.StartRun(extra_side_packets));
for (int i = 0; i < num_input_packets; ++i) {
MP_ASSERT_OK(graph.AddPacketToInputStream(input_stream_name, input_fn(i)));
}
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) {
// Input and handling.
std::vector<std::vector<int>> input_packets = {
{1, 1, 2, 1}, {3, 5, 8, 2}, {13, 21, 34, 0},
{55, 89, 144, 2}, {233, 377, 610, 0}, {987, 1597, 2584, 1},
{4181, 6765, 10946, 2},
};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<std::vector<int>>(input_packets[i])
.At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status {
output.push_back(p.Get<int>());
return ::mediapipe::OkStatus();
};
RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn,
kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 8, 13, 144, 233, 1597, 10946));
}
TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) {
// Input and handling.
std::vector<std::vector<int>> input_packets = {
{1, 1, 2}, {3, 5, 8}, {13, 21, 34}, {55, 89, 144},
{233, 377, 610}, {987, 1597, 2584}, {4181, 6765, 10946},
};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<std::vector<int>>(input_packets[i])
.At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status {
output.push_back(p.Get<int>());
return ::mediapipe::OkStatus();
};
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(0)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 3, 13, 55, 233, 987, 4181));
output.clear();
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(1)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 5, 21, 89, 377, 1597, 6765));
output.clear();
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(2)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(2, 8, 34, 144, 610, 2584, 10946));
}
TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) {
// Input and handling.
std::vector<int> input_packets = {1, 1, 2, 3, 5, 8, 13,
21, 34, 55, 89, 144, 233, 377,
610, 987, 1597, 2584, 4181, 6765, 10946};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<int>(input_packets[i]).At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> ::mediapipe::Status {
output.push_back(p.Get<int>());
return ::mediapipe::OkStatus();
};
RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn,
kOutputName, output_fn);
EXPECT_EQ(output, input_packets);
}
} // namespace
} // namespace mediapipe

View File

@ -128,11 +128,17 @@ class PreviousLoopbackCalculator : public CalculatorBase {
loop_packets_.pop_front(); loop_packets_.pop_front();
main_packet_specs_.pop_front(); main_packet_specs_.pop_front();
} }
// We can close PREV_LOOP output stream as soon as we processed last
// possible MAIN packet. That can happen in two cases:
// a) Non-empty MAIN packet has been received with Timestamp::Max()
// b) Empty MAIN packet has been received with Timestamp::Max() indicating
// MAIN is done.
if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) {
prev_loop.Close();
}
} }
if (main_packet_specs_.empty() && cc->Inputs().Get(main_id_).IsDone()) {
prev_loop.Close();
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -228,6 +228,104 @@ TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
MP_EXPECT_OK(graph_.WaitUntilDone()); MP_EXPECT_OK(graph_.WaitUntilDone());
} }
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestamp) {
std::vector<Packet> out_and_previous_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous'
output_stream: 'out_and_previous'
}
)");
tool::AddVectorSink("out_and_previous", &graph_config,
&out_and_previous_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(1).At(Timestamp::Max())));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(out_and_previous_packets,
ElementsAre(PairPacket(Timestamp::Max(),
Pair(IntPacket(1), EmptyPacket()))));
MP_EXPECT_OK(graph.CloseAllInputStreams());
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.WaitUntilDone());
}
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestampNonEmptyPrevious) {
std::vector<Packet> out_and_previous_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous'
output_stream: 'out_and_previous'
}
)");
tool::AddVectorSink("out_and_previous", &graph_config,
&out_and_previous_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(1).At(Timestamp::Min())));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(2).At(Timestamp::Max())));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
out_and_previous_packets,
ElementsAre(
PairPacket(Timestamp::Min(), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp::Max(), Pair(IntPacket(2), IntPacket(1)))));
MP_EXPECT_OK(graph.CloseAllInputStreams());
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.WaitUntilDone());
}
// Demonstrates that downstream calculators won't be blocked by // Demonstrates that downstream calculators won't be blocked by
// always-empty-LOOP-stream. // always-empty-LOOP-stream.
TEST(PreviousLoopbackCalculator, EmptyLoopForever) { TEST(PreviousLoopbackCalculator, EmptyLoopForever) {

View File

@ -34,6 +34,8 @@ constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
constexpr char kTagAtZero[] = "AT_ZERO"; constexpr char kTagAtZero[] = "AT_ZERO";
constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagAtTick[] = "AT_TICK";
constexpr char kTagTick[] = "TICK"; constexpr char kTagTick[] = "TICK";
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
constexpr char kTagSideInputTimestamp[] = "TIMESTAMP";
static std::map<std::string, Timestamp>* kTimestampMap = []() { static std::map<std::string, Timestamp>* kTimestampMap = []() {
auto* res = new std::map<std::string, Timestamp>(); auto* res = new std::map<std::string, Timestamp>();
@ -41,6 +43,7 @@ static std::map<std::string, Timestamp>* kTimestampMap = []() {
res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtPostStream, Timestamp::PostStream());
res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtZero, Timestamp(0));
res->emplace(kTagAtTick, Timestamp::Unset()); res->emplace(kTagAtTick, Timestamp::Unset());
res->emplace(kTagAtTimestamp, Timestamp::Unset());
return res; return res;
}(); }();
@ -56,9 +59,10 @@ std::string GetOutputTag(const CC& cc) {
// timestamp, depending on the tag used to define output stream(s). (One tag can // timestamp, depending on the tag used to define output stream(s). (One tag can
// be used only.) // be used only.)
// //
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK and // Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP
// corresponding timestamps are Timestamp::PreStream(), Timestamp::PostStream(), // and corresponding timestamps are Timestamp::PreStream(),
// Timestamp(0) and timestamp of a packet received in TICK input. // Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
// input, and timestamp received from a side input.
// //
// Examples: // Examples:
// node { // node {
@ -73,6 +77,13 @@ std::string GetOutputTag(const CC& cc) {
// input_side_packet: "side_packet" // input_side_packet: "side_packet"
// output_stream: "AT_TICK:packet" // output_stream: "AT_TICK:packet"
// } // }
//
// node {
// calculator: "SidePacketToStreamCalculator"
// input_side_packet: "TIMESTAMP:timestamp"
// input_side_packet: "side_packet"
// output_stream: "AT_TIMESTAMP:packet"
// }
class SidePacketToStreamCalculator : public CalculatorBase { class SidePacketToStreamCalculator : public CalculatorBase {
public: public:
SidePacketToStreamCalculator() = default; SidePacketToStreamCalculator() = default;
@ -93,16 +104,29 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
CalculatorContract* cc) { CalculatorContract* cc) {
const auto& tags = cc->Outputs().GetTags(); const auto& tags = cc->Outputs().GetTags();
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK tags is " << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
"allowed and required to specify output stream(s)."; "AT_TIMESTAMP tags is allowed and required to specify output "
"stream(s).";
RET_CHECK( RET_CHECK(
(cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || (cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) ||
(!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) (!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick)))
<< "Either both of TICK and AT_TICK should be used or none of them."; << "Either both of TICK and AT_TICK should be used or none of them.";
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
(!cc->Outputs().HasTag(kTagAtTimestamp) &&
!cc->InputSidePackets().HasTag(kTagSideInputTimestamp)))
<< "Either both TIMESTAMP and AT_TIMESTAMP should be used or none of "
"them.";
const std::string output_tag = GetOutputTag(*cc); const std::string output_tag = GetOutputTag(*cc);
const int num_entries = cc->Outputs().NumEntries(output_tag); const int num_entries = cc->Outputs().NumEntries(output_tag);
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) if (cc->Outputs().HasTag(kTagAtTimestamp)) {
<< "Same number of input side packets and output streams is required."; RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64>();
} else {
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
<< "Same number of input side packets and output streams is required.";
}
for (int i = 0; i < num_entries; ++i) { for (int i = 0; i < num_entries; ++i) {
cc->InputSidePackets().Index(i).SetAny(); cc->InputSidePackets().Index(i).SetAny();
cc->Outputs() cc->Outputs()
@ -147,13 +171,22 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
} }
::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
if (!cc->Outputs().HasTag(kTagAtTick)) { if (!cc->Outputs().HasTag(kTagAtTick) &&
!cc->Outputs().HasTag(kTagAtTimestamp)) {
const auto& timestamp = kTimestampMap->at(output_tag_); const auto& timestamp = kTimestampMap->at(output_tag_);
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs() cc->Outputs()
.Get(output_tag_, i) .Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); .AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
} }
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
int64 timestamp =
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64>();
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs()
.Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(Timestamp(timestamp)));
}
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -51,6 +51,27 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
"Either both of TICK and AT_TICK should be used or none of them."); "Either both of TICK and AT_TICK should be used or none of them.");
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "timestamp"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "side_packet"
output_stream: "AT_TIMESTAMP:packet"
}
)");
CalculatorGraph graph;
auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok());
EXPECT_PRED2(
absl::StrContains, status.message(),
"Either both TIMESTAMP and AT_TIMESTAMP should be used or none of them.");
}
TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
@ -68,8 +89,9 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_PRED2(absl::StrContains, status.message(), EXPECT_PRED2(absl::StrContains, status.message(),
"Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK " "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
"tags is allowed and required to specify output stream(s)."); "AT_TIMESTAMP tags is allowed and required to specify output "
"stream(s).");
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
@ -91,8 +113,9 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_PRED2(absl::StrContains, status.message(), EXPECT_PRED2(absl::StrContains, status.message(),
"Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO and AT_TICK " "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
"tags is allowed and required to specify output stream(s)."); "AT_TIMESTAMP tags is allowed and required to specify output "
"stream(s).");
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
@ -271,5 +294,79 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
tick_and_verify(/*at_timestamp=*/1025); tick_and_verify(/*at_timestamp=*/1025);
} }
TEST(SidePacketToStreamCalculator, AtTimestamp) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_side_packet: "timestamp"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "TIMESTAMP:timestamp"
input_side_packet: "side_packet"
output_stream: "AT_TIMESTAMP:packet"
}
)");
std::vector<Packet> output_packets;
tool::AddVectorSink("packet", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value = 20;
const int64 expected_timestamp = 5;
MP_ASSERT_OK(
graph.StartRun({{"side_packet", MakePacket<int>(expected_value)},
{"timestamp", MakePacket<int64>(expected_timestamp)}}));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_FALSE(output_packets.empty());
EXPECT_EQ(Timestamp(expected_timestamp), output_packets.back().Timestamp());
EXPECT_EQ(expected_value, output_packets.back().Get<int>());
}
TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_side_packet: "timestamp"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "TIMESTAMP:timestamp"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "AT_TIMESTAMP:0:packet0"
output_stream: "AT_TIMESTAMP:1:packet1"
}
)");
std::vector<Packet> output_packets0;
tool::AddVectorSink("packet0", &graph_config, &output_packets0);
std::vector<Packet> output_packets1;
tool::AddVectorSink("packet1", &graph_config, &output_packets1);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value0 = 20;
const int expected_value1 = 15;
const int64 expected_timestamp = 5;
MP_ASSERT_OK(
graph.StartRun({{"side_packet0", MakePacket<int>(expected_value0)},
{"side_packet1", MakePacket<int>(expected_value1)},
{"timestamp", MakePacket<int64>(expected_timestamp)}}));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_FALSE(output_packets0.empty());
EXPECT_EQ(Timestamp(expected_timestamp), output_packets0.back().Timestamp());
EXPECT_EQ(expected_value0, output_packets0.back().Get<int>());
ASSERT_FALSE(output_packets1.empty());
EXPECT_EQ(Timestamp(expected_timestamp), output_packets1.back().Timestamp());
EXPECT_EQ(expected_value1, output_packets1.back().Get<int>());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -449,19 +449,15 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
switch (rotation_) { switch (rotation_) {
case mediapipe::RotationMode_Mode_UNKNOWN: case mediapipe::RotationMode_Mode_UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0: case mediapipe::RotationMode_Mode_ROTATION_0:
LOG(ERROR) << "Not rotating image.";
rotated_mat = input_mat; rotated_mat = input_mat;
break; break;
case mediapipe::RotationMode_Mode_ROTATION_90: case mediapipe::RotationMode_Mode_ROTATION_90:
LOG(ERROR) << "Rotating image by 90 degrees ccw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE); cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
break; break;
case mediapipe::RotationMode_Mode_ROTATION_180: case mediapipe::RotationMode_Mode_ROTATION_180:
LOG(ERROR) << "Rotating image by 180 degrees.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180); cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
break; break;
case mediapipe::RotationMode_Mode_ROTATION_270: case mediapipe::RotationMode_Mode_ROTATION_270:
LOG(ERROR) << "Rotating image by 90 degrees cw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE); cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
break; break;
} }

View File

@ -57,22 +57,6 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"], deps = ["//mediapipe/framework:calculator_proto"],
) )
proto_library(
name = "tensorflow_session_from_saved_model_generator_proto",
srcs = ["tensorflow_session_from_saved_model_generator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:packet_generator_proto"],
)
proto_library(
name = "tensorflow_session_from_saved_model_calculator_proto",
srcs = ["tensorflow_session_from_saved_model_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library( proto_library(
name = "tensor_squeeze_dimensions_calculator_proto", name = "tensor_squeeze_dimensions_calculator_proto",
srcs = ["tensor_squeeze_dimensions_calculator.proto"], srcs = ["tensor_squeeze_dimensions_calculator.proto"],
@ -212,7 +196,10 @@ mediapipe_cc_proto_library(
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "tensorflow_session_from_saved_model_generator_cc_proto", name = "tensorflow_session_from_saved_model_generator_cc_proto",
srcs = ["tensorflow_session_from_saved_model_generator.proto"], srcs = ["tensorflow_session_from_saved_model_generator.proto"],
cc_deps = ["//mediapipe/framework:packet_generator_cc_proto"], cc_deps = [
"//mediapipe/framework:packet_generator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":tensorflow_session_from_saved_model_generator_proto"], deps = [":tensorflow_session_from_saved_model_generator_proto"],
) )
@ -220,7 +207,10 @@ mediapipe_cc_proto_library(
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "tensorflow_session_from_saved_model_calculator_cc_proto", name = "tensorflow_session_from_saved_model_calculator_cc_proto",
srcs = ["tensorflow_session_from_saved_model_calculator.proto"], srcs = ["tensorflow_session_from_saved_model_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"], cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":tensorflow_session_from_saved_model_calculator_proto"], deps = [":tensorflow_session_from_saved_model_calculator_proto"],
) )
@ -488,6 +478,8 @@ cc_library(
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
] + select({ ] + select({
@ -518,6 +510,8 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
@ -929,6 +923,7 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:protos_all_cc",
], ],
) )
@ -954,6 +949,7 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:protos_all_cc",
], ],
) )

View File

@ -26,9 +26,14 @@
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
#if defined(MEDIAPIPE_MOBILE) #if defined(MEDIAPIPE_MOBILE)
@ -41,6 +46,17 @@ namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace {
// Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) {
if (node.device().empty()) {
node.set_device(device_id);
}
}
}
} // namespace
class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
public: public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
@ -77,6 +93,9 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
} }
::mediapipe::Status Open(CalculatorContext* cc) override { ::mediapipe::Status Open(CalculatorContext* cc) override {
auto clock = std::unique_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
const uint64 start_time = absl::ToUnixMicros(clock->TimeNow());
const auto& options = const auto& options =
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>(); cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
// Output bundle packet. // Output bundle packet.
@ -108,6 +127,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
tensorflow::GraphDef graph_def; tensorflow::GraphDef graph_def;
RET_CHECK(graph_def.ParseFromString(graph_def_serialized)); RET_CHECK(graph_def.ParseFromString(graph_def_serialized));
// Update the graph nodes to use the preferred device, if set.
if (!options.preferred_device_id().empty()) {
SetPreferredDevice(&graph_def, options.preferred_device_id());
}
const tf::Status tf_status = session->session->Create(graph_def); const tf::Status tf_status = session->session->Create(graph_def);
RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString(); RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString();
@ -123,6 +148,9 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
} }
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds.";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -69,4 +69,12 @@ message TensorFlowSessionFromFrozenGraphCalculatorOptions {
// Graph nodes to run to initialize the model. Any output of these ops is // Graph nodes to run to initialize the model. Any output of these ops is
// ignored. // ignored.
repeated string initialization_op_names = 4; repeated string initialization_op_names = 4;
// The id of the device you would prefer to execute the graph nodes on.
// If set, all graph nodes without a previously specified device, will be set
// to run on preferred_device_id. Example values include:
// ["/device:GPU:0","/device:CPU:0", ...]
// NOTE: If config.allow_soft_placement = false, and the device is not found,
// an error will be thrown.
optional string preferred_device_id = 5;
} }

View File

@ -66,6 +66,7 @@ class TensorFlowSessionFromFrozenGraphCalculatorTest : public ::testing::Test {
(*calculator_options_->mutable_tag_to_tensor_names())["B"] = "b:0"; (*calculator_options_->mutable_tag_to_tensor_names())["B"] = "b:0";
calculator_options_->mutable_config()->set_intra_op_parallelism_threads(1); calculator_options_->mutable_config()->set_intra_op_parallelism_threads(1);
calculator_options_->mutable_config()->set_inter_op_parallelism_threads(2); calculator_options_->mutable_config()->set_inter_op_parallelism_threads(2);
calculator_options_->set_preferred_device_id("/device:CPU:0");
} }
void VerifySignatureMap(const TensorFlowSession& session) { void VerifySignatureMap(const TensorFlowSession& session) {

View File

@ -27,16 +27,32 @@
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
namespace mediapipe { namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace {
// Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) {
if (node.device().empty()) {
node.set_device(device_id);
}
}
}
} // namespace
class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
public: public:
static ::mediapipe::Status FillExpectations( static ::mediapipe::Status FillExpectations(
@ -77,6 +93,9 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
static ::mediapipe::Status Generate( static ::mediapipe::Status Generate(
const PacketGeneratorOptions& packet_generator_options, const PacketGeneratorOptions& packet_generator_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) { const PacketSet& input_side_packets, PacketSet* output_side_packets) {
auto clock = std::unique_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
const uint64 start_time = absl::ToUnixMicros(clock->TimeNow());
const TensorFlowSessionFromFrozenGraphGeneratorOptions& options = const TensorFlowSessionFromFrozenGraphGeneratorOptions& options =
packet_generator_options.GetExtension( packet_generator_options.GetExtension(
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
@ -108,6 +127,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
tensorflow::GraphDef graph_def; tensorflow::GraphDef graph_def;
RET_CHECK(graph_def.ParseFromString(graph_def_serialized)); RET_CHECK(graph_def.ParseFromString(graph_def_serialized));
// Update the graph nodes to use the preferred device, if set.
if (!options.preferred_device_id().empty()) {
SetPreferredDevice(&graph_def, options.preferred_device_id());
}
const tf::Status tf_status = session->session->Create(graph_def); const tf::Status tf_status = session->session->Create(graph_def);
RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString(); RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString();
@ -123,6 +148,9 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
} }
output_side_packets->Tag("SESSION") = Adopt(session.release()); output_side_packets->Tag("SESSION") = Adopt(session.release());
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds.";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
}; };

View File

@ -69,4 +69,12 @@ message TensorFlowSessionFromFrozenGraphGeneratorOptions {
// Graph nodes to run to initialize the model. Any output of these ops is // Graph nodes to run to initialize the model. Any output of these ops is
// ignored. // ignored.
repeated string initialization_op_names = 4; repeated string initialization_op_names = 4;
// The id of the device you would prefer to execute the graph nodes on.
// If set, all graph nodes without a previously specified device, will be set
// to run on preferred_device_id. Example values include:
// ["/device:GPU:0","/device:CPU:0", ...]
// NOTE: If config.allow_soft_placement = false, and the device is not found,
// an error will be thrown.
optional string preferred_device_id = 5;
} }

View File

@ -66,6 +66,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
(*generator_options_->mutable_tag_to_tensor_names())["B"] = "b:0"; (*generator_options_->mutable_tag_to_tensor_names())["B"] = "b:0";
generator_options_->mutable_config()->set_intra_op_parallelism_threads(1); generator_options_->mutable_config()->set_intra_op_parallelism_threads(1);
generator_options_->mutable_config()->set_inter_op_parallelism_threads(2); generator_options_->mutable_config()->set_inter_op_parallelism_threads(2);
generator_options_->set_preferred_device_id("/device:CPU:0");
} }
void VerifySignatureMap(PacketSet* output_side_packets) { void VerifySignatureMap(PacketSet* output_side_packets) {

View File

@ -134,8 +134,8 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
} }
tensorflow::RunOptions run_options; tensorflow::RunOptions run_options;
// In the future, could construct session options from the options proto.
tensorflow::SessionOptions session_options; tensorflow::SessionOptions session_options;
session_options.config = options.session_config();
auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>(); auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
::tensorflow::Status status = tensorflow::LoadSavedModel( ::tensorflow::Status status = tensorflow::LoadSavedModel(
session_options, run_options, path, tags_set, saved_model.get()); session_options, run_options, path, tags_set, saved_model.get());

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "tensorflow/core/protobuf/config.proto";
message TensorFlowSessionFromSavedModelCalculatorOptions { message TensorFlowSessionFromSavedModelCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -55,4 +56,7 @@ message TensorFlowSessionFromSavedModelCalculatorOptions {
// If no tag is specified, then use "serve" as the default. Note that in order // If no tag is specified, then use "serve" as the default. Note that in order
// to use TPU accelerator hardware, the tag "tpu" needs to be specified. // to use TPU accelerator hardware, the tag "tpu" needs to be specified.
repeated string saved_model_tag = 6; repeated string saved_model_tag = 6;
// Tensorflow session config options.
optional tensorflow.ConfigProto session_config = 7;
} }

View File

@ -26,6 +26,7 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/tag_map_helper.h"
#include "mediapipe/framework/tool/validate_type.h" #include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace mediapipe { namespace mediapipe {
@ -204,5 +205,31 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
ConfiguresSessionGivenConfig) {
options_->set_saved_model_path(
std::string(file::SplitPath(GetSavedModelDir()).first));
options_->set_load_latest_model(true);
options_->mutable_session_config()->mutable_device_count()->insert(
{"CPU", 10});
CalculatorRunner runner(absl::Substitute(R"(
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
$0
}
})",
options_->DebugString()));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices;
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK());
EXPECT_THAT(devices.size(), 10);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -129,8 +129,8 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
} }
tensorflow::RunOptions run_options; tensorflow::RunOptions run_options;
// In the future, could construct session options from the options proto.
tensorflow::SessionOptions session_options; tensorflow::SessionOptions session_options;
session_options.config = options.session_config();
auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>(); auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
::tensorflow::Status status = tensorflow::LoadSavedModel( ::tensorflow::Status status = tensorflow::LoadSavedModel(
session_options, run_options, path, tags_set, saved_model.get()); session_options, run_options, path, tags_set, saved_model.get());

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/packet_generator.proto"; import "mediapipe/framework/packet_generator.proto";
import "tensorflow/core/protobuf/config.proto";
message TensorFlowSessionFromSavedModelGeneratorOptions { message TensorFlowSessionFromSavedModelGeneratorOptions {
extend mediapipe.PacketGeneratorOptions { extend mediapipe.PacketGeneratorOptions {
@ -55,4 +56,7 @@ message TensorFlowSessionFromSavedModelGeneratorOptions {
// If no tag is specified, then use "serve" as the default. Note that in order // If no tag is specified, then use "serve" as the default. Note that in order
// to use TPU accelerator hardware, the tag "tpu" needs to be specified. // to use TPU accelerator hardware, the tag "tpu" needs to be specified.
repeated string saved_model_tag = 6; repeated string saved_model_tag = 6;
// Tensorflow session config options.
optional tensorflow.ConfigProto session_config = 9;
} }

View File

@ -25,6 +25,7 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/tag_map_helper.h"
#include "mediapipe/framework/tool/validate_type.h" #include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace mediapipe { namespace mediapipe {
@ -196,5 +197,29 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
ConfiguresSessionGivenConfig) {
generator_options_->set_saved_model_path(
std::string(file::SplitPath(GetSavedModelDir()).first));
generator_options_->set_load_latest_model(true);
generator_options_->mutable_session_config()->mutable_device_count()->insert(
{"CPU", 10});
PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices;
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK());
EXPECT_THAT(devices.size(), 10);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -91,11 +91,11 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator);
cc->Inputs().Index(0).Value().Get<std::vector<std::vector<float>>>(); cc->Inputs().Index(0).Value().Get<std::vector<std::vector<float>>>();
const int32 rows = input.size(); const int32 rows = input.size();
CHECK_GE(rows, 1); RET_CHECK_GE(rows, 1);
const int32 cols = input[0].size(); const int32 cols = input[0].size();
CHECK_GE(cols, 1); RET_CHECK_GE(cols, 1);
for (int i = 1; i < rows; ++i) { for (int i = 1; i < rows; ++i) {
CHECK_EQ(input[i].size(), cols); RET_CHECK_EQ(input[i].size(), cols);
} }
if (options_.transpose()) { if (options_.transpose()) {
tensor_shape = tf::TensorShape({cols, rows}); tensor_shape = tf::TensorShape({cols, rows});
@ -116,7 +116,7 @@ REGISTER_CALCULATOR(VectorFloatToTensorCalculator);
} else if (options_.input_size() == INPUT_1D) { } else if (options_.input_size() == INPUT_1D) {
const std::vector<float>& input = const std::vector<float>& input =
cc->Inputs().Index(0).Value().Get<std::vector<float>>(); cc->Inputs().Index(0).Value().Get<std::vector<float>>();
CHECK_GE(input.size(), 1); RET_CHECK_GE(input.size(), 1);
const int32 length = input.size(); const int32 length = input.size();
tensor_shape = tf::TensorShape({length}); tensor_shape = tf::TensorShape({length});
auto output = ::absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape); auto output = ::absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);

View File

@ -196,13 +196,6 @@ cc_test(
], ],
) )
cc_library(
name = "util",
hdrs = ["util.h"],
visibility = ["//visibility:public"],
alwayslink = 1,
)
selects.config_setting_group( selects.config_setting_group(
name = "gpu_inference_disabled", name = "gpu_inference_disabled",
match_any = [ match_any = [
@ -229,7 +222,6 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_inference_calculator_cc_proto", ":tflite_inference_calculator_cc_proto",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -295,7 +287,6 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
":util",
":tflite_converter_calculator_cc_proto", ":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -334,7 +325,6 @@ cc_library(
srcs = ["tflite_model_calculator.cc"], srcs = ["tflite_model_calculator.cc"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -348,7 +338,6 @@ cc_library(
srcs = ["tflite_tensors_to_segmentation_calculator.cc"], srcs = ["tflite_tensors_to_segmentation_calculator.cc"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_tensors_to_segmentation_calculator_cc_proto", ":tflite_tensors_to_segmentation_calculator_cc_proto",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -418,7 +407,6 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
":util",
":tflite_tensors_to_detections_calculator_cc_proto", ":tflite_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -551,6 +539,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ],
) )

View File

@ -16,7 +16,6 @@
#include <vector> #include <vector>
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -146,8 +145,7 @@ class TfLiteConverterCalculator : public CalculatorBase {
::mediapipe::Status LoadOptions(CalculatorContext* cc); ::mediapipe::Status LoadOptions(CalculatorContext* cc);
template <class T> template <class T>
::mediapipe::Status NormalizeImage(const ImageFrame& image_frame, ::mediapipe::Status NormalizeImage(const ImageFrame& image_frame,
bool zero_center, bool flip_vertically, bool flip_vertically, float* tensor_ptr);
float* tensor_ptr);
::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, ::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix,
float* tensor_ptr); float* tensor_ptr);
::mediapipe::Status ProcessCPU(CalculatorContext* cc); ::mediapipe::Status ProcessCPU(CalculatorContext* cc);
@ -165,10 +163,7 @@ class TfLiteConverterCalculator : public CalculatorBase {
bool initialized_ = false; bool initialized_ = false;
bool use_gpu_ = false; bool use_gpu_ = false;
bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1] absl::optional<std::pair<float, float>> output_range_;
bool use_custom_normalization_ = false;
float custom_div_ = -1.0f;
float custom_sub_ = -1.0f;
bool flip_vertically_ = false; bool flip_vertically_ = false;
bool row_major_matrix_ = false; bool row_major_matrix_ = false;
bool use_quantized_tensors_ = false; bool use_quantized_tensors_ = false;
@ -362,11 +357,11 @@ bool ShouldUseGpu(CC* cc) {
float* tensor_buffer = tensor->data.f; float* tensor_buffer = tensor->data.f;
RET_CHECK(tensor_buffer); RET_CHECK(tensor_buffer);
if (image_frame.ByteDepth() == 1) { if (image_frame.ByteDepth() == 1) {
MP_RETURN_IF_ERROR(NormalizeImage<uint8>( MP_RETURN_IF_ERROR(NormalizeImage<uint8>(image_frame, flip_vertically_,
image_frame, zero_center_, flip_vertically_, tensor_buffer)); tensor_buffer));
} else if (image_frame.ByteDepth() == 4) { } else if (image_frame.ByteDepth() == 4) {
MP_RETURN_IF_ERROR(NormalizeImage<float>( MP_RETURN_IF_ERROR(NormalizeImage<float>(image_frame, flip_vertically_,
image_frame, zero_center_, flip_vertically_, tensor_buffer)); tensor_buffer));
} else { } else {
return ::mediapipe::InternalError( return ::mediapipe::InternalError(
"Only byte-based (8 bit) and float (32 bit) images supported."); "Only byte-based (8 bit) and float (32 bit) images supported.");
@ -427,11 +422,11 @@ bool ShouldUseGpu(CC* cc) {
auto src = gpu_helper_.CreateSourceTexture(input); auto src = gpu_helper_.CreateSourceTexture(input);
glActiveTexture(GL_TEXTURE0 + 0); glActiveTexture(GL_TEXTURE0 + 0);
glBindTexture(GL_TEXTURE_2D, src.name()); glBindTexture(GL_TEXTURE_2D, src.name());
RET_CHECK_CALL(gpu_data_out_->buffer.BindToIndex(1)); MP_RETURN_IF_ERROR(gpu_data_out_->buffer.BindToIndex(1));
const tflite::gpu::uint3 workgroups = { const tflite::gpu::uint3 workgroups = {
NumGroups(input.width(), kWorkgroupSize), NumGroups(input.width(), kWorkgroupSize),
NumGroups(input.height(), kWorkgroupSize), 1}; NumGroups(input.height(), kWorkgroupSize), 1};
RET_CHECK_CALL(gpu_data_out_->program.Dispatch(workgroups)); MP_RETURN_IF_ERROR(gpu_data_out_->program.Dispatch(workgroups));
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
src.Release(); src.Release();
@ -445,9 +440,9 @@ bool ShouldUseGpu(CC* cc) {
output_tensors->resize(1); output_tensors->resize(1);
{ {
GpuTensor& tensor = output_tensors->at(0); GpuTensor& tensor = output_tensors->at(0);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->elements, &tensor)); gpu_data_out_->elements, &tensor));
RET_CHECK_CALL(CopyBuffer(gpu_data_out_->buffer, tensor)); MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_->buffer, tensor));
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
@ -521,7 +516,7 @@ bool ShouldUseGpu(CC* cc) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
// Device memory. // Device memory.
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->elements, &gpu_data_out_->buffer)); gpu_data_out_->elements, &gpu_data_out_->buffer));
@ -544,7 +539,13 @@ bool ShouldUseGpu(CC* cc) {
$6 // alpha channel $6 // alpha channel
})", })",
/*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(),
/*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", /*$3=*/
output_range_.has_value()
? absl::Substitute(
"pixel = pixel * float($0) + float($1);",
(output_range_->second - output_range_->first),
output_range_->first)
: "",
/*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y",
/*$5=*/ /*$5=*/
single_channel single_channel
@ -555,10 +556,10 @@ bool ShouldUseGpu(CC* cc) {
include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;"
: "", : "",
/*$7=*/max_num_channels_); /*$7=*/max_num_channels_);
RET_CHECK_CALL(GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, MP_RETURN_IF_ERROR(GlShader::CompileShader(
&gpu_data_out_->shader)); GL_COMPUTE_SHADER, shader_source, &gpu_data_out_->shader));
RET_CHECK_CALL(GlProgram::CreateWithShader(gpu_data_out_->shader, MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(
&gpu_data_out_->program)); gpu_data_out_->shader, &gpu_data_out_->program));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
@ -599,7 +600,12 @@ bool ShouldUseGpu(CC* cc) {
)", )",
/*$0=*/include_alpha ? "float4" : "float3", /*$0=*/include_alpha ? "float4" : "float3",
/*$1=*/include_alpha ? "rgba" : "rgb", /*$1=*/include_alpha ? "rgba" : "rgb",
/*$2=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", /*$2=*/
output_range_.has_value()
? absl::Substitute("pixel = pixel * float($0) + float($1);",
(output_range_->second - output_range_->first),
output_range_->first)
: "",
/*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y", /*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y",
/*$4=*/include_alpha ? 4 : 3, /*$4=*/include_alpha ? 4 : 3,
/*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : ""); /*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : "");
@ -630,13 +636,27 @@ bool ShouldUseGpu(CC* cc) {
const auto& options = const auto& options =
cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>(); cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>();
// Get data normalization mode. // if zero_center, set output float range to match [-1, 1] as specified in
zero_center_ = options.zero_center(); // calculator proto.
if (options.zero_center()) {
output_range_.emplace(std::pair<float, float>(-1.0, 1.0));
}
// Custom output_tensor_float_range values.
// If the float range is specified in pb text, use the specified values
// instead.
if (options.has_output_tensor_float_range()) {
output_range_.emplace(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max());
CHECK_GT(output_range_->second, output_range_->first);
}
// Custom div and sub values. // Custom div and sub values.
use_custom_normalization_ = options.use_custom_normalization(); if (options.use_custom_normalization()) {
custom_div_ = options.custom_div(); output_range_.emplace(std::pair<float, float>(
custom_sub_ = options.custom_sub(); -options.custom_sub(),
-options.custom_sub() + 255.0 / options.custom_div()));
}
// Get y-flip mode. // Get y-flip mode.
flip_vertically_ = options.flip_vertically(); flip_vertically_ = options.flip_vertically();
@ -664,40 +684,46 @@ bool ShouldUseGpu(CC* cc) {
template <class T> template <class T>
::mediapipe::Status TfLiteConverterCalculator::NormalizeImage( ::mediapipe::Status TfLiteConverterCalculator::NormalizeImage(
const ImageFrame& image_frame, bool zero_center, bool flip_vertically, const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) {
float* tensor_ptr) {
const int height = image_frame.Height(); const int height = image_frame.Height();
const int width = image_frame.Width(); const int width = image_frame.Width();
const int channels = image_frame.NumberOfChannels(); const int channels = image_frame.NumberOfChannels();
const int channels_preserved = std::min(channels, max_num_channels_); const int channels_preserved = std::min(channels, max_num_channels_);
const int channels_ignored = channels - channels_preserved; const int channels_ignored = channels - channels_preserved;
float div, sub; if (output_range_.has_value()) {
// If the output float range is set and we are not using custom
// normalization, normalize the pixel values from [0, 255] to the specified
// output range.
RET_CHECK_NE(output_range_->first, output_range_->second);
const float scale = (output_range_->second - output_range_->first) / 255.0f;
const float bias = output_range_->first;
if (use_custom_normalization_) { for (int i = 0; i < height; ++i) {
RET_CHECK_GT(custom_div_, 0.0f); const T* image_ptr = reinterpret_cast<const T*>(
RET_CHECK_GE(custom_sub_, 0.0f); image_frame.PixelData() +
div = custom_div_; (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep());
sub = custom_sub_; for (int j = 0; j < width; ++j) {
} else if (zero_center) { for (int c = 0; c < channels_preserved; ++c) {
// [-1,1] *tensor_ptr++ = *image_ptr++ * scale + bias;
div = 127.5f; }
sub = 1.0f; image_ptr += channels_ignored;
} else { }
// [0,1] }
div = 255.0f; } else {
sub = 0.0f; // [0,1], scale only (bias == 0)
} // Verified that there are no precision issues with 1.0f / 255.0f expression
const float scale = 1.0f / 255.0f;
for (int i = 0; i < height; ++i) { for (int i = 0; i < height; ++i) {
const T* image_ptr = reinterpret_cast<const T*>( const T* image_ptr = reinterpret_cast<const T*>(
image_frame.PixelData() + image_frame.PixelData() +
(flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep());
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
for (int c = 0; c < channels_preserved; ++c) { for (int c = 0; c < channels_preserved; ++c) {
*tensor_ptr++ = *image_ptr++ / div - sub; *tensor_ptr++ = *image_ptr++ * scale;
}
image_ptr += channels_ignored;
} }
image_ptr += channels_ignored;
} }
} }

View File

@ -56,4 +56,14 @@ message TfLiteConverterCalculatorOptions {
// Quantization option (CPU only). // Quantization option (CPU only).
// When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32. // When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32.
optional bool use_quantized_tensors = 5 [default = false]; optional bool use_quantized_tensors = 5 [default = false];
// Normalization option.
// Setting normalization_range results in the values normalized to
// the range [output_tensor_float_range.min, output_tensor_float_range.max].
optional TensorFloatRange output_tensor_float_range = 9;
message TensorFloatRange {
optional float min = 1;
optional float max = 2;
}
} }

View File

@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
@ -40,6 +41,7 @@ constexpr char kTransposeOptionsString[] =
} // namespace } // namespace
using RandomEngine = std::mt19937_64; using RandomEngine = std::mt19937_64;
using testing::Eq;
const uint32 kSeed = 1234; const uint32 kSeed = 1234;
const int kNumSizes = 8; const int kNumSizes = 8;
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
@ -232,7 +234,6 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) {
// Wait until the calculator done processing. // Wait until the calculator done processing.
MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(1, output_packets.size());
// Get and process results. // Get and process results.
const std::vector<TfLiteTensor>& tensor_vec = const std::vector<TfLiteTensor>& tensor_vec =
@ -249,4 +250,70 @@ TEST_F(TfLiteConverterCalculatorTest, CustomDivAndSub) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
TEST_F(TfLiteConverterCalculatorTest, SetOutputRange) {
std::vector<std::pair<float, float>> range_values = {
std::make_pair(0.0, 1.0), std::make_pair(-1.0, 1.0),
std::make_pair(-0.5, 0.5)};
for (std::pair<float, float> range : range_values) {
CalculatorGraph graph;
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "input_image"
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
output_tensor_float_range {
min: $0
max: $1
}
}
}
}
)",
/*$0=*/range.first,
/*$1=*/range.second));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto input_image = absl::make_unique<ImageFrame>(ImageFormat::GRAY8, 1, 1);
cv::Mat mat = ::mediapipe::formats::MatView(input_image.get());
mat.at<uint8>(0, 0) = 200;
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", Adopt(input_image.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets.size(), Eq(1));
// Get and process results.
const std::vector<TfLiteTensor>& tensor_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
EXPECT_THAT(tensor_vec.size(), Eq(1));
const TfLiteTensor* tensor = &tensor_vec[0];
// Calculate the expected normalized value:
float normalized_value =
range.first + (200 * (range.second - range.first)) / 255.0;
EXPECT_THAT(tensor->type, Eq(kTfLiteFloat32));
EXPECT_THAT(normalized_value,
testing::FloatNear(*tensor->data.f,
2.0f * std::abs(*tensor->data.f) *
std::numeric_limits<float>::epsilon()));
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("input_image"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,7 +19,6 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
@ -496,7 +495,7 @@ bool ShouldUseGpu(CC* cc) {
output_tensors_gpu->resize(gpu_data_out_.size()); output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i); GpuTensor& tensor = output_tensors_gpu->at(i);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor)); gpu_data_out_[i]->elements, &tensor));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i)); tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
@ -518,7 +517,7 @@ bool ShouldUseGpu(CC* cc) {
// Explicit copy input. // Explicit copy input.
gpu_data_in_.resize(input_tensors.size()); gpu_data_in_.resize(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer)); MP_RETURN_IF_ERROR(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
} }
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& input_tensors = const auto& input_tensors =
@ -582,7 +581,7 @@ bool ShouldUseGpu(CC* cc) {
for (int i = 0; i < tensor_indexes.size(); ++i) { for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
std::vector<float> gpu_data(tensor->bytes / sizeof(float)); std::vector<float> gpu_data(tensor->bytes / sizeof(float));
RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read( MP_RETURN_IF_ERROR(gpu_data_out_[i]->buffer.Read(
absl::MakeSpan(tensor->data.f, tensor->bytes))); absl::MakeSpan(tensor->data.f, tensor->bytes)));
output_tensors_cpu->emplace_back(*tensor); output_tensors_cpu->emplace_back(*tensor);
} }
@ -599,9 +598,9 @@ bool ShouldUseGpu(CC* cc) {
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i); GpuTensor& tensor = output_tensors_gpu->at(i);
// Allocate output tensor. // Allocate output tensor.
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor)); gpu_data_out_[i]->elements, &tensor));
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
} }
cc->Outputs() cc->Outputs()
.Tag(kTensorsGpuTag) .Tag(kTensorsGpuTag)
@ -655,7 +654,8 @@ bool ShouldUseGpu(CC* cc) {
options.priority3 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options); tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); MP_RETURN_IF_ERROR(
tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
// Allocate interpreter memory for cpu output. // Allocate interpreter memory for cpu output.
if (!gpu_output_) { if (!gpu_output_) {
@ -688,10 +688,11 @@ bool ShouldUseGpu(CC* cc) {
ASSIGN_OR_RETURN(gpu_data_out_[i]->elements, ASSIGN_OR_RETURN(gpu_data_out_[i]->elements,
tflite_gpu_runner_->GetOutputElements(i)); tflite_gpu_runner_->GetOutputElements(i));
// Create and bind input buffer. // Create and bind input buffer.
RET_CHECK_CALL(::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
} }
RET_CHECK_CALL(tflite_gpu_runner_->Build()); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -841,7 +842,7 @@ bool ShouldUseGpu(CC* cc) {
gpu_data_in_[i]->elements *= tensor->dims->data[d]; gpu_data_in_[i]->elements *= tensor->dims->data[d];
} }
// Create and bind input buffer. // Create and bind input buffer.
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_in_[i]->elements, &gpu_data_in_[i]->buffer)); gpu_data_in_[i]->elements, &gpu_data_in_[i]->buffer));
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
@ -866,7 +867,7 @@ bool ShouldUseGpu(CC* cc) {
// Create and bind output buffers. // Create and bind output buffers.
interpreter_->SetAllowBufferHandleOutput(true); interpreter_->SetAllowBufferHandleOutput(true);
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
delegate_.get(), gpu_data_out_[i]->buffer.id(), delegate_.get(), gpu_data_out_[i]->buffer.id(),

View File

@ -18,7 +18,6 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
@ -404,8 +403,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
&output_detections]() &output_detections]()
-> ::mediapipe::Status { -> ::mediapipe::Status {
// Copy inputs. // Copy inputs.
RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); MP_RETURN_IF_ERROR(
RET_CHECK_CALL(CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer));
MP_RETURN_IF_ERROR(
CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer));
if (!anchors_init_) { if (!anchors_init_) {
if (side_packet_anchors_) { if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
@ -413,11 +414,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>(); cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox); std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.Write<float>( MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write<float>(
absl::MakeSpan(raw_anchors))); absl::MakeSpan(raw_anchors)));
} else { } else {
CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer)); CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer));
} }
anchors_init_ = true; anchors_init_ = true;
@ -425,23 +426,24 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
// Run shaders. // Run shaders.
// Decode boxes. // Decode boxes.
RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.BindToIndex(0)); MP_RETURN_IF_ERROR(gpu_data_->decoded_boxes_buffer.BindToIndex(0));
RET_CHECK_CALL(gpu_data_->raw_boxes_buffer.BindToIndex(1)); MP_RETURN_IF_ERROR(gpu_data_->raw_boxes_buffer.BindToIndex(1));
RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.BindToIndex(2)); MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.BindToIndex(2));
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
RET_CHECK_CALL(gpu_data_->decode_program.Dispatch(decode_workgroups)); MP_RETURN_IF_ERROR(gpu_data_->decode_program.Dispatch(decode_workgroups));
// Score boxes. // Score boxes.
RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.BindToIndex(0)); MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.BindToIndex(0));
RET_CHECK_CALL(gpu_data_->raw_scores_buffer.BindToIndex(1)); MP_RETURN_IF_ERROR(gpu_data_->raw_scores_buffer.BindToIndex(1));
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
RET_CHECK_CALL(gpu_data_->score_program.Dispatch(score_workgroups)); MP_RETURN_IF_ERROR(gpu_data_->score_program.Dispatch(score_workgroups));
// Copy decoded boxes from GPU to CPU. // Copy decoded boxes from GPU to CPU.
std::vector<float> boxes(num_boxes_ * num_coords_); std::vector<float> boxes(num_boxes_ * num_coords_);
RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes))); MP_RETURN_IF_ERROR(
gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes)));
std::vector<float> score_class_id_pairs(num_boxes_ * 2); std::vector<float> score_class_id_pairs(num_boxes_ * 2);
RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.Read( MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.Read(
absl::MakeSpan(score_class_id_pairs))); absl::MakeSpan(score_class_id_pairs)));
// TODO: b/138851969. Is it possible to output a float vector // TODO: b/138851969. Is it possible to output a float vector
@ -802,20 +804,20 @@ void main() {
// Shader program // Shader program
GlShader decode_shader; GlShader decode_shader;
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader)); GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader));
RET_CHECK_CALL(GpuProgram::CreateWithShader(decode_shader, MP_RETURN_IF_ERROR(GpuProgram::CreateWithShader(
&gpu_data_->decode_program)); decode_shader, &gpu_data_->decode_program));
// Outputs // Outputs
size_t decoded_boxes_length = num_boxes_ * num_coords_; size_t decoded_boxes_length = num_boxes_ * num_coords_;
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
decoded_boxes_length, &gpu_data_->decoded_boxes_buffer)); decoded_boxes_length, &gpu_data_->decoded_boxes_buffer));
// Inputs // Inputs
size_t raw_boxes_length = num_boxes_ * num_coords_; size_t raw_boxes_length = num_boxes_ * num_coords_;
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
raw_boxes_length, &gpu_data_->raw_boxes_buffer)); raw_boxes_length, &gpu_data_->raw_boxes_buffer));
size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox; size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
raw_anchors_length, &gpu_data_->raw_anchors_buffer)); raw_anchors_length, &gpu_data_->raw_anchors_buffer));
// Parameters // Parameters
glUseProgram(gpu_data_->decode_program.id()); glUseProgram(gpu_data_->decode_program.id());
@ -896,17 +898,17 @@ void main() {
// Shader program // Shader program
GlShader score_shader; GlShader score_shader;
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader)); GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader));
RET_CHECK_CALL( MP_RETURN_IF_ERROR(
GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program)); GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program));
// Outputs // Outputs
size_t scored_boxes_length = num_boxes_ * 2; // score, class size_t scored_boxes_length = num_boxes_ * 2; // score, class
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
scored_boxes_length, &gpu_data_->scored_boxes_buffer)); scored_boxes_length, &gpu_data_->scored_boxes_buffer));
// Inputs // Inputs
size_t raw_scores_length = num_boxes_ * num_classes_; size_t raw_scores_length = num_boxes_ * num_classes_;
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
raw_scores_length, &gpu_data_->raw_scores_buffer)); raw_scores_length, &gpu_data_->raw_scores_buffer));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();

View File

@ -17,7 +17,6 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -400,7 +399,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
// Create initial working mask texture. // Create initial working mask texture.
::tflite::gpu::gl::GlTexture small_mask_texture; ::tflite::gpu::gl::GlTexture small_mask_texture;
RET_CHECK_CALL(CreateReadWriteRgbaImageTexture( MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::UINT8, // GL_RGBA8 tflite::gpu::DataType::UINT8, // GL_RGBA8
{tensor_width_, tensor_height_}, &small_mask_texture)); {tensor_width_, tensor_height_}, &small_mask_texture));
@ -410,7 +409,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
: mediapipe::GlTexture(); : mediapipe::GlTexture();
// Copy input tensor. // Copy input tensor.
RET_CHECK_CALL(CopyBuffer(input_tensors[0], *tensor_buffer_)); MP_RETURN_IF_ERROR(CopyBuffer(input_tensors[0], *tensor_buffer_));
// Run shader, process mask tensor. // Run shader, process mask tensor.
// Run softmax over tensor output and blend with previous mask. // Run softmax over tensor output and blend with previous mask.
@ -418,18 +417,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
const int output_index = 0; const int output_index = 0;
glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA8); GL_WRITE_ONLY, GL_RGBA8);
RET_CHECK_CALL(tensor_buffer_->BindToIndex(2)); MP_RETURN_IF_ERROR(tensor_buffer_->BindToIndex(2));
const tflite::gpu::uint3 workgroups = { const tflite::gpu::uint3 workgroups = {
NumGroups(tensor_width_, kWorkgroupSize), NumGroups(tensor_width_, kWorkgroupSize),
NumGroups(tensor_height_, kWorkgroupSize), 1}; NumGroups(tensor_height_, kWorkgroupSize), 1};
if (!has_prev_mask) { if (!has_prev_mask) {
RET_CHECK_CALL(mask_program_no_prev_->Dispatch(workgroups)); MP_RETURN_IF_ERROR(mask_program_no_prev_->Dispatch(workgroups));
} else { } else {
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_mask_texture.name()); glBindTexture(GL_TEXTURE_2D, input_mask_texture.name());
RET_CHECK_CALL(mask_program_with_prev_->Dispatch(workgroups)); MP_RETURN_IF_ERROR(mask_program_with_prev_->Dispatch(workgroups));
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
} }
@ -622,22 +621,22 @@ void main() {
// Shader programs. // Shader programs.
GlShader shader_without_previous; GlShader shader_without_previous;
RET_CHECK_CALL(GlShader::CompileShader( MP_RETURN_IF_ERROR(GlShader::CompileShader(
GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous)); GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous));
mask_program_no_prev_ = absl::make_unique<GlProgram>(); mask_program_no_prev_ = absl::make_unique<GlProgram>();
RET_CHECK_CALL(GlProgram::CreateWithShader(shader_without_previous, MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(
mask_program_no_prev_.get())); shader_without_previous, mask_program_no_prev_.get()));
GlShader shader_with_previous; GlShader shader_with_previous;
RET_CHECK_CALL(GlShader::CompileShader( MP_RETURN_IF_ERROR(GlShader::CompileShader(
GL_COMPUTE_SHADER, shader_src_with_previous, &shader_with_previous)); GL_COMPUTE_SHADER, shader_src_with_previous, &shader_with_previous));
mask_program_with_prev_ = absl::make_unique<GlProgram>(); mask_program_with_prev_ = absl::make_unique<GlProgram>();
RET_CHECK_CALL(GlProgram::CreateWithShader(shader_with_previous, MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(
mask_program_with_prev_.get())); shader_with_previous, mask_program_with_prev_.get()));
// Buffer storage for input tensor. // Buffer storage for input tensor.
size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_;
tensor_buffer_ = absl::make_unique<GlBuffer>(); tensor_buffer_ = absl::make_unique<GlBuffer>();
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
tensor_length, tensor_buffer_.get())); tensor_length, tensor_buffer_.get()));
// Parameters. // Parameters.

View File

@ -1,25 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#define RET_CHECK_CALL(call) \
do { \
const auto status = (call); \
if (ABSL_PREDICT_FALSE(!status.ok())) \
return ::mediapipe::InternalError(status.message()); \
} while (0);
#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_

View File

@ -700,6 +700,8 @@ mediapipe_cc_proto_library(
deps = [":rect_to_render_data_calculator_proto"], deps = [":rect_to_render_data_calculator_proto"],
) )
# TODO: What is that one for?
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "detections_to_render_data_calculator_cc_proto", name = "detections_to_render_data_calculator_cc_proto",
srcs = ["detections_to_render_data_calculator.proto"], srcs = ["detections_to_render_data_calculator.proto"],

View File

@ -160,6 +160,8 @@ class AnnotationOverlayCalculator : public CalculatorBase {
GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU.
int width_ = 0; int width_ = 0;
int height_ = 0; int height_ = 0;
int width_gpu_ = 0; // Size of overlay drawing texture.
int height_gpu_ = 0;
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(AnnotationOverlayCalculator); REGISTER_CALCULATOR(AnnotationOverlayCalculator);
@ -389,7 +391,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
glBindTexture(GL_TEXTURE_2D, image_mat_tex_); glBindTexture(GL_TEXTURE_2D, image_mat_tex_);
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_, height_, GL_RGB, glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_gpu_, height_gpu_, GL_RGB,
GL_UNSIGNED_BYTE, overlay_image); GL_UNSIGNED_BYTE, overlay_image);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
} }
@ -492,12 +494,12 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
if (format != mediapipe::ImageFormat::SRGBA && if (format != mediapipe::ImageFormat::SRGBA &&
format != mediapipe::ImageFormat::SRGB) format != mediapipe::ImageFormat::SRGB)
RET_CHECK_FAIL() << "Unsupported GPU input format: " << format; RET_CHECK_FAIL() << "Unsupported GPU input format: " << format;
image_mat = absl::make_unique<cv::Mat>(height_, width_, CV_8UC3); image_mat = absl::make_unique<cv::Mat>(height_gpu_, width_gpu_, CV_8UC3);
memset(image_mat->data, kAnnotationBackgroundColor, memset(image_mat->data, kAnnotationBackgroundColor,
height_ * width_ * image_mat->elemSize()); height_gpu_ * width_gpu_ * image_mat->elemSize());
} else { } else {
image_mat = absl::make_unique<cv::Mat>( image_mat = absl::make_unique<cv::Mat>(
options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, height_gpu_, width_gpu_, CV_8UC3,
cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(),
options_.canvas_color().b())); options_.canvas_color().b()));
} }
@ -632,18 +634,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
kAnnotationBackgroundColor / 255.0, kAnnotationBackgroundColor / 255.0,
kAnnotationBackgroundColor / 255.0); kAnnotationBackgroundColor / 255.0);
// Init texture for opencv rendered frame.
const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
// Ensure GPU texture is divisible by 4. See b/138751944 for more info. // Ensure GPU texture is divisible by 4. See b/138751944 for more info.
width_ = const float alignment = ImageFrame::kGlDefaultAlignmentBoundary;
RoundUp(input_frame.width(), ImageFrame::kGlDefaultAlignmentBoundary); const float scale_factor = options_.gpu_scale_factor();
height_ = if (image_frame_available_) {
RoundUp(input_frame.height(), ImageFrame::kGlDefaultAlignmentBoundary); const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
width_ = RoundUp(input_frame.width(), alignment);
height_ = RoundUp(input_frame.height(), alignment);
} else {
width_ = RoundUp(options_.canvas_width_px(), alignment);
height_ = RoundUp(options_.canvas_height_px(), alignment);
}
width_gpu_ = RoundUp(width_ * scale_factor, alignment);
height_gpu_ = RoundUp(height_ * scale_factor, alignment);
// Init texture for opencv rendered frame.
{ {
glGenTextures(1, &image_mat_tex_); glGenTextures(1, &image_mat_tex_);
glBindTexture(GL_TEXTURE_2D, image_mat_tex_); glBindTexture(GL_TEXTURE_2D, image_mat_tex_);
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width_, height_, 0, GL_RGB, // TODO
// OpenCV only renders to RGB images, not RGBA. Ideally this should be RGBA.
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width_gpu_, height_gpu_, 0, GL_RGB,
GL_UNSIGNED_BYTE, nullptr); GL_UNSIGNED_BYTE, nullptr);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);

View File

@ -45,4 +45,12 @@ message AnnotationOverlayCalculatorOptions {
// origin. (Historically, OpenGL uses bottom left origin, but most MediaPipe // origin. (Historically, OpenGL uses bottom left origin, but most MediaPipe
// examples expect textures to have top-left origin.) // examples expect textures to have top-left origin.)
optional bool gpu_uses_top_left_origin = 6 [default = true]; optional bool gpu_uses_top_left_origin = 6 [default = true];
// Scale factor for intermediate image for GPU rendering.
// This can be used to speed up annotation by drawing the annotation on an
// intermediate image with a reduced scale, e.g. 0.5 (of the input image width
// and height), before resizing and overlaying it on top of the input image.
// Should only be used if *all* render data uses normalized coordinates
// (or absolute coordinates are updated to scale accordingly).
optional float gpu_scale_factor = 7 [default = 1.0];
} }

View File

@ -235,9 +235,10 @@ void DetectionsToRenderDataCalculator::AddLabels(
const Detection& detection, const Detection& detection,
const DetectionsToRenderDataCalculatorOptions& options, const DetectionsToRenderDataCalculatorOptions& options,
float text_line_height, RenderData* render_data) { float text_line_height, RenderData* render_data) {
CHECK(detection.label().empty() || detection.label_id().empty()) CHECK(detection.label().empty() || detection.label_id().empty() ||
<< "Either std::string or integer labels must be used for detection " detection.label_size() == detection.label_id_size())
"but not both at the same time."; << "String or integer labels should be of same size. Or only one of them "
"is present.";
const auto num_labels = const auto num_labels =
std::max(detection.label_size(), detection.label_id_size()); std::max(detection.label_size(), detection.label_id_size());
CHECK_EQ(detection.score_size(), num_labels) CHECK_EQ(detection.score_size(), num_labels)

View File

@ -316,6 +316,7 @@ cc_library(
"//mediapipe/util/tracking", "//mediapipe/util/tracking",
"//mediapipe/util/tracking:box_tracker", "//mediapipe/util/tracking:box_tracker",
"//mediapipe/util/tracking:tracking_visualization_utilities", "//mediapipe/util/tracking:tracking_visualization_utilities",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "absl/container/node_hash_set.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "mediapipe/calculators/video/box_tracker_calculator.pb.h" #include "mediapipe/calculators/video/box_tracker_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -193,12 +194,12 @@ class BoxTrackerCalculator : public CalculatorBase {
TimedBoxProtoList initial_pos_; TimedBoxProtoList initial_pos_;
// Keeps tracks boxes that have already been initialized. // Keeps tracks boxes that have already been initialized.
std::unordered_set<int> initialized_ids_; absl::node_hash_set<int> initialized_ids_;
// Non empty for batch mode tracking. // Non empty for batch mode tracking.
std::string cache_dir_; std::string cache_dir_;
// Ids to be tracked in batch_mode. // Ids to be tracked in batch_mode.
std::unordered_set<int> batch_track_ids_; absl::node_hash_set<int> batch_track_ids_;
int frame_num_ = 0; int frame_num_ = 0;

View File

@ -1 +1 @@
This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details. This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev) for details.

View File

@ -43,19 +43,23 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
inputSidePackets.put(INPUT_NUM_FACES_SIDE_PACKET_NAME, packetCreator.createInt32(NUM_FACES)); inputSidePackets.put(INPUT_NUM_FACES_SIDE_PACKET_NAME, packetCreator.createInt32(NUM_FACES));
processor.setInputSidePackets(inputSidePackets); processor.setInputSidePackets(inputSidePackets);
processor.addPacketCallback( // To show verbose logging, run:
// adb shell setprop log.tag.MainActivity VERBOSE
if (Log.isLoggable(TAG, Log.VERBOSE)) {
processor.addPacketCallback(
OUTPUT_LANDMARKS_STREAM_NAME, OUTPUT_LANDMARKS_STREAM_NAME,
(packet) -> { (packet) -> {
Log.d(TAG, "Received multi face landmarks packet."); Log.v(TAG, "Received multi face landmarks packet.");
List<NormalizedLandmarkList> multiFaceLandmarks = List<NormalizedLandmarkList> multiFaceLandmarks =
PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser());
Log.d( Log.v(
TAG, TAG,
"[TS:" "[TS:"
+ packet.getTimestamp() + packet.getTimestamp()
+ "] " + "] "
+ getMultiFaceLandmarksDebugString(multiFaceLandmarks)); + getMultiFaceLandmarksDebugString(multiFaceLandmarks));
}); });
}
} }
private static String getMultiFaceLandmarksDebugString( private static String getMultiFaceLandmarksDebugString(

View File

@ -43,29 +43,33 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
} }
}); });
processor.addPacketCallback( // To show verbose logging, run:
// adb shell setprop log.tag.MainActivity VERBOSE
if (Log.isLoggable(TAG, Log.VERBOSE)) {
processor.addPacketCallback(
OUTPUT_LANDMARKS_STREAM_NAME, OUTPUT_LANDMARKS_STREAM_NAME,
(packet) -> { (packet) -> {
byte[] landmarksRaw = PacketGetter.getProtoBytes(packet); byte[] landmarksRaw = PacketGetter.getProtoBytes(packet);
try { try {
NormalizedLandmarkList landmarks = NormalizedLandmarkList.parseFrom(landmarksRaw); NormalizedLandmarkList landmarks = NormalizedLandmarkList.parseFrom(landmarksRaw);
if (landmarks == null) { if (landmarks == null) {
Log.d(TAG, "[TS:" + packet.getTimestamp() + "] No hand landmarks."); Log.v(TAG, "[TS:" + packet.getTimestamp() + "] No hand landmarks.");
return; return;
} }
// Note: If hand_presence is false, these landmarks are useless. // Note: If hand_presence is false, these landmarks are useless.
Log.d( Log.v(
TAG, TAG,
"[TS:" "[TS:"
+ packet.getTimestamp() + packet.getTimestamp()
+ "] #Landmarks for hand: " + "] #Landmarks for hand: "
+ landmarks.getLandmarkCount()); + landmarks.getLandmarkCount());
Log.d(TAG, getLandmarksDebugString(landmarks)); Log.v(TAG, getLandmarksDebugString(landmarks));
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
Log.e(TAG, "Couldn't Exception received - " + e); Log.e(TAG, "Couldn't Exception received - " + e);
return; return;
} }
}); });
}
} }
private static String getLandmarksDebugString(NormalizedLandmarkList landmarks) { private static String getLandmarksDebugString(NormalizedLandmarkList landmarks) {

View File

@ -31,19 +31,23 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
processor.addPacketCallback( // To show verbose logging, run:
// adb shell setprop log.tag.MainActivity VERBOSE
if (Log.isLoggable(TAG, Log.VERBOSE)) {
processor.addPacketCallback(
OUTPUT_LANDMARKS_STREAM_NAME, OUTPUT_LANDMARKS_STREAM_NAME,
(packet) -> { (packet) -> {
Log.d(TAG, "Received multi-hand landmarks packet."); Log.v(TAG, "Received multi-hand landmarks packet.");
List<NormalizedLandmarkList> multiHandLandmarks = List<NormalizedLandmarkList> multiHandLandmarks =
PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser()); PacketGetter.getProtoVector(packet, NormalizedLandmarkList.parser());
Log.d( Log.v(
TAG, TAG,
"[TS:" "[TS:"
+ packet.getTimestamp() + packet.getTimestamp()
+ "] " + "] "
+ getMultiHandLandmarksDebugString(multiHandLandmarks)); + getMultiHandLandmarksDebugString(multiHandLandmarks));
}); });
}
} }
private String getMultiHandLandmarksDebugString(List<NormalizedLandmarkList> multiHandLandmarks) { private String getMultiHandLandmarksDebugString(List<NormalizedLandmarkList> multiHandLandmarks) {

View File

@ -324,6 +324,19 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
int path_offset_y; int path_offset_y;
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y)); MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y));
// Prevent box from extending beyond the image after camera smoothing.
if (path_offset_y - ceil(path_height / 2.0) < 0) {
path_offset_y = ceil(path_height / 2.0);
} else if (path_offset_y + ceil(path_height / 2.0) > frame_height_) {
path_offset_y = frame_height_ - ceil(path_height / 2.0);
}
int path_width = path_height * target_aspect_;
if (path_offset_x - ceil(path_width / 2.0) < 0) {
path_offset_x = ceil(path_width / 2.0);
} else if (path_offset_x + ceil(path_width / 2.0) > frame_width_) {
path_offset_x = frame_width_ - ceil(path_width / 2.0);
}
// Convert to top/bottom borders to remove. // Convert to top/bottom borders to remove.
int path_top = path_offset_y - path_height / 2; int path_top = path_offset_y - path_height / 2;
int path_bottom = frame_height_ - (path_offset_y + path_height / 2); int path_bottom = frame_height_ - (path_offset_y + path_height / 2);

View File

@ -344,6 +344,28 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
CheckBorder(static_features, 1000, 1000, 495, 395); CheckBorder(static_features, 1000, 1000, 495, 395);
} }
TEST(ContentZoomingCalculatorTest, ZoomTestNearOutsideBorder) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
AddDetection(cv::Rect_<float>(.95, .95, .05, .05), 0, runner.get());
AddDetection(cv::Rect_<float>(.9, .9, .1, .1), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(972, 972, 55, 55, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(958, 958, 83, 83, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, ZoomTestNearInsideBorder) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
AddDetection(cv::Rect_<float>(0, 0, .05, .05), 0, runner.get());
AddDetection(cv::Rect_<float>(0, 0, .1, .1), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(28, 28, 55, 55, 0, runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(42, 42, 83, 83, 1, runner->Outputs().Tag("CROP_RECT").packets);
}
} // namespace } // namespace
} // namespace autoflip } // namespace autoflip

View File

@ -10,6 +10,10 @@ namespace autoflip {
current_time_ = time_us; current_time_ = time_us;
initialized_ = true; initialized_ = true;
current_velocity_deg_per_s_ = 0; current_velocity_deg_per_s_ = 0;
RET_CHECK_GT(pixels_per_degree_, 0)
<< "pixels_per_degree must be larger than 0.";
RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window())
<< "Reframe window cannot exceed min_motion_to_reframe.";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -22,6 +26,14 @@ namespace autoflip {
if (abs(delta_degs) < options_.min_motion_to_reframe()) { if (abs(delta_degs) < options_.min_motion_to_reframe()) {
position = current_position_px_; position = current_position_px_;
delta_degs = 0; delta_degs = 0;
} else if (delta_degs > 0) {
// Apply new position, less the reframe window size.
position = position - pixels_per_degree_ * options_.reframe_window();
delta_degs = (position - current_position_px_) / pixels_per_degree_;
} else {
// Apply new position, plus the reframe window size.
position = position + pixels_per_degree_ * options_.reframe_window();
delta_degs = (position - current_position_px_) / pixels_per_degree_;
} }
// Time and position updates. // Time and position updates.

View File

@ -10,4 +10,9 @@ message KinematicOptions {
optional double max_velocity = 2 [default = 18]; optional double max_velocity = 2 [default = 18];
// Min motion (in degrees) to react in pixels. // Min motion (in degrees) to react in pixels.
optional float min_motion_to_reframe = 3 [default = 1.8]; optional float min_motion_to_reframe = 3 [default = 1.8];
// When motion exceeds min_motion_to_reframe, move within this distance of the
// camera from the starting direction. Setting this value non-zero reduces
// total reframe distance on average. Value cannot exceed
// min_motion_to_reframe value.
optional float reframe_window = 4 [default = 0];
} }

View File

@ -27,6 +27,12 @@ namespace mediapipe {
namespace autoflip { namespace autoflip {
namespace { namespace {
TEST(KinematicPathSolverTest, FailZeroPixelsPerDegree) {
KinematicOptions options;
KinematicPathSolver solver(options, 0, 1000, 0);
EXPECT_FALSE(solver.AddObservation(500, kMicroSecInSec * 0).ok());
}
TEST(KinematicPathSolverTest, FailNotInitializedState) { TEST(KinematicPathSolverTest, FailNotInitializedState) {
KinematicOptions options; KinematicOptions options;
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
@ -109,6 +115,38 @@ TEST(KinematicPathSolverTest, PassEnoughMotionSmallImg) {
EXPECT_EQ(state, 410); EXPECT_EQ(state, 410);
} }
TEST(KinematicPathSolverTest, FailReframeWindowSetting) {
KinematicOptions options;
// Set min motion to 1deg
options.set_min_motion_to_reframe(1.0);
options.set_update_rate(1);
options.set_max_velocity(1000);
// Set reframe window size to .75 for test.
options.set_reframe_window(1.1);
// Set degrees / pixel to 16.6
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
ASSERT_FALSE(solver.AddObservation(500, kMicroSecInSec * 0).ok());
}
TEST(KinematicPathSolverTest, PassReframeWindow) {
KinematicOptions options;
// Set min motion to 1deg
options.set_min_motion_to_reframe(1.0);
options.set_update_rate(1);
options.set_max_velocity(1000);
// Set reframe window size to .75 for test.
options.set_reframe_window(0.75);
// Set degrees / pixel to 16.6
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
int state;
MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0));
// Move target by 20px / 16.6 = 1.2deg
MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1));
MP_ASSERT_OK(solver.GetState(&state));
// Expect cam to move 1.2-.75 deg, * 16.6 = 7.47px + 500 =
EXPECT_EQ(state, 507);
}
TEST(KinematicPathSolverTest, PassUpdateRate) { TEST(KinematicPathSolverTest, PassUpdateRate) {
KinematicOptions options; KinematicOptions options;
options.set_min_motion_to_reframe(1.0); options.set_min_motion_to_reframe(1.0);

View File

@ -30,7 +30,7 @@ SECONDS_TO_MICROSECONDS = 1000000
def bytes23(string): def bytes23(string):
"""Creates a bytes string in either Python 2 or 3.""" """Creates a bytes string in either Python 2 or 3."""
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
return bytes(string, 'utf8') return bytes(string, 'utf8')
else: else:

View File

@ -0,0 +1,26 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration helper for iOS app bundle ids and provisioning profiles.
"""
BUNDLE_ID_PREFIX = "*SEE_IOS_INSTRUCTIONS*.mediapipe.examples"
# Look for a provisioning profile in the example's directory first,
# otherwise look for a common one.
def example_provisioning():
local_profile = native.glob(["provisioning_profile.mobileprovision"])
if local_profile:
return local_profile[0]
return "//mediapipe/examples/ios:provisioning_profile"

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "edgedetectiongpu", name = "edgedetectiongpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "EdgeDetectionGpuApp", name = "EdgeDetectionGpuApp",
bundle_id = "com.google.mediapipe.EdgeDetectionGpu", bundle_id = BUNDLE_ID_PREFIX + ".EdgeDetectionGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [":EdgeDetectionGpuAppLibrary"], deps = [":EdgeDetectionGpuAppLibrary"],
) )

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "facedetectioncpu", name = "facedetectioncpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "FaceDetectionCpuApp", name = "FaceDetectionCpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionCpu", bundle_id = BUNDLE_ID_PREFIX + ".FaceDetectionCpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":FaceDetectionCpuAppLibrary", ":FaceDetectionCpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "facedetectiongpu", name = "facedetectiongpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "FaceDetectionGpuApp", name = "FaceDetectionGpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionGpu", bundle_id = BUNDLE_ID_PREFIX + ".FaceDetectionGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":FaceDetectionGpuAppLibrary", ":FaceDetectionGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -16,6 +16,11 @@ load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "FaceMeshGpuApp", name = "FaceMeshGpuApp",
bundle_id = "com.google.mediapipe.FaceMeshGpu", bundle_id = BUNDLE_ID_PREFIX + ".FaceMeshGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":FaceMeshGpuAppLibrary", ":FaceMeshGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "handdetectiongpu", name = "handdetectiongpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "HandDetectionGpuApp", name = "HandDetectionGpuApp",
bundle_id = "com.google.mediapipe.HandDetectionGpu", bundle_id = BUNDLE_ID_PREFIX + ".HandDetectionGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":HandDetectionGpuAppLibrary", ":HandDetectionGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -16,6 +16,11 @@ load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "HandTrackingGpuApp", name = "HandTrackingGpuApp",
bundle_id = "com.google.mediapipe.HandTrackingGpu", bundle_id = BUNDLE_ID_PREFIX + ".HandTrackingGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":HandTrackingGpuAppLibrary", ":HandTrackingGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""This script is used to set up automatic provisioning for iOS examples.
It scans the provisioning profiles used by Xcode, looking for one matching the
application identifier for each example app. If found, it symlinks the profile
in the appropriate location for Bazel to find it.
It also checks whether the bundle_id.bzl file has a placeholder bundle ID, and
replaces it with a unique ID if so.
"""
import os
import plistlib
import re
import subprocess
import uuid
# This script is meant to be located in the MediaPipe iOS examples directory
# root. The logic below will have to be changed if the directory structure is
# reorganized.
examples_ios = os.path.dirname(os.path.realpath(__file__))
example_names = {
f for f in os.listdir(examples_ios)
if os.path.isdir(os.path.join(examples_ios, f))
}
def configure_bundle_id_prefix(
bundle_id_bzl=os.path.join(examples_ios, "bundle_id.bzl")) -> str:
"""Configures the bundle id prefix to use.
Gets the bundle id prefix in use from bundle_id.bzl; sets up a unique
prefix if not already set.
Args:
bundle_id_bzl: Path to the bzl file where the bundle id is stored.
Returns:
The bundle id prefix to use.
Raises:
Exception: If parsing of bundle_id.bzl fails.
"""
bundle_id_re = re.compile(
r'^BUNDLE_ID_PREFIX\s*=\s*"(.*)"', flags=re.MULTILINE)
with open(bundle_id_bzl, "r") as f:
contents = f.read()
match = bundle_id_re.search(contents)
if not match:
raise Exception("could not find BUNDLE_ID_PREFIX")
bundle_id_prefix = match.group(1)
# The default value contains a *, which is an invalid character in bundle IDs.
if "*" in bundle_id_prefix:
bundle_id_prefix = str(uuid.uuid4()) + ".mediapipe.examples"
contents = contents[:match.start(1)] + bundle_id_prefix + contents[match
.end(1):]
with open(bundle_id_bzl, "w") as f:
f.write(contents)
print("Set up a unique bundle ID prefix: " + bundle_id_prefix)
return bundle_id_prefix
def get_app_id(profile_path) -> str:
try:
plist = subprocess.check_output(
["security", "cms", "-D", "-i", profile_path])
profile = plistlib.loads(plist)
return profile["Entitlements"]["application-identifier"]
except Exception: # pylint: disable=broad-except
return None
def update_symlink(target_path, link_path):
if os.path.islink(link_path):
print(f" Removing existing symlink at {link_path}")
os.remove(link_path)
elif os.path.exists(link_path):
print(f" Unexpected existing file at {link_path}; skipping")
return
os.symlink(target_path, link_path)
print(f" Created symlink to {target_path} at {link_path}")
def process_profile(profile_path, our_app_id_re):
"""Processes one mobileprovision file.
Checks if its app ID matches one of our example apps, and symlinks it in the
appropriate location if so.
Args:
profile_path: Path to the mobileprovision file.
our_app_id_re: Regular expression to extract the example name from one of
out app ids.
"""
app_id = get_app_id(profile_path)
if not app_id:
print(f"Could not parse '{profile_path}', skipping")
return
match = our_app_id_re.match(app_id)
if not match:
return
app_name = match.group(1)
app_dir_name = app_name.lower()
if app_dir_name not in example_names:
print(f"The app id '{app_id}' has our prefix, but does not seem to match" +
"any of our examples. Skipping.")
return
print(f"Found profile for {app_name}")
link_path = os.path.join(examples_ios, app_dir_name,
"provisioning_profile.mobileprovision")
update_symlink(profile_path, link_path)
def main():
bundle_id_prefix = configure_bundle_id_prefix()
our_app_id_re = re.compile(r"[0-9A-Z]+\." + re.escape(bundle_id_prefix) +
r"\.(.*)")
profile_dir = os.path.expanduser(
"~/Library/MobileDevice/Provisioning Profiles")
if not os.path.isdir(profile_dir):
print(f"Could not find provisioning profiles directory at {profile_dir}")
return 2
print(
f"Looking for profiles for app ids with prefix '{bundle_id_prefix}' in '{profile_dir}'"
)
for name in os.listdir(profile_dir):
if not name.endswith(".mobileprovision"):
continue
profile_path = os.path.join(profile_dir, name)
process_profile(profile_path, our_app_id_re)
if __name__ == "__main__":
main()

View File

@ -16,6 +16,11 @@ load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "MultiHandTrackingGpuApp", name = "MultiHandTrackingGpuApp",
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu", bundle_id = BUNDLE_ID_PREFIX + ".MultiHandTrackingGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":MultiHandTrackingGpuAppLibrary", ":MultiHandTrackingGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "objectdetectioncpu", name = "objectdetectioncpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "ObjectDetectionCpuApp", name = "ObjectDetectionCpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionCpu", bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionCpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":ObjectDetectionCpuAppLibrary", ":ObjectDetectionCpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",
"ios_application", "ios_application",
) )
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias( alias(
name = "objectdetectiongpu", name = "objectdetectiongpu",
@ -28,14 +33,14 @@ alias(
ios_application( ios_application(
name = "ObjectDetectionGpuApp", name = "ObjectDetectionGpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionGpu", bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionGpu",
families = [ families = [
"iphone", "iphone",
"ipad", "ipad",
], ],
infoplists = ["Info.plist"], infoplists = ["Info.plist"],
minimum_os_version = MIN_IOS_VERSION, minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", provisioning_profile = example_provisioning(),
deps = [ deps = [
":ObjectDetectionGpuAppLibrary", ":ObjectDetectionGpuAppLibrary",
"@ios_opencv//:OpencvFramework", "@ios_opencv//:OpencvFramework",

View File

@ -112,7 +112,7 @@ mediapipe_proto_library(
mediapipe_proto_library( mediapipe_proto_library(
name = "stream_handler_proto", name = "stream_handler_proto",
srcs = ["stream_handler.proto"], srcs = ["stream_handler.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"],
) )
@ -130,7 +130,7 @@ mediapipe_proto_library(
mediapipe_proto_library( mediapipe_proto_library(
name = "thread_pool_executor_proto", name = "thread_pool_executor_proto",
srcs = ["thread_pool_executor.proto"], srcs = ["thread_pool_executor.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"],
) )
@ -923,6 +923,15 @@ cc_library(
], ],
) )
# When --copt=-fno-rtti is set, MEDIAPIPE_HAS_RTTI is cleared in port.h.
# To explicitly clear MEDIAPIPE_HAS_RTTI, compile with:
# bazel build --define=disable_rtti_and_exceptions=true
config_setting(
name = "disable_rtti_and_exceptions",
define_values = {"disable_rtti_and_exceptions": "true"},
visibility = ["//visibility:public"],
)
cc_library( cc_library(
name = "port", name = "port",
hdrs = ["port.h"], hdrs = ["port.h"],
@ -931,6 +940,11 @@ cc_library(
}) + select({ }) + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU"], "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU"],
}) + select({
"//conditions:default": [],
"//mediapipe/framework:disable_rtti_and_exceptions": [
"MEDIAPIPE_HAS_RTTI=0",
],
}), }),
visibility = [ visibility = [
"//mediapipe/framework:__subpackages__", "//mediapipe/framework:__subpackages__",
@ -1167,6 +1181,7 @@ cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:tag_map_helper",
"@com_google_absl//absl/container:flat_hash_set",
], ],
) )

View File

@ -15,6 +15,7 @@
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
// TODO: Move protos in another CL after the C++ code migration. // TODO: Move protos in another CL after the C++ code migration.
#include "absl/container/flat_hash_set.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_context_manager.h" #include "mediapipe/framework/calculator_context_manager.h"
@ -192,8 +193,8 @@ TEST(CalculatorTest, CreateByName) {
// Tests registration of a calculator within a whitelisted namespace. // Tests registration of a calculator within a whitelisted namespace.
TEST(CalculatorTest, CreateByNameWhitelisted) { TEST(CalculatorTest, CreateByNameWhitelisted) {
// Reset the registration namespace whitelist. // Reset the registration namespace whitelist.
*const_cast<std::unordered_set<std::string>*>( *const_cast<absl::flat_hash_set<std::string>*>(
&NamespaceWhitelist::TopNamespaces()) = std::unordered_set<std::string>{ &NamespaceWhitelist::TopNamespaces()) = absl::flat_hash_set<std::string>{
"mediapipe::test_ns::whitelisted_ns", "mediapipe::test_ns::whitelisted_ns",
"mediapipe", "mediapipe",
}; };

View File

@ -375,7 +375,7 @@ TEST_F(CalculatorGraphEventLoopTest, TryToAddPacketToInputStream) {
this, std::placeholders::_1))}, this, std::placeholders::_1))},
{"blocking_mutex", mutex_side_packet}})); {"blocking_mutex", mutex_side_packet}}));
constexpr int kNumInputPackets = 2; constexpr int kNumInputPackets = 20;
constexpr int kMaxQueueSize = 1; constexpr int kMaxQueueSize = 1;
// Lock the mutex so that the BlockingPassThroughCalculator cannot read any of // Lock the mutex so that the BlockingPassThroughCalculator cannot read any of

View File

@ -1828,14 +1828,14 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) {
status = graph->Initialize(config); status = graph->Initialize(config);
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("StringStatusHandler"), testing::AllOf(
// The problematic input side packet. testing::HasSubstr("StringStatusHandler"),
testing::HasSubstr("generated_by_generator"), // The problematic input side packet.
// Actual type. testing::HasSubstr("generated_by_generator"),
testing::HasSubstr("string"), // Actual type.
// Expected type. testing::HasSubstr(MediaPipeTypeStringOrDemangled<uint32>()),
testing::HasSubstr( // Expected type.
MediaPipeTypeStringOrDemangled<uint32>()))); testing::HasSubstr("string")));
} }
TEST(CalculatorGraph, GenerateInInitialize) { TEST(CalculatorGraph, GenerateInInitialize) {

View File

@ -405,7 +405,7 @@ namespace {
// Returns the Packet sent to an OutputSidePacket, or an empty packet // Returns the Packet sent to an OutputSidePacket, or an empty packet
// if none available. // if none available.
const Packet GetPacket(const OutputSidePacket& out) { const Packet GetPacket(const OutputSidePacket& out) {
auto impl = dynamic_cast<const OutputSidePacketImpl*>(&out); auto impl = static_cast<const OutputSidePacketImpl*>(&out);
return (impl == nullptr) ? Packet() : impl->GetPacket(); return (impl == nullptr) ? Packet() : impl->GetPacket();
} }

View File

@ -209,6 +209,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/deps/registration.h"
#include "absl/container/flat_hash_set.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -34,9 +36,9 @@ inline size_t array_size(T (&arr)[SIZE]) {
} // namespace } // namespace
/*static*/ /*static*/
const std::unordered_set<std::string>& NamespaceWhitelist::TopNamespaces() { const absl::flat_hash_set<std::string>& NamespaceWhitelist::TopNamespaces() {
static std::unordered_set<std::string>* result = static absl::flat_hash_set<std::string>* result =
new std::unordered_set<std::string>( new absl::flat_hash_set<std::string>(
kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces)); kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces));
return *result; return *result;
} }

View File

@ -26,6 +26,7 @@
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_set.h"
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
@ -145,7 +146,7 @@ struct WrapStatusOr<::mediapipe::StatusOr<T>> {
class NamespaceWhitelist { class NamespaceWhitelist {
public: public:
static const std::unordered_set<std::string>& TopNamespaces(); static const absl::flat_hash_set<std::string>& TopNamespaces();
}; };
template <typename R, typename... Args> template <typename R, typename... Args>

View File

@ -95,18 +95,23 @@ cc_library(
hdrs = ["image_frame.h"], hdrs = ["image_frame.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:aligned_malloc_and_free",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:source_location",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], "//mediapipe/framework:port",
"//mediapipe/framework/port:aligned_malloc_and_free",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:type_util",
] + select({
"//conditions:default": [
],
"//mediapipe/framework:disable_rtti_and_exceptions": [],
}),
) )
cc_library( cc_library(

View File

@ -42,6 +42,9 @@
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/tool/type_util.h"
#define IMAGE_FRAME_RAW_IMAGE MEDIAPIPE_HAS_RTTI
namespace mediapipe { namespace mediapipe {

View File

@ -407,7 +407,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
ConvertToVectorOfProtoMessageLitePtrs(const T* data, ConvertToVectorOfProtoMessageLitePtrs(const T* data,
/*is_proto_vector=*/std::false_type) { /*is_proto_vector=*/std::false_type) {
return ::mediapipe::InvalidArgumentError(absl::StrCat( return ::mediapipe::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", typeid(T).name(), "\"", "The Packet stores \"", tool::TypeId<T>().name(), "\"",
"which is not convertible to vector<proto_ns::MessageLite*>.")); "which is not convertible to vector<proto_ns::MessageLite*>."));
} }

View File

@ -147,6 +147,9 @@ struct UnregisteredPairStruct {
}; };
MEDIAPIPE_REGISTER_TYPE(::mediapipe::RegisteredPairStruct, MEDIAPIPE_REGISTER_TYPE(::mediapipe::RegisteredPairStruct,
"::mediapipe::RegisteredPairStruct", nullptr, nullptr); "::mediapipe::RegisteredPairStruct", nullptr, nullptr);
MEDIAPIPE_REGISTER_TYPE(int, "int", nullptr, nullptr);
MEDIAPIPE_REGISTER_TYPE(float, "float", nullptr, nullptr);
constexpr bool kHaveUnregisteredTypeNames = MEDIAPIPE_HAS_RTTI;
TEST(PacketTest, TypeRegistrationDebugString) { TEST(PacketTest, TypeRegistrationDebugString) {
// Test registered type. // Test registered type.
@ -159,9 +162,13 @@ TEST(PacketTest, TypeRegistrationDebugString) {
// Unregistered type. // Unregistered type.
UnregisteredPairStruct u{"s", true}; UnregisteredPairStruct u{"s", true};
Packet packet2 = MakePacket<UnregisteredPairStruct>(u); Packet packet2 = MakePacket<UnregisteredPairStruct>(u);
std::string expected_type_name =
(kHaveUnregisteredTypeNames)
? "mediapipe::(anonymous namespace)::UnregisteredPairStruct"
: "<unknown>";
EXPECT_EQ(packet2.DebugString(), EXPECT_EQ(packet2.DebugString(),
"mediapipe::Packet with timestamp: Timestamp::Unset() and type: " "mediapipe::Packet with timestamp: Timestamp::Unset() and type: " +
"mediapipe::(anonymous namespace)::UnregisteredPairStruct"); expected_type_name);
} }
TEST(PacketTest, ReturnGenericProtobufMessage) { TEST(PacketTest, ReturnGenericProtobufMessage) {

View File

@ -80,4 +80,17 @@
#endif #endif
#endif #endif
#ifndef MEDIAPIPE_HAS_RTTI
// Detect if RTTI is disabled in the compiler.
#if defined(__clang__) && defined(__has_feature)
#define MEDIAPIPE_HAS_RTTI __has_feature(cxx_rtti)
#elif defined(__GNUC__) && !defined(__GXX_RTTI)
#define MEDIAPIPE_HAS_RTTI 0
#elif defined(_MSC_VER) && !defined(_CPPRTTI)
#define MEDIAPIPE_HAS_RTTI 0
#else
#define MEDIAPIPE_HAS_RTTI 1
#endif
#endif // MEDIAPIPE_HAS_RTTI
#endif // MEDIAPIPE_FRAMEWORK_PORT_H_ #endif // MEDIAPIPE_FRAMEWORK_PORT_H_

View File

@ -307,11 +307,10 @@ cc_library(
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":core_proto",
":logging", ":logging",
"//mediapipe/framework:port", "//mediapipe/framework:port",
] + select({ ],
"//conditions:default": ["@com_google_protobuf//:protobuf"],
}),
) )
cc_library( cc_library(

View File

@ -15,16 +15,21 @@
#ifndef MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ #ifndef MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_
#define MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ #define MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_
#include "google/protobuf/text_format.h" #include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
namespace mediapipe { namespace mediapipe {
template <typename T>
bool ParseTextProto(const std::string& input, T* proto) {
return proto_ns::TextFormat::ParseFromString(input, proto);
}
template <typename T> template <typename T>
T ParseTextProtoOrDie(const std::string& input) { T ParseTextProtoOrDie(const std::string& input) {
T result; T result;
CHECK(google::protobuf::TextFormat::ParseFromString(input, &result)); CHECK(ParseTextProto(input, &result));
return result; return result;
} }

View File

@ -261,6 +261,7 @@ cc_test(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:threadpool", "//mediapipe/framework/port:threadpool",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],

View File

@ -16,6 +16,7 @@
#include <functional> #include <functional>
#include "absl/container/node_hash_map.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/time/clock.h" #include "absl/time/clock.h"
#include "absl/time/time.h" #include "absl/time/time.h"
@ -122,7 +123,7 @@ absl::Duration time(const std::function<void()>& f) {
// With bazel build -c opt, the ShardedMap reduces CPU time by 60%. // With bazel build -c opt, the ShardedMap reduces CPU time by 60%.
TEST(ShardedMapTest, TestParallelAccess) { TEST(ShardedMapTest, TestParallelAccess) {
absl::Duration simple_time = time([] { absl::Duration simple_time = time([] {
std::unordered_map<int64, int64> simple_map; absl::node_hash_map<int64, int64> simple_map;
TestParallelAccess(simple_map, 1); TestParallelAccess(simple_map, 1);
}); });
absl::Duration safe_time = time([] { absl::Duration safe_time = time([] {

View File

@ -22,6 +22,10 @@ package mediapipe;
import "mediapipe/framework/mediapipe_options.proto"; import "mediapipe/framework/mediapipe_options.proto";
option java_package = "com.google.mediapipe.proto";
option java_outer_classname = "StreamHandlerProto";
option objc_class_prefix = "MediaPipe";
// Settings specifying an input stream handler. // Settings specifying an input stream handler.
message InputStreamHandlerConfig { message InputStreamHandlerConfig {
// Name of the registered input stream handler class. // Name of the registered input stream handler class.

View File

@ -167,6 +167,7 @@ cc_library(
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:packet_set", "//mediapipe/framework:packet_set",
"//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:any_proto",
"//mediapipe/framework/tool:type_util",
], ],
) )
@ -371,6 +372,9 @@ cc_library(
name = "type_util", name = "type_util",
hdrs = ["type_util.h"], hdrs = ["type_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
"//mediapipe/framework:port",
],
) )
cc_library( cc_library(

View File

@ -15,13 +15,12 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_UTIL_H_
#include <typeindex>
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe { namespace mediapipe {
@ -54,18 +53,18 @@ class TypeMap {
public: public:
template <class T> template <class T>
bool Has() const { bool Has() const {
return content_.count(typeid(T)) > 0; return content_.count(TypeId<T>()) > 0;
} }
template <class T> template <class T>
T* Get() const { T* Get() const {
if (!Has<T>()) { if (!Has<T>()) {
content_[typeid(T)] = std::make_shared<T>(); content_[TypeId<T>()] = std::make_shared<T>();
} }
return static_cast<T*>(content_[typeid(T)].get()); return static_cast<T*>(content_[TypeId<T>()].get());
} }
private: private:
mutable std::map<std::type_index, std::shared_ptr<void>> content_; mutable std::map<TypeIndex, std::shared_ptr<void>> content_;
}; };
template <class T, template <class T,

Some files were not shown because too many files have changed in this diff Show More