commit d68f5e416903e3d756ebe07d08c4a3b911741a91 Author: MediaPipe Team Date: Sun Jun 16 16:03:25 2019 -0700 Project import generated by Copybara. PiperOrigin-RevId: 253489161 diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 000000000..ebcc402fb --- /dev/null +++ b/.bazelrc @@ -0,0 +1,34 @@ +# The bazelrc file for MediaPipe OSS. + +# Basic build settings +build --jobs 128 +build --define='absl=1' +build --cxxopt='-std=c++11' +build --copt='-Wno-sign-compare' +build --copt='-Wno-unused-function' +build --copt='-Wno-uninitialized' +build --copt='-Wno-unused-result' +build --copt='-Wno-comment' +build --copt='-Wno-return-type' +build --copt='-Wno-unused-local-typedefs' +build --copt='-Wno-ignored-attributes' + +# Sets the default Apple platform to macOS. +build --apple_platform_type=macos + +# Android configs. +build:android --crosstool_top=//external:android/crosstool +build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +build:android --linkopt=-landroid +build:android --linkopt=-ldl +build:android --linkopt=-llog +build:android --linkopt=-lm +build:android --linkopt=-Wl,--gc-sections + +build:android_arm --config=android +build:android_arm --cpu=armeabi-v7a +build:android_arm --fat_apk_cpu=armeabi-v7a + +build:android_arm64 --config=android +build:android_arm64 --cpu=arm64-v8a +build:android_arm64 --fat_apk_cpu=arm64-v8a diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..331d38729 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,2 @@ +.git +Dockerfile diff --git a/BUILD b/BUILD new file mode 100644 index 000000000..38d7cc1d7 --- /dev/null +++ b/BUILD @@ -0,0 +1,17 @@ +# Copyright 2019 The MediaPipeOSS Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..3703a7014 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,127 @@ +# Contributing guidelines + +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#running-unit-tests). + +## How to become a contributor and submit your own code + +### Contributor License Agreements + +We'd love to accept your patches! Before we can take them, we have to jump a couple of legal hurdles. + +Please fill out either the individual or corporate Contributor License Agreement (CLA). + + * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html). + * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html). + +Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. + +***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository. + +### Contributing code + +If you have improvements to MediaPipe, send us your pull requests! For those +just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/). + +MediaPipe team members will be assigned to review your pull requests. Once the +pull requests are approved and pass continuous integration checks, a MediaPipe +team member will apply `ready to pull` label to your change. This means we are +working on getting your pull request submitted to our internal repository. After +the change has been submitted internally, your pull request will be merged +automatically on GitHub. + +If you want to contribute but you're not sure where to start, take a look at the +[issues with the "contributions welcome" label](https://github.com/google/mediapipe/labels/stat%3Acontributions%20welcome). +These are issues that we believe are particularly well suited for outside +contributions, often because we probably won't get to them right now. If you +decide to start on an issue, leave a comment so that other people know that +you're working on it. If you want to help out, but not alone, use the issue +comment thread to coordinate. + +### Contribution guidelines and standards + +Before sending your pull request for +[review](https://github.com/google/mediapipe/pulls), +make sure your changes are consistent with the guidelines and follow the +MediaPipe coding style. + +#### General guidelines and philosophy for contribution + +* Include unit tests when you contribute new features, as they help to a) + prove that your code works correctly, and b) guard against future breaking + changes to lower the maintenance cost. +* Bug fixes also generally require unit tests, because the presence of bugs + usually indicates insufficient test coverage. +* Keep API compatibility in mind when you change code in MediaPipe framework + e.g., code in + [mediapipe/framework](https://github.com/google/mediapipe/tree/master/mediapipe/framework) + and + [mediapipe/calculators](https://github.com/google/mediapipe/tree/master/mediapipe/calculators). + Once MediaPipe has reached version 1 and we will not make + non-backward-compatible API changes without a major release. Reviewers of + your pull request will comment on any API compatibility issues. +* When you contribute a new feature to MediaPipe, the maintenance burden is + (by default) transferred to the MediaPipe team. This means that benefit of + the contribution must be compared against the cost of maintaining the + feature. +* Full new features (e.g., a new op implementing a cutting-edge algorithm) + typically will live in + [mediapipe/addons](https://github.com/google/mediapipe/addons) to get some + airtime before decision is made regarding whether they are to be migrated to + the core. + +#### License + +Include a license at the top of new files. + +* [C/C++ license example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/calculator_base.cc#L1) +* [Java license example](https://github.com/google/mediapipe/blob/master/mediapipe/java/com/google/mediapipe/components/CameraHelper.java) + +Bazel BUILD files also need to include a license section, e.g., +[BUILD example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/BUILD#L61). + +#### C++ coding style + +Changes to MediaPipe C++ code should conform to +[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). + +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: + +```bash +apt-get install -y clang-tidy +``` + +You can check a C/C++ file by doing: + + +```bash +clang-format --style=google > /tmp/my_cc_file.cc +diff /tmp/my_cc_file.cc +``` + +#### Coding style for other languages + +* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) +* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) +* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) +* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html) + +#### Running sanity check + +If you have Docker installed on your system, you can perform a sanity check on +your changes by running the command: + +```bash +mediapipe/tools/ci_build/ci_build.sh CPU mediapipe/tools/ci_build/ci_sanity.sh +``` + +This will catch most license, Python coding style and BUILD file issues that +may exist in your changes. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..ad7c6f909 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,52 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM ubuntu:latest + +MAINTAINER + +WORKDIR /io +WORKDIR /mediapipe + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + git \ + wget \ + unzip \ + python \ + libopencv-core-dev \ + libopencv-highgui-dev \ + libopencv-imgproc-dev \ + libopencv-video-dev \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install bazel +ARG BAZEL_VERSION=0.26.1 +RUN mkdir /bazel && \ + wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ +azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ + wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ + chmod +x /bazel/installer.sh && \ + /bazel/installer.sh && \ + rm -f /bazel/installer.sh + +COPY . /mediapipe/ + +# If we want the docker image to contain the pre-built object_detection_offline_demo binary, do the following +# RUN bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/demo:object_detection_tensorflow_demo diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..3276de974 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +![MediaPipe](mediapipe/docs/images/mediapipe_small.png?raw=true "MediaPipe logo") +======================================================================= + +#### We will be [presenting at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA. Come join us! + +[MediaPipe](http://g.co/mediapipe) is a framework for building multimodal (eg. video, audio, any time series data) applied ML pipelines. With MediaPipe, a perception pipeline can be built as a graph of modular components, including, for instance, inference models (e.g., TensorFlow, TFLite) and media processing functions. + +![Real-time Face Detection](mediapipe/docs/images/mobile/face_detection_android_gpu_small.gif) + +## Installation +Follow these [instructions](mediapipe/docs/install.md). + +## Getting started +See mobile and desktop [examples](mediapipe/docs/examples.md). + +## Documentation +On [MediaPipe Read-the-Docs](https://mediapipe.readthedocs.io/). + +## Visualizing MediaPipe graphs +A web-based visualizer is hosted on [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). Please also see instructions [here](mediapipe/docs/visualizer.md). + +## Publications +* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/) on [arXiv](https://arxiv.org/). +* [MediaPipe: A Framework for Perceiving and Augmenting Reality](http://mixedreality.cs.cornell.edu/s/22_crv2_MediaPipe_CVPR_CV4ARVR_Workshop_2019_v2.pdf), extended abstract for [Third Workshop on Computer Vision for AR/VR](http://mixedreality.cs.cornell.edu/workshop/program). + +## Contributing +We welcome contributions. Please follow these [guidelines](./CONTRIBUTING.md). + +We use GitHub issues for tracking requests and bugs. Please post questions to the MediaPipe Stack Overflow with a 'mediapipe' tag. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 000000000..9d371a10f --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,181 @@ +workspace(name = "mediapipe") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "bazel_skylib", + sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d", + strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"], +) +load("@bazel_skylib//lib:versions.bzl", "versions") +versions.check(minimum_bazel_version = "0.23.0") + +# ABSL cpp library. +http_archive( + name = "com_google_absl", + # Head commit on 2019-04-12. + # TODO: Switch to the latest absl version when the problem gets + # fixed. + urls = [ + "https://github.com/abseil/abseil-cpp/archive/a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a.tar.gz", + ], + sha256 = "d437920d1434c766d22e85773b899c77c672b8b4865d5dc2cd61a29fdff3cf03", + strip_prefix = "abseil-cpp-a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a", +) + +# GoogleTest/GoogleMock framework. Used by most unit-tests. +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/master.zip"], + strip_prefix = "googletest-master", +) + +# Google Benchmark library. +http_archive( + name = "com_google_benchmark", + urls = ["https://github.com/google/benchmark/archive/master.zip"], + strip_prefix = "benchmark-master", + build_file = "@//third_party:benchmark.BUILD", +) + +# gflags needed by glog +http_archive( + name = "com_github_gflags_gflags", + sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe", + strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a", + urls = [ + "https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz", + "https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz", + ], +) + +# glog +http_archive( + name = "com_google_glog", + url = "https://github.com/google/glog/archive/v0.3.5.zip", + sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8", + strip_prefix = "glog-0.3.5", + build_file = "@//third_party:glog.BUILD", +) + +# libyuv +http_archive( + name = "libyuv", + urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/refs/heads/master.tar.gz"], + build_file = "@//third_party:libyuv.BUILD", +) + +http_archive( + name = "com_google_protobuf_javalite", + sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc", + strip_prefix = "protobuf-384989534b2246d413dbcd750744faab2607b516", + urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"], +) + +# Needed by TensorFlow +http_archive( + name = "io_bazel_rules_closure", + sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9", + strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df", + urls = [ + "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04 + ], +) + +# TensorFlow r1.14-rc0 +http_archive( + name = "org_tensorflow", + strip_prefix = "tensorflow-1.14.0-rc0", + sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10", + urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"], +) + +load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") +tf_workspace(tf_repo_name = "org_tensorflow") + +# Please run $ sudo apt-get install libopencv-dev +new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr", +) + +# Please run $ brew install opencv +new_local_repository( + name = "macos_opencv", + build_file = "@//third_party:opencv_macos.BUILD", + path = "/usr", +) + +http_archive( + name = "android_opencv", + sha256="cd7e5d5ec76eeddadf36a1cfe5197129328e80287d4d198c169e090421f838ba", + build_file = "@//third_party:opencv_android.BUILD", + strip_prefix = "OpenCV-android-sdk", + type = "zip", + url = "https://sourceforge.net/projects/opencvlibrary/files/4.0.1/opencv-4.0.1-android-sdk.zip/download" +) + +# Google Maven Repository +GMAVEN_TAG = "20181212-2" + +http_archive( + name = "gmaven_rules", + strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG, + url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG, +) + +load("@gmaven_rules//:gmaven.bzl", "gmaven_rules") + +gmaven_rules() + +maven_server( + name = "google_server", + url = "http://maven.google.com", +) + +maven_jar( + name = "androidx_lifecycle", + artifact = "androidx.lifecycle:lifecycle-common:2.0.0", + server = "google_server", +) + +maven_jar( + name = "androidx_concurrent_futures", + artifact = "androidx.concurrent:concurrent-futures:1.0.0-alpha03", + server = "google_server", +) + +maven_jar( + name = "com_google_guava_android", + artifact = "com.google.guava:guava:27.0.1-android", + sha1 = "b7e1c37f66ef193796ccd7ea6e80c2b05426182d", +) + +maven_jar( + name = "com_google_common_flogger", + artifact = "com.google.flogger:flogger:0.3.1", + sha1 = "585030fe1ec709760cbef997a459729fb965df0e", +) + +maven_jar( + name = "com_google_common_flogger_system_backend", + artifact = "com.google.flogger:flogger-system-backend:0.3.1", + sha1 = "287b569d76abcd82f9de87fe41829fbc7ebd8ac9", +) + +maven_jar( + name = "com_google_code_findbugs", + artifact = "com.google.code.findbugs:jsr305:3.0.2", +) + +# You may run setup_android.sh to install Android SDK and NDK. +android_ndk_repository( + name = "androidndk", +) + +android_sdk_repository( + name = "androidsdk", +) diff --git a/mediapipe/BUILD b/mediapipe/BUILD new file mode 100644 index 000000000..a5fc4d1f6 --- /dev/null +++ b/mediapipe/BUILD @@ -0,0 +1,75 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86_64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86_64", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_armeabi", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_arm", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi-v7a", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_arm64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "arm64-v8a", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "macos", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) diff --git a/mediapipe/__init__.py b/mediapipe/__init__.py new file mode 100644 index 000000000..6db73bc52 --- /dev/null +++ b/mediapipe/__init__.py @@ -0,0 +1,14 @@ +"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD new file mode 100644 index 000000000..e4fdd0bed --- /dev/null +++ b/mediapipe/calculators/core/BUILD @@ -0,0 +1,270 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "packet_resampler_calculator_proto", + srcs = ["packet_resampler_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_resampler_calculator_cc_proto", + srcs = ["packet_resampler_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":packet_resampler_calculator_proto"], +) + +cc_library( + name = "counting_source_calculator", + srcs = ["counting_source_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "make_pair_calculator", + srcs = ["make_pair_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "mux_calculator", + srcs = ["mux_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/stream_handler:mux_input_stream_handler", + ], + alwayslink = 1, +) + +cc_library( + name = "packet_cloner_calculator", + srcs = ["packet_cloner_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "pass_through_calculator", + srcs = ["pass_through_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "round_robin_demux_calculator", + srcs = ["round_robin_demux_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_library( + name = "immediate_mux_calculator", + srcs = ["immediate_mux_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "previous_loopback_calculator", + srcs = ["previous_loopback_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + ], + alwayslink = 1, +) + +cc_library( + name = "real_time_flow_limiter_calculator", + srcs = ["real_time_flow_limiter_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/util:header_util", + ], + alwayslink = 1, +) + +cc_test( + name = "immediate_mux_calculator_test", + srcs = ["immediate_mux_calculator_test.cc"], + deps = [ + ":immediate_mux_calculator", + ":round_robin_demux_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:test_calculators", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "packet_resampler_calculator", + srcs = ["packet_resampler_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/tool:options_util", + "@com_google_absl//absl/strings", + "//mediapipe/framework/deps:mathutil", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:integral_types", + ] + select({ + "//conditions:default": [ + "//mediapipe/framework/deps:random", + ], + }), + alwayslink = 1, +) + +cc_test( + name = "packet_resampler_calculator_test", + timeout = "short", + srcs = ["packet_resampler_calculator_test.cc"], + deps = [ + ":packet_resampler_calculator", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "previous_loopback_calculator_test", + srcs = ["previous_loopback_calculator_test.cc"], + deps = [ + ":previous_loopback_calculator", + "//mediapipe/calculators/core:make_pair_calculator", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "real_time_flow_limiter_calculator_test", + srcs = ["real_time_flow_limiter_calculator_test.cc"], + deps = [ + ":real_time_flow_limiter_calculator", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:test_calculators", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/time", + ], +) diff --git a/mediapipe/calculators/core/counting_source_calculator.cc b/mediapipe/calculators/core/counting_source_calculator.cc new file mode 100644 index 000000000..7b2f79a0c --- /dev/null +++ b/mediapipe/calculators/core/counting_source_calculator.cc @@ -0,0 +1,114 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/string_view.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of +// sequential numbers from INITIAL_VALUE (default 0) with a common +// difference of INCREMENT (default 1) between successive numbers (with +// timestamps corresponding to the sequence numbers). The packets are +// produced in BATCH_SIZE sized batches with each call to Process(). An +// error will be returned after ERROR_COUNT batches. An error will be +// produced in Open() if ERROR_ON_OPEN is true. Either MAX_COUNT or +// ERROR_COUNT must be provided and non-negative. If BATCH_SIZE is not +// provided, then batches are of size 1. +class CountingSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + + if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { + cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set(); + } + + RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") || + cc->InputSidePackets().HasTag("ERROR_COUNT")); + if (cc->InputSidePackets().HasTag("MAX_COUNT")) { + cc->InputSidePackets().Tag("MAX_COUNT").Set(); + } + if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { + cc->InputSidePackets().Tag("ERROR_COUNT").Set(); + } + + if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { + cc->InputSidePackets().Tag("BATCH_SIZE").Set(); + } + if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { + cc->InputSidePackets().Tag("INITIAL_VALUE").Set(); + } + if (cc->InputSidePackets().HasTag("INCREMENT")) { + cc->InputSidePackets().Tag("INCREMENT").Set(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && + cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { + return ::mediapipe::NotFoundError("expected error"); + } + if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { + error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get(); + RET_CHECK_LE(0, error_count_); + } + if (cc->InputSidePackets().HasTag("MAX_COUNT")) { + max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get(); + RET_CHECK_LE(0, max_count_); + } + if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { + batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get(); + RET_CHECK_LT(0, batch_size_); + } + if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { + counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get(); + } + if (cc->InputSidePackets().HasTag("INCREMENT")) { + increment_ = cc->InputSidePackets().Tag("INCREMENT").Get(); + RET_CHECK_LT(0, increment_); + } + RET_CHECK(error_count_ >= 0 || max_count_ >= 0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (error_count_ >= 0 && batch_counter_ >= error_count_) { + return ::mediapipe::InternalError("expected error"); + } + if (max_count_ >= 0 && batch_counter_ >= max_count_) { + return tool::StatusStop(); + } + for (int i = 0; i < batch_size_; ++i) { + cc->Outputs().Index(0).Add(new int(counter_), Timestamp(counter_)); + counter_ += increment_; + } + ++batch_counter_; + return ::mediapipe::OkStatus(); + } + + private: + int max_count_ = -1; + int error_count_ = -1; + int batch_size_ = 1; + int batch_counter_ = 0; + int counter_ = 0; + int increment_ = 1; +}; +REGISTER_CALCULATOR(CountingSourceCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator.cc b/mediapipe/calculators/core/immediate_mux_calculator.cc new file mode 100644 index 000000000..cb930bed7 --- /dev/null +++ b/mediapipe/calculators/core/immediate_mux_calculator.cc @@ -0,0 +1,90 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// This Calculator multiplexes several input streams into a single +// output stream, dropping input packets with timestamps older than the +// last output packet. In case two packets arrive with the same timestamp, +// the packet with the lower stream index will be output and the rest will +// be dropped. +// +// This Calculator optionally produces a finish inidicator as its second +// output stream. One indicator packet is produced for each input packet +// received. +// +// This Calculator can be used with an ImmediateInputStreamHandler or with the +// default ISH. Note that currently ImmediateInputStreamHandler seems to +// interfere with timestamp bound propagation, so it is better to use the +// default unless the immediate one is needed. (b/118387598) +// +// This Calculator is designed to work with a Demux calculator such as +// the RoundRobinDemuxCalculator. Therefore, packets from different +// input streams are normally not expected to have the same timestamp. +// +class ImmediateMuxCalculator : public CalculatorBase { + public: + // This calculator combines any set of input streams into a single + // output stream. All input stream types must match the output stream type. + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + // Passes any input packet to the output stream immediately, unless the + // packet timestamp is lower than a previously passed packet. + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Open(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(ImmediateMuxCalculator); + +::mediapipe::Status ImmediateMuxCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2) + << "This calculator produces only one or two output streams."; + cc->Outputs().Index(0).SetAny(); + if (cc->Outputs().NumEntries() >= 2) { + cc->Outputs().Index(1).Set(); + } + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0)); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { + // Pass along the first packet, unless it has been superseded. + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + const Packet& packet = cc->Inputs().Index(i).Value(); + if (!packet.IsEmpty()) { + if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) { + cc->Outputs().Index(0).AddPacket(packet); + } + if (cc->Outputs().NumEntries() >= 2) { + Timestamp output_timestamp = std::max( + cc->InputTimestamp(), cc->Outputs().Index(1).NextTimestampBound()); + cc->Outputs().Index(1).Add(new bool(true), output_timestamp); + } + } + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator_test.cc b/mediapipe/calculators/core/immediate_mux_calculator_test.cc new file mode 100644 index 000000000..6fe318712 --- /dev/null +++ b/mediapipe/calculators/core/immediate_mux_calculator_test.cc @@ -0,0 +1,373 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/port/threadpool.h" +#include "mediapipe/framework/tool/sink.h" + +// Tests for ImmediateMuxCalculator. These tests show how parallel output +// packets are handled when they arrive in various orders. +using testing::ElementsAre; + +namespace mediapipe { + +namespace { + +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_.fetch_add(amount); } + + private: + std::atomic supply_; +}; + +// A mediapipe::Executor that signals the start and finish of each task. +// Provides 4 worker threads. +class CountingExecutor : public Executor { + public: + CountingExecutor(std::function start_callback, + std::function finish_callback) + : thread_pool_(4), + start_callback_(std::move(start_callback)), + finish_callback_(std::move(finish_callback)) { + thread_pool_.StartWorkers(); + } + void Schedule(std::function task) override { + start_callback_(); + thread_pool_.Schedule([this, task] { + task(); + finish_callback_(); + }); + } + + private: + ThreadPool thread_pool_; + std::function start_callback_; + std::function finish_callback_; +}; + +// Returns a new mediapipe::Executor with 4 worker threads. +std::shared_ptr MakeExecutor(std::function start_callback, + std::function finish_callback) { + return std::make_shared(start_callback, finish_callback); +} + +// Tests showing ImmediateMuxCalculator dropping packets in various sequences. +class ImmediateMuxCalculatorTest : public ::testing::Test { + protected: + void SetUpMuxGraph() { + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: "input_packets_0" + input_stream: "input_packets_1" + node { + calculator: "ImmediateMuxCalculator" + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + input_stream: "input_packets_0" + input_stream: "input_packets_1" + output_stream: "output_packets_0" + } + )", + &graph_config_)); + } + + void SetUpDemuxGraph() { + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: "input_packets_0" + node { + calculator: "RoundRobinDemuxCalculator" + input_stream: "input_packets_0" + output_stream: "OUTPUT:0:input_0" + output_stream: "OUTPUT:1:input_1" + } + node { + calculator: "LambdaCalculator" + input_side_packet: 'callback_0' + input_stream: "input_0" + output_stream: "output_0" + } + node { + calculator: "LambdaCalculator" + input_side_packet: 'callback_1' + input_stream: "input_1" + output_stream: "output_1" + } + node { + calculator: "ImmediateMuxCalculator" + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + input_stream: "output_0" + input_stream: "output_1" + output_stream: "output_packets_0" + } + )", + &graph_config_)); + } + + void SetUpDemuxInFlightGraph() { + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: "input_packets_0" + node { + calculator: 'RealTimeFlowLimiterCalculator' + input_stream_handler { + input_stream_handler: 'ImmediateInputStreamHandler' + } + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'input_packets_0' + input_stream: 'FINISHED:finish_indicator' + input_stream_info: { + tag_index: 'FINISHED' + back_edge: true + } + output_stream: 'input_0_sampled' + } + node { + calculator: "RoundRobinDemuxCalculator" + input_stream: "input_0_sampled" + output_stream: "OUTPUT:0:input_0" + output_stream: "OUTPUT:1:input_1" + } + node { + calculator: "LambdaCalculator" + input_side_packet: 'callback_0' + input_stream: "input_0" + output_stream: "output_0" + } + node { + calculator: "LambdaCalculator" + input_side_packet: 'callback_1' + input_stream: "input_1" + output_stream: "output_1" + } + node { + calculator: "ImmediateMuxCalculator" + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + input_stream: "output_0" + input_stream: "output_1" + output_stream: 'output_packets_0' + output_stream: 'finish_indicator' + } + )", + &graph_config_)); + } + + static Packet PacketAt(int64 ts) { + return Adopt(new int64(999)).At(Timestamp(ts)); + } + static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); } + static bool IsNone(const Packet& packet) { + return packet.Timestamp() == Timestamp::OneOverPostStream(); + } + // Return the values of the timestamps of a vector of Packets. + static std::vector TimestampValues( + const std::vector& packets) { + std::vector result; + for (const Packet& p : packets) { + result.push_back(p.Timestamp().Value()); + } + return result; + } + + // Runs a CalculatorGraph with a series of packet sets. + // Returns a vector of packets from each graph output stream. + void RunGraph(const std::vector>& input_sets, + std::vector* output_packets) { + // Register output packet observers. + tool::AddVectorSink("output_packets_0", &graph_config_, output_packets); + + // Start running the graph. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config_)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + // Send each packet to the graph in the specified order. + for (int t = 0; t < input_sets.size(); t++) { + const std::vector& input_set = input_sets[t]; + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); + for (int i = 0; i < input_set.size(); i++) { + const Packet& packet = input_set[i]; + if (!IsNone(packet)) { + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + absl::StrCat("input_packets_", i), packet)); + } + } + } + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + } + + CalculatorGraphConfig graph_config_; +}; + +TEST_F(ImmediateMuxCalculatorTest, IncreasingTimestamps) { + // Run the graph with a series of packet sets. + std::vector> input_sets = { + {PacketAt(10000), None()}, // + {PacketAt(20000), None()}, // + {None(), PacketAt(30000)}, // + {None(), PacketAt(40000)}, + }; + SetUpMuxGraph(); + std::vector output_packets; + RunGraph(input_sets, &output_packets); + + // Validate the output packets. + EXPECT_THAT(TimestampValues(output_packets), + ElementsAre(10000, 20000, 30000, 40000)); +} + +TEST_F(ImmediateMuxCalculatorTest, SupersededTimestamp) { + // Run the graph with a series of packet sets. + std::vector> input_sets = { + {PacketAt(10000), None()}, // + {PacketAt(30000), None()}, // + {None(), PacketAt(20000)}, // + {None(), PacketAt(40000)}, + }; + SetUpMuxGraph(); + std::vector output_packets; + RunGraph(input_sets, &output_packets); + + // Output packet 20000 is superseded and dropped. + EXPECT_THAT(TimestampValues(output_packets), + ElementsAre(10000, 30000, 40000)); +} + +TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) { + // Run the graph with a series of packet sets. + std::vector> input_sets = { + {PacketAt(10000), None()}, // + {PacketAt(40000), PacketAt(20000)}, // + {None(), PacketAt(30000)}, + }; + SetUpMuxGraph(); + std::vector output_packets; + RunGraph(input_sets, &output_packets); + + // Output packet 20000 is superseded and dropped. + EXPECT_THAT(TimestampValues(output_packets), ElementsAre(10000, 40000)); +} + +// A Calculator::Process callback function. +typedef std::function<::mediapipe::Status(const InputStreamShardSet&, + OutputStreamShardSet*)> + ProcessFunction; + +// A testing callback function that passes through all packets. +::mediapipe::Status PassThrough(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + for (int i = 0; i < inputs.NumEntries(); ++i) { + if (!inputs.Index(i).Value().IsEmpty()) { + outputs->Index(i).AddPacket(inputs.Index(i).Value()); + } + } + return ::mediapipe::OkStatus(); +} + +TEST_F(ImmediateMuxCalculatorTest, Demux) { + // Semaphores to sequence the parallel Process outputs. + AtomicSemaphore semaphore_0(0); + AtomicSemaphore semaphore_1(0); + ProcessFunction wait_0 = [&semaphore_0](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + semaphore_0.Acquire(1); + return PassThrough(inputs, outputs); + }; + ProcessFunction wait_1 = [&semaphore_1](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + semaphore_1.Acquire(1); + return PassThrough(inputs, outputs); + }; + + // A callback to await and capture output packets. + std::vector out_packets; + absl::Mutex out_mutex; + auto out_cb = [&](const Packet& p) { + absl::MutexLock lock(&out_mutex); + out_packets.push_back(p); + return ::mediapipe::OkStatus(); + }; + auto wait_for = [&](std::function cond) { + absl::MutexLock lock(&out_mutex); + out_mutex.Await(absl::Condition(&cond)); + }; + SetUpDemuxGraph(); + + // Start the graph and add five input packets. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize( + graph_config_, { + {"callback_0", Adopt(new auto(wait_0))}, + {"callback_1", Adopt(new auto(wait_1))}, + })); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream("output_packets_0", out_cb)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("input_packets_0", PacketAt(10000))); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("input_packets_0", PacketAt(20000))); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("input_packets_0", PacketAt(30000))); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("input_packets_0", PacketAt(40000))); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("input_packets_0", PacketAt(50000))); + + // Release the outputs in order 20000, 10000, 30000, 50000, 40000. + semaphore_1.Release(1); // 20000 + wait_for([&] { return !out_packets.empty(); }); + semaphore_0.Release(1); // 10000 + semaphore_0.Release(1); // 30000 + wait_for([&] { return out_packets.size() >= 2; }); + semaphore_0.Release(1); // 50000 + wait_for([&] { return out_packets.size() >= 3; }); + semaphore_1.Release(1); // 40000 + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + // Output packets 10000 and 40000 are superseded and dropped. + EXPECT_THAT(TimestampValues(out_packets), ElementsAre(20000, 30000, 50000)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/make_pair_calculator.cc b/mediapipe/calculators/core/make_pair_calculator.cc new file mode 100644 index 000000000..8eb4cb67b --- /dev/null +++ b/mediapipe/calculators/core/make_pair_calculator.cc @@ -0,0 +1,61 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// Given two input streams (A, B), output a single stream containing a pair. +// +// Example config: +// node { +// calculator: "MakePairCalculator" +// input_stream: "packet_a" +// input_stream: "packet_b" +// output_stream: "output_pair_a_b" +// } +class MakePairCalculator : public CalculatorBase { + public: + MakePairCalculator() {} + ~MakePairCalculator() override {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Inputs().Index(1).SetAny(); + cc->Outputs().Index(0).Set>(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).Add( + new std::pair(cc->Inputs().Index(0).Value(), + cc->Inputs().Index(1).Value()), + cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(MakePairCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc new file mode 100644 index 000000000..7f96da760 --- /dev/null +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -0,0 +1,77 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., +// using the integer value (0, 1, ...) in the packet on the "SELECT" input +// stream, and passes the packet on the selected input stream to the "OUTPUT" +// output stream. +// +// Note that this calculator defaults to use MuxInputStreamHandler, which is +// required for this calculator. +class MuxCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("SELECT").Set(); + CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); + PacketType* data_input0 = &cc->Inputs().Get(data_input_id); + data_input0->SetAny(); + ++data_input_id; + for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) { + cc->Inputs().Get(data_input_id).SetSameAs(data_input0); + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); + cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); + + // Assign this calculator's default InputStreamHandler. + cc->SetInputStreamHandler("MuxInputStreamHandler"); + MediaPipeOptions options; + cc->SetInputStreamHandlerOptions(options); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + select_input_ = cc->Inputs().GetId("SELECT", 0); + data_input_base_ = cc->Inputs().GetId("INPUT", 0); + num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); + output_ = cc->Outputs().GetId("OUTPUT", 0); + cc->SetOffset(mediapipe::TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int select = cc->Inputs().Get(select_input_).Get(); + RET_CHECK(0 <= select && select < num_data_inputs_); + if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) { + cc->Outputs().Get(output_).AddPacket( + cc->Inputs().Get(data_input_base_ + select).Value()); + } + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId select_input_; + CollectionItemId data_input_base_; + int num_data_inputs_ = 0; + CollectionItemId output_; +}; + +REGISTER_CALCULATOR(MuxCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc new file mode 100644 index 000000000..2750f1257 --- /dev/null +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -0,0 +1,98 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B. +// For every packet that appears in B, outputs the most recent packet from each +// of the A_i on a separate stream. + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { + +// For every packet received on the last stream, output the latest packet +// obtained on all other streams. Therefore, if the last stream outputs at a +// higher rate than the others, this effectively clones the packets from the +// other streams to match the last. +// +// Example config: +// node { +// calculator: "PacketClonerCalculator" +// input_stream: "first_base_signal" +// input_stream: "second_base_signal" +// input_stream: "tick_signal" +// output_stream: "cloned_first_base_signal" +// output_stream: "cloned_second_base_signal" +// } +// +// Related: +// merge_input_streams_calculator.cc: One output stream. +// packet_inner_join_calculator.cc: Don't output unless all inputs are new. +class PacketClonerCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const int tick_signal_index = cc->Inputs().NumEntries() - 1; + for (int i = 0; i < tick_signal_index; ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); + } + cc->Inputs().Index(tick_signal_index).SetAny(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + tick_signal_index_ = cc->Inputs().NumEntries() - 1; + current_.resize(tick_signal_index_); + // Pass along the header for each stream if present. + for (int i = 0; i < tick_signal_index_; ++i) { + if (!cc->Inputs().Index(i).Header().IsEmpty()) { + cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); + } + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + // Store input signals. + for (int i = 0; i < tick_signal_index_; ++i) { + if (!cc->Inputs().Index(i).Value().IsEmpty()) { + current_[i] = cc->Inputs().Index(i).Value(); + } + } + + // Output if the tick signal is non-empty. + if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) { + for (int i = 0; i < tick_signal_index_; ++i) { + if (!current_[i].IsEmpty()) { + cc->Outputs().Index(i).AddPacket( + current_[i].At(cc->InputTimestamp())); + } else { + cc->Outputs().Index(i).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + } + } + return ::mediapipe::OkStatus(); + } + + private: + std::vector current_; + int tick_signal_index_; +}; + +REGISTER_CALCULATOR(PacketClonerCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc new file mode 100644 index 000000000..5b851e874 --- /dev/null +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -0,0 +1,434 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/deps/mathutil.h" +#include "mediapipe/framework/deps/random_base.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/options_util.h" + +namespace { + +// Creates a secure random number generator for use in ProcessWithJitter. +// If no secure random number generator can be constructed, the jitter +// option is disabled in order to mainatain a consistent security and +// consistent random seeding. +std::unique_ptr CreateSecureRandom(const std::string& seed) { + RandomBase* result = nullptr; + return std::unique_ptr(result); +} + +} // namespace + +namespace mediapipe { + +// This calculator is used to normalize the frequency of the packets +// out of a stream. Given a desired frame rate, packets are going to be +// removed or added to achieve it. +// +// The jitter feature is disabled by default. To enable it, you need to +// implement CreateSecureRandom(const std::string&). +// +// The data stream may be either specified as the only stream (by index) +// or as the stream with tag "DATA". +// +// The input and output streams may be accompanied by a VIDEO_HEADER +// stream. This stream includes a VideoHeader at Timestamp::PreStream(). +// The input VideoHeader on the VIDEO_HEADER stream will always be updated +// with the resampler frame rate no matter what the options value for +// output_header is before being output on the output VIDEO_HEADER stream. +// If the input VideoHeader is not available, then only the frame rate +// value will be set in the output. +// +// Related: +// packet_downsampler_calculator.cc: skips packets regardless of timestamps. +class PacketResamplerCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Logic for Process() when jitter_ != 0.0. + ::mediapipe::Status ProcessWithJitter(CalculatorContext* cc); + + // Logic for Process() when jitter_ == 0.0. + ::mediapipe::Status ProcessWithoutJitter(CalculatorContext* cc); + + // Given the current count of periods that have passed, this returns + // the next valid timestamp of the middle point of the next period: + // if count is 0, it returns the first_timestamp_. + // if count is 1, it returns the first_timestamp_ + period (corresponding + // to the first tick using exact fps) + // e.g. for frame_rate=30 and first_timestamp_=0: + // 0: 0 + // 1: 33333 + // 2: 66667 + // 3: 100000 + // + // Can only be used if jitter_ equals zero. + Timestamp PeriodIndexToTimestamp(int64 index) const; + + // Given a Timestamp, finds the closest sync Timestamp based on + // first_timestamp_ and the desired fps. + // + // Can only be used if jitter_ equals zero. + int64 TimestampToPeriodIndex(Timestamp timestamp) const; + + // Outputs a packet if it is in range (start_time_, end_time_). + void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const; + + // The timestamp of the first packet received. + Timestamp first_timestamp_; + + // Number of frames per second (desired output frequency). + double frame_rate_; + + // Inverse of frame_rate_. + int64 frame_time_usec_; + + // Number of periods that have passed (= #packets sent to the output). + // + // Can only be used if jitter_ equals zero. + int64 period_count_; + + // The last packet that was received. + Packet last_packet_; + + VideoHeader video_header_; + // The "DATA" input stream. + CollectionItemId input_data_id_; + // The "DATA" output stream. + CollectionItemId output_data_id_; + + // Indicator whether to flush last packet even if its timestamp is greater + // than the final stream timestamp. Set to false when jitter_ is non-zero. + bool flush_last_packet_; + + // Jitter-related variables. + std::unique_ptr random_; + double jitter_ = 0.0; + Timestamp next_output_timestamp_; + + // If specified, output timestamps are aligned with base_timestamp. + // Otherwise, they are aligned with the first input timestamp. + Timestamp base_timestamp_; + + // If specified, only outputs at/after start_time are included. + Timestamp start_time_; + + // If specified, only outputs before end_time are included. + Timestamp end_time_; + + // If set, the output timestamps nearest to start_time and end_time + // are included in the output, even if the nearest timestamp is not + // between start_time and end_time. + bool round_limits_; +}; + +REGISTER_CALCULATOR(PacketResamplerCalculator); + +namespace { +// Returns a TimestampDiff (assuming microseconds) corresponding to the +// given time in seconds. +TimestampDiff TimestampDiffFromSeconds(double seconds) { + return TimestampDiff(MathUtil::SafeRound( + seconds * Timestamp::kTimestampUnitsPerSecond)); +} +} // namespace + +::mediapipe::Status PacketResamplerCalculator::GetContract( + CalculatorContract* cc) { + const auto& resampler_options = + cc->Options(); + if (cc->InputSidePackets().HasTag("OPTIONS")) { + cc->InputSidePackets().Tag("OPTIONS").Set(); + } + CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0); + if (!input_data_id.IsValid()) { + input_data_id = cc->Inputs().GetId("", 0); + } + cc->Inputs().Get(input_data_id).SetAny(); + if (cc->Inputs().HasTag("VIDEO_HEADER")) { + cc->Inputs().Tag("VIDEO_HEADER").Set(); + } + + CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0); + if (!output_data_id.IsValid()) { + output_data_id = cc->Outputs().GetId("", 0); + } + cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id)); + if (cc->Outputs().HasTag("VIDEO_HEADER")) { + cc->Outputs().Tag("VIDEO_HEADER").Set(); + } + + if (resampler_options.jitter() != 0.0) { + RET_CHECK_GT(resampler_options.jitter(), 0.0); + RET_CHECK_LE(resampler_options.jitter(), 1.0); + RET_CHECK(cc->InputSidePackets().HasTag("SEED")); + cc->InputSidePackets().Tag("SEED").Set(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + + flush_last_packet_ = resampler_options.flush_last_packet(); + jitter_ = resampler_options.jitter(); + + input_data_id_ = cc->Inputs().GetId("DATA", 0); + if (!input_data_id_.IsValid()) { + input_data_id_ = cc->Inputs().GetId("", 0); + } + output_data_id_ = cc->Outputs().GetId("DATA", 0); + if (!output_data_id_.IsValid()) { + output_data_id_ = cc->Outputs().GetId("", 0); + } + + period_count_ = 0; + frame_rate_ = resampler_options.frame_rate(); + base_timestamp_ = resampler_options.has_base_timestamp() + ? Timestamp(resampler_options.base_timestamp()) + : Timestamp::Unset(); + start_time_ = resampler_options.has_start_time() + ? Timestamp(resampler_options.start_time()) + : Timestamp::Min(); + end_time_ = resampler_options.has_end_time() + ? Timestamp(resampler_options.end_time()) + : Timestamp::Max(); + round_limits_ = resampler_options.round_limits(); + // The frame_rate has a default value of -1.0, so the user must set it! + RET_CHECK_LT(0, frame_rate_) + << "The output frame rate must be greater than zero"; + RET_CHECK_LE(frame_rate_, Timestamp::kTimestampUnitsPerSecond) + << "The output frame rate must be smaller than " + << Timestamp::kTimestampUnitsPerSecond; + + frame_time_usec_ = static_cast(1000000.0 / frame_rate_); + video_header_.frame_rate = frame_rate_; + + if (resampler_options.output_header() != + PacketResamplerCalculatorOptions::NONE && + !cc->Inputs().Get(input_data_id_).Header().IsEmpty()) { + if (resampler_options.output_header() == + PacketResamplerCalculatorOptions::UPDATE_VIDEO_HEADER) { + video_header_ = + cc->Inputs().Get(input_data_id_).Header().Get(); + video_header_.frame_rate = frame_rate_; + cc->Outputs() + .Get(output_data_id_) + .SetHeader(Adopt(new VideoHeader(video_header_))); + } else { + cc->Outputs() + .Get(output_data_id_) + .SetHeader(cc->Inputs().Get(input_data_id_).Header()); + } + } + + if (jitter_ != 0.0) { + if (resampler_options.output_header() != + PacketResamplerCalculatorOptions::NONE) { + LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; + } + if (flush_last_packet_) { + flush_last_packet_ = false; + LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; + } + const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); + random_ = CreateSecureRandom(seed); + if (random_ == nullptr) { + return ::mediapipe::Status( + ::mediapipe::StatusCode::kInvalidArgument, + "SecureRandom is not available. With \"jitter\" specified, " + "PacketResamplerCalculator processing cannot proceed."); + } + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { + if (cc->InputTimestamp() == Timestamp::PreStream() && + cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && + !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { + video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); + video_header_.frame_rate = frame_rate_; + if (cc->Inputs().Get(input_data_id_).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + } + if (jitter_ != 0.0 && random_ != nullptr) { + RETURN_IF_ERROR(ProcessWithJitter(cc)); + } else { + RETURN_IF_ERROR(ProcessWithoutJitter(cc)); + } + last_packet_ = cc->Inputs().Get(input_data_id_).Value(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketResamplerCalculator::ProcessWithJitter( + CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + RET_CHECK_NE(jitter_, 0.0); + + if (first_timestamp_ == Timestamp::Unset()) { + first_timestamp_ = cc->InputTimestamp(); + next_output_timestamp_ = + first_timestamp_ + frame_time_usec_ * random_->RandFloat(); + return ::mediapipe::OkStatus(); + } + + LOG_IF(WARNING, frame_time_usec_ < + (cc->InputTimestamp() - last_packet_.Timestamp()).Value()) + << "Adding jitter is meaningless when upsampling."; + + const int64 curr_diff = + (next_output_timestamp_ - cc->InputTimestamp()).Value(); + const int64 last_diff = + (next_output_timestamp_ - last_packet_.Timestamp()).Value(); + if (curr_diff * last_diff > 0) { + return ::mediapipe::OkStatus(); + } + OutputWithinLimits(cc, (std::abs(curr_diff) > std::abs(last_diff) + ? last_packet_ + : cc->Inputs().Get(input_data_id_).Value()) + .At(next_output_timestamp_)); + next_output_timestamp_ += + frame_time_usec_ * + ((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketResamplerCalculator::ProcessWithoutJitter( + CalculatorContext* cc) { + RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); + RET_CHECK_EQ(jitter_, 0.0); + + if (first_timestamp_ == Timestamp::Unset()) { + // This is the first packet, initialize the first_timestamp_. + if (base_timestamp_ == Timestamp::Unset()) { + // Initialize first_timestamp_ with exactly the first packet timestamp. + first_timestamp_ = cc->InputTimestamp(); + } else { + // Initialize first_timestamp_ with the first packet timestamp + // aligned to the base_timestamp_. + int64 first_index = MathUtil::SafeRound( + (cc->InputTimestamp() - base_timestamp_).Seconds() * frame_rate_); + first_timestamp_ = + base_timestamp_ + TimestampDiffFromSeconds(first_index / frame_rate_); + } + if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { + cc->Outputs() + .Tag("VIDEO_HEADER") + .Add(new VideoHeader(video_header_), Timestamp::PreStream()); + } + } + const Timestamp received_timestamp = cc->InputTimestamp(); + const int64 received_timestamp_idx = + TimestampToPeriodIndex(received_timestamp); + // Only consider the received packet if it belongs to the current period + // (== period_count_) or to a newer one (> period_count_). + if (received_timestamp_idx >= period_count_) { + // Fill the empty periods until we are in the same index as the received + // packet. + while (received_timestamp_idx > period_count_) { + OutputWithinLimits( + cc, last_packet_.At(PeriodIndexToTimestamp(period_count_))); + ++period_count_; + } + // Now, if the received packet has a timestamp larger than the middle of + // the current period, we can send a packet without waiting. We send the + // one closer to the middle. + Timestamp target_timestamp = PeriodIndexToTimestamp(period_count_); + if (received_timestamp >= target_timestamp) { + bool have_last_packet = (last_packet_.Timestamp() != Timestamp::Unset()); + bool send_current = + !have_last_packet || (received_timestamp - target_timestamp <= + target_timestamp - last_packet_.Timestamp()); + if (send_current) { + OutputWithinLimits( + cc, cc->Inputs().Get(input_data_id_).Value().At(target_timestamp)); + } else { + OutputWithinLimits(cc, last_packet_.At(target_timestamp)); + } + ++period_count_; + } + // TODO: Add a mechanism to the framework to allow these packets + // to be output earlier (without waiting for a much later packet to + // arrive) + + // Update the bound for the next packet. + cc->Outputs() + .Get(output_data_id_) + .SetNextTimestampBound(PeriodIndexToTimestamp(period_count_)); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { + if (!cc->GraphStatus().ok()) { + return ::mediapipe::OkStatus(); + } + // Emit the last packet received if we have at least one packet, but + // haven't sent anything for its period. + if (first_timestamp_ != Timestamp::Unset() && flush_last_packet_ && + TimestampToPeriodIndex(last_packet_.Timestamp()) == period_count_) { + OutputWithinLimits(cc, + last_packet_.At(PeriodIndexToTimestamp(period_count_))); + } + return ::mediapipe::OkStatus(); +} + +Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const { + CHECK_EQ(jitter_, 0.0); + CHECK_NE(first_timestamp_, Timestamp::Unset()); + return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_); +} + +int64 PacketResamplerCalculator::TimestampToPeriodIndex( + Timestamp timestamp) const { + CHECK_EQ(jitter_, 0.0); + CHECK_NE(first_timestamp_, Timestamp::Unset()); + return MathUtil::SafeRound( + (timestamp - first_timestamp_).Seconds() * frame_rate_); +} + +void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc, + const Packet& packet) const { + TimestampDiff margin((round_limits_) ? frame_time_usec_ / 2 : 0); + if (packet.Timestamp() >= start_time_ - margin && + packet.Timestamp() < end_time_ + margin) { + cc->Outputs().Get(output_data_id_).AddPacket(packet); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_resampler_calculator.proto b/mediapipe/calculators/core/packet_resampler_calculator.proto new file mode 100644 index 000000000..190a9c269 --- /dev/null +++ b/mediapipe/calculators/core/packet_resampler_calculator.proto @@ -0,0 +1,95 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PacketResamplerCalculatorOptions { + extend CalculatorOptions { + optional PacketResamplerCalculatorOptions ext = 95743844; + } + + // The output frame rate measured in frames per second. + // + // The closest packet in time in each period will be chosen. If there + // is no packet in the period then the most recent packet will be chosen + // (not the closest in time). + optional double frame_rate = 1 [default = -1.0]; + + enum OutputHeader { + // Do not output a header, even if the input contained one. + NONE = 0; + // Pass the header, if the input contained one. + PASS_HEADER = 1; + // Update the frame rate in the header, which must be of type VideoHeader. + UPDATE_VIDEO_HEADER = 2; + } + + // Whether and what kind of header to place on the output stream. + // Note, this is about the actual header, not the VIDEO_HEADER stream. + // If this option is set to UPDATE_VIDEO_HEADER then the header will + // also be parsed (updated) and passed along to the VIDEO_HEADER stream. + optional OutputHeader output_header = 2 [default = NONE]; + + // Flush last packet even if its timestamp is greater than the final stream + // timestamp. + optional bool flush_last_packet = 3 [default = true]; + + // Adds jitter to resampling if set, so that Google's sampling is not + // externally deterministic. + // + // When set, the randomizer will be initialized with a seed. Then, the first + // sample is chosen randomly (uniform distribution) among frames that + // correspond to timestamps [0, 1/frame_rate). Let the chosen frame + // correspond to timestamp t. The next frame is chosen randomly (uniform + // distribution) among frames that correspond to [t+(1-jitter)/frame_rate, + // t+(1+jitter)/frame_rate]. t is updated and the process is repeated. + // + // Valid values are in the range of [0.0, 1.0] with the default being 0.0 (no + // jitter). A typical value would be a value in the range of 0.1-0.25. + // + // Note that this does NOT guarantee the desired frame rate, but if the + // pseudo-random number generator does its job and the number of frames is + // sufficiently large, the average frame rate will be close to this value. + optional double jitter = 4; + + // If specified, output timestamps are aligned with base_timestamp. + // Otherwise, they are aligned with the first input timestamp. + // + // In order to ensure that the outptut timestamps are reproducible, + // with round_limits = false, the bounds for input timestamps must include: + // [start_time - period / 2, end_time + period / 2], + // with round_limits = true, the bounds for input timestamps must include: + // [start_time - period, end_time + period], + // where period = 1 / frame_rate. + // + // For example, in PacketResamplerCalculatorOptions specify + // "start_time: 3000000", and in MediaDecoderOptions specify + // "start_time: 2999950". + optional int64 base_timestamp = 5; + + // If specified, only outputs at/after start_time are included. + optional int64 start_time = 6; + + // If specified, only outputs before end_time are included. + optional int64 end_time = 7; + + // If set, the output timestamps nearest to start_time and end_time + // are included in the output, even if the nearest timestamp is not + // between start_time and end_time. + optional bool round_limits = 8 [default = false]; +} diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc new file mode 100644 index 000000000..c7be91439 --- /dev/null +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -0,0 +1,679 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +// A simple version of CalculatorRunner with built-in convenience +// methods for setting inputs from a vector and checking outputs +// against expected outputs (both timestamps and contents). +class SimpleRunner : public CalculatorRunner { + public: + explicit SimpleRunner(const std::string& options_string) + : CalculatorRunner("PacketResamplerCalculator", options_string, 1, 1, 0) { + } + explicit SimpleRunner(const CalculatorGraphConfig::Node& node_config) + : CalculatorRunner(node_config) {} + + virtual ~SimpleRunner() {} + + void SetInput(const std::vector& timestamp_list) { + MutableInputs()->Index(0).packets.clear(); + for (const int64 ts : timestamp_list) { + MutableInputs()->Index(0).packets.push_back( + Adopt(new std::string(absl::StrCat("Frame #", ts))) + .At(Timestamp(ts))); + } + } + + void SetVideoHeader(const double frame_rate) { + video_header_.width = static_count_; + video_header_.height = static_count_ * 10; + video_header_.frame_rate = frame_rate; + video_header_.duration = static_count_ * 100.0; + video_header_.format = static_cast( + static_count_ % ImageFormat::Format_ARRAYSIZE); + MutableInputs()->Index(0).header = Adopt(new VideoHeader(video_header_)); + ++static_count_; + } + + void CheckOutputTimestamps( + const std::vector& expected_frames, + const std::vector& expected_timestamps) const { + EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size()); + EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size()); + int count = 0; + for (const Packet& packet : Outputs().Index(0).packets) { + EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp()); + const std::string& packet_contents = packet.Get(); + EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])), + packet_contents); + ++count; + } + } + + void CheckVideoHeader(const double expected_frame_rate) const { + ASSERT_FALSE(Outputs().Index(0).header.IsEmpty()); + const VideoHeader& header = Outputs().Index(0).header.Get(); + const double frame_rate = header.frame_rate; + + EXPECT_EQ(video_header_.width, header.width); + EXPECT_EQ(video_header_.height, header.height); + EXPECT_DOUBLE_EQ(expected_frame_rate, frame_rate); + EXPECT_FLOAT_EQ(video_header_.duration, header.duration); + EXPECT_EQ(video_header_.format, header.format); + } + + private: + VideoHeader video_header_; + static int static_count_; +}; + +int SimpleRunner::static_count_ = 0; + +TEST(PacketResamplerCalculatorTest, NoPacketsInStream) { + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + } +} + +TEST(PacketResamplerCalculatorTest, SinglePacketInStream) { + // Stream with 1 packet / 1 period. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0}, {0}); + } + + // Stream with 1 packet / 1 period (0 < packet timestamp < first limit). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({1000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({1000}, {1000}); + } + + // Stream with 1 packet / 1 period (packet timestamp > first limit). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({16668}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({16668}, {16668}); + } +} + +TEST(PacketResamplerCalculatorTest, TwoPacketsInStream) { + // Stream with 2 packets / 1 period. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 16666}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0}, {0}); + } + + // Stream with 2 packets / 2 periods (left extreme for second period). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 16667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 16667}, {0, 33333}); + } + + // Stream with 2 packets / 2 periods (right extreme for second period). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 49999}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 49999}, {0, 33333}); + } + + // Stream with 2 packets / 3 periods (filling 1 in the middle). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 50000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 0, 50000}, {0, 33333, 66667}); + } + + // Stream with 2 packets / 4 periods (filling 2 in the middle). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({2000, 118666}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({2000, 2000, 2000, 118666}, + {2000, 35333, 68667, 102000}); + } +} + +TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepoints) { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 33333, 66667, 100000, 133333, 166667, 200000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps( + {0, 33333, 66667, 100000, 133333, 166667, 200000}, + {0, 33333, 66667, 100000, 133333, 166667, 200000}); +} + +// When there are several candidates for a period, the one closer to the center +// should be sent to the output. +TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriods) { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 16666, 16667, 20000, 33300, 49999, 50000, 66600}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 33300, 66600}, {0, 33333, 66667}); +} + +// When a period must be filled, we use the latest packet received (not +// necessarily the same as the one stored for the best in the previous period). +TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacket) { + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 5000, 16666, 83334}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 16666, 16666, 83334}, + {0, 33333, 66667, 100000}); + } + + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 16666, 16667, 25000, 33000, 35000, 135000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 33000, 35000, 35000, 135000}, + {0, 33333, 66667, 100000, 133333}); + } + + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({0, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 32000, 49999, 49999, 49999, 150000}, + {0, 33333, 66667, 100000, 133333, 166667}); + } +} + +TEST(PacketResamplerCalculatorTest, SuperHighFrameRate) { + // frame rate == 500000 (a packet will have to be sent every 2 ticks). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:500000}"); + runner.SetInput({0, 10, 13}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 0, 0, 0, 0, 10, 10, 13}, + {0, 2, 4, 6, 8, 10, 12, 14}); + } + + // frame rate == 1000000 (a packet will have to be sent in each tick). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:1000000}"); + runner.SetInput({0, 10, 13}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps( + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 13}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}); + } +} + +TEST(PacketResamplerCalculatorTest, NegativeTimestampTest) { + // Stream with negative timestamps / 1 period. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-200, -20, 16466}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-200}, {-200}); + } + + // Stream with negative timestamps / 2 periods. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-200, -20, 16467}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-200, 16467}, {-200, 33133}); + } + + // Stream with negative timestamps and filling an empty period. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-500, 66667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-500, -500, 66667}, {-500, 32833, 66167}); + } + + // Stream with negative timestamps and initial packet < -period. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-50000, -33334, 33334}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-50000, -33334, -33334, 33334}, + {-50000, -16667, 16667, 50000}); + } +} + +TEST(PacketResamplerCalculatorTest, ExactFramesPerSecond) { + // Using frame_rate=50, that makes a period of 20000 microsends (exact). + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:50}"); + runner.SetInput({0, 9999, 29999}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 29999}, {0, 20000}); + } + + // Test filling empty periods. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:50}"); + runner.SetInput({0, 10000, 50000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 10000, 10000, 50000}, + {0, 20000, 40000, 60000}); + } +} + +TEST(PacketResamplerCalculatorTest, FrameRateTest) { + // Test changing Frame Rate to the same initial value. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}"); + runner.SetInput({0, 10000, 30000, 50000, 60000}); + runner.SetVideoHeader(50.0); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 10000, 30000, 60000}, + {0, 20000, 40000, 60000}); + runner.CheckVideoHeader(50.0); + } + + // Test changing Frame Rate to new value. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}"); + runner.SetInput({0, 5000, 10010, 15001, 19990}); + runner.SetVideoHeader(200.0); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 19990}, {0, 20000}); + runner.CheckVideoHeader(50.0); + } + + // Test that the frame rate is not changing if update_video_header = false. + { + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:50, output_header:PASS_HEADER}"); + runner.SetInput({0, 5000, 10010, 15001, 19990}); + runner.SetVideoHeader(200.0); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({0, 19990}, {0, 20000}); + runner.CheckVideoHeader(200.0); + } +} + +TEST(PacketResamplerCalculatorTest, SetVideoHeader) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "PacketResamplerCalculator" + input_stream: "DATA:in_data" + input_stream: "VIDEO_HEADER:in_video_header" + output_stream: "DATA:out_data" + output_stream: "VIDEO_HEADER:out_video_header" + options { + [mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 50.0 } + } + )")); + + for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { + runner.MutableInputs()->Tag("DATA").packets.push_back( + Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); + } + VideoHeader video_header_in; + video_header_in.width = 10; + video_header_in.height = 100; + video_header_in.frame_rate = 1.0; + video_header_in.duration = 1.0; + video_header_in.format = ImageFormat::SRGB; + runner.MutableInputs() + ->Tag("VIDEO_HEADER") + .packets.push_back( + Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size()); + EXPECT_EQ(Timestamp::PreStream(), + runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp()); + const VideoHeader& video_header_out = + runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get(); + EXPECT_EQ(video_header_in.width, video_header_out.width); + EXPECT_EQ(video_header_in.height, video_header_out.height); + EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate); + EXPECT_FLOAT_EQ(video_header_in.duration, video_header_out.duration); + EXPECT_EQ(video_header_in.format, video_header_out.format); +} + +TEST(PacketResamplerCalculatorTest, FlushLastPacketWithoutRound) { + SimpleRunner runner(R"( + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 1 + })"); + runner.SetInput({0, 333333, 666667, 1000000, 1333333}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + // 1333333 is not emitted as 2000000, because it does not round to 2000000. + runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000}); +} + +TEST(PacketResamplerCalculatorTest, FlushLastPacketWithRound) { + SimpleRunner runner(R"( + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 1 + })"); + runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + // 1666667 is emitted as 2000000, because it rounds to 2000000. + runner.CheckOutputTimestamps({0, 1000000, 1666667}, {0, 1000000, 2000000}); +} + +TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithoutRound) { + SimpleRunner runner(R"( + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 1 + flush_last_packet: false + })"); + runner.SetInput({0, 333333, 666667, 1000000, 1333333}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + // 1333333 is not emitted no matter what; see FlushLastPacketWithoutRound. + runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000}); +} + +TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithRound) { + SimpleRunner runner(R"( + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 1 + flush_last_packet: false + })"); + runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + // 1666667 is not emitted due to flush_last_packet: false. + runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000}); +} + +// When base_timestamp is specified, output timestamps are aligned with it. +TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepointsAligned) { + { + // Without base_timestamp, outputs are aligned with the first input + // timestamp, (33333 - 222). + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({33111, 66667, 100000, 133333, 166667, 200000}, + {33111, 66444, 99778, 133111, 166444, 199778}); + } + { + // With base_timestamp, outputs are aligned with base_timestamp, 0. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0}"); + runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps( + {33111, 66667, 100000, 133333, 166667, 200000}, + {33333, 66666, 100000, 133333, 166666, 200000}); + } +} + +// When base_timestamp is specified, output timestamps are aligned with it. +TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriodsAligned) { + { + // Without base_timestamp, outputs are aligned with the first input, -222. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-222, 33300, 66600}, {-222, 33111, 66445}); + } + { + // With base_timestamp, outputs are aligned with base_timestamp, 900011. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:900011}"); + runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-222, 33300, 66600}, {11, 33344, 66678}); + } + { + // With base_timestamp, outputs still approximate input timestamps, + // while aligned to base_timestamp, 11. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:11}"); + runner.SetInput( + {899888, 916666, 916667, 920000, 933300, 949999, 950000, 966600}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({899888, 933300, 966600}, + {900011, 933344, 966678}); + } +} + +// When a period must be filled, we use the latest packet received. +// When base_timestamp is specified, output timestamps are aligned with it. +TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacketAligned) { + { + // Without base_timestamp, outputs are aligned with the first input, -222. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30}"); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000}, + {-222, 33111, 66445, 99778, 133111, 166445}); + } + { + // With base_timestamp, outputs are aligned with base_timestamp, 0. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0}"); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000}, + {0, 33333, 66667, 100000, 133333, 166667}); + } +} + +// When base_timestamp is specified, output timestamps are aligned with it. +// The first packet is included, because we assume that the input includes the +// whole first sampling interval. +TEST(PacketResamplerCalculatorTest, FirstInputAfterMiddlepointAligned) { + { + // Packet 100020 is omitted from the output sequence because + // packet 99990 is closer to the period midpoint. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0}"); + runner.SetInput({66667, 100020, 133333, 166667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({66667, 100020, 133333, 166667}, + {66667, 100000, 133334, 166667}); + } + { + // If we seek to packet 100020, packet 100020 is included in + // the output sequence, because we assume that the input includes the + // whole first sampling interval. + // + // We assume that the input includes whole sampling intervals + // in order to produce "reproducible timestamps", which are timestamps + // from the series of timestamps starting at 0. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0}"); + runner.SetInput({100020, 133333, 166667}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({100020, 133333, 166667}, + {100000, 133333, 166667}); + } +} + +TEST(PacketResamplerCalculatorTest, OutputTimestampRangeAligned) { + { + // With base_timestamp, outputs are aligned with base_timestamp, 0. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0}"); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000}, + {0, 33333, 66667, 100000, 133333, 166667}); + } + { + // With start_time, end_time, outputs are filtered. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0 " + "start_time:40000 " + "end_time:160000}"); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({49999, 49999, 49999}, + {66667, 100000, 133333}); + } + { + // With start_time, end_time, round_limits, outputs are filtered, + // rounding to the nearest limit. + SimpleRunner runner( + "[mediapipe.PacketResamplerCalculatorOptions.ext]: " + "{frame_rate:30 " + "base_timestamp:0 " + "start_time:40000 " + "end_time:160000 " + "round_limits:true}"); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + runner.CheckOutputTimestamps({32000, 49999, 49999, 49999, 150000}, + {33333, 66667, 100000, 133333, 166667}); + } +} + +TEST(PacketResamplerCalculatorTest, OptionsSidePacket) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "PacketResamplerCalculator" + input_side_packet: "OPTIONS:options" + input_stream: "input" + output_stream: "output" + options { + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 60 + base_timestamp: 0 + } + })"); + + { + SimpleRunner runner(node_config); + auto options = + new CalculatorOptions(ParseTextProtoOrDie( + R"( + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 30 + })")); + runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(6, runner.Outputs().Index(0).packets.size()); + } + { + SimpleRunner runner(node_config); + + auto options = + new CalculatorOptions(ParseTextProtoOrDie(R"( + merge_fields: false + [mediapipe.PacketResamplerCalculatorOptions.ext] { + frame_rate: 30 + base_timestamp: 0 + })")); + runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); + + runner.SetInput({-222, 15000, 32000, 49999, 150000}); + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(6, runner.Outputs().Index(0).packets.size()); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/pass_through_calculator.cc b/mediapipe/calculators/core/pass_through_calculator.cc new file mode 100644 index 000000000..d4e648037 --- /dev/null +++ b/mediapipe/calculators/core/pass_through_calculator.cc @@ -0,0 +1,98 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" + +namespace mediapipe { + +// A Calculator that simply passes its input Packets and header through, +// unchanged. The inputs may be specified by tag or index. The outputs +// must match the inputs exactly. Any number of input side packets may +// also be specified. If output side packets are specified, they must +// match the input side packets exactly and the Calculator passes its +// input side packets through, unchanged. Otherwise, the input side +// packets will be ignored (allowing PassThroughCalculator to be used to +// test internal behavior). Any options may be specified and will be +// ignored. +class PassThroughCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) { + return ::mediapipe::InvalidArgumentError( + "Input and output streams to PassThroughCalculator must use " + "matching tags and indexes."); + } + for (CollectionItemId id = cc->Inputs().BeginId(); + id < cc->Inputs().EndId(); ++id) { + cc->Inputs().Get(id).SetAny(); + cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id)); + } + for (CollectionItemId id = cc->InputSidePackets().BeginId(); + id < cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } + if (cc->OutputSidePackets().NumEntries() != 0) { + if (!cc->InputSidePackets().TagMap()->SameAs( + *cc->OutputSidePackets().TagMap())) { + return ::mediapipe::InvalidArgumentError( + "Input and output side packets to PassThroughCalculator must use " + "matching tags and indexes."); + } + for (CollectionItemId id = cc->InputSidePackets().BeginId(); + id < cc->InputSidePackets().EndId(); ++id) { + cc->OutputSidePackets().Get(id).SetSameAs( + &cc->InputSidePackets().Get(id)); + } + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + for (CollectionItemId id = cc->Inputs().BeginId(); + id < cc->Inputs().EndId(); ++id) { + if (!cc->Inputs().Get(id).Header().IsEmpty()) { + cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header()); + } + } + if (cc->OutputSidePackets().NumEntries() != 0) { + for (CollectionItemId id = cc->InputSidePackets().BeginId(); + id < cc->InputSidePackets().EndId(); ++id) { + cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id)); + } + } + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->GetCounter("PassThrough")->Increment(); + if (cc->Inputs().NumEntries() == 0) { + return tool::StatusStop(); + } + for (CollectionItemId id = cc->Inputs().BeginId(); + id < cc->Inputs().EndId(); ++id) { + if (!cc->Inputs().Get(id).IsEmpty()) { + VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to " + << cc->Outputs().Get(id).Name() << " at " + << cc->InputTimestamp().DebugString(); + cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value()); + } + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(PassThroughCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc new file mode 100644 index 000000000..6b23a0e70 --- /dev/null +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -0,0 +1,118 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// PreviousLoopbackCalculator is useful when a graph needs to process an input +// together with some previous output. +// +// For the first packet that arrives on the MAIN input, the timestamp bound is +// advanced on the output. Downstream calculators will see this as an empty +// packet. This way they are not kept waiting for the previous output, which +// for the first iteration does not exist. +// +// Thereafter, each packet received on MAIN is matched with a packet received +// on LOOP; the LOOP packet's timestamp is changed to that of the MAIN packet, +// and it is output on PREV_LOOP. +// +// Example config: +// node { +// calculator: "PreviousLoopbackCalculator" +// input_stream: "MAIN:input" +// input_stream: "LOOP:output" +// input_stream_info: { tag_index: 'LOOP' back_edge: true } +// output_stream: "PREV_LOOP:prev_output" +// } +// node { +// calculator: "FaceTracker" +// input_stream: "VIDEO:input" +// input_stream: "PREV_TRACK:prev_output" +// output_stream: "TRACK:output" +// } +class PreviousLoopbackCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Get("MAIN", 0).SetAny(); + cc->Inputs().Get("LOOP", 0).SetAny(); + cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0))); + // TODO: an optional PREV_TIMESTAMP output could be added to + // carry the original timestamp of the packet on PREV_LOOP. + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + main_id_ = cc->Inputs().GetId("MAIN", 0); + loop_id_ = cc->Inputs().GetId("LOOP", 0); + loop_out_id_ = cc->Outputs().GetId("PREV_LOOP", 0); + cc->Outputs() + .Get(loop_out_id_) + .SetHeader(cc->Inputs().Get(loop_id_).Header()); + + // Use an empty packet for the first round, since there is no previous + // output. + loopback_packets_.push_back({}); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + Packet& main_packet = cc->Inputs().Get(main_id_).Value(); + if (!main_packet.IsEmpty()) { + main_ts_.push_back(main_packet.Timestamp()); + } + Packet& loopback_packet = cc->Inputs().Get(loop_id_).Value(); + if (!loopback_packet.IsEmpty()) { + loopback_packets_.push_back(loopback_packet); + while (!main_ts_.empty() && + main_ts_.front() <= loopback_packets_.front().Timestamp()) { + main_ts_.pop_front(); + } + } + + while (!main_ts_.empty() && !loopback_packets_.empty()) { + Timestamp main_timestamp = main_ts_.front(); + main_ts_.pop_front(); + Packet previous_loopback = loopback_packets_.front().At(main_timestamp); + loopback_packets_.pop_front(); + + if (previous_loopback.IsEmpty()) { + // TODO: SetCompleteTimestampBound would be more useful. + cc->Outputs() + .Get(loop_out_id_) + .SetNextTimestampBound(main_timestamp + 1); + } else { + cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback)); + } + } + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId main_id_; + CollectionItemId loop_id_; + CollectionItemId loop_out_id_; + + std::deque main_ts_; + std::deque loopback_packets_; +}; +REGISTER_CALCULATOR(PreviousLoopbackCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc new file mode 100644 index 000000000..1dc359ba1 --- /dev/null +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -0,0 +1,111 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" + +namespace mediapipe { + +namespace { + +// Returns the timestamp values for a vector of Packets. +// TODO: puth this kind of test util in a common place. +std::vector TimestampValues(const std::vector& packets) { + std::vector result; + for (const Packet& packet : packets) { + result.push_back(packet.Timestamp().Value()); + } + return result; +} + +TEST(PreviousLoopbackCalculator, CorrectTimestamps) { + std::vector in_prev; + CalculatorGraphConfig graph_config_ = + ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PreviousLoopbackCalculator' + input_stream: 'MAIN:in' + input_stream: 'LOOP:out' + input_stream_info: { tag_index: 'LOOP' back_edge: true } + output_stream: 'PREV_LOOP:previous' + } + # This calculator synchronizes its inputs as normal, so it is used + # to check that both "in" and "previous" are ready. + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + input_stream: 'previous' + output_stream: 'out' + output_stream: 'previous2' + } + node { + calculator: 'MakePairCalculator' + input_stream: 'out' + input_stream: 'previous2' + output_stream: 'pair' + } + )"); + tool::AddVectorSink("pair", &graph_config_, &in_prev); + + CalculatorGraph graph_; + MEDIAPIPE_ASSERT_OK(graph_.Initialize(graph_config_, {})); + MEDIAPIPE_ASSERT_OK(graph_.StartRun({})); + + auto send_packet = [&graph_](const std::string& input_name, int n) { + MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + auto pair_values = [](const Packet& packet) { + auto pair = packet.Get>(); + int first = pair.first.IsEmpty() ? -1 : pair.first.Get(); + int second = pair.second.IsEmpty() ? -1 : pair.second.Get(); + return std::make_pair(first, second); + }; + + send_packet("in", 1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(in_prev), (std::vector{1})); + EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(1, -1)); + + send_packet("in", 5); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(in_prev), (std::vector{1, 5})); + EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(5, 1)); + + send_packet("in", 15); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(in_prev), (std::vector{1, 5, 15})); + EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(15, 5)); + + MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone()); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc new file mode 100644 index 000000000..ea857dc43 --- /dev/null +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -0,0 +1,199 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/header_util.h" + +namespace mediapipe { + +// RealTimeFlowLimiterCalculator is used to limit the number of pipelined +// processing operations in a section of the graph. +// +// Typical topology: +// +// in ->-[RTFLC]-[foo]-...-[bar]-+->- out +// ^____________________| +// FINISHED +// +// By connecting the output of the graph section to this calculator's FINISHED +// input with a backwards edge, this allows RTFLC to keep track of how many +// timestamps are currently being processed. +// +// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side +// packet. +// +// As long as the number of timestamps being processed ("in flight") is below +// the limit, RTFLC allows input to pass through. When the limit is reached, +// RTFLC starts dropping input packets, keeping only the most recent. When the +// processing count decreases again, as signaled by the receipt of a packet on +// FINISHED, RTFLC allows packets to flow again, releasing the most recently +// queued packet, if any. +// +// If there are multiple input streams, packet dropping is synchronized. +// +// IMPORTANT: for each timestamp where RTFLC forwards a packet (or a set of +// packets, if using multiple data streams), a packet must eventually arrive on +// the FINISHED stream. Dropping packets in the section between RTFLC and +// FINISHED will make the in-flight count incorrect. +// +// TODO: Remove this comment when graph-level ISH has been removed. +// NOTE: this calculator should always use the ImmediateInputStreamHandler and +// uses it by default. However, if the graph specifies a graph-level +// InputStreamHandler, to override that setting, the InputStreamHandler must +// be explicitly specified as shown below. +// +// Example config: +// node { +// calculator: "RealTimeFlowLimiterCalculator" +// input_stream: "raw_frames" +// input_stream: "FINISHED:finished" +// input_stream_info: { +// tag_index: 'FINISHED' +// back_edge: true +// } +// input_stream_handler { +// input_stream_handler: 'ImmediateInputStreamHandler' +// } +// output_stream: "gated_frames" +// } +class RealTimeFlowLimiterCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + int num_data_streams = cc->Inputs().NumEntries(""); + RET_CHECK_GE(num_data_streams, 1); + RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) + << "Output streams must correspond input streams except for the " + "finish indicator input stream."; + for (int i = 0; i < num_data_streams; ++i) { + cc->Inputs().Get("", i).SetAny(); + cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); + } + cc->Inputs().Get("FINISHED", 0).SetAny(); + if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set(); + } + if (cc->Outputs().HasTag("ALLOW")) { + cc->Outputs().Tag("ALLOW").Set(); + } + + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + finished_id_ = cc->Inputs().GetId("FINISHED", 0); + max_in_flight_ = 1; + if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { + max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get(); + } + RET_CHECK_GE(max_in_flight_, 1); + num_in_flight_ = 0; + + allowed_id_ = cc->Outputs().GetId("ALLOW", 0); + allow_ctr_ts_ = Timestamp(0); + + num_data_streams_ = cc->Inputs().NumEntries(""); + data_stream_bound_ts_.resize(num_data_streams_); + RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); + return ::mediapipe::OkStatus(); + } + + bool Allow() { return num_in_flight_ < max_in_flight_; } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + bool old_allow = Allow(); + Timestamp lowest_incomplete_ts = Timestamp::Done(); + + // Process FINISHED stream. + if (!cc->Inputs().Get(finished_id_).Value().IsEmpty()) { + RET_CHECK_GT(num_in_flight_, 0) + << "Received a FINISHED packet, but we had none in flight."; + --num_in_flight_; + } + + // Process data streams. + for (int i = 0; i < num_data_streams_; ++i) { + auto& stream = cc->Inputs().Get("", i); + auto& out = cc->Outputs().Get("", i); + Packet& packet = stream.Value(); + auto ts = packet.Timestamp(); + if (ts.IsRangeValue() && data_stream_bound_ts_[i] <= ts) { + data_stream_bound_ts_[i] = ts + 1; + // Note: it's ok to update the output bound here, before sending the + // packet, because updates are batched during the Process function. + out.SetNextTimestampBound(data_stream_bound_ts_[i]); + } + lowest_incomplete_ts = + std::min(lowest_incomplete_ts, data_stream_bound_ts_[i]); + + if (packet.IsEmpty()) { + // If the input stream is closed, close the corresponding output. + if (stream.IsDone() && !out.IsClosed()) { + out.Close(); + } + // TODO: if the packet is empty, the ts is unset, and we + // cannot read the timestamp bound, even though we'd like to propagate + // it. + } else if (mediapipe::ContainsKey(pending_ts_, ts)) { + // If we have already sent this timestamp (on another stream), send it + // on this stream too. + out.AddPacket(std::move(packet)); + } else if (Allow() && (ts > last_dropped_ts_)) { + // If the in-flight is under the limit, and if we have not already + // dropped this or a later timestamp on another stream, then send + // the packet and add an in-flight timestamp. + out.AddPacket(std::move(packet)); + pending_ts_.insert(ts); + ++num_in_flight_; + } else { + // Otherwise, we'll drop the packet. + last_dropped_ts_ = std::max(last_dropped_ts_, ts); + } + } + + // Remove old pending_ts_ entries. + auto it = std::lower_bound(pending_ts_.begin(), pending_ts_.end(), + lowest_incomplete_ts); + pending_ts_.erase(pending_ts_.begin(), it); + + // Update ALLOW signal. + if ((old_allow != Allow()) && allowed_id_.IsValid()) { + cc->Outputs() + .Get(allowed_id_) + .AddPacket(MakePacket(Allow()).At(++allow_ctr_ts_)); + } + return ::mediapipe::OkStatus(); + } + + private: + std::set pending_ts_; + Timestamp last_dropped_ts_; + int num_data_streams_; + int num_in_flight_; + int max_in_flight_; + CollectionItemId finished_id_; + CollectionItemId allowed_id_; + Timestamp allow_ctr_ts_; + std::vector data_stream_bound_ts_; +}; +REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc new file mode 100644 index 000000000..8c386b8ea --- /dev/null +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc @@ -0,0 +1,496 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" + +namespace mediapipe { + +namespace { +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_.fetch_add(amount); } + + private: + std::atomic supply_; +}; + +// Returns the timestamp values for a vector of Packets. +std::vector TimestampValues(const std::vector& packets) { + std::vector result; + for (const Packet& packet : packets) { + result.push_back(packet.Timestamp().Value()); + } + return result; +} + +// Returns the packet values for a vector of Packets. +template +std::vector PacketValues(const std::vector& packets) { + std::vector result; + for (const Packet& packet : packets) { + result.push_back(packet.Get()); + } + return result; +} + +constexpr int kNumImageFrames = 5; +constexpr int kNumFinished = 3; +CalculatorGraphConfig::Node GetDefaultNode() { + return ParseTextProtoOrDie(R"( + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "raw_frames" + input_stream: "FINISHED:finished" + input_stream_info: { tag_index: "FINISHED" back_edge: true } + output_stream: "gated_frames" + )"); +} + +// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs +// just one packet when MAX_IN_FLIGHT is 1. +TEST(RealTimeFlowLimiterCalculator, OneOutputTest) { + // Setup the calculator runner and add only ImageFrame packets. + CalculatorRunner runner(GetDefaultNode()); + for (int i = 0; i < kNumImageFrames; ++i) { + Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs()->Index(0).packets.push_back( + MakePacket().At(timestamp)); + } + + // Run the calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& frame_output_packets = + runner.Outputs().Index(0).packets; + + EXPECT_EQ(frame_output_packets.size(), 1); +} + +// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all +// input streams to have at least one packet available before publishing. +TEST(RealTimeFlowLimiterCalculator, BasicTest) { + // Setup the calculator runner and add both ImageFrame and finish packets. + CalculatorRunner runner(GetDefaultNode()); + for (int i = 0; i < kNumImageFrames; ++i) { + Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs()->Index(0).packets.push_back( + MakePacket().At(timestamp)); + } + for (int i = 0; i < kNumFinished; ++i) { + Timestamp timestamp = + Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); + runner.MutableInputs() + ->Tag("FINISHED") + .packets.push_back(MakePacket(true).At(timestamp)); + } + + // Run the calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& frame_output_packets = + runner.Outputs().Index(0).packets; + + // Only outputs packets if both input streams are available. + int expected_num_packets = std::min(kNumImageFrames, kNumFinished + 1); + EXPECT_EQ(frame_output_packets.size(), expected_num_packets); +} + +// A Calculator::Process callback function. +typedef std::function<::mediapipe::Status(const InputStreamShardSet&, + OutputStreamShardSet*)> + ProcessFunction; + +// A testing callback function that passes through all packets. +::mediapipe::Status PassthroughFunction(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + for (int i = 0; i < inputs.NumEntries(); ++i) { + if (!inputs.Index(i).Value().IsEmpty()) { + outputs->Index(i).AddPacket(inputs.Index(i).Value()); + } + } + return ::mediapipe::OkStatus(); +} + +// A Calculator that runs a testing callback function in Close. +class CloseCallbackCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (CollectionItemId id = cc->Inputs().BeginId(); + id < cc->Inputs().EndId(); ++id) { + cc->Inputs().Get(id).SetAny(); + } + for (CollectionItemId id = cc->Outputs().BeginId(); + id < cc->Outputs().EndId(); ++id) { + cc->Outputs().Get(id).SetAny(); + } + cc->InputSidePackets().Index(0).Set>(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return PassthroughFunction(cc->Inputs(), &(cc->Outputs())); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + const auto& callback = cc->InputSidePackets() + .Index(0) + .Get>(); + return callback(); + } +}; +REGISTER_CALCULATOR(CloseCallbackCalculator); + +// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic +// graph. +// TODO: clean up these tests. +class RealTimeFlowLimiterCalculatorTest : public testing::Test { + public: + RealTimeFlowLimiterCalculatorTest() + : enter_semaphore_(0), exit_semaphore_(0) {} + + void SetUp() override { + graph_config_ = InflightGraphConfig(); + tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_); + tool::AddVectorSink("out_2", &graph_config_, &out_2_packets_); + } + + void InitializeGraph(int max_in_flight) { + ProcessFunction semaphore_0_func = [&](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + enter_semaphore_.Release(1); + return PassthroughFunction(inputs, outputs); + }; + ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + exit_semaphore_.Acquire(1); + return PassthroughFunction(inputs, outputs); + }; + std::function<::mediapipe::Status()> close_func = [this]() { + close_count_++; + return ::mediapipe::OkStatus(); + }; + MEDIAPIPE_ASSERT_OK(graph_.Initialize( + graph_config_, { + {"max_in_flight", MakePacket(max_in_flight)}, + {"callback_0", Adopt(new auto(semaphore_0_func))}, + {"callback_1", Adopt(new auto(semaphore_1_func))}, + {"callback_2", Adopt(new auto(close_func))}, + })); + } + + // Adds a packet to a graph input stream. + void AddPacket(const std::string& input_name, int value) { + MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(value).At(Timestamp(value)))); + } + + // A calculator graph starting with an RealTimeFlowLimiterCalculator and + // ending with a InFlightFinishCalculator. + // Back-edge "finished" limits processing to one frame in-flight. + // The two LambdaCalculators are used to keep certain packet sets in flight. + CalculatorGraphConfig InflightGraphConfig() { + return ParseTextProtoOrDie(R"( + input_stream: 'in_1' + input_stream: 'in_2' + node { + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' + } + node { + calculator: 'LambdaCalculator' + input_side_packet: 'callback_0' + input_stream: 'in_1_sampled' + input_stream: 'in_2_sampled' + output_stream: 'queue_1' + output_stream: 'queue_2' + } + node { + calculator: 'LambdaCalculator' + input_side_packet: 'callback_1' + input_stream: 'queue_1' + input_stream: 'queue_2' + output_stream: 'close_1' + output_stream: 'close_2' + } + node { + calculator: 'CloseCallbackCalculator' + input_side_packet: 'callback_2' + input_stream: 'close_1' + input_stream: 'close_2' + output_stream: 'out_1' + output_stream: 'out_2' + } + )"); + } + + protected: + CalculatorGraphConfig graph_config_; + CalculatorGraph graph_; + AtomicSemaphore enter_semaphore_; + AtomicSemaphore exit_semaphore_; + std::vector out_1_packets_; + std::vector out_2_packets_; + int close_count_ = 0; +}; + +// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic +// graph. This test shows that: +// +// (1) Timestamps are passed through unaltered. +// (2) All output streams including the back_edge stream are closed when +// the first input stream is closed. +// +TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) { + InitializeGraph(1); + MEDIAPIPE_ASSERT_OK(graph_.StartRun({})); + + auto send_packet = [this](const std::string& input_name, int64 n) { + MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + + for (int i = 0; i < 10; i++) { + send_packet("in_1", i * 10); + // This next input should be dropped. + send_packet("in_1", i * 10 + 5); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + send_packet("in_2", i * 10); + exit_semaphore_.Release(1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + } + MEDIAPIPE_EXPECT_OK(graph_.CloseInputStream("in_1")); + MEDIAPIPE_EXPECT_OK(graph_.CloseInputStream("in_2")); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + + // All output streams are closed and all output packets are delivered, + // with stream "in_1" and stream "in_2" closed. + EXPECT_EQ(10, out_1_packets_.size()); + EXPECT_EQ(10, out_2_packets_.size()); + + // Timestamps have not been messed with. + EXPECT_EQ(PacketValues(out_1_packets_), + TimestampValues(out_1_packets_)); + EXPECT_EQ(PacketValues(out_2_packets_), + TimestampValues(out_2_packets_)); + + // Extra inputs on in_1 have been dropped + EXPECT_EQ(TimestampValues(out_1_packets_), + (std::vector{0, 10, 20, 30, 40, 50, 60, 70, 80, 90})); + EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); + + // The closing of the stream has been propagated. + EXPECT_EQ(1, close_count_); +} + +// A test demonstrating that all output streams are closed when all +// input streams are closed after the last input packet has been processed. +TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) { + InitializeGraph(1); + MEDIAPIPE_ASSERT_OK(graph_.StartRun({})); + + exit_semaphore_.Release(10); + for (int i = 0; i < 10; i++) { + AddPacket("in_1", i); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + AddPacket("in_2", i); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + } + MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + + EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_)); + EXPECT_EQ(TimestampValues(out_1_packets_), + (std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); + EXPECT_EQ(1, close_count_); +} + +TEST(RealTimeFlowLimiterCalculator, TwoStreams) { + std::vector a_passed; + std::vector b_passed; + CalculatorGraphConfig graph_config_ = + ParseTextProtoOrDie(R"( + input_stream: 'in_a' + input_stream: 'in_b' + input_stream: 'finished' + node { + name: 'input_dropper' + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in_a' + input_stream: 'in_b' + input_stream: 'FINISHED:finished' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_a_sampled' + output_stream: 'in_b_sampled' + output_stream: 'ALLOW:allow' + } + )"); + std::string allow_cb_name; + tool::AddVectorSink("in_a_sampled", &graph_config_, &a_passed); + tool::AddVectorSink("in_b_sampled", &graph_config_, &b_passed); + tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); + + bool allow = true; + auto allow_cb = [&allow](const Packet& packet) { + allow = packet.Get(); + }; + + CalculatorGraph graph_; + MEDIAPIPE_EXPECT_OK(graph_.Initialize( + graph_config_, + { + {"max_in_flight", MakePacket(1)}, + {allow_cb_name, + MakePacket>(allow_cb)}, + })); + + MEDIAPIPE_EXPECT_OK(graph_.StartRun({})); + + auto send_packet = [&graph_](const std::string& input_name, int n) { + MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + send_packet("in_a", 1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(allow, false); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{})); + + send_packet("in_a", 2); + send_packet("in_b", 1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, false); + + send_packet("finished", 1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, true); + + send_packet("in_b", 2); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1})); + EXPECT_EQ(allow, true); + + send_packet("in_b", 3); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("in_b", 4); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("in_a", 3); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, false); + + send_packet("finished", 3); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(TimestampValues(a_passed), (std::vector{1, 3})); + EXPECT_EQ(TimestampValues(b_passed), (std::vector{1, 3})); + EXPECT_EQ(allow, true); + + MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone()); +} + +TEST(RealTimeFlowLimiterCalculator, CanConsume) { + std::vector in_sampled_packets_; + CalculatorGraphConfig graph_config_ = + ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'finished' + node { + name: 'input_dropper' + calculator: 'RealTimeFlowLimiterCalculator' + input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' + input_stream: 'in' + input_stream: 'FINISHED:finished' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_sampled' + output_stream: 'ALLOW:allow' + } + )"); + std::string allow_cb_name; + tool::AddVectorSink("in_sampled", &graph_config_, &in_sampled_packets_); + tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true); + + bool allow = true; + auto allow_cb = [&allow](const Packet& packet) { + allow = packet.Get(); + }; + + CalculatorGraph graph_; + MEDIAPIPE_EXPECT_OK(graph_.Initialize( + graph_config_, + { + {"max_in_flight", MakePacket(1)}, + {allow_cb_name, + MakePacket>(allow_cb)}, + })); + + MEDIAPIPE_EXPECT_OK(graph_.StartRun({})); + + auto send_packet = [&graph_](const std::string& input_name, int n) { + MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream( + input_name, MakePacket(n).At(Timestamp(n)))); + }; + send_packet("in", 1); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle()); + EXPECT_EQ(allow, false); + EXPECT_EQ(TimestampValues(in_sampled_packets_), (std::vector{1})); + + MEDIAPIPE_EXPECT_OK(in_sampled_packets_[0].Consume()); + + MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone()); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/round_robin_demux_calculator.cc b/mediapipe/calculators/core/round_robin_demux_calculator.cc new file mode 100644 index 000000000..c84e08884 --- /dev/null +++ b/mediapipe/calculators/core/round_robin_demux_calculator.cc @@ -0,0 +1,120 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +// Forwards the input packet to one of the n output streams "OUTPUT:0", +// "OUTPUT:1", ..., in round robin fashion. The index of the selected output +// stream is emitted to the output stream "SELECT". If not needed, the +// output stream "SELECT" may be omitted. +// +// Designed to run graph bottlenecks in parallel and thus reduce graph +// processing latency by parallelizing. +// +// A simple example config is: +// +// node { +// calculator: "RoundRobinDemuxCalculator" +// input_stream: "signal" +// output_stream: "OUTPUT:0:signal0" +// output_stream: "OUTPUT:1:signal1" +// output_stream: "SELECT:select" +// } +// +// node { +// calculator: "SlowCalculator" +// input_stream: "signal0" +// output_stream: "output0" +// } +// +// node { +// calculator: "SlowCalculator" +// input_stream: "signal1" +// output_stream: "output1" +// } +// +// node { +// calculator: "MuxCalculator" +// input_stream: "INPUT:0:output0" +// input_stream: "INPUT:1:output1" +// input_stream: "SELECT:select" +// output_stream: "OUTPUT:output" +// input_stream_handler { +// input_stream_handler: "MuxInputStreamHandler" +// } +// } +// +// which is essentially running the following configuration in parallel with a +// concurrency level of two: +// +// node { +// calculator: "SlowCalculator" +// input_stream: "signal" +// output_stream: "output" +// } +// +// If SlowCalculator has more than one output stream, the user can group the +// output with MakePairCalculator, MakeVectorCalculator, or a similar variant to +// use it with MuxCalculator and later unpack, or can create new variants of +// MuxCalculator/MuxInputStreamHandler. +class RoundRobinDemuxCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1); + cc->Inputs().Index(0).SetAny(); + if (cc->Outputs().HasTag("SELECT")) { + cc->Outputs().Tag("SELECT").Set(); + } + for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); + id < cc->Outputs().EndId("OUTPUT"); ++id) { + cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Index(0)); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + select_output_ = cc->Outputs().GetId("SELECT", 0); + output_data_stream_index_ = 0; + output_data_stream_base_ = cc->Outputs().GetId("OUTPUT", 0); + num_output_data_streams_ = cc->Outputs().NumEntries("OUTPUT"); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs() + .Get(output_data_stream_base_ + output_data_stream_index_) + .AddPacket(cc->Inputs().Index(0).Value()); + if (select_output_.IsValid()) { + cc->Outputs() + .Get(select_output_) + .Add(new int(output_data_stream_index_), cc->InputTimestamp()); + } + output_data_stream_index_ = + (output_data_stream_index_ + 1) % num_output_data_streams_; + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId select_output_; + CollectionItemId output_data_stream_base_; + int num_output_data_streams_; + int output_data_stream_index_; +}; + +REGISTER_CALCULATOR(RoundRobinDemuxCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD new file mode 100644 index 000000000..b3cd31b3e --- /dev/null +++ b/mediapipe/calculators/image/BUILD @@ -0,0 +1,415 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +exports_files(["LICENSE"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "opencv_image_encoder_calculator_proto", + srcs = ["opencv_image_encoder_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "scale_image_calculator_proto", + srcs = ["scale_image_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/framework/formats:image_format_proto", + ], +) + +proto_library( + name = "set_alpha_calculator_proto", + srcs = ["set_alpha_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "recolor_calculator_proto", + srcs = ["recolor_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + ], +) + +mediapipe_cc_proto_library( + name = "opencv_image_encoder_calculator_cc_proto", + srcs = ["opencv_image_encoder_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":opencv_image_encoder_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "mask_overlay_calculator_cc_proto", + srcs = ["mask_overlay_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":mask_overlay_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "scale_image_calculator_cc_proto", + srcs = ["scale_image_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":scale_image_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "set_alpha_calculator_cc_proto", + srcs = ["set_alpha_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":set_alpha_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "recolor_calculator_cc_proto", + srcs = ["recolor_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":recolor_calculator_proto"], +) + +cc_library( + name = "color_convert_calculator", + srcs = ["color_convert_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "opencv_encoded_image_to_image_frame_calculator", + srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "opencv_image_encoder_calculator", + srcs = ["opencv_image_encoder_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":opencv_image_encoder_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "opencv_put_text_calculator", + srcs = ["opencv_put_text_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "set_alpha_calculator", + srcs = ["set_alpha_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":set_alpha_calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +proto_library( + name = "image_transformation_calculator_proto", + srcs = ["image_transformation_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:scale_mode_proto", + ], +) + +mediapipe_cc_proto_library( + name = "image_transformation_calculator_cc_proto", + srcs = ["image_transformation_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/gpu:scale_mode_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":image_transformation_calculator_proto"], +) + +cc_library( + name = "image_transformation_calculator", + srcs = ["image_transformation_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":image_transformation_calculator_cc_proto", + "//mediapipe/gpu:scale_mode_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gl_quad_renderer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "luminance_calculator", + srcs = ["luminance_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_simple_calculator", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:shader_util", + ], + alwayslink = 1, +) + +cc_library( + name = "sobel_edges_calculator", + srcs = ["sobel_edges_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_simple_calculator", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:shader_util", + ], + alwayslink = 1, +) + +cc_library( + name = "recolor_calculator", + srcs = ["recolor_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":recolor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:color_cc_proto", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "scale_image_utils", + srcs = ["scale_image_utils.cc"], + hdrs = ["scale_image_utils.h"], + visibility = [ + "//mediapipe:__subpackages__", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "scale_image_calculator", + srcs = ["scale_image_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":scale_image_utils", + "//mediapipe/calculators/image:scale_image_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/formats:yuv_image", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:image_resizer", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:image_frame_util", + "@com_google_absl//absl/strings", + "@libyuv", + ], + alwayslink = 1, +) + +cc_test( + name = "opencv_encoded_image_to_image_frame_calculator_test", + srcs = ["opencv_encoded_image_to_image_frame_calculator_test.cc"], + data = ["//mediapipe/calculators/image/testdata:test_images"], + deps = [ + ":opencv_encoded_image_to_image_frame_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_test( + name = "opencv_image_encoder_calculator_test", + srcs = ["opencv_image_encoder_calculator_test.cc"], + data = ["//mediapipe/calculators/image/testdata:test_images"], + deps = [ + ":opencv_image_encoder_calculator", + ":opencv_image_encoder_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_test( + name = "scale_image_utils_test", + srcs = ["scale_image_utils_test.cc"], + deps = [ + ":scale_image_utils", + "//mediapipe/framework/port:gtest_main", + ], +) + +proto_library( + name = "mask_overlay_calculator_proto", + srcs = ["mask_overlay_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +cc_library( + name = "mask_overlay_calculator", + srcs = ["mask_overlay_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":mask_overlay_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:shader_util", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc new file mode 100644 index 000000000..875323e9e --- /dev/null +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -0,0 +1,160 @@ +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status_builder.h" + +namespace mediapipe { +namespace { +void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { + CHECK(mat->depth() == CV_8U); + CHECK(channel < mat->channels()); + const int step = mat->channels(); + for (int r = 0; r < mat->rows; ++r) { + uint8* row_ptr = mat->ptr(r); + for (int offset = channel; offset < mat->cols * step; offset += step) { + row_ptr[offset] = value; + } + } +} + +constexpr char kRgbaInTag[] = "RGBA_IN"; +constexpr char kRgbInTag[] = "RGB_IN"; +constexpr char kGrayInTag[] = "GRAY_IN"; +constexpr char kRgbaOutTag[] = "RGBA_OUT"; +constexpr char kRgbOutTag[] = "RGB_OUT"; +constexpr char kGrayOutTag[] = "GRAY_OUT"; +} // namespace + +// A portable color conversion calculator calculator. +// +// The following conversions are currently supported, but it's fairly easy to +// add new ones if this doesn't meet your needs--Don't forget to add a test to +// color_convert_calculator_test.cc if you do! +// RGBA -> RGB +// GRAY -> RGB +// RGB -> GRAY +// RGB -> RGBA +// +// This calculator only supports a single input stream and output stream at a +// time. If more than one input stream or output stream is present, the +// calculator will fail at FillExpectations. +// TODO: Remove this requirement by replacing the typed input streams +// with a single generic input and allow multiple simultaneous outputs. +// +// Input streams: +// RGBA_IN: The input video stream (ImageFrame, SRGBA). +// RGB_IN: The input video stream (ImageFrame, SRGB). +// GRAY_IN: The input video stream (ImageFrame, GRAY8). +// +// Output streams: +// RGBA_OUT: The output video stream (ImageFrame, SRGBA). +// RGB_OUT: The output video stream (ImageFrame, SRGB). +// GRAY_OUT: The output video stream (ImageFrame, GRAY8). +class ColorConvertCalculator : public CalculatorBase { + public: + ~ColorConvertCalculator() override = default; + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Wrangles the appropriate inputs and outputs to perform the color + // conversion. The ImageFrame on input_tag is converted using the + // open_cv_convert_code provided and then output on the output_tag stream. + // Note that the output_format must match the destination conversion code. + ::mediapipe::Status ConvertAndOutput(const std::string& input_tag, + const std::string& output_tag, + ImageFormat::Format output_format, + int open_cv_convert_code, + CalculatorContext* cc); +}; + +REGISTER_CALCULATOR(ColorConvertCalculator); + +::mediapipe::Status ColorConvertCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is allowed."; + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is allowed."; + + if (cc->Inputs().HasTag(kRgbaInTag)) { + cc->Inputs().Tag(kRgbaInTag).Set(); + } + + if (cc->Inputs().HasTag(kGrayInTag)) { + cc->Inputs().Tag(kGrayInTag).Set(); + } + + if (cc->Inputs().HasTag(kRgbInTag)) { + cc->Inputs().Tag(kRgbInTag).Set(); + } + + if (cc->Outputs().HasTag(kRgbOutTag)) { + cc->Outputs().Tag(kRgbOutTag).Set(); + } + + if (cc->Outputs().HasTag(kGrayOutTag)) { + cc->Outputs().Tag(kGrayOutTag).Set(); + } + + if (cc->Outputs().HasTag(kRgbaOutTag)) { + cc->Outputs().Tag(kRgbaOutTag).Set(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ColorConvertCalculator::ConvertAndOutput( + const std::string& input_tag, const std::string& output_tag, + ImageFormat::Format output_format, int open_cv_convert_code, + CalculatorContext* cc) { + const cv::Mat& input_mat = + formats::MatView(&cc->Inputs().Tag(input_tag).Get()); + std::unique_ptr output_frame( + new ImageFrame(output_format, input_mat.cols, input_mat.rows)); + cv::Mat output_mat = formats::MatView(output_frame.get()); + cv::cvtColor(input_mat, output_mat, open_cv_convert_code); + + // cv::cvtColor will leave the alpha channel set to 0, which is a bizarre + // design choice. Instead, let's set alpha to 255. + if (open_cv_convert_code == cv::COLOR_RGB2RGBA) { + SetColorChannel(3, 255, &output_mat); + } + cc->Outputs() + .Tag(output_tag) + .Add(output_frame.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ColorConvertCalculator::Process(CalculatorContext* cc) { + // RGBA -> RGB + if (cc->Inputs().HasTag(kRgbaInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kRgbaInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_RGBA2RGB, cc); + } + // GRAY -> RGB + if (cc->Inputs().HasTag(kGrayInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kGrayInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_GRAY2RGB, cc); + } + // RGB -> GRAY + if (cc->Inputs().HasTag(kRgbInTag) && cc->Outputs().HasTag(kGrayOutTag)) { + return ConvertAndOutput(kRgbInTag, kGrayOutTag, ImageFormat::GRAY8, + cv::COLOR_RGB2GRAY, cc); + } + // RGB -> RGBA + if (cc->Inputs().HasTag(kRgbInTag) && cc->Outputs().HasTag(kRgbaOutTag)) { + return ConvertAndOutput(kRgbInTag, kRgbaOutTag, ImageFormat::SRGBA, + cv::COLOR_RGB2RGBA, cc); + } + + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Unsupported image format conversion."; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc new file mode 100644 index 000000000..303712f32 --- /dev/null +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -0,0 +1,465 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/gpu/scale_mode.pb.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_quad_renderer.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ + +#if defined(__ANDROID__) +// The size of Java arrays is dynamic, which makes it difficult to +// generate the right packet type with a fixed size. Therefore, we +// are using unsized arrays on Android. +typedef int DimensionsPacketType[]; +#else +typedef int DimensionsPacketType[2]; +#endif // __ANDROID__ + +#define DEFAULT_SCALE_MODE mediapipe::ScaleMode_Mode_STRETCH + +namespace mediapipe { + +#if defined(__ANDROID__) + +#endif // __ANDROID__ + +namespace { +int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { + switch (rotation) { + case mediapipe::RotationMode_Mode_UNKNOWN: + case mediapipe::RotationMode_Mode_ROTATION_0: + return 0; + case mediapipe::RotationMode_Mode_ROTATION_90: + return 90; + case mediapipe::RotationMode_Mode_ROTATION_180: + return 180; + case mediapipe::RotationMode_Mode_ROTATION_270: + return 270; + } +} +mediapipe::RotationMode_Mode DegreesToRotationMode(int degrees) { + switch (degrees) { + case 0: + return mediapipe::RotationMode_Mode_ROTATION_0; + case 90: + return mediapipe::RotationMode_Mode_ROTATION_90; + case 180: + return mediapipe::RotationMode_Mode_ROTATION_180; + case 270: + return mediapipe::RotationMode_Mode_ROTATION_270; + default: + return mediapipe::RotationMode_Mode_UNKNOWN; + } +} +mediapipe::ScaleMode_Mode ParseScaleMode( + mediapipe::ScaleMode_Mode scale_mode, + mediapipe::ScaleMode_Mode default_mode) { + switch (scale_mode) { + case mediapipe::ScaleMode_Mode_DEFAULT: + return default_mode; + case mediapipe::ScaleMode_Mode_STRETCH: + return scale_mode; + case mediapipe::ScaleMode_Mode_FIT: + return scale_mode; + case mediapipe::ScaleMode_Mode_FILL_AND_CROP: + return scale_mode; + default: + return default_mode; + } +} +} // namespace + +// Scales, rotates, and flips images horizontally or vertically. +// +// Input: +// One of the following two tags: +// IMAGE: ImageFrame representing the input image. +// IMAGE_GPU: GpuBuffer representing the input image. +// +// ROTATION_DEGREES (optional): The counterclockwise rotation angle in +// degrees. This allows different rotation angles for different frames. It has +// to be a multiple of 90 degrees. If provided, it overrides the +// ROTATION_DEGREES input side packet. +// +// Output: +// One of the following two tags: +// IMAGE - ImageFrame representing the output image. +// IMAGE_GPU - GpuBuffer representing the output image. +// +// LETTERBOX_PADDING (optional): An std::array representing the +// letterbox padding from the 4 sides ([left, top, right, bottom]) of the +// output image, normalized to [0.f, 1.f] by the output dimensions. The +// padding values are non-zero only when the scale mode specified in the +// calculator options is FIT. For instance, when the input image is 10x10 +// (width x height) and the output dimensions specified in the calculator +// option are 20x40 and scale mode is FIT, the calculator scales the input +// image to 20x20 and places it in the middle of the output image with an +// equal padding of 10 pixels at the top and the bottom. The resulting array +// is therefore [0.f, 0.25f, 0.f, 0.25f] (10/40 = 0.25f). +// +// Input side packet: +// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as the +// first two elements in an integer array. It overrides the corresponding +// field in the calculator options. +// +// ROTATION_DEGREES (optional): The counterclockwise rotation angle in +// degrees. It has to be a multiple of 90 degrees. It overrides the +// corresponding field in the calculator options. +// +// Calculator options (see image_transformation_calculator.proto): +// output_width, output_height - (optional) Desired scaled image size. +// rotation_mode - (optional) Rotation in multiples of 90 degrees. +// flip_vertically, flip_horizontally - (optional) flip about x or y axis. +// scale_mode - (optional) Stretch, Fit, or Fill and Crop +// +// Note: To enable horizontal or vertical flipping, specify them in the +// calculator options. Flipping is applied after rotation. +// +// Note: Only scale mode STRETCH is currently supported on CPU, +// and flipping is not yet supported either. +// +class ImageTransformationCalculator : public CalculatorBase { + public: + ImageTransformationCalculator() = default; + ~ImageTransformationCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status RenderCpu(CalculatorContext* cc); + ::mediapipe::Status RenderGpu(CalculatorContext* cc); + ::mediapipe::Status GlSetup(); + + void ComputeOutputDimensions(int input_width, int input_height, + int* output_width, int* output_height); + void ComputeOutputLetterboxPadding(int input_width, int input_height, + int output_width, int output_height, + std::array* padding); + + ImageTransformationCalculatorOptions options_; + int output_width_ = 0; + int output_height_ = 0; + mediapipe::RotationMode_Mode rotation_; + mediapipe::ScaleMode_Mode scale_mode_; + + bool use_gpu_ = false; +#if defined(__ANDROID__) + GlCalculatorHelper helper_; + std::unique_ptr rgb_renderer_; + std::unique_ptr ext_rgb_renderer_; +#endif // __ANDROID__ +}; +REGISTER_CALCULATOR(ImageTransformationCalculator); + +// static +::mediapipe::Status ImageTransformationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); + RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); + + if (cc->Inputs().HasTag("IMAGE")) { + RET_CHECK(cc->Outputs().HasTag("IMAGE")); + cc->Inputs().Tag("IMAGE").Set(); + cc->Outputs().Tag("IMAGE").Set(); + } +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("IMAGE_GPU")) { + RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); + cc->Inputs().Tag("IMAGE_GPU").Set(); + cc->Outputs().Tag("IMAGE_GPU").Set(); + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag("ROTATION_DEGREES")) { + cc->Inputs().Tag("ROTATION_DEGREES").Set(); + } + + if (cc->InputSidePackets().HasTag("OUTPUT_DIMENSIONS")) { + cc->InputSidePackets().Tag("OUTPUT_DIMENSIONS").Set(); + } + if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) { + cc->InputSidePackets().Tag("ROTATION_DEGREES").Set(); + } + + if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { + cc->Outputs().Tag("LETTERBOX_PADDING").Set>(); + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { + // Inform the framework that we always output at the same timestamp + // as we receive a packet at. + cc->SetOffset(mediapipe::TimestampDiff(0)); + + options_ = cc->Options(); + + if (cc->Inputs().HasTag("IMAGE_GPU")) { + use_gpu_ = true; + } + + if (cc->InputSidePackets().HasTag("OUTPUT_DIMENSIONS")) { + const auto& dimensions = cc->InputSidePackets() + .Tag("OUTPUT_DIMENSIONS") + .Get(); + output_width_ = dimensions[0]; + output_height_ = dimensions[1]; + } else { + output_width_ = options_.output_width(); + output_height_ = options_.output_height(); + } + if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) { + rotation_ = DegreesToRotationMode( + cc->InputSidePackets().Tag("ROTATION_DEGREES").Get()); + } else { + rotation_ = DegreesToRotationMode(options_.rotation_mode()); + } + + scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); + + if (use_gpu_) { +#if defined(__ANDROID__) + // Let the helper access the GL context information. + RETURN_IF_ERROR(helper_.Open(cc)); +#else + RET_CHECK_FAIL() << "GPU processing for non-Android not supported yet."; +#endif // __ANDROID__ + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageTransformationCalculator::Process( + CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) + return helper_.RunInGlContext( + [this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); }); +#endif // __ANDROID__ + } else { + return RenderCpu(cc); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageTransformationCalculator::Close( + CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) + QuadRenderer* rgb_renderer = rgb_renderer_.release(); + QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); + helper_.RunInGlContext([rgb_renderer, ext_rgb_renderer] { + if (rgb_renderer) { + rgb_renderer->GlTeardown(); + delete rgb_renderer; + } + if (ext_rgb_renderer) { + ext_rgb_renderer->GlTeardown(); + delete ext_rgb_renderer; + } + }); +#endif // __ANDROID__ + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageTransformationCalculator::RenderCpu( + CalculatorContext* cc) { + int input_width = cc->Inputs().Tag("IMAGE").Get().Width(); + int input_height = cc->Inputs().Tag("IMAGE").Get().Height(); + + int output_width; + int output_height; + ComputeOutputDimensions(input_width, input_height, &output_width, + &output_height); + if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { + auto padding = absl::make_unique>(); + ComputeOutputLetterboxPadding(input_width, input_height, output_width, + output_height, padding.get()); + cc->Outputs() + .Tag("LETTERBOX_PADDING") + .Add(padding.release(), cc->InputTimestamp()); + } + + if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) { + rotation_ = DegreesToRotationMode( + cc->InputSidePackets().Tag("ROTATION_DEGREES").Get()); + } + + const auto& input_img = cc->Inputs().Tag("IMAGE").Get(); + std::unique_ptr output_frame( + new ImageFrame(input_img.Format(), output_width, output_height)); + cv::Mat input_mat = formats::MatView(&input_img); + cv::Mat output_mat = formats::MatView(output_frame.get()); + + cv::Mat scaled_mat; + if (scale_mode_ != mediapipe::ScaleMode_Mode_STRETCH) { + // TODO finish CPU version features. + return ::mediapipe::UnimplementedError( + "Only STRETCH scale mode currently supported."); + } + cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_)); + + cv::Mat rotated_mat; + const int angle = RotationModeToDegrees(rotation_); + cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0); + cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0); + cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size()); + + rotated_mat.copyTo(output_mat); + cc->Outputs().Tag("IMAGE").Add(output_frame.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageTransformationCalculator::RenderGpu( + CalculatorContext* cc) { +#if defined(__ANDROID__) + int input_width = cc->Inputs().Tag("IMAGE_GPU").Get().width(); + int input_height = cc->Inputs().Tag("IMAGE_GPU").Get().height(); + + int output_width; + int output_height; + ComputeOutputDimensions(input_width, input_height, &output_width, + &output_height); + + if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { + auto padding = absl::make_unique>(); + ComputeOutputLetterboxPadding(input_width, input_height, output_width, + output_height, padding.get()); + cc->Outputs() + .Tag("LETTERBOX_PADDING") + .Add(padding.release(), cc->InputTimestamp()); + } + + const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); + QuadRenderer* renderer = nullptr; + GlTexture src1; + { + src1 = helper_.CreateSourceTexture(input); + if (src1.target() == GL_TEXTURE_EXTERNAL_OES) { + if (!ext_rgb_renderer_) { + ext_rgb_renderer_ = absl::make_unique(); + RETURN_IF_ERROR(ext_rgb_renderer_->GlSetup( + ::mediapipe::kBasicTexturedFragmentShaderOES, {"video_frame"})); + } + renderer = ext_rgb_renderer_.get(); + } else { + if (!rgb_renderer_) { + rgb_renderer_ = absl::make_unique(); + RETURN_IF_ERROR(rgb_renderer_->GlSetup()); + } + renderer = rgb_renderer_.get(); + } + } + RET_CHECK(renderer) << "Unsupported input texture type"; + + if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) { + rotation_ = DegreesToRotationMode( + cc->InputSidePackets().Tag("ROTATION_DEGREES").Get()); + } + + static mediapipe::FrameScaleMode scale_mode = + mediapipe::FrameScaleModeFromProto(scale_mode_, + mediapipe::FrameScaleMode::kStretch); + mediapipe::FrameRotation rotation = + mediapipe::FrameRotationFromDegrees(RotationModeToDegrees(rotation_)); + + auto dst = helper_.CreateDestinationTexture(output_width, output_height, + input.format()); + + helper_.BindFramebuffer(dst); // GL_TEXTURE0 + glActiveTexture(GL_TEXTURE1); + glBindTexture(src1.target(), src1.name()); + + RETURN_IF_ERROR(renderer->GlRender( + src1.width(), src1.height(), dst.width(), dst.height(), scale_mode, + rotation, options_.flip_horizontally(), options_.flip_vertically(), + /*flip_texture=*/false)); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(src1.target(), 0); + + // Execute GL commands, before getting result. + glFlush(); + + auto output = dst.GetFrame(); + cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); + +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +void ImageTransformationCalculator::ComputeOutputDimensions( + int input_width, int input_height, int* output_width, int* output_height) { + if (output_width_ > 0 && output_height_ > 0) { + *output_width = output_width_; + *output_height = output_height_; + } else if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || + rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { + *output_width = input_height; + *output_height = input_width; + } else { + *output_width = input_width; + *output_height = input_height; + } +} + +void ImageTransformationCalculator::ComputeOutputLetterboxPadding( + int input_width, int input_height, int output_width, int output_height, + std::array* padding) { + if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { + if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || + rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { + std::swap(input_width, input_height); + } + const float input_aspect_ratio = + static_cast(input_width) / input_height; + const float output_aspect_ratio = + static_cast(output_width) / output_height; + if (input_aspect_ratio < output_aspect_ratio) { + // Compute left and right padding. + (*padding)[0] = (1.f - input_aspect_ratio / output_aspect_ratio) / 2.f; + (*padding)[2] = (*padding)[0]; + } else if (output_aspect_ratio < input_aspect_ratio) { + // Compute top and bottom padding. + (*padding)[1] = (1.f - output_aspect_ratio / input_aspect_ratio) / 2.f; + (*padding)[3] = (*padding)[1]; + } + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/image_transformation_calculator.proto b/mediapipe/calculators/image/image_transformation_calculator.proto new file mode 100644 index 000000000..235354a00 --- /dev/null +++ b/mediapipe/calculators/image/image_transformation_calculator.proto @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/gpu/scale_mode.proto"; + +// Counterclockwise rotation. +message RotationMode { + enum Mode { + UNKNOWN = 0; + ROTATION_0 = 1; + ROTATION_90 = 2; + ROTATION_180 = 3; + ROTATION_270 = 4; + } +} + +message ImageTransformationCalculatorOptions { + extend CalculatorOptions { + optional ImageTransformationCalculatorOptions ext = 251952830; + } + + // Output dimensions. Set to 0 if they should be the same as the input. + optional int32 output_width = 1 [default = 0]; + optional int32 output_height = 2 [default = 0]; + // Counterclockwise rotation mode. + optional RotationMode.Mode rotation_mode = 3; + // Vertical flipping, applied after rotation. + optional bool flip_vertically = 4 [default = false]; + // Horizontal flipping, applied after rotation. + optional bool flip_horizontally = 5 [default = false]; + // Scale mode. + optional ScaleMode.Mode scale_mode = 6; +} diff --git a/mediapipe/calculators/image/luminance_calculator.cc b/mediapipe/calculators/image/luminance_calculator.cc new file mode 100644 index 000000000..325745d99 --- /dev/null +++ b/mediapipe/calculators/image/luminance_calculator.cc @@ -0,0 +1,151 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/gpu/gl_simple_calculator.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + +namespace mediapipe { + +// Converts RGB images into luminance images, still stored in RGB format. +// See GlSimpleCalculatorBase for inputs, outputs and input side packets. +class LuminanceCalculator : public GlSimpleCalculator { + public: + ::mediapipe::Status GlSetup() override; + ::mediapipe::Status GlRender(const GlTexture& src, + const GlTexture& dst) override; + ::mediapipe::Status GlTeardown() override; + + private: + GLuint program_ = 0; + GLint frame_; +}; +REGISTER_CALCULATOR(LuminanceCalculator); + +::mediapipe::Status LuminanceCalculator::GlSetup() { + // Load vertex and fragment shaders + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + const GLchar* frag_src = GLES_VERSION_COMPAT + R"( +#if __VERSION__ < 130 + #define in varying +#endif // __VERSION__ < 130 + +#ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; +#else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; +#endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D video_frame; + const highp vec3 W = vec3(0.2125, 0.7154, 0.0721); + + void main() { + vec4 color = texture2D(video_frame, sample_coordinate); + float luminance = dot(color.rgb, W); + fragColor.rgb = vec3(luminance); + fragColor.a = color.a; + } + + )"; + + // shader program + GlhCreateProgram(kBasicVertexShader, frag_src, NUM_ATTRIBUTES, + (const GLchar**)&attr_name[0], attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + frame_ = glGetUniformLocation(program_, "video_frame"); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + glUniform1i(frame_, 1); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LuminanceCalculator::GlTeardown() { + if (program_) { + glDeleteProgram(program_); + program_ = 0; + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/mask_overlay_calculator.cc b/mediapipe/calculators/image/mask_overlay_calculator.cc new file mode 100644 index 000000000..ef0cc4ca3 --- /dev/null +++ b/mediapipe/calculators/image/mask_overlay_calculator.cc @@ -0,0 +1,282 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/mask_overlay_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + +namespace mediapipe { + +using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_ALPHA; +using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_RED; +using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_UNKNOWN; + +// Mixes two frames using a third mask frame or constant value. +// +// Inputs: +// VIDEO:[0,1] (GpuBuffer): +// Two inputs should be provided. +// MASK (GpuBuffer): +// Optional. +// Where the mask is 0, VIDEO:0 will be used. Where it is 1, VIDEO:1. +// Intermediate values will blend. +// If not specified, CONST_MASK float must be present. +// CONST_MASK (float): +// Optional. +// If not specified, MASK GpuBuffer must be present. +// Similar to MASK GpuBuffer, but applied globally to every pixel. +// +// Outputs: +// OUTPUT (GpuBuffer): +// The mix. + +class MaskOverlayCalculator : public CalculatorBase { + public: + MaskOverlayCalculator() {} + ~MaskOverlayCalculator(); + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + ::mediapipe::Status GlSetup( + const MaskOverlayCalculatorOptions::MaskChannel mask_channel); + ::mediapipe::Status GlRender(const float mask_const); + + private: + GlCalculatorHelper helper_; + bool initialized_ = false; + bool use_mask_tex_ = false; // Otherwise, use constant float value. + GLuint program_ = 0; + GLint unif_frame1_; + GLint unif_frame2_; + GLint unif_mask_; +}; +REGISTER_CALCULATOR(MaskOverlayCalculator); + +// static +::mediapipe::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { + RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); + cc->Inputs().Get("VIDEO", 0).Set(); + cc->Inputs().Get("VIDEO", 1).Set(); + if (cc->Inputs().HasTag("MASK")) + cc->Inputs().Tag("MASK").Set(); + else if (cc->Inputs().HasTag("CONST_MASK")) + cc->Inputs().Tag("CONST_MASK").Set(); + else + return ::mediapipe::Status( + ::mediapipe::StatusCode::kNotFound, + "At least one mask input stream must be present."); + cc->Outputs().Tag("OUTPUT").Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + if (cc->Inputs().HasTag("MASK")) { + use_mask_tex_ = true; + } + return helper_.Open(cc); +} + +::mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { + return helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + if (!initialized_) { + const auto& options = cc->Options(); + const auto mask_channel = options.mask_channel(); + + RETURN_IF_ERROR(GlSetup(mask_channel)); + initialized_ = true; + } + + glDisable(GL_BLEND); + + const Packet& input1_packet = cc->Inputs().Get("VIDEO", 1).Value(); + const Packet& mask_packet = use_mask_tex_ + ? cc->Inputs().Tag("MASK").Value() + : cc->Inputs().Tag("CONST_MASK").Value(); + + if (mask_packet.IsEmpty()) { + cc->Outputs().Tag("OUTPUT").AddPacket(input1_packet); + return ::mediapipe::OkStatus(); + } + + const auto& input0_buffer = cc->Inputs().Get("VIDEO", 0).Get(); + const auto& input1_buffer = input1_packet.Get(); + + auto src1 = helper_.CreateSourceTexture(input0_buffer); + auto src2 = helper_.CreateSourceTexture(input1_buffer); + + GlTexture mask_tex; + if (use_mask_tex_) { + const auto& mask_buffer = mask_packet.Get(); + mask_tex = helper_.CreateSourceTexture(mask_buffer); + } + + auto dst = helper_.CreateDestinationTexture(src1.width(), src1.height()); + + helper_.BindFramebuffer(dst); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(src1.target(), src1.name()); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(src2.target(), src2.name()); + + if (use_mask_tex_) { + const float mask_const = -1; + + glActiveTexture(GL_TEXTURE3); + glBindTexture(mask_tex.target(), mask_tex.name()); + + RETURN_IF_ERROR(GlRender(mask_const)); + + glActiveTexture(GL_TEXTURE3); + glBindTexture(mask_tex.target(), 0); + + } else { + const float mask_const = mask_packet.Get(); + + RETURN_IF_ERROR(GlRender(mask_const)); + } + + glActiveTexture(GL_TEXTURE2); + glBindTexture(src2.target(), 0); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(src1.target(), 0); + + glFlush(); + + auto output = dst.GetFrame(); + src1.Release(); + src2.Release(); + if (use_mask_tex_) mask_tex.Release(); + dst.Release(); + + cc->Outputs().Tag("OUTPUT").Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + }); +} + +::mediapipe::Status MaskOverlayCalculator::GlSetup( + const MaskOverlayCalculatorOptions::MaskChannel mask_channel) { + // Load vertex and fragment shaders + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + std::string mask_component; + switch (mask_channel) { + case MaskOverlayCalculatorOptions_MaskChannel_UNKNOWN: + case MaskOverlayCalculatorOptions_MaskChannel_RED: + mask_component = "r"; + break; + case MaskOverlayCalculatorOptions_MaskChannel_ALPHA: + mask_component = "a"; + break; + } + + const std::string frag_src_tex = + std::string(kMediaPipeFragmentShaderPreamble) + + R"( + DEFAULT_PRECISION(highp, float) + + in vec2 sample_coordinate; + uniform sampler2D frame1; + uniform sampler2D frame2; + uniform sampler2D mask; + + void main() { + vec4 color1 = texture2D(frame1, sample_coordinate); + vec4 color2 = texture2D(frame2, sample_coordinate); + vec4 weight = texture2D(mask, sample_coordinate); + + #define MASK_COMPONENT )" + + mask_component + + R"( + + gl_FragColor = mix(color1, color2, weight.MASK_COMPONENT); + } + )"; + + const GLchar* frag_src_const = R"( + precision highp float; + + varying vec2 sample_coordinate; + uniform sampler2D frame1; + uniform sampler2D frame2; + uniform float mask; + + void main() { + vec4 color1 = texture2D(frame1, sample_coordinate); + vec4 color2 = texture2D(frame2, sample_coordinate); + float weight = mask; + + gl_FragColor = mix(color1, color2, weight); + } + )"; + + // shader program + GlhCreateProgram(kBasicVertexShader, + use_mask_tex_ ? frag_src_tex.c_str() : frag_src_const, + NUM_ATTRIBUTES, &attr_name[0], attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + unif_frame1_ = glGetUniformLocation(program_, "frame1"); + unif_frame2_ = glGetUniformLocation(program_, "frame2"); + unif_mask_ = glGetUniformLocation(program_, "mask"); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MaskOverlayCalculator::GlRender(const float mask_const) { + glUseProgram(program_); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, kBasicSquareVertices); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, + kBasicTextureVertices); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + + glUniform1i(unif_frame1_, 1); + glUniform1i(unif_frame2_, 2); + if (use_mask_tex_) + glUniform1i(unif_mask_, 3); + else + glUniform1f(unif_mask_, mask_const); + + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + return ::mediapipe::OkStatus(); +} + +MaskOverlayCalculator::~MaskOverlayCalculator() { + helper_.RunInGlContext([this] { + if (program_) { + glDeleteProgram(program_); + program_ = 0; + } + }); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/mask_overlay_calculator.proto b/mediapipe/calculators/image/mask_overlay_calculator.proto new file mode 100644 index 000000000..3b82d610e --- /dev/null +++ b/mediapipe/calculators/image/mask_overlay_calculator.proto @@ -0,0 +1,34 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message MaskOverlayCalculatorOptions { + extend CalculatorOptions { + optional MaskOverlayCalculatorOptions ext = 252129282; + } + + enum MaskChannel { + UNKNOWN = 0; + RED = 1; + ALPHA = 2; + } + + // Selects which channel of the MASK input to use for masking. + optional MaskChannel mask_channel = 1 [default = RED]; +} diff --git a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc new file mode 100644 index 000000000..445329009 --- /dev/null +++ b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc @@ -0,0 +1,81 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" + +namespace mediapipe { + +// Takes in an encoded image std::string, decodes it by OpenCV, and converts to +// an ImageFrame. Note that this calculator only supports grayscale and RGB +// images for now. +// +// Example config: +// node { +// calculator: "OpenCvEncodedImageToImageFrameCalculator" +// input_stream: "encoded_image" +// output_stream: "image_frame" +// } +class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Process(CalculatorContext* cc) override; +}; + +::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process( + CalculatorContext* cc) { + const std::string& contents = cc->Inputs().Index(0).Get(); + const std::vector contents_vector(contents.begin(), contents.end()); + cv::Mat decoded_mat = + cv::imdecode(contents_vector, -1 /* return the loaded image as-is */); + + ImageFormat::Format image_format = ImageFormat::UNKNOWN; + cv::Mat output_mat; + switch (decoded_mat.channels()) { + case 1: + image_format = ImageFormat::GRAY8; + output_mat = decoded_mat; + break; + case 3: + image_format = ImageFormat::SRGB; + cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGB); + break; + case 4: + return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + << "4-channel image isn't supported yet"; + default: + return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + << "Unsupported number of channels: " << decoded_mat.channels(); + } + std::unique_ptr output_frame = absl::make_unique( + image_format, decoded_mat.size().width, decoded_mat.size().height); + output_mat.copyTo(formats::MatView(output_frame.get())); + cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(OpenCvEncodedImageToImageFrameCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator_test.cc b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator_test.cc new file mode 100644 index 000000000..b5db460e9 --- /dev/null +++ b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator_test.cc @@ -0,0 +1,106 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +TEST(OpenCvEncodedImageToImageFrameCalculatorTest, TestRgbJpeg) { + std::string contents; + MEDIAPIPE_ASSERT_OK(file::GetContents( + file::JoinPath("./", "/mediapipe/calculators/image/testdata/dino.jpg"), + &contents)); + Packet input_packet = MakePacket(contents); + + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "OpenCvEncodedImageToImageFrameCalculator" + input_stream: "encoded_image" + output_stream: "image_frame" + )"); + CalculatorRunner runner(node_config); + runner.MutableInputs()->Index(0).packets.push_back( + input_packet.At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(1, outputs.NumEntries()); + const std::vector& packets = outputs.Index(0).packets; + ASSERT_EQ(1, packets.size()); + const ImageFrame& output_frame = packets[0].Get(); + + cv::Mat input_mat = cv::imread( + file::JoinPath("./", "/mediapipe/calculators/image/testdata/dino.jpg")); + cv::Mat output_mat; + cv::cvtColor(formats::MatView(&output_frame), output_mat, cv::COLOR_RGB2BGR); + cv::Mat diff; + cv::absdiff(input_mat, output_mat, diff); + double max_val; + cv::minMaxLoc(diff, nullptr, &max_val); + // Expects that the maximum absolute pixel-by-pixel difference is less + // than 10. + EXPECT_LE(max_val, 10); +} + +TEST(OpenCvEncodedImageToImageFrameCalculatorTest, TestGrayscaleJpeg) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/dino.jpg")), + input_mat, cv::COLOR_RGB2GRAY); + std::vector encode_buffer; + std::vector parameters; + parameters.push_back(cv::IMWRITE_JPEG_QUALITY); + parameters.push_back(100); + cv::imencode(".jpg", input_mat, encode_buffer, parameters); + Packet input_packet = MakePacket(std::string(absl::string_view( + reinterpret_cast(&encode_buffer[0]), encode_buffer.size()))); + + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "OpenCvEncodedImageToImageFrameCalculator" + input_stream: "encoded_image" + output_stream: "image_frame" + )"); + CalculatorRunner runner(node_config); + runner.MutableInputs()->Index(0).packets.push_back( + input_packet.At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(1, outputs.NumEntries()); + const std::vector& packets = outputs.Index(0).packets; + ASSERT_EQ(1, packets.size()); + const ImageFrame& output_frame = packets[0].Get(); + cv::Mat diff; + cv::absdiff(input_mat, formats::MatView(&output_frame), diff); + double max_val; + cv::minMaxLoc(diff, nullptr, &max_val); + // Expects that the maximum absolute pixel-by-pixel difference is less + // than 10. + EXPECT_LE(max_val, 10); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc new file mode 100644 index 000000000..efe79d99c --- /dev/null +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc @@ -0,0 +1,121 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" + +namespace mediapipe { + +// Calculator to encode raw image frames. This will result in considerable space +// savings if the frames need to be stored on disk. +// +// Example config: +// node { +// calculator: "OpenCvImageEncoderCalculator" +// input_stream: "image" +// output_stream: "encoded_image" +// node_options { +// [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { +// quality: 80 +// } +// } +// } +class OpenCvImageEncoderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + int encoding_quality_; +}; + +::mediapipe::Status OpenCvImageEncoderCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { + auto options = cc->Options(); + encoding_quality_ = options.quality(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvImageEncoderCalculator::Process( + CalculatorContext* cc) { + const ImageFrame& image_frame = cc->Inputs().Index(0).Get(); + CHECK_EQ(1, image_frame.ByteDepth()); + + std::unique_ptr encoded_result = + absl::make_unique(); + encoded_result->set_width(image_frame.Width()); + encoded_result->set_height(image_frame.Height()); + + cv::Mat original_mat = formats::MatView(&image_frame); + cv::Mat input_mat; + switch (original_mat.channels()) { + case 1: + input_mat = original_mat; + encoded_result->set_colorspace( + OpenCvImageEncoderCalculatorResults::GRAYSCALE); + break; + case 3: + // OpenCV assumes the image to be BGR order. To use imencode(), do color + // conversion first. + cv::cvtColor(original_mat, input_mat, cv::COLOR_RGB2BGR); + encoded_result->set_colorspace(OpenCvImageEncoderCalculatorResults::RGB); + break; + case 4: + return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + << "4-channel image isn't supported yet"; + default: + return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) + << "Unsupported number of channels: " << original_mat.channels(); + } + + std::vector parameters; + parameters.push_back(cv::IMWRITE_JPEG_QUALITY); + parameters.push_back(encoding_quality_); + + std::vector encode_buffer; + // Note that imencode() will store the data in RGB order. + // Check its JpegEncoder::write() in "imgcodecs/src/grfmt_jpeg.cpp" for more + // info. + if (!cv::imencode(".jpg", input_mat, encode_buffer, parameters)) { + return ::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) + << "Fail to encode the image to be jpeg format."; + } + + encoded_result->set_encoded_image(std::string(absl::string_view( + reinterpret_cast(&encode_buffer[0]), encode_buffer.size()))); + + cc->Outputs().Index(0).Add(encoded_result.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) { + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(OpenCvImageEncoderCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.proto b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto new file mode 100644 index 000000000..43172b319 --- /dev/null +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto @@ -0,0 +1,47 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message OpenCvImageEncoderCalculatorOptions { + extend CalculatorOptions { + optional OpenCvImageEncoderCalculatorOptions ext = 227563646; + } + + // Quality of the encoding. An integer between (0, 100]. + optional int32 quality = 1; +} + +// TODO: Consider renaming it to EncodedImage. +message OpenCvImageEncoderCalculatorResults { + // Encoded image + optional string encoded_image = 1; + + // Dimensions of the encoded image + optional int32 height = 2; + optional int32 width = 3; + + enum ColorSpace { + UNKNOWN = 0; + GRAYSCALE = 1; + RGB = 2; + } + + // Color space used. + optional ColorSpace colorspace = 4; +} diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator_test.cc b/mediapipe/calculators/image/opencv_image_encoder_calculator_test.cc new file mode 100644 index 000000000..48b867c2f --- /dev/null +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator_test.cc @@ -0,0 +1,87 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +TEST(OpenCvImageEncoderCalculatorTest, TestJpegWithQualities) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/dino.jpg")), + input_mat, cv::COLOR_BGR2RGB); + Packet input_packet = MakePacket( + ImageFormat::SRGB, input_mat.size().width, input_mat.size().height); + input_mat.copyTo(formats::MatView(&(input_packet.Get()))); + + std::vector qualities = {50, 80}; + for (int quality : qualities) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie( + absl::Substitute(R"( + calculator: "OpenCvImageEncoderCalculator" + input_stream: "image_frames" + output_stream: "encoded_images" + node_options { + [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { + quality: $0 + } + })", + quality)); + CalculatorRunner runner(node_config); + runner.MutableInputs()->Index(0).packets.push_back( + input_packet.At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(1, outputs.NumEntries()); + const std::vector& packets = outputs.Index(0).packets; + ASSERT_EQ(1, packets.size()); + const auto& result = packets[0].Get(); + ASSERT_EQ(input_mat.size().height, result.height()); + ASSERT_EQ(input_mat.size().width, result.width()); + ASSERT_EQ(OpenCvImageEncoderCalculatorResults::RGB, result.colorspace()); + + cv::Mat expected_output = cv::imread( + file::JoinPath("./", absl::Substitute("/mediapipe/calculators/image/" + "testdata/dino_quality_$0.jpg", + quality))); + const std::vector contents_vector(result.encoded_image().begin(), + result.encoded_image().end()); + cv::Mat decoded_output = + cv::imdecode(contents_vector, -1 /* return the loaded image as-is */); + cv::Mat diff; + cv::absdiff(expected_output, decoded_output, diff); + double max_val; + cv::minMaxLoc(diff, nullptr, &max_val); + // Expects that the maximum absolute pixel-by-pixel difference is less + // than 10. + EXPECT_LE(max_val, 10); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/image/opencv_put_text_calculator.cc b/mediapipe/calculators/image/opencv_put_text_calculator.cc new file mode 100644 index 000000000..ff336ff92 --- /dev/null +++ b/mediapipe/calculators/image/opencv_put_text_calculator.cc @@ -0,0 +1,60 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" + +namespace mediapipe { + +// Takes in a std::string, draws the text std::string by cv::putText(), and +// outputs an ImageFrame. +// +// Example config: +// node { +// calculator: "OpenCvPutTextCalculator" +// input_stream: "text_to_put" +// output_stream: "out_image_frames" +// } +// TODO: Generalize the calculator for other text use cases. +class OpenCvPutTextCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Process(CalculatorContext* cc) override; +}; + +::mediapipe::Status OpenCvPutTextCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { + const std::string& text_content = cc->Inputs().Index(0).Get(); + cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC3); + cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3, + cv::Scalar(255, 255, 0), 4); + std::unique_ptr output_frame = absl::make_unique( + ImageFormat::SRGB, mat.size().width, mat.size().height); + mat.copyTo(formats::MatView(output_frame.get())); + cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(OpenCvPutTextCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc new file mode 100644 index 000000000..b5b49adcb --- /dev/null +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -0,0 +1,377 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/image/recolor_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/color.pb.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ + +namespace { +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + +namespace mediapipe { + +// A calculator to recolor a masked area of an image to a specified color. +// +// A mask image is used to specify where to overlay a user defined color. +// The luminance of the input image is used to adjust the blending weight, +// to help preserve image textures. +// +// TODO implement cpu support. +// +// Inputs: +// One of the following IMAGE tags: +// IMAGE: An ImageFrame input image, RGB or RGBA. +// IMAGE_GPU: A GpuBuffer input image, RGBA. +// One of the following MASK tags: +// MASK: An ImageFrame input mask, Gray, RGB or RGBA. +// MASK_GPU: A GpuBuffer input mask, RGBA. +// Output: +// One of the following IMAGE tags: +// IMAGE: An ImageFrame output image. +// IMAGE_GPU: A GpuBuffer output image. +// +// Options: +// color_rgb (required): A map of RGB values [0-255]. +// mask_channel (optional): Which channel of mask image is used [RED or ALPHA] +// +// Usage example: +// node { +// calculator: "RecolorCalculator" +// input_stream: "IMAGE_GPU:input_image" +// input_stream: "MASK_GPU:input_mask" +// output_stream: "IMAGE_GPU:output_image" +// node_options: { +// [mediapipe.RecolorCalculatorOptions] { +// color { r: 0 g: 0 b: 255 } +// mask_channel: RED +// } +// } +// } +// +class RecolorCalculator : public CalculatorBase { + public: + RecolorCalculator() = default; + ~RecolorCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + ::mediapipe::Status InitGpu(CalculatorContext* cc); + ::mediapipe::Status RenderGpu(CalculatorContext* cc); + ::mediapipe::Status RenderCpu(CalculatorContext* cc); + void GlRender(); + + bool initialized_ = false; + std::vector color_; + mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; + + bool use_gpu_ = false; +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; +#endif // __ANDROID__ +}; +REGISTER_CALCULATOR(RecolorCalculator); + +// static +::mediapipe::Status RecolorCalculator::GetContract(CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("IMAGE_GPU")) { + cc->Inputs().Tag("IMAGE_GPU").Set(); + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag("IMAGE")) { + cc->Inputs().Tag("IMAGE").Set(); + } + +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("MASK_GPU")) { + cc->Inputs().Tag("MASK_GPU").Set(); + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag("MASK")) { + cc->Inputs().Tag("MASK").Set(); + } + +#if defined(__ANDROID__) + if (cc->Outputs().HasTag("IMAGE_GPU")) { + cc->Outputs().Tag("IMAGE_GPU").Set(); + } +#endif // __ANDROID__ + if (cc->Outputs().HasTag("IMAGE")) { + cc->Outputs().Tag("IMAGE").Set(); + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) { + if (cc->Inputs().HasTag("IMAGE_GPU")) { + use_gpu_ = true; +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif // __ANDROID__ + } + + RETURN_IF_ERROR(LoadOptions(cc)); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + if (!initialized_) { + RETURN_IF_ERROR(InitGpu(cc)); + initialized_ = true; + } + RETURN_IF_ERROR(RenderGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ + } else { + RETURN_IF_ERROR(RenderCpu(cc)); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + }); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { + return ::mediapipe::UnimplementedError("CPU support is not implemented yet."); +} + +::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { + if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) { + return ::mediapipe::OkStatus(); + } +#if defined(__ANDROID__) + // Get inputs and setup output. + const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); + const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value(); + + const auto& input_buffer = input_packet.Get(); + const auto& mask_buffer = mask_packet.Get(); + + auto img_tex = gpu_helper_.CreateSourceTexture(input_buffer); + auto mask_tex = gpu_helper_.CreateSourceTexture(mask_buffer); + auto dst_tex = + gpu_helper_.CreateDestinationTexture(img_tex.width(), img_tex.height()); + + // Run recolor shader on GPU. + { + gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 + + glActiveTexture(GL_TEXTURE1); + glBindTexture(img_tex.target(), img_tex.name()); + glActiveTexture(GL_TEXTURE2); + glBindTexture(mask_tex.target(), mask_tex.name()); + + GlRender(); + + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } + + // Send result image in GPU packet. + auto output = dst_tex.GetFrame(); + cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); + + // Cleanup + img_tex.Release(); + mask_tex.Release(); + dst_tex.Release(); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +void RecolorCalculator::GlRender() { +#if defined(__ANDROID__) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // __ANDROID__ +} + +::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { + const auto& options = cc->Options(); + + mask_channel_ = options.mask_channel(); + + if (!options.has_color()) RET_CHECK_FAIL() << "Missing color option."; + + color_.push_back(options.color().r() / 255.0); + color_.push_back(options.color().g() / 255.0); + color_.push_back(options.color().b() / 255.0); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { +#if defined(__ANDROID__) + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + std::string mask_component; + switch (mask_channel_) { + case mediapipe::RecolorCalculatorOptions_MaskChannel_UNKNOWN: + case mediapipe::RecolorCalculatorOptions_MaskChannel_RED: + mask_component = "r"; + break; + case mediapipe::RecolorCalculatorOptions_MaskChannel_ALPHA: + mask_component = "a"; + break; + } + + // A shader to blend a color onto an image where the mask > 0. + // The blending is based on the input image luminosity. + const std::string frag_src = R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + #define MASK_COMPONENT )" + mask_component + + R"( + + in vec2 sample_coordinate; + uniform sampler2D frame; + uniform sampler2D mask; + uniform vec3 recolor; + + void main() { + vec4 weight = texture2D(mask, sample_coordinate); + vec4 color1 = texture2D(frame, sample_coordinate); + vec4 color2 = vec4(recolor, 1.0); + + float luminance = dot(color1.rgb, vec3(0.299, 0.587, 0.114)); + float mix_value = weight.MASK_COMPONENT * luminance; + + fragColor = mix(color1, color2, mix_value); + } + )"; + + // shader program and params + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(), + NUM_ATTRIBUTES, &attr_name[0], attr_location, + &program_); + RET_CHECK(program_) << "Problem initializing the program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "frame"), 1); + glUniform1i(glGetUniformLocation(program_, "mask"), 2); + glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1], + color_[2]); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/recolor_calculator.proto b/mediapipe/calculators/image/recolor_calculator.proto new file mode 100644 index 000000000..76326c079 --- /dev/null +++ b/mediapipe/calculators/image/recolor_calculator.proto @@ -0,0 +1,39 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; + +message RecolorCalculatorOptions { + extend CalculatorOptions { + optional RecolorCalculatorOptions ext = 252527117; + } + + enum MaskChannel { + UNKNOWN = 0; + RED = 1; + ALPHA = 2; + } + + // Selects which channel of the MASK input to use for masking. + optional MaskChannel mask_channel = 1 [default = RED]; + + // Color to blend into input image where mask is > 0. + // The blending is based on the input image luminosity. + optional Color color = 2; +} diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc new file mode 100644 index 000000000..c9be8774b --- /dev/null +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -0,0 +1,693 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This Calculator takes an ImageFrame and scales it appropriately. + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "libyuv/scale.h" +#include "mediapipe/calculators/image/scale_image_calculator.pb.h" +#include "mediapipe/calculators/image/scale_image_utils.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/formats/yuv_image.h" +#include "mediapipe/framework/port/image_resizer.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/image_frame_util.h" + +namespace mediapipe { + +namespace { + +// Given an upscaling algorithm, determine which OpenCV interpolation algorithm +// to use. +::mediapipe::Status FindInterpolationAlgorithm( + ScaleImageCalculatorOptions::ScaleAlgorithm upscaling_algorithm, + int* interpolation_algorithm) { + switch (upscaling_algorithm) { + case ScaleImageCalculatorOptions::DEFAULT: + *interpolation_algorithm = cv::INTER_CUBIC; + break; + case ScaleImageCalculatorOptions::LINEAR: + *interpolation_algorithm = cv::INTER_LINEAR; + break; + case ScaleImageCalculatorOptions::CUBIC: + *interpolation_algorithm = cv::INTER_CUBIC; + break; + case ScaleImageCalculatorOptions::AREA: + *interpolation_algorithm = cv::INTER_AREA; + break; + case ScaleImageCalculatorOptions::LANCZOS: + *interpolation_algorithm = cv::INTER_LANCZOS4; + break; + case ScaleImageCalculatorOptions::DEFAULT_WITHOUT_UPSCALE: + *interpolation_algorithm = -1; + break; + default: + RET_CHECK_FAIL() << absl::Substitute("Unknown upscaling algorithm: $0", + upscaling_algorithm); + } + return ::mediapipe::OkStatus(); +} + +void CropImageFrame(const ImageFrame& original, int col_start, int row_start, + int crop_width, int crop_height, ImageFrame* cropped) { + const uint8* src = original.PixelData(); + uint8* dst = cropped->MutablePixelData(); + + int des_y = 0; + for (int y = row_start; y < row_start + crop_height; ++y) { + const uint8* src_line = src + y * original.WidthStep(); + const uint8* src_pixel = src_line + col_start * + original.NumberOfChannels() * + original.ByteDepth(); + uint8* dst_line = dst + des_y * cropped->WidthStep(); + std::memcpy( + dst_line, src_pixel, + crop_width * cropped->NumberOfChannels() * cropped->ByteDepth()); + ++des_y; + } +} + +} // namespace + +// Crops and scales an ImageFrame or YUVImage according to the options; +// The output can be cropped and scaled ImageFrame with the SRGB format. If the +// input is a YUVImage, the output can be a scaled YUVImage (the scaling is done +// using libyuv). Cropping is not yet supported for a YUVImage to a scaled +// YUVImage conversion. +// +// Example config: +// node { +// calculator: "ScaleImageCalculator" +// input_stream: "raw_frames" +// output_stream: "scaled_frames" +// node_options { +// [type.googleapis.com/mediapipe.ScaleImageCalculatorOptions] { +// target_width: 320 +// target_height: 320 +// preserve_aspect_ratio: true +// output_format: SRGB +// algorithm: DEFAULT +// } +// } +// } +// +// ScaleImageCalculator can also create or update a VideoHeader that is +// provided at Timestamp::PreStream on stream VIDEO_HEADER. +// +// Example config: +// node { +// calculator: "ScaleImageCalculator" +// input_stream: "FRAMES:ycbcr_frames" +// input_stream: "VIDEO_HEADER:ycbcr_frames_header" # Optional. +// output_stream: "FRAMES:srgb_frames" +// output_stream: "VIDEO_HEADER:srgb_frames_header" # Independently Optional. +// node_options { +// [type.googleapis.com/mediapipe.ScaleImageCalculatorOptions] { +// target_width: 320 +// target_height: 320 +// preserve_aspect_ratio: true +// output_format: SRGB +// algorithm: DEFAULT +// } +// } +// } +// +// The calculator options can be overrided with an input stream +// "OVERRIDE_OPTIONS". If this is provided, and non-empty at PreStream, the +// calculator options proto is merged with the proto provided in this packet +// (fields are overwritten in the original options) and the +// initialization happens in Process at PreStream, and not at Open. +class ScaleImageCalculator : public CalculatorBase { + public: + ScaleImageCalculator(); + ~ScaleImageCalculator() override; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + ScaleImageCalculatorOptions options = + cc->Options(); + + CollectionItemId input_data_id = cc->Inputs().GetId("FRAMES", 0); + if (!input_data_id.IsValid()) { + input_data_id = cc->Inputs().GetId("", 0); + } + CollectionItemId output_data_id = cc->Outputs().GetId("FRAMES", 0); + if (!output_data_id.IsValid()) { + output_data_id = cc->Outputs().GetId("", 0); + } + + if (cc->Inputs().HasTag("VIDEO_HEADER")) { + cc->Inputs().Tag("VIDEO_HEADER").Set(); + } + if (options.has_input_format() && + options.input_format() == ImageFormat::YCBCR420P) { + cc->Inputs().Get(input_data_id).Set(); + } else { + cc->Inputs().Get(input_data_id).Set(); + } + + if (cc->Outputs().HasTag("VIDEO_HEADER")) { + cc->Outputs().Tag("VIDEO_HEADER").Set(); + } + if (options.has_output_format() && + options.output_format() == ImageFormat::YCBCR420P) { + RET_CHECK_EQ(ImageFormat::YCBCR420P, options.input_format()); + cc->Outputs().Get(output_data_id).Set(); + } else { + cc->Outputs().Get(output_data_id).Set(); + } + + if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { + cc->Inputs().Tag("OVERRIDE_OPTIONS").Set(); + } + return ::mediapipe::OkStatus(); + } + + // From Calculator. + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Initialize some data members from options_. This can be called either from + // Open or Process depending on whether OVERRIDE_OPTIONS is used. + ::mediapipe::Status InitializeFromOptions(); + // Initialize crop and output parameters based on set member variable + // values. This function will also send the header information on + // the VIDEO_HEADER stream if it hasn't been done yet. + ::mediapipe::Status InitializeFrameInfo(CalculatorContext* cc); + // Validate that input_format_ and output_format_ are supported image + // formats. + ::mediapipe::Status ValidateImageFormats() const; + // Validate that the image frame has the proper format and dimensions. + // If the dimensions and format weren't initialized by the header, + // then the first frame on which this function is called is used + // to initialize. + ::mediapipe::Status ValidateImageFrame(CalculatorContext* cc, + const ImageFrame& image_frame); + // Validate that the YUV image has the proper dimensions. If the + // dimensions weren't initialized by the header, then the first image + // on which this function is called is used to initialize. + ::mediapipe::Status ValidateYUVImage(CalculatorContext* cc, + const YUVImage& yuv_image); + + bool has_header_; // True if the input stream has a header. + int input_width_; + int input_height_; + int crop_width_; + int crop_height_; + int col_start_; + int row_start_; + int output_width_; + int output_height_; + ImageFormat::Format input_format_; + ImageFormat::Format output_format_; + int interpolation_algorithm_; + + // The "DATA" input stream. + CollectionItemId input_data_id_; + // The "DATA" output stream. + CollectionItemId output_data_id_; + VideoHeader input_video_header_; + + // Whether the header information was sent on the VIDEO_HEADER stream. + bool header_sent_ = false; + + // The alignment boundary that newly created images should have. + int alignment_boundary_; + + ScaleImageCalculatorOptions options_; + + // Efficient image resizer with gamma correction and optional sharpening. + std::unique_ptr downscaler_; +}; + +REGISTER_CALCULATOR(ScaleImageCalculator); + +ScaleImageCalculator::ScaleImageCalculator() {} + +ScaleImageCalculator::~ScaleImageCalculator() {} + +::mediapipe::Status ScaleImageCalculator::InitializeFrameInfo( + CalculatorContext* cc) { + RETURN_IF_ERROR( + scale_image::FindCropDimensions(input_width_, input_height_, // + options_.min_aspect_ratio(), // + options_.max_aspect_ratio(), // + &crop_width_, &crop_height_, // + &col_start_, &row_start_)); + RETURN_IF_ERROR( + scale_image::FindOutputDimensions(crop_width_, crop_height_, // + options_.target_width(), // + options_.target_height(), // + options_.preserve_aspect_ratio(), // + options_.scale_to_multiple_of_two(), // + &output_width_, &output_height_)); + RETURN_IF_ERROR(FindInterpolationAlgorithm(options_.algorithm(), + &interpolation_algorithm_)); + if (interpolation_algorithm_ == -1 && + (output_width_ > crop_width_ || output_height_ > crop_height_)) { + output_width_ = crop_width_; + output_height_ = crop_height_; + } + VLOG(1) << "Image scaling parameters:" + << "\ninput_width_ " << input_width_ // + << "\ninput_height_ " << input_height_ // + << "\ninput_format_ " << input_format_ // + << "\ncrop_width_ " << crop_width_ // + << "\ncrop_height_ " << crop_height_ // + << "\ncol_start_ " << col_start_ // + << "\nrow_start_ " << row_start_ // + << "\noutput_width_ " << output_width_ // + << "\noutput_height_ " << output_height_ // + << "\noutput_format_ " << output_format_ // + << "\nOpenCV interpolation algorithm " << interpolation_algorithm_; + if (!header_sent_ && cc->Outputs().UsesTags() && + cc->Outputs().HasTag("VIDEO_HEADER")) { + header_sent_ = true; + auto header = absl::make_unique(); + *header = input_video_header_; + header->width = output_width_; + header->height = output_height_; + header->format = output_format_; + LOG(INFO) << "OUTPUTTING HEADER on stream"; + cc->Outputs() + .Tag("VIDEO_HEADER") + .Add(header.release(), Timestamp::PreStream()); + cc->Outputs().Tag("VIDEO_HEADER").Close(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + + input_data_id_ = cc->Inputs().GetId("FRAMES", 0); + if (!input_data_id_.IsValid()) { + input_data_id_ = cc->Inputs().GetId("", 0); + } + output_data_id_ = cc->Outputs().GetId("FRAMES", 0); + if (!output_data_id_.IsValid()) { + output_data_id_ = cc->Outputs().GetId("", 0); + } + + // The output packets are at the same timestamp as the input. + cc->Outputs().Get(output_data_id_).SetOffset(mediapipe::TimestampDiff(0)); + + has_header_ = false; + input_width_ = 0; + input_height_ = 0; + crop_width_ = 0; + crop_height_ = 0; + output_width_ = 0; + output_height_ = 0; + bool has_override_options = cc->Inputs().HasTag("OVERRIDE_OPTIONS"); + + if (!has_override_options) { + RETURN_IF_ERROR(InitializeFromOptions()); + } + + if (!cc->Inputs().Get(input_data_id_).Header().IsEmpty()) { + // If the input stream has a header then our output stream also has a + // header. + + if (has_override_options) { + // It's not possible to use OVERRIDE_OPTIONS when the main input stream + // has a header. At this point in the code, the ScaleImageCalculator + // config may be changed by the new options at PreStream, so the output + // header can't be determined. + return ::mediapipe::InvalidArgumentError( + "OVERRIDE_OPTIONS stream can't be used when the main input stream " + "has a header."); + } + input_video_header_ = + cc->Inputs().Get(input_data_id_).Header().Get(); + + input_format_ = input_video_header_.format; + if (options_.has_input_format()) { + RET_CHECK_EQ(input_format_, options_.input_format()) + << "The input header format does not match the input_format option."; + } + + input_width_ = input_video_header_.width; + input_height_ = input_video_header_.height; + + if (options_.has_output_format()) { + output_format_ = options_.output_format(); + } else { + output_format_ = input_format_; + } + + if (output_format_ == ImageFormat::YCBCR420P) { + RET_CHECK(options_.scale_to_multiple_of_two()) + << "ScaleImageCalculator always outputs width and height that are " + "divisible by 2 when output format is YCbCr420P. To scale to " + "width and height of odd numbers, the output format must be SRGB."; + } else if (options_.preserve_aspect_ratio()) { + RET_CHECK(options_.scale_to_multiple_of_two()) + << "ScaleImageCalculator always outputs width and height that are " + "divisible by 2 when perserving aspect ratio. To scale to width " + "and height of odd numbers, please set " + "preserve_aspect_ratio to false."; + } + + if (input_width_ > 0 && input_height_ > 0 && + input_format_ != ImageFormat::UNKNOWN && + output_format_ != ImageFormat::UNKNOWN) { + RETURN_IF_ERROR(ValidateImageFormats()); + RETURN_IF_ERROR(InitializeFrameInfo(cc)); + std::unique_ptr output_header(new VideoHeader()); + *output_header = input_video_header_; + output_header->format = output_format_; + output_header->width = output_width_; + output_header->height = output_height_; + cc->Outputs() + .Get(output_data_id_) + .SetHeader(Adopt(output_header.release())); + has_header_ = true; + } else { + LOG(WARNING) << "Stream had a VideoHeader which didn't have sufficient " + "information. " + "Dropping VideoHeader and trying to deduce needed " + "information."; + input_width_ = 0; + input_height_ = 0; + if (!options_.has_input_format()) { + input_format_ = ImageFormat::UNKNOWN; + } + output_format_ = ImageFormat::UNKNOWN; + } + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::InitializeFromOptions() { + if (options_.has_input_format()) { + input_format_ = options_.input_format(); + } else { + input_format_ = ImageFormat::UNKNOWN; + } + + alignment_boundary_ = 16; + if (options_.alignment_boundary() > 0) { + alignment_boundary_ = options_.alignment_boundary(); + } + + downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient())); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::ValidateImageFormats() const { + RET_CHECK_NE(input_format_, ImageFormat::UNKNOWN) + << "The input image format was UNKNOWN."; + RET_CHECK_NE(output_format_, ImageFormat::UNKNOWN) + << "The output image format was set to UNKNOWN."; + // TODO Remove these conditions. + RET_CHECK(output_format_ == ImageFormat::SRGB || + (input_format_ == output_format_ && + output_format_ == ImageFormat::YCBCR420P)) + << "Outputting YCbCr420P images from SRGB input is not yet supported"; + RET_CHECK(input_format_ == output_format_ || + input_format_ == ImageFormat::YCBCR420P) + << "Conversion of the color space (except from " + "YCbCr420P to SRGB) is not yet supported."; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::ValidateImageFrame( + CalculatorContext* cc, const ImageFrame& image_frame) { + if (!has_header_) { + if (input_width_ != image_frame.Width() || + input_height_ != image_frame.Height() || + input_format_ != image_frame.Format()) { + // Set the dimensions based on the image frame. There was no header. + input_width_ = image_frame.Width(); + input_height_ = image_frame.Height(); + RET_CHECK(input_width_ > 0 && input_height_ > 0) << absl::StrCat( + "The input image did not have positive dimensions. dimensions: ", + input_width_, "x", input_height_); + input_format_ = image_frame.Format(); + if (options_.has_input_format()) { + RET_CHECK_EQ(input_format_, options_.input_format()) + << "The input image format does not match the input_format option."; + } + if (options_.has_output_format()) { + output_format_ = options_.output_format(); + } else { + output_format_ = input_format_; + } + RETURN_IF_ERROR(InitializeFrameInfo(cc)); + } + RETURN_IF_ERROR(ValidateImageFormats()); + } else { + if (input_width_ != image_frame.Width() || + input_height_ != image_frame.Height()) { + return tool::StatusFail(absl::StrCat( + "If a header specifies a width and a height, then image frames on " + "the stream must have that size. Received frame of size ", + image_frame.Width(), "x", image_frame.Height(), " but expected ", + input_width_, "x", input_height_)); + } + if (input_format_ != image_frame.Format()) { + const proto_ns::EnumDescriptor* desc = ImageFormat::Format_descriptor(); + return tool::StatusFail(absl::StrCat( + "If a header specifies a format, then image frames on " + "the stream must have that format. Actual format ", + desc->FindValueByNumber(image_frame.Format())->DebugString(), + " but expected ", + desc->FindValueByNumber(input_format_)->DebugString())); + } + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::ValidateYUVImage( + CalculatorContext* cc, const YUVImage& yuv_image) { + CHECK_EQ(input_format_, ImageFormat::YCBCR420P); + if (!has_header_) { + if (input_width_ != yuv_image.width() || + input_height_ != yuv_image.height()) { + // Set the dimensions based on the YUV image. There was no header. + input_width_ = yuv_image.width(); + input_height_ = yuv_image.height(); + RET_CHECK(input_width_ > 0 && input_height_ > 0) << absl::StrCat( + "The input image did not have positive dimensions. dimensions: ", + input_width_, "x", input_height_); + if (options_.has_output_format()) { + output_format_ = options_.output_format(); + } else { + output_format_ = input_format_; + } + RETURN_IF_ERROR(InitializeFrameInfo(cc)); + } + RETURN_IF_ERROR(ValidateImageFormats()); + } else { + if (input_width_ != yuv_image.width() || + input_height_ != yuv_image.height()) { + return tool::StatusFail(absl::StrCat( + "If a header specifies a width and a height, then YUV images on " + "the stream must have that size. Additionally, all YUV images in " + "a stream must have the same size. Received frame of size ", + yuv_image.width(), "x", yuv_image.height(), " but expected ", + input_width_, "x", input_height_)); + } + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { + if (cc->InputTimestamp() == Timestamp::PreStream()) { + if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { + if (cc->Inputs().Tag("OVERRIDE_OPTIONS").IsEmpty()) { + return ::mediapipe::InvalidArgumentError( + "The OVERRIDE_OPTIONS input stream must be non-empty at PreStream " + "time if used."); + } + options_.MergeFrom(cc->Inputs() + .Tag("OVERRIDE_OPTIONS") + .Get()); + RETURN_IF_ERROR(InitializeFromOptions()); + } + if (cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && + !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { + input_video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); + } + if (cc->Inputs().Get(input_data_id_).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + } + + cc->GetCounter("Inputs")->Increment(); + const ImageFrame* image_frame; + ImageFrame converted_image_frame; + if (input_format_ == ImageFormat::YCBCR420P) { + const YUVImage* yuv_image = + &cc->Inputs().Get(input_data_id_).Get(); + RETURN_IF_ERROR(ValidateYUVImage(cc, *yuv_image)); + + if (output_format_ == ImageFormat::SRGB) { + // TODO: For ease of implementation, YUVImage is converted to + // ImageFrame immediately, before cropping and scaling. Investigate how to + // make color space conversion more efficient when cropping or scaling is + // also needed. + image_frame_util::YUVImageToImageFrame(*yuv_image, &converted_image_frame, + options_.use_bt709()); + image_frame = &converted_image_frame; + } else if (output_format_ == ImageFormat::YCBCR420P) { + RET_CHECK(row_start_ == 0 && col_start_ == 0 && + crop_width_ == input_width_ && crop_height_ == input_height_) + << "ScaleImageCalculator only supports scaling on YUVImages. To crop " + "images, the output format must be SRGB."; + + // Scale the YUVImage and output without converting the color space. + const int y_size = output_width_ * output_height_; + const int uv_size = output_width_ * output_height_ / 4; + std::unique_ptr yuv_data(new uint8_t[y_size + uv_size * 2]); + uint8* y = yuv_data.get(); + uint8* u = y + y_size; + uint8* v = u + uv_size; + RET_CHECK_EQ(0, I420Scale(yuv_image->data(0), yuv_image->stride(0), + yuv_image->data(1), yuv_image->stride(1), + yuv_image->data(2), yuv_image->stride(2), + yuv_image->width(), yuv_image->height(), y, + output_width_, u, output_width_ / 2, v, + output_width_ / 2, output_width_, + output_height_, libyuv::kFilterBox)); + auto output_image = absl::make_unique( + libyuv::FOURCC_I420, std::move(yuv_data), y, output_width_, u, + output_width_ / 2, v, output_width_ / 2, output_width_, + output_height_); + cc->GetCounter("Outputs Scaled")->Increment(); + if (yuv_image->width() >= output_width_ && + yuv_image->height() >= output_height_) { + cc->GetCounter("Downscales")->Increment(); + } else if (interpolation_algorithm_ != -1) { + cc->GetCounter("Upscales")->Increment(); + } + cc->Outputs() + .Get(output_data_id_) + .Add(output_image.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + } else { + image_frame = &cc->Inputs().Get(input_data_id_).Get(); + RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame)); + } + + std::unique_ptr cropped_image; + if (crop_width_ < input_width_ || crop_height_ < input_height_) { + cc->GetCounter("Crops")->Increment(); + // TODO Do the crop as a range restrict inside OpenCV code below. + cropped_image.reset(new ImageFrame(image_frame->Format(), crop_width_, + crop_height_, alignment_boundary_)); + if (image_frame->ByteDepth() == 1 || image_frame->ByteDepth() == 2) { + CropImageFrame(*image_frame, col_start_, row_start_, crop_width_, + crop_height_, cropped_image.get()); + } else { + return tool::StatusInvalid( + "Input format does not have ByteDepth of 1 or 2."); + } + + // Update the image_frame to point to the cropped image. The + // unique_ptr will take care of deleting the cropped image when the + // function returns. + image_frame = cropped_image.get(); + } + + // Skip later operations if no scaling is necessary. + if (crop_width_ == output_width_ && crop_height_ == output_height_) { + // Efficiently use either the cropped image or the original image. + if (image_frame == cropped_image.get()) { + if (options_.set_alignment_padding()) { + cropped_image->SetAlignmentPaddingAreas(); + } + cc->GetCounter("Outputs Cropped")->Increment(); + cc->Outputs() + .Get(output_data_id_) + .Add(cropped_image.release(), cc->InputTimestamp()); + } else { + if (options_.alignment_boundary() <= 0 && + (!options_.set_alignment_padding() || image_frame->IsContiguous())) { + // Any alignment is acceptable and we don't need to clear the + // alignment padding (either because the user didn't request it + // or because the data is contiguous). + cc->GetCounter("Outputs Inputs")->Increment(); + cc->Outputs() + .Get(output_data_id_) + .AddPacket(cc->Inputs().Get(input_data_id_).Value()); + } else { + // Make a copy with the correct alignment. + std::unique_ptr output_frame(new ImageFrame()); + output_frame->CopyFrom(*image_frame, alignment_boundary_); + if (options_.set_alignment_padding()) { + output_frame->SetAlignmentPaddingAreas(); + } + cc->GetCounter("Outputs Aligned")->Increment(); + cc->Outputs() + .Get(output_data_id_) + .Add(output_frame.release(), cc->InputTimestamp()); + } + } + return ::mediapipe::OkStatus(); + } + + // Rescale the image frame. + std::unique_ptr output_frame(new ImageFrame()); + if (image_frame->Width() >= output_width_ && + image_frame->Height() >= output_height_) { + // Downscale. + cc->GetCounter("Downscales")->Increment(); + cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame); + output_frame->Reset(image_frame->Format(), output_width_, output_height_, + alignment_boundary_); + cv::Mat output_mat = ::mediapipe::formats::MatView(output_frame.get()); + downscaler_->Resize(input_mat, &output_mat); + } else { + // Upscale. If upscaling is disallowed, output_width_ and output_height_ are + // the same as the input/crop width and height. + image_frame_util::RescaleImageFrame( + *image_frame, output_width_, output_height_, alignment_boundary_, + interpolation_algorithm_, output_frame.get()); + if (interpolation_algorithm_ != -1) { + cc->GetCounter("Upscales")->Increment(); + } + } + + if (options_.set_alignment_padding()) { + cc->GetCounter("Pads")->Increment(); + output_frame->SetAlignmentPaddingAreas(); + } + + cc->GetCounter("Outputs Scaled")->Increment(); + cc->Outputs() + .Get(output_data_id_) + .Add(output_frame.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_calculator.proto b/mediapipe/calculators/image/scale_image_calculator.proto new file mode 100644 index 000000000..2fc782a4f --- /dev/null +++ b/mediapipe/calculators/image/scale_image_calculator.proto @@ -0,0 +1,109 @@ +// Options for ScaleImageCalculator. +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/formats/image_format.proto"; + +// Order of operations. +// 1) Crop the image to fit within min_aspect_ratio and max_aspect_ratio. +// 2) Scale and convert the image to fit inside target_width x target_height +// using the specified scaling algorithm. (maintaining the aspect +// ratio if preserve_aspect_ratio is true). +// The output width and height will be divisible by 2. It is possible to output +// width and height that are odd number when the output format is SRGB and not +// perserving the aspect ratio. See scale_to_multiple_of_two option for details. +message ScaleImageCalculatorOptions { + extend CalculatorOptions { + optional ScaleImageCalculatorOptions ext = 66237115; + } + + // Target output width and height. The final output's size may vary + // depending on the other options below. If unset, use the same width + // or height as the input. If only one is set then determine the other + // from the aspect ratio (after cropping). The output width and height + // will be divisible by 2. + optional int32 target_width = 1; + optional int32 target_height = 2; + + // If true, the image is scaled up or down proportionally so that it + // fits inside the box represented by target_width and target_height. + // Otherwise it is scaled to fit target_width and target_height + // completely. In any case, the aspect ratio that is preserved is + // that after cropping to the minimum/maximum aspect ratio. + optional bool preserve_aspect_ratio = 3 [default = true]; + + // If ratio is positive, crop the image to this minimum and maximum + // aspect ratio (preserving the center of the frame). This is done + // before scaling. + // For example, for a min_aspect_ratio of "9/16" and max of "16/9" the + // following cropping will occur: + // 1920x1080 (which is 16:9) is not cropped + // 640x1024 (which is 10:16) is not cropped + // 640x320 (which is 2:1) cropped to 568x320 (just under 16/9) + // 96x480 (which is 1:5), cropped to 96x170 (just over 9/16) + // The resultant frame will always be between (or at) the + // min_aspect_ratio and max_aspect_ratio. + optional string min_aspect_ratio = 4 [default = "9/16"]; + optional string max_aspect_ratio = 5 [default = "16/9"]; + + // If unset, use the same format as the input. + // NOTE: in the current implementation, the output format (either specified + // in the output_format option or inherited from the input format) must be + // SRGB. It can be YCBCR420P if the input_format is also the same. + optional ImageFormat.Format output_format = 6; + + enum ScaleAlgorithm { + DEFAULT = 0; + LINEAR = 1; + CUBIC = 2; + AREA = 3; + LANCZOS = 4; + DEFAULT_WITHOUT_UPSCALE = 5; // Option to disallow upscaling. + } + + // The upscaling algorithm to use. The default is to use CUBIC. Note that + // downscaling unconditionally uses DDA; see image_processing:: + // AffineGammaResizer for documentation. + optional ScaleAlgorithm algorithm = 7 [default = DEFAULT]; + + // The output image will have this alignment. If set to zero, then + // any alignment could be used. If set to one, the output image will + // be stored contiguously. + optional int32 alignment_boundary = 8 [default = 16]; + + // Set the alignment padding area to deterministic values (as opposed + // to possibly leaving it as uninitialized memory). The padding is + // the space between the pixel values in a row and the end of the row + // (which may be different due to alignment requirements on the length + // of a row). + optional bool set_alignment_padding = 9 [default = true]; + + optional bool OBSOLETE_skip_linear_rgb_conversion = 10 [default = false]; + + // Applies sharpening for downscaled images as post-processing. See + // image_processing::AffineGammaResizer for documentation. + optional float post_sharpening_coefficient = 11 [default = 0.0]; + + // If input_format is YCBCR420P, input packets contain a YUVImage. If + // input_format is a format other than YCBCR420P or is unset, input packets + // contain an ImageFrame. + // NOTE: in the current implementation, the input format (either specified + // in the input_format option or inferred from the input packets) must be + // SRGB or YCBCR420P. + optional ImageFormat.Format input_format = 12; + + // If true, the output width and height will be divisible by 2. Otherwise it + // will use the exact specified output width and height, which is only + // supported when the output format is SRGB and preserve_aspect_ratio option + // is set to false. + optional bool scale_to_multiple_of_two = 13 [default = true]; + + // If true, assume the input YUV is BT.709 (this is the HDTV standard, so most + // content is likely using it). If false use the previous assumption of BT.601 + // (mid-80s standard). Ideally this information should be contained in the + // input YUV Frame, but as of 02/06/2019, it's not. Once this info is baked + // in, this flag becomes useless. + optional bool use_bt709 = 14 [default = false]; +} diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc new file mode 100644 index 000000000..db55774ad --- /dev/null +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/scale_image_utils.h" + +#include + +#include + +#include "absl/strings/str_split.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace scale_image { + +namespace { +double ParseRational(const std::string& rational) { + const std::vector v = absl::StrSplit(rational, '/'); + const double numerator = std::strtod(v[0].c_str(), nullptr); + const double denominator = std::strtod(v[1].c_str(), nullptr); + return numerator / denominator; +} +} // namespace + +::mediapipe::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start) { + CHECK(crop_width); + CHECK(crop_height); + CHECK(col_start); + CHECK(row_start); + + double min_aspect_ratio_q = 0.0; + double max_aspect_ratio_q = 0.0; + if (!min_aspect_ratio.empty()) { + min_aspect_ratio_q = ParseRational(min_aspect_ratio); + } + if (!max_aspect_ratio.empty()) { + max_aspect_ratio_q = ParseRational(max_aspect_ratio); + } + + *crop_width = input_width; + *crop_height = input_height; + *col_start = 0; + *row_start = 0; + + // Determine the current aspect ratio. + const double aspect_ratio = + static_cast(input_width) / static_cast(input_height); + + if (!std::isinf(max_aspect_ratio_q) && !std::isinf(min_aspect_ratio_q)) { + if (max_aspect_ratio_q > 0 && aspect_ratio > max_aspect_ratio_q) { + // Determine the width based on the height multiplied by the max + // aspect ratio. + *crop_width = static_cast(static_cast(input_height) * + max_aspect_ratio_q); + *crop_width = (*crop_width / 2) * 2; + // The col_start should be half the difference between the input width + // and the output width. + *col_start = (input_width - *crop_width) / 2; + } else if (min_aspect_ratio_q > 0 && aspect_ratio < min_aspect_ratio_q) { + // Determine the height based on the width divided by the min + // aspect ratio. + *crop_height = static_cast(static_cast(input_width) / + min_aspect_ratio_q); + *crop_height = (*crop_height / 2) * 2; + *row_start = (input_height - *crop_height) / 2; + } + } + + CHECK_LE(*crop_width, input_width); + CHECK_LE(*crop_height, input_height); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status FindOutputDimensions(int input_width, // + int input_height, // + int target_width, // + int target_height, // + bool preserve_aspect_ratio, // + bool scale_to_multiple_of_two, // + int* output_width, + int* output_height) { + CHECK(output_width); + CHECK(output_height); + + if (!preserve_aspect_ratio || (target_width <= 0 && target_height <= 0)) { + if (target_width <= 0) { + target_width = input_width; + } + if (target_height <= 0) { + target_height = input_height; + } + if (scale_to_multiple_of_two) { + *output_width = (target_width / 2) * 2; + *output_height = (target_height / 2) * 2; + } else { + *output_width = target_width; + *output_height = target_height; + } + return ::mediapipe::OkStatus(); + } + + if (target_width > 0) { + // Try setting the height based on the width and the aspect ratio. + int try_width = target_width; + int try_height = static_cast(static_cast(target_width) / + static_cast(input_width) * + static_cast(input_height)); + try_width = (try_width / 2) * 2; + try_height = (try_height / 2) * 2; + + if (target_height <= 0 || try_height <= target_height) { + // The resulting height based on the target width and aspect ratio + // was within the image, so use these dimensions. + *output_width = try_width; + *output_height = try_height; + return ::mediapipe::OkStatus(); + } + } + + if (target_height > 0) { + // Try setting the width based on the height and the aspect ratio. + int try_height = target_height; + int try_width = static_cast(static_cast(target_height) / + static_cast(input_height) * + static_cast(input_width)); + try_width = (try_width / 2) * 2; + try_height = (try_height / 2) * 2; + + if (target_width <= 0 || try_width <= target_width) { + // The resulting width based on the target width and aspect ratio + // was within the image, so use these dimensions. + *output_width = try_width; + *output_height = try_height; + return ::mediapipe::OkStatus(); + } + } + RET_CHECK_FAIL() + << "Unable to set output dimensions based on target dimensions."; +} + +} // namespace scale_image +} // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_utils.h b/mediapipe/calculators/image/scale_image_utils.h new file mode 100644 index 000000000..450b81d23 --- /dev/null +++ b/mediapipe/calculators/image/scale_image_utils.h @@ -0,0 +1,54 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for scaling operations defined by ScaleImageCalculatorOptions. +#ifndef MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_ +#define MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_ + +#include + +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace scale_image { + +// Given a width and height and min and max aspect ratios, determine the +// target width and height and column and row starts such that the target +// is a centered, cropped portion of the image that falls within the min +// and max aspect ratio. If either the min or max aspect ratio argument +// is empty or has a 0 in the numerator or denominator then it is ignored. +::mediapipe::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start); + +// Given an input width and height, a target width and height, whether to +// preserve the aspect ratio, and whether to round down to a multiple of 2, +// determine the output width and height. If target_width or target_height is +// non-positive, then they will be set to the input_width and input_height +// respectively. The output_width and output_height will be reduced as necessary +// to preserve_aspect_ratio and to scale_to_multipe_of_two if these options are +// specified. +::mediapipe::Status FindOutputDimensions(int input_width, int input_height, // + int target_width, + int target_height, // + bool preserve_aspect_ratio, // + bool scale_to_multiple_of_two, // + int* output_width, int* output_height); + +} // namespace scale_image +} // namespace mediapipe + +#endif // MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_ diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc new file mode 100644 index 000000000..62522be8f --- /dev/null +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -0,0 +1,153 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/image/scale_image_utils.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace scale_image { +namespace { + +TEST(ScaleImageUtilsTest, FindCropDimensions) { + int crop_width; + int crop_height; + int col_start; + int row_start; + // No cropping because aspect ratios should be ignored. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(50, 100, "0/1", "1/0", &crop_width, + &crop_height, &col_start, &row_start)); + EXPECT_EQ(50, crop_width); + EXPECT_EQ(100, crop_height); + EXPECT_EQ(0, row_start); + EXPECT_EQ(0, col_start); + + // Tests proto examples. + // 16:9 aspect ratio, should be unchanged. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(1920, 1080, "9/16", "16/9", + &crop_width, &crop_height, &col_start, + &row_start)); + EXPECT_EQ(0, col_start); + EXPECT_EQ(1920, crop_width); + EXPECT_EQ(0, row_start); + EXPECT_EQ(1080, crop_height); + // 10:16 aspect ratio, should be unchanged. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(640, 1024, "9/16", "16/9", &crop_width, + &crop_height, &col_start, &row_start)); + EXPECT_EQ(0, col_start); + EXPECT_EQ(640, crop_width); + EXPECT_EQ(0, row_start); + EXPECT_EQ(1024, crop_height); + + // 2:1 aspect ratio, width is cropped. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(640, 320, "9/16", "16/9", &crop_width, + &crop_height, &col_start, &row_start)); + EXPECT_EQ(36, col_start); + EXPECT_EQ(568, crop_width); + EXPECT_EQ(0, row_start); + EXPECT_EQ(320, crop_height); + // 1:5 aspect ratio, height is cropped. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(96, 480, "9/16", "16/9", &crop_width, + &crop_height, &col_start, &row_start)); + EXPECT_EQ(0, col_start); + EXPECT_EQ(96, crop_width); + EXPECT_EQ(155, row_start); + EXPECT_EQ(170, crop_height); + + // Tests min = max, crops width. + MEDIAPIPE_ASSERT_OK(FindCropDimensions(200, 100, "1/1", "1/1", &crop_width, + &crop_height, &col_start, &row_start)); + EXPECT_EQ(50, col_start); + EXPECT_EQ(100, crop_width); + EXPECT_EQ(0, row_start); + EXPECT_EQ(100, crop_height); +} + +TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { + int output_width; + int output_height; + // Not scale. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, true, true, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(100, output_height); + // Not scale with odd input size. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, false, false, + &output_width, &output_height)); + EXPECT_EQ(201, output_width); + EXPECT_EQ(101, output_height); + // Scale down by 1/2. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, true, true, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(50, output_height); + // Scale up, doubling dimensions. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, true, true, + &output_width, &output_height)); + EXPECT_EQ(400, output_width); + EXPECT_EQ(200, output_height); + // Fits a 2:1 image into a 150 x 150 box. Output dimensions are always + // visible by 2. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, true, true, + &output_width, &output_height)); + EXPECT_EQ(150, output_width); + EXPECT_EQ(74, output_height); + // Fits a 2:1 image into a 400 x 50 box. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, true, true, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(50, output_height); + // Scale to multiple number with odd targe size. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, true, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(50, output_height); + // Scale to multiple number with odd targe size. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, false, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(50, output_height); + // Scale to odd size. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, false, false, + &output_width, &output_height)); + EXPECT_EQ(151, output_width); + EXPECT_EQ(101, output_height); +} + +// Tests scaling without keeping the aspect ratio fixed. +TEST(ScaleImageUtilsTest, FindOutputDimensionsNoAspectRatio) { + int output_width; + int output_height; + // Scale width only. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, true, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(100, output_height); + // Scale height only. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, false, true, + &output_width, &output_height)); + EXPECT_EQ(200, output_width); + EXPECT_EQ(200, output_height); + // Scale both dimensions. + MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, true, + &output_width, &output_height)); + EXPECT_EQ(150, output_width); + EXPECT_EQ(200, output_height); +} + +} // namespace +} // namespace scale_image +} // namespace mediapipe diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc new file mode 100644 index 000000000..52668a14d --- /dev/null +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -0,0 +1,462 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/image/set_alpha_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ + +namespace mediapipe { + +namespace { + +constexpr char kInputFrameTag[] = "IMAGE"; +constexpr char kInputAlphaTag[] = "ALPHA"; +constexpr char kOutputFrameTag[] = "IMAGE"; + +constexpr char kInputFrameTagGpu[] = "IMAGE_GPU"; +constexpr char kInputAlphaTagGpu[] = "ALPHA_GPU"; +constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU"; + +constexpr int kNumChannelsRGBA = 4; + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + +// A calculator for setting the alpha channel of an RGBA image. +// +// The alpha channel can be set to a single value, or come from an image mask. +// If the input image has an alpha channel, it will be updated. +// If the input image doesn't have an alpha channel, one will be added. +// Adding alpha channel to a Grayscale (single channel) input is not suported. +// +// Inputs: +// One of the following two IMAGE tags: +// IMAGE: ImageFrame containing input image - RGB or RGBA. +// IMAGE_GPU: GpuBuffer containing input image - RGB or RGBA. +// +// ALPHA (optional): ImageFrame alpha mask to apply, +// can be any # of channels, only first channel used, +// must be same format as input +// ALPHA_GPU (optional): GpuBuffer alpha mask to apply, +// can be any # of channels, only first channel used, +// must be same format as input +// If ALPHA* input tag is not set, the 'alpha_value' option must be used. +// +// Output: +// One of the following two tags: +// IMAGE: An ImageFrame with alpha channel set - RGBA only. +// IMAGE_GPU: A GpuBuffer with alpha channel set - RGBA only. +// +// Options: +// alpha_value (optional): The alpha value to set to input image, [0-255], +// takes precedence over input mask. +// If alpha_value is not set, the ALPHA* input tag must be used. +// +// Notes: +// Either alpha_value option or ALPHA (or ALPHA_GPU) must be set. +// All CPU inputs must have the same image dimensions and data type. +// +class SetAlphaCalculator : public CalculatorBase { + public: + SetAlphaCalculator() = default; + ~SetAlphaCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + // From Calculator. + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status RenderGpu(CalculatorContext* cc); + ::mediapipe::Status RenderCpu(CalculatorContext* cc); + + ::mediapipe::Status GlRender(CalculatorContext* cc); + ::mediapipe::Status GlSetup(CalculatorContext* cc); + + mediapipe::SetAlphaCalculatorOptions options_; + float alpha_value_ = -1.f; + + bool use_gpu_ = false; + bool gpu_initialized_ = false; +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; +#endif // __ANDROID__ +}; +REGISTER_CALCULATOR(SetAlphaCalculator); + +::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + if (cc->Inputs().HasTag(kInputFrameTag) && + cc->Inputs().HasTag(kInputFrameTagGpu)) { + return ::mediapipe::InternalError("Cannot have multiple input images."); + } + if (cc->Inputs().HasTag(kInputFrameTagGpu) != + cc->Outputs().HasTag(kOutputFrameTagGpu)) { + return ::mediapipe::InternalError("GPU output must have GPU input."); + } + + // Input image to add/edit alpha channel. +#if defined(__ANDROID__) + if (cc->Inputs().HasTag(kInputFrameTagGpu)) { + cc->Inputs().Tag(kInputFrameTagGpu).Set(); + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag(kInputFrameTag)) { + cc->Inputs().Tag(kInputFrameTag).Set(); + } + + // Input alpha image mask (optional) +#if defined(__ANDROID__) + if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { + cc->Inputs().Tag(kInputAlphaTagGpu).Set(); + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag(kInputAlphaTag)) { + cc->Inputs().Tag(kInputAlphaTag).Set(); + } + + // RGBA output image. +#if defined(__ANDROID__) + if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { + cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + } +#endif // __ANDROID__ + if (cc->Outputs().HasTag(kOutputFrameTag)) { + cc->Outputs().Tag(kOutputFrameTag).Set(); + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + + if (cc->Inputs().HasTag(kInputFrameTagGpu) && + cc->Outputs().HasTag(kOutputFrameTagGpu)) { +#if defined(__ANDROID__) + use_gpu_ = true; +#else + RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; +#endif // __ANDROID__ + } + + // Get global value from options (-1 if not set). + alpha_value_ = options_.alpha_value(); + if (use_gpu_) alpha_value_ /= 255.0; + + const bool use_image_mask = cc->Inputs().HasTag(kInputAlphaTag) || + cc->Inputs().HasTag(kInputAlphaTagGpu); + if (!((alpha_value_ >= 0) ^ use_image_mask)) + RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; + + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + if (!gpu_initialized_) { + RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + RETURN_IF_ERROR(RenderGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ + } else { + RETURN_IF_ERROR(RenderCpu(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + }); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { + if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + // Setup source image + const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); + const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame); + if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { + LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; + } + + // Setup destination image + auto output_frame = absl::make_unique( + ImageFormat::SRGBA, input_mat.cols, input_mat.rows); + cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); + + const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) && + !cc->Inputs().Tag(kInputAlphaTag).IsEmpty(); + const bool use_alpa_mask = alpha_value_ < 0 && has_alpha_mask; + + // Setup alpha image and Update image in CPU. + if (use_alpa_mask) { + const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get(); + cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask); + RET_CHECK_EQ(input_mat.rows, alpha_mat.rows); + RET_CHECK_EQ(input_mat.cols, alpha_mat.cols); + + for (int i = 0; i < output_mat.rows; ++i) { + const uchar* in_ptr = input_mat.ptr(i); + uchar* alpha_ptr = alpha_mat.ptr(i); + uchar* out_ptr = output_mat.ptr(i); + for (int j = 0; j < output_mat.cols; ++j) { + const int out_idx = j * kNumChannelsRGBA; + const int in_idx = j * input_mat.channels(); + const int alpha_idx = j * alpha_mat.channels(); + out_ptr[out_idx + 0] = in_ptr[in_idx + 0]; + out_ptr[out_idx + 1] = in_ptr[in_idx + 1]; + out_ptr[out_idx + 2] = in_ptr[in_idx + 2]; + out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask + } + } + } else { + const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f); + for (int i = 0; i < output_mat.rows; ++i) { + const uchar* in_ptr = input_mat.ptr(i); + uchar* out_ptr = output_mat.ptr(i); + for (int j = 0; j < output_mat.cols; ++j) { + const int out_idx = j * kNumChannelsRGBA; + const int in_idx = j * input_mat.channels(); + out_ptr[out_idx + 0] = in_ptr[in_idx + 0]; + out_ptr[out_idx + 1] = in_ptr[in_idx + 1]; + out_ptr[out_idx + 2] = in_ptr[in_idx + 2]; + out_ptr[out_idx + 3] = alpha_value; // use value from options + } + } + } + + cc->Outputs() + .Tag(kOutputFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { + if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { + return ::mediapipe::OkStatus(); + } +#if defined(__ANDROID__) + // Setup source texture. + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + if (!(input_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 || + input_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) { + LOG(ERROR) << "Only RGB or RGBA input image supported"; + } + auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); + + // Setup destination texture. + const int width = input_frame.width(), height = input_frame.height(); + auto output_texture = gpu_helper_.CreateDestinationTexture( + width, height, mediapipe::GpuBufferFormat::kBGRA32); + + const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTagGpu) && + !cc->Inputs().Tag(kInputAlphaTagGpu).IsEmpty(); + + // Setup alpha texture and Update image in GPU shader. + if (has_alpha_mask) { + const auto& alpha_mask = + cc->Inputs().Tag(kInputAlphaTagGpu).Get(); + auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask); + gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_texture.name()); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, alpha_texture.name()); + GlRender(cc); // use channel 0 of mask + alpha_texture.Release(); + } else { + gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_texture.name()); + GlRender(cc); // use value from options + } + + // Send out image as GPU packet. + auto output_frame = output_texture.GetFrame(); + cc->Outputs() + .Tag(kOutputFrameTagGpu) + .Add(output_frame.release(), cc->InputTimestamp()); + + // Cleanup + input_texture.Release(); + output_texture.Release(); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::GlRender(CalculatorContext* cc) { +#if defined(__ANDROID__) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + + // execute command queue + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { +#if defined(__ANDROID__) + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Shader to overlay a texture onto another when overlay is non-zero. + const GLchar* frag_src = GLES_VERSION_COMPAT + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_frame; + uniform sampler2D alpha_mask; + uniform float alpha_value; + + void main() { + vec3 image_pix = texture2D(input_frame, sample_coordinate).rgb; + float alpha = alpha_value; + if (alpha_value < 0.0) alpha = texture2D(alpha_mask, sample_coordinate).r; + vec4 out_pix = vec4(image_pix, alpha); + fragColor = out_pix; + } + )"; + + // Create shader program and set parameters. + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src, + NUM_ATTRIBUTES, (const GLchar**)&attr_name[0], + attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); + glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); + glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); + +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/set_alpha_calculator.proto b/mediapipe/calculators/image/set_alpha_calculator.proto new file mode 100644 index 000000000..0e2bc9732 --- /dev/null +++ b/mediapipe/calculators/image/set_alpha_calculator.proto @@ -0,0 +1,16 @@ +// Options for SetAlphaCalculator +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message SetAlphaCalculatorOptions { + extend CalculatorOptions { + optional SetAlphaCalculatorOptions ext = 250949799; + } + + // The value to set the alpha channel to (0-255). + // This option is ignored when set to -1 (use image mask instead). + optional sint32 alpha_value = 1 [default = -1]; +} diff --git a/mediapipe/calculators/image/sobel_edges_calculator.cc b/mediapipe/calculators/image/sobel_edges_calculator.cc new file mode 100644 index 000000000..e710a99f5 --- /dev/null +++ b/mediapipe/calculators/image/sobel_edges_calculator.cc @@ -0,0 +1,239 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/gpu/gl_simple_calculator.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + +namespace mediapipe { + +// Applies the Sobel filter to an image. Expects a grayscale image stored as +// RGB, like LuminanceCalculator outputs. +// See GlSimpleCalculatorBase for inputs, outputs and input side packets. +class SobelEdgesCalculator : public GlSimpleCalculator { + public: + ::mediapipe::Status GlSetup() override; + ::mediapipe::Status GlRender(const GlTexture& src, + const GlTexture& dst) override; + ::mediapipe::Status GlTeardown() override; + + private: + GLuint program_ = 0; + GLint frame_; + GLint pixel_w_; + GLint pixel_h_; +}; +REGISTER_CALCULATOR(SobelEdgesCalculator); + +::mediapipe::Status SobelEdgesCalculator::GlSetup() { + // Load vertex and fragment shaders + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "vertexPosition", + "vertexTextureCoordinate", + }; + + const GLchar* vert_src = GLES_VERSION_COMPAT + R"( +#if __VERSION__ < 130 + #define in attribute + #define out varying +#endif // __VERSION__ < 130 + + in vec4 vertexPosition; + in vec4 vertexTextureCoordinate; + + // width of a pixel in normalized texture coordinates (0..1) + uniform highp float pixelW; + + // height of a pixel in normalized texture coordinates (0..1) + uniform highp float pixelH; + + // Dependent texture reads (i.e. texture reads where texture coordinates + // are computed in the fragment shader) are slow on pre-ES 3.0 hardware. + // Avoid them by computing all texture coordinates in the vertex shader. + + // iOS OGLES performance guide: https://developer.apple.com/library/ios/documentation/3DDrawing/Conceptual/OpenGLES_ProgrammingGuide/BestPracticesforShaders/BestPracticesforShaders.html + + // Code for coordinates: u = up, d = down, l = left, r = right, c = center. + // Horizontal coordinate first, then vertical. + out vec2 luTexCoord; + out vec2 lcTexCoord; + out vec2 ldTexCoord; + + out vec2 cuTexCoord; +// out vec2 ccTexCoord; + out vec2 cdTexCoord; + + out vec2 ruTexCoord; + out vec2 rcTexCoord; + out vec2 rdTexCoord; + + void main() { + gl_Position = vertexPosition; + + vec2 right = vec2(pixelW, 0.0); + vec2 up = vec2(0.0, pixelH); + + lcTexCoord = vertexTextureCoordinate.xy - right; + luTexCoord = lcTexCoord + up; + ldTexCoord = lcTexCoord - up; + + vec2 ccTexCoord = vertexTextureCoordinate.xy; + cuTexCoord = ccTexCoord + up; + cdTexCoord = ccTexCoord - up; + + rcTexCoord = vertexTextureCoordinate.xy + right; + ruTexCoord = rcTexCoord + up; + rdTexCoord = rcTexCoord - up; + } + )"; + const GLchar* frag_src = GLES_VERSION_COMPAT + R"( +#if __VERSION__ < 130 + #define in varying +#endif // __VERSION__ < 130 + +#ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; +#else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; +#endif // defined(GL_ES) + + in vec2 luTexCoord; + in vec2 lcTexCoord; + in vec2 ldTexCoord; + + in vec2 cuTexCoord; +// in vec2 ccTexCoord; + in vec2 cdTexCoord; + + in vec2 ruTexCoord; + in vec2 rcTexCoord; + in vec2 rdTexCoord; + + uniform sampler2D inputImage; + + void main() { + float luPx = texture2D(inputImage, luTexCoord).r; + float lcPx = texture2D(inputImage, lcTexCoord).r; + float ldPx = texture2D(inputImage, ldTexCoord).r; + + float cuPx = texture2D(inputImage, cuTexCoord).r; +// float ccPx = texture2D(inputImage, ccTexCoord).r; + float cdPx = texture2D(inputImage, cdTexCoord).r; + + float ruPx = texture2D(inputImage, ruTexCoord).r; + float rcPx = texture2D(inputImage, rcTexCoord).r; + float rdPx = texture2D(inputImage, rdTexCoord).r; + + float h = -luPx - 2.0 * lcPx - ldPx + ruPx + 2.0 * rcPx + rdPx; + float v = -luPx - 2.0 * cuPx - ruPx + ldPx + 2.0 * cdPx + rdPx; + + float mag = length(vec2(h, v)); + + fragColor = vec4(vec3(mag), 1.0); + } + )"; + + // shader program + GlhCreateProgram(vert_src, frag_src, NUM_ATTRIBUTES, + (const GLchar**)&attr_name[0], attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + frame_ = glGetUniformLocation(program_, "inputImage"); + pixel_w_ = glGetUniformLocation(program_, "pixelW"); + pixel_h_ = glGetUniformLocation(program_, "pixelH"); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SobelEdgesCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const float texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + glUniform1i(frame_, 1); + + // parameters + glUniform1i(frame_, 1); + glUniform1f(pixel_w_, 1.0 / src.width()); + glUniform1f(pixel_h_, 1.0 / src.height()); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SobelEdgesCalculator::GlTeardown() { + if (program_) { + glDeleteProgram(program_); + program_ = 0; + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/testdata/BUILD b/mediapipe/calculators/image/testdata/BUILD new file mode 100644 index 000000000..1d14543cd --- /dev/null +++ b/mediapipe/calculators/image/testdata/BUILD @@ -0,0 +1,26 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "test_images", + srcs = [ + "dino.jpg", + "dino_quality_50.jpg", + "dino_quality_80.jpg", + ], + visibility = ["//visibility:public"], +) diff --git a/mediapipe/calculators/image/testdata/dino.jpg b/mediapipe/calculators/image/testdata/dino.jpg new file mode 100644 index 000000000..df2dfc693 Binary files /dev/null and b/mediapipe/calculators/image/testdata/dino.jpg differ diff --git a/mediapipe/calculators/image/testdata/dino_quality_50.jpg b/mediapipe/calculators/image/testdata/dino_quality_50.jpg new file mode 100644 index 000000000..de8af0dd4 Binary files /dev/null and b/mediapipe/calculators/image/testdata/dino_quality_50.jpg differ diff --git a/mediapipe/calculators/image/testdata/dino_quality_80.jpg b/mediapipe/calculators/image/testdata/dino_quality_80.jpg new file mode 100644 index 000000000..eb7aa7b25 Binary files /dev/null and b/mediapipe/calculators/image/testdata/dino_quality_80.jpg differ diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD new file mode 100644 index 000000000..20f77fbd2 --- /dev/null +++ b/mediapipe/calculators/internal/BUILD @@ -0,0 +1,47 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +package(default_visibility = ["//visibility:private"]) + +proto_library( + name = "callback_packet_calculator_proto", + srcs = ["callback_packet_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "callback_packet_calculator_cc_proto", + srcs = ["callback_packet_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":callback_packet_calculator_proto"], +) + +cc_library( + name = "callback_packet_calculator", + srcs = ["callback_packet_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":callback_packet_calculator_cc_proto", + "//mediapipe/framework:calculator_base", + "//mediapipe/framework:calculator_registry", + "//mediapipe/framework:output_side_packet", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/internal/callback_packet_calculator.cc b/mediapipe/calculators/internal/callback_packet_calculator.cc new file mode 100644 index 000000000..e9f85ee83 --- /dev/null +++ b/mediapipe/calculators/internal/callback_packet_calculator.cc @@ -0,0 +1,103 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" // NOLINT +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_registry.h" +#include "mediapipe/framework/output_side_packet.h" + +namespace mediapipe { + +namespace { + +// Callback function for writing a packet to a vector. The output is before the +// input since std::bind fills arguments from left to right (and only +// dumped_data is filled by std::bind). +void DumpToVector(std::vector* dumped_data, const Packet& packet) { + dumped_data->push_back(packet); +} + +// Callback function for saving the Timestamp::PostStream() packet. +// The output is before the input since std::bind fills arguments from left to +// right (and only post_stream_packet is filled by std::bind). +void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) { + if (packet.Timestamp() == Timestamp::PostStream()) { + *post_stream_packet = packet; + } +} +} // namespace + +// Creates a callback which takes a packet and stores it either in a +// vector of packets or stores only the packet at PostStream timestamp. +// The kind of callback is controlled by an option. The callback is +// a std::function and is directly usable by CallbackCalculator. +// Since the options for the packet generator include a serialized pointer +// value, the resulting callback is only valid on the original machine +// while that pointer is still alive. +class CallbackPacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + switch (options.type()) { + case CallbackPacketCalculatorOptions::VECTOR_PACKET: + case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: + cc->OutputSidePackets() + .Index(0) + .Set>(); + break; + default: + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Invalid type of callback to produce."; + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const auto& options = cc->Options(); + void* ptr; + if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Stored pointer value in options is invalid."; + } + switch (options.type()) { + case CallbackPacketCalculatorOptions::VECTOR_PACKET: + cc->OutputSidePackets().Index(0).Set( + MakePacket>(std::bind( + &DumpToVector, reinterpret_cast*>(ptr), + std::placeholders::_1))); + break; + case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: + cc->OutputSidePackets().Index(0).Set( + MakePacket>( + std::bind(&DumpPostStreamPacket, reinterpret_cast(ptr), + std::placeholders::_1))); + break; + default: + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Invalid type to dump into."; + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(CallbackPacketCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/internal/callback_packet_calculator.proto b/mediapipe/calculators/internal/callback_packet_calculator.proto new file mode 100644 index 000000000..6a5cfb05a --- /dev/null +++ b/mediapipe/calculators/internal/callback_packet_calculator.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CallbackPacketCalculatorOptions { + extend CalculatorOptions { + optional CallbackPacketCalculatorOptions ext = 245965803; + } + + enum PointerType { + UNKNOWN = 0; + VECTOR_PACKET = 1; + POST_STREAM_PACKET = 2; + } + + // The type of the data pointer that the callback will put data into. + optional PointerType type = 1; + // The location of the data stored as a string printed with + // snprintf(address, sizeof(address), "%p", pointer). + // This calculator only produces a reasonable callback if it is + // constructed on the same machine as the original pointer was created on and + // that pointer is still alive. + optional bytes pointer = 2; +} diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD new file mode 100644 index 000000000..66493aba6 --- /dev/null +++ b/mediapipe/calculators/tensorflow/BUILD @@ -0,0 +1,975 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "graph_tensors_packet_generator_proto", + srcs = ["graph_tensors_packet_generator.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/framework:packet_generator_proto", + ], +) + +proto_library( + name = "matrix_to_tensor_calculator_options_proto", + srcs = ["matrix_to_tensor_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "lapped_tensor_buffer_calculator_proto", + srcs = ["lapped_tensor_buffer_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "object_detection_tensors_to_detections_calculator_proto", + srcs = ["object_detection_tensors_to_detections_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tensorflow_inference_calculator_proto", + srcs = ["tensorflow_inference_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tensorflow_session_from_saved_model_generator_proto", + srcs = ["tensorflow_session_from_saved_model_generator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:packet_generator_proto"], +) + +proto_library( + name = "tensorflow_session_from_saved_model_calculator_proto", + srcs = ["tensorflow_session_from_saved_model_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "tensor_squeeze_dimensions_calculator_proto", + srcs = ["tensor_squeeze_dimensions_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tensor_to_image_frame_calculator_proto", + srcs = ["tensor_to_image_frame_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tensor_to_matrix_calculator_proto", + srcs = ["tensor_to_matrix_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/framework/formats:time_series_header_proto", + ], +) + +proto_library( + name = "tensor_to_vector_float_calculator_options_proto", + srcs = ["tensor_to_vector_float_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "vector_float_to_tensor_calculator_options_proto", + srcs = ["vector_float_to_tensor_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "graph_tensors_packet_generator_cc_proto", + srcs = ["graph_tensors_packet_generator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":graph_tensors_packet_generator_proto"], +) + +mediapipe_cc_proto_library( + name = "image_frame_to_tensor_calculator_cc_proto", + srcs = ["image_frame_to_tensor_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":image_frame_to_tensor_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "matrix_to_tensor_calculator_options_cc_proto", + srcs = ["matrix_to_tensor_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":matrix_to_tensor_calculator_options_proto"], +) + +mediapipe_cc_proto_library( + name = "lapped_tensor_buffer_calculator_cc_proto", + srcs = ["lapped_tensor_buffer_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":lapped_tensor_buffer_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "object_detection_tensors_to_detections_calculator_cc_proto", + srcs = ["object_detection_tensors_to_detections_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":object_detection_tensors_to_detections_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "pack_media_sequence_calculator_cc_proto", + srcs = ["pack_media_sequence_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":pack_media_sequence_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensorflow_inference_calculator_cc_proto", + srcs = ["tensorflow_inference_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensorflow_inference_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensorflow_session_from_frozen_graph_generator_cc_proto", + srcs = ["tensorflow_session_from_frozen_graph_generator.proto"], + cc_deps = [ + "//mediapipe/framework:packet_generator_cc_proto", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensorflow_session_from_frozen_graph_generator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensorflow_session_from_saved_model_generator_cc_proto", + srcs = ["tensorflow_session_from_saved_model_generator.proto"], + cc_deps = ["//mediapipe/framework:packet_generator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensorflow_session_from_saved_model_generator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensorflow_session_from_saved_model_calculator_cc_proto", + srcs = ["tensorflow_session_from_saved_model_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensorflow_session_from_saved_model_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensor_squeeze_dimensions_calculator_cc_proto", + srcs = ["tensor_squeeze_dimensions_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensor_squeeze_dimensions_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensor_to_image_frame_calculator_cc_proto", + srcs = ["tensor_to_image_frame_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensor_to_image_frame_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensor_to_matrix_calculator_cc_proto", + srcs = ["tensor_to_matrix_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensor_to_matrix_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tensor_to_vector_float_calculator_options_cc_proto", + srcs = ["tensor_to_vector_float_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tensor_to_vector_float_calculator_options_proto"], +) + +mediapipe_cc_proto_library( + name = "unpack_media_sequence_calculator_cc_proto", + srcs = ["unpack_media_sequence_calculator.proto"], + cc_deps = [ + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":unpack_media_sequence_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "vector_float_to_tensor_calculator_options_cc_proto", + srcs = ["vector_float_to_tensor_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":vector_float_to_tensor_calculator_options_proto"], +) + +cc_library( + name = "graph_tensors_packet_generator", + srcs = ["graph_tensors_packet_generator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "image_frame_to_tensor_calculator", + srcs = ["image_frame_to_tensor_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:image_frame_to_tensor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_lib_lite", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "matrix_to_tensor_calculator", + srcs = ["matrix_to_tensor_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_lib_lite", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "lapped_tensor_buffer_calculator", + srcs = ["lapped_tensor_buffer_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/profiler:circular_buffer", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "object_detection_tensors_to_detections_calculator", + srcs = ["object_detection_tensors_to_detections_calculator.cc"], + features = [ + # Layering check doesn't play nicely with portable proto wrappers. + "no_layering_check", + ], + visibility = [ + "//visibility:public", + ], + deps = [ + ":object_detection_tensors_to_detections_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:tensor_to_detection", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "pack_media_sequence_calculator", + srcs = ["pack_media_sequence_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", + "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util/sequence:media_sequence", + "//mediapipe/util/sequence:media_sequence_util", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +cc_library( + name = "string_to_sequence_example_calculator", + srcs = ["string_to_sequence_example_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +# On android, this calculator is configured to run with lite protos. Therefore, +# compile your binary with the flag TENSORFLOW_PROTOS=lite. +cc_library( + name = "tensorflow_inference_calculator", + srcs = ["tensorflow_inference_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_session", + "//mediapipe/calculators/tensorflow:tensorflow_inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "tensorflow_session", + hdrs = [ + "tensorflow_session.h", + ], + features = ["no_layering_check"], + visibility = ["//visibility:public"], + deps = select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:core", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", + ], + }), +) + +cc_library( + name = "tensorflow_session_from_frozen_graph_generator", + srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], + features = ["no_layering_check"], + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_session", + "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + "@org_tensorflow//tensorflow/core:core", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", + "//mediapipe/android/file/base", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "tensorflow_session_from_saved_model_calculator", + srcs = ["tensorflow_session_from_saved_model_calculator.cc"], + defines = select({ + "//mediapipe:android": ["__ANDROID__"], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_session", + ":tensorflow_session_from_saved_model_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/cc/saved_model:constants", + "@org_tensorflow//tensorflow/cc/saved_model:loader_lite", + "@org_tensorflow//tensorflow/cc/saved_model:tag_constants", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "tensorflow_session_from_saved_model_generator", + srcs = ["tensorflow_session_from_saved_model_generator.cc"], + defines = select({ + "//mediapipe:android": ["__ANDROID__"], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_session", + "//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto", + "//mediapipe/framework:packet_generator", + "//mediapipe/framework:packet_type", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/cc/saved_model:constants", + "@org_tensorflow//tensorflow/cc/saved_model:loader_lite", + "@org_tensorflow//tensorflow/cc/saved_model:tag_constants", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "tensor_squeeze_dimensions_calculator", + srcs = ["tensor_squeeze_dimensions_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "tensor_to_image_frame_calculator", + srcs = ["tensor_to_image_frame_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "tensor_to_matrix_calculator", + srcs = ["tensor_to_matrix_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_lib_lite", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "tensor_to_vector_float_calculator", + srcs = ["tensor_to_vector_float_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + "//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:framework", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_lib_lite", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "unpack_media_sequence_calculator", + srcs = ["unpack_media_sequence_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +cc_library( + name = "vector_float_to_tensor_calculator", + srcs = ["vector_float_to_tensor_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + +cc_test( + name = "graph_tensors_packet_generator_test", + srcs = ["graph_tensors_packet_generator_test.cc"], + deps = [ + ":graph_tensors_packet_generator", + "//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:packet_set", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:validate_type", + "@org_tensorflow//tensorflow/core:framework", + ], +) + +cc_test( + name = "image_frame_to_tensor_calculator_test", + size = "small", + srcs = ["image_frame_to_tensor_calculator_test.cc"], + deps = [ + ":image_frame_to_tensor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], +) + +cc_test( + name = "matrix_to_tensor_calculator_test", + size = "small", + srcs = ["matrix_to_tensor_calculator_test.cc"], + deps = [ + ":matrix_to_tensor_calculator", + "//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "@org_tensorflow//tensorflow/core:framework", + ], +) + +cc_test( + name = "lapped_tensor_buffer_calculator_test", + size = "small", + srcs = ["lapped_tensor_buffer_calculator_test.cc"], + deps = [ + ":lapped_tensor_buffer_calculator", + "//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/memory", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "object_detection_tensors_to_detections_calculator_test", + srcs = ["object_detection_tensors_to_detections_calculator_test.cc"], + linkstatic = 1, + deps = [ + ":object_detection_tensors_to_detections_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:testlib", + ], +) + +cc_test( + name = "pack_media_sequence_calculator_test", + srcs = ["pack_media_sequence_calculator_test.cc"], + deps = [ + ":pack_media_sequence_calculator", + "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", + "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:status", + "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "tensorflow_session_from_frozen_graph_generator_test", + srcs = ["tensorflow_session_from_frozen_graph_generator_test.cc"], + data = [":test_frozen_graph"], + linkstatic = 1, + deps = [ + ":tensorflow_inference_calculator", + ":tensorflow_session", + ":tensorflow_session_from_frozen_graph_generator", + "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:tag_map_helper", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@org_tensorflow//tensorflow/core:testlib", + "@org_tensorflow//tensorflow/core/kernels:conv_ops", + "@org_tensorflow//tensorflow/core/kernels:math", + ], +) + +cc_test( + name = "tensorflow_session_from_saved_model_generator_test", + srcs = ["tensorflow_session_from_saved_model_generator_test.cc"], + data = [":test_saved_model"], + linkstatic = 1, + deps = [ + ":tensorflow_inference_calculator", + ":tensorflow_session", + ":tensorflow_session_from_saved_model_generator", + "//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:tag_map_helper", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core/kernels:array", + "@org_tensorflow//tensorflow/core/kernels:bitcast_op", + "@org_tensorflow//tensorflow/core/kernels:conv_ops", + "@org_tensorflow//tensorflow/core/kernels:io", + "@org_tensorflow//tensorflow/core/kernels:state", + "@org_tensorflow//tensorflow/core/kernels:string", + ], +) + +cc_test( + name = "tensorflow_session_from_saved_model_calculator_test", + srcs = ["tensorflow_session_from_saved_model_calculator_test.cc"], + data = [":test_saved_model"], + linkstatic = 1, + deps = [ + ":tensorflow_inference_calculator", + ":tensorflow_session", + ":tensorflow_session_from_saved_model_calculator", + ":tensorflow_session_from_saved_model_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:tag_map_helper", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core/kernels:array", + "@org_tensorflow//tensorflow/core/kernels:bitcast_op", + "@org_tensorflow//tensorflow/core/kernels:conv_ops", + "@org_tensorflow//tensorflow/core/kernels:io", + "@org_tensorflow//tensorflow/core/kernels:state", + "@org_tensorflow//tensorflow/core/kernels:string", + ], +) + +cc_test( + name = "tensor_squeeze_dimensions_calculator_test", + srcs = ["tensor_squeeze_dimensions_calculator_test.cc"], + deps = [ + ":tensor_squeeze_dimensions_calculator", + "//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "tensor_to_image_frame_calculator_test", + size = "small", + srcs = ["tensor_to_image_frame_calculator_test.cc"], + deps = [ + ":tensor_to_image_frame_calculator", + "//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "tensor_to_matrix_calculator_test", + size = "small", + srcs = ["tensor_to_matrix_calculator_test.cc"], + deps = [ + ":tensor_to_matrix_calculator", + "//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "tensor_to_vector_float_calculator_test", + srcs = ["tensor_to_vector_float_calculator_test.cc"], + deps = [ + ":tensor_to_vector_float_calculator", + "//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "unpack_media_sequence_calculator_test", + srcs = ["unpack_media_sequence_calculator_test.cc"], + deps = [ + ":unpack_media_sequence_calculator", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:rectangle", + "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "vector_float_to_tensor_calculator_test", + srcs = ["vector_float_to_tensor_calculator_test.cc"], + deps = [ + ":vector_float_to_tensor_calculator", + "//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + +test_suite( + name = "ios", + tags = ["ios"], +) + +test_suite( + name = "android", + tags = ["android"], +) + +cc_test( + name = "tensorflow_inference_calculator_test", + size = "medium", + srcs = ["tensorflow_inference_calculator_test.cc"], + data = [":test_frozen_graph"], + linkstatic = 1, + deps = [ + ":tensorflow_session", + ":tensorflow_inference_calculator", + ":tensorflow_session_from_frozen_graph_generator", + "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool:validate_type", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:testlib", + "@org_tensorflow//tensorflow/core/kernels:math", + "@org_tensorflow//tensorflow/core/kernels:conv_ops", + "@org_tensorflow//tensorflow/core:direct_session", + ], + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib", + ], + }), +) + +filegroup( + name = "test_frozen_graph", + srcs = [ + "testdata/bundle/00000000/checkpoint", + "testdata/bundle/00000000/export.meta", + "testdata/bundle/00000000/export-00000-of-00001", + "testdata/frozen_graph_def.pb", + "testdata/model.chkpt-0", + "testdata/model.chkpt-0.meta", + "testdata/tf_graph_def.pb", + ], +) + +filegroup( + name = "test_saved_model", + srcs = [ + "testdata/tensorflow_saved_model/00000000/saved_model.pb", + "testdata/tensorflow_saved_model/00000000/variables/variables.data-00000-of-00001", + "testdata/tensorflow_saved_model/00000000/variables/variables.index", + ], +) diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc new file mode 100644 index 000000000..54126cf1d --- /dev/null +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc @@ -0,0 +1,73 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Generates row tensors of prescribed length that are initialized to zeros. +// The tensors are placed in an ordered map, which maps the tensors to the +// tensor tags, and emitted as a packet. This generator has been developed +// primarily to generate initialization states for LSTMs. + +#include +#include + +#include "mediapipe/calculators/tensorflow/graph_tensors_packet_generator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/core/framework/tensor.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +class GraphTensorsPacketGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + RET_CHECK(extendable_options.HasExtension( + GraphTensorsPacketGeneratorOptions::ext)); + const auto& options = extendable_options.GetExtension( // NOLINT + GraphTensorsPacketGeneratorOptions::ext); + output_side_packets->Index(0) + .Set>>( + /* "A map of tensor tags and tensors" */); + RET_CHECK_EQ(options.tensor_tag_size(), options.tensor_num_nodes_size()); + RET_CHECK_GT(options.tensor_tag_size(), 0); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& packet_generator_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + const GraphTensorsPacketGeneratorOptions& options = + packet_generator_options.GetExtension( + GraphTensorsPacketGeneratorOptions::ext); + // Output bundle packet. + auto tensor_map = absl::make_unique>(); + + for (int i = 0; i < options.tensor_tag_size(); ++i) { + const std::string& tensor_tag = options.tensor_tag(i); + const int32 tensor_num_nodes = options.tensor_num_nodes(i); + (*tensor_map)[tensor_tag] = + tf::Tensor(tf::DT_FLOAT, tf::TensorShape{1, tensor_num_nodes}); + (*tensor_map)[tensor_tag].flat().setZero(); + } + output_side_packets->Index(0) = AdoptAsUniquePtr(tensor_map.release()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(GraphTensorsPacketGenerator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.proto b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.proto new file mode 100644 index 000000000..1196ca1ef --- /dev/null +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.proto @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message GraphTensorsPacketGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional GraphTensorsPacketGeneratorOptions ext = 142721046; + } + + // Names of tensor tags for each of the generated tensors. + // Examples are: "STATE_C_0" or "STATE_M_0". + repeated string tensor_tag = 1; + + // Must be same length as tensor_tag. Each tensor tag must be paired with the + // number of nodes. + // Tags must be capitalized, matching regex [A-Z0-9_]+. Examples: "STATE_C_0" + // and "STATE_M_0". Then, those tags can be used as the MediaPipe tags of + // tensors to initialized in TensorflowInferenceCalculator consuming + // the packet produced by this generator. For example, a mediapipe graph + // can include the node: + // packet_generator { + // packet_generator: "GraphTensorsPacketGenerator" + // output_side_packet: "init_tensors" + // options { + // [mediapipe.StateTensorsPacketGeneratorOptions.ext]: { + // tensor_tag: "STATE_C_0" + // tensor_num_nodes:128 + // tensor_tag: "STATE_M_0" + // tensor_num_nodes:128 + // } + // } + // } + repeated int32 tensor_num_nodes = 2; +} diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc new file mode 100644 index 000000000..d826ce9e3 --- /dev/null +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc @@ -0,0 +1,82 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/graph_tensors_packet_generator.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/core/framework/tensor.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +// Helper function that creates a row tensor that is initialized to zeros. +tf::Tensor ZeroRowTensor(const int col_length) { + tf::Tensor tensor(tf::DT_FLOAT, tf::TensorShape{1, col_length}); + tensor.flat().setZero(); + + return tensor; +} + +class GraphTensorsPacketGeneratorTest : public ::testing::Test { + protected: + void SetUp() override { + extendable_options_.Clear(); + generator_options_ = extendable_options_.MutableExtension( + GraphTensorsPacketGeneratorOptions::ext); + generator_options_->add_tensor_tag("A"); + generator_options_->add_tensor_num_nodes(3); + generator_options_->add_tensor_tag("B"); + generator_options_->add_tensor_num_nodes(4); + } + + void VerifyTensorMap(PacketSet* output_side_packets) { + const std::map* tensor_map = + GetFromUniquePtr>( + output_side_packets->Index(0)); + + EXPECT_FALSE(tensor_map->find("A") == tensor_map->end()); + EXPECT_FALSE(tensor_map->find("B") == tensor_map->end()); + + tf::Tensor expected_tensor = ZeroRowTensor(3); + EXPECT_EQ(expected_tensor.DebugString(), + tensor_map->find("A")->second.DebugString()); + + expected_tensor = ZeroRowTensor(4); + EXPECT_EQ(expected_tensor.DebugString(), + tensor_map->find("B")->second.DebugString()); + } + PacketGeneratorOptions extendable_options_; + GraphTensorsPacketGeneratorOptions* generator_options_; +}; + +// Test that the tensors are of the right size and shape +TEST_F(GraphTensorsPacketGeneratorTest, VerifyTensorSizeShapeAndValue) { + PacketSet inputs({}); + PacketSet outputs(1); + + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "GraphTensorsPacketGenerator", extendable_options_, inputs, &outputs); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + VerifyTensorMap(&outputs); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc new file mode 100644 index 000000000..fd109a3bd --- /dev/null +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc @@ -0,0 +1,180 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = tensorflow; + +namespace { +// Convert the ImageFrame into Tensor with floating point value type. +// The value will be normalized based on mean and stddev. +std::unique_ptr ImageFrameToNormalizedTensor( + const ImageFrame& image_frame, float mean, float stddev) { + const int cols = image_frame.Width(); + const int rows = image_frame.Height(); + const int channels = image_frame.NumberOfChannels(); + const uint8* pixel = image_frame.PixelData(); + const int width_padding = image_frame.WidthStep() - cols * channels; + auto tensor = ::absl::make_unique( + tf::DT_FLOAT, tf::TensorShape({rows, cols, channels})); + auto tensor_data = tensor->tensor(); + + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + for (int channel = 0; channel < channels; ++channel) { + tensor_data(row, col, channel) = (pixel[channel] - mean) / stddev; + } + pixel += channels; + } + pixel += width_padding; + } + return tensor; +} + +} // namespace + +// Converts ImageFrames to TensorFlow Tensors. +// +// The calculator expects one input (a packet containing an ImageFrame) and +// generates one output (a packet containing a tf::Tensor holding the same +// pixel data). The output tensor will be 3D with dimensions corresponding to +// height, width, and the number of channels (e.g. 3 for RGB or 1 for GRAY8). +// +// This calculator supports ImageFrame objects with any valid format (SRGB +// SRGBA, GRAY8, GRAY16, and VEC32F1). It will generate a Tensor using DT_UINT8 +// for the first three types, DT_UINT16 for GRAY16, and DT_FLOAT for VEC32F1. +// +// The ImageFrame data can be packed or padded. The pixel data will be copied +// to the Tensor in row-major order. +// +// Example config: +// node { +// calculator: "ImageFrameToTensorCalculator" +// input_stream: "scaled_frames" +// output_stream: "video_tensors" +// } +class ImageFrameToTensorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + ImageFrameToTensorCalculatorOptions options_; +}; +REGISTER_CALCULATOR(ImageFrameToTensorCalculator); + +::mediapipe::Status ImageFrameToTensorCalculator::GetContract( + CalculatorContract* cc) { + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // ImageFrame frame. + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output TensorFlow Tensor. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + // Inform the framework that we always output at the same timestamp + // as we receive a packet at. + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageFrameToTensorCalculator::Process( + CalculatorContext* cc) { + const Packet& input_item = cc->Inputs().Index(0).Value(); + RET_CHECK(!input_item.IsEmpty()) << "Input cannot be empty."; + + // Extract the ImageFrame and metadata from the input packet. + const ImageFrame& video_frame = input_item.Get(); + const int bytes_per_pixel = video_frame.ByteDepth(); + + std::unique_ptr tensor; + if (options_.has_data_type()) { + RET_CHECK_EQ(bytes_per_pixel, 1) << "Unsupported image format (" + << bytes_per_pixel << " bytes per pixel)"; + const tf::DataType data_type = options_.data_type(); + RET_CHECK_EQ(data_type, tf::DT_FLOAT) + << "Unsupported data type " << data_type; + RET_CHECK_GT(options_.stddev(), 0.0f); + tensor = ImageFrameToNormalizedTensor(video_frame, options_.mean(), + options_.stddev()); + } else { + const int height = video_frame.Height(); + const int width = video_frame.Width(); + const int num_channels = video_frame.NumberOfChannels(); + const int num_components = width * height * num_channels; + tf::TensorShape tensor_shape({height, width, num_channels}); + + // Use uint8 uint16, or float as the TF type depending on bpp of ImageFrame. + tf::DataType data_type; + if (bytes_per_pixel == 1) { + data_type = tf::DT_UINT8; + } else if (bytes_per_pixel == 2) { + data_type = tf::DT_UINT16; + } else if (bytes_per_pixel == 4) { + data_type = tf::DT_FLOAT; + } else { + return ::mediapipe::InvalidArgumentError(absl::StrCat( + "Unsupported image format (", bytes_per_pixel, " bytes per pixel)")); + } + + // This failure should never trigger, but it protects the code against + // internal TF changes. + RET_CHECK(tf::DataTypeCanUseMemcpy(data_type)) + << "Tensor data type does not support memcpy (type=" << data_type + << ")"; + + // Create the output tensor. + tensor = ::absl::make_unique(data_type, tensor_shape); + + // Copy pixel data from the ImageFrame to the tensor. + if (data_type == tf::DT_UINT8) { + uint8* dst = tensor->flat().data(); + video_frame.CopyToBuffer(dst, num_components); + } else if (data_type == tf::DT_UINT16) { + uint16* dst = tensor->flat().data(); + video_frame.CopyToBuffer(dst, num_components); + } else { + float* dst = tensor->flat().data(); + video_frame.CopyToBuffer(dst, num_components); + } + } + + cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto new file mode 100644 index 000000000..0e5a47716 --- /dev/null +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto @@ -0,0 +1,37 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "tensorflow/core/framework/types.proto"; + +message ImageFrameToTensorCalculatorOptions { + extend CalculatorOptions { + optional ImageFrameToTensorCalculatorOptions ext = 120603667; + } + + // If set, the output tensor will be of data type specified by this field. + // Otherwise, the output tensor data type is equal to that of the input image + // frame. + optional tensorflow.DataType data_type = 1; + + // If set, the output tensor T is equal to (F - mean * J) / stddev, where F + // and J are the input image frame and the all-ones matrix of the same size, + // respectively. Otherwise, T is equal to F. + optional float mean = 2; + optional float stddev = 3; +} diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc new file mode 100644 index 000000000..925d40d25 --- /dev/null +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc @@ -0,0 +1,457 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = tensorflow; +using RandomEngine = std::mt19937_64; + +const uint8 kGray8 = 42; +const uint16 kGray16 = 4242; +const float kFloat = 42.0; +const uint kRed = 255; +const uint kGreen = 36; +const uint kBlue = 156; +const uint kAlpha = 42; + +const int kFixedNoiseWidth = 3; +const int kFixedNoiseHeight = 2; +const uint8 kFixedNoiseData[kFixedNoiseWidth * kFixedNoiseHeight * 3] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 123, 213, 156, 9, 10, 11, 255, 0, 128}; + +class ImageFrameToTensorCalculatorTest : public ::testing::Test { + protected: + // Set image_frame to a constant per-channel pix_value. + template + void SetToColor(const T* pix_value, ImageFrame* image_frame) { + const int cols = image_frame->Width(); + const int rows = image_frame->Height(); + const int channels = image_frame->NumberOfChannels(); + const int width_padding = + image_frame->WidthStep() / (sizeof(T)) - cols * channels; + T* pixel = reinterpret_cast(image_frame->MutablePixelData()); + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + for (int channel = 0; channel < channels; ++channel) { + pixel[channel] = pix_value[channel]; + } + pixel += channels; + } + pixel += width_padding; + } + } + + // Adds a packet with a solid red 8-bit RGB ImageFrame. + void AddRGBFrame(int width, int height) { + auto image_frame = + ::absl::make_unique(ImageFormat::SRGB, width, height); + const uint8 color[] = {kRed, kGreen, kBlue}; + SetToColor(color, image_frame.get()); + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with a solid red 8-bit RGBA ImageFrame. + void AddRGBAFrame(int width, int height) { + auto image_frame = + ::absl::make_unique(ImageFormat::SRGBA, width, height); + const uint8 color[] = {kRed, kGreen, kBlue, kAlpha}; + SetToColor(color, image_frame.get()); + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with a solid GRAY8 ImageFrame. + void AddGray8Frame(int width, int height) { + auto image_frame = + ::absl::make_unique(ImageFormat::GRAY8, width, height); + const uint8 gray[] = {kGray8}; + SetToColor(gray, image_frame.get()); + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with a solid GRAY16 ImageFrame. + void AddGray16Frame(int width, int height) { + auto image_frame = + ::absl::make_unique(ImageFormat::GRAY16, width, height, 1); + const uint16 gray[] = {kGray16}; + SetToColor(gray, image_frame.get()); + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with a solid VEC32F1 ImageFrame. + void AddFloatFrame(int width, int height) { + auto image_frame = + ::absl::make_unique(ImageFormat::VEC32F1, width, height, 1); + const float gray[] = {kFloat}; + SetToColor(gray, image_frame.get()); + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with an 8-bit RGB ImageFrame containing pre-determined noise. + void AddFixedNoiseRGBFrame() { + auto image_frame = ::absl::make_unique( + ImageFormat::SRGB, kFixedNoiseWidth, kFixedNoiseHeight); + + // Copy fixed noise data into the ImageFrame. + const uint8* src = kFixedNoiseData; + uint8* pixels = image_frame->MutablePixelData(); + for (int y = 0; y < kFixedNoiseHeight; ++y) { + uint8* row = pixels + y * image_frame->WidthStep(); + std::memcpy(row, src, kFixedNoiseWidth * 3); + src += kFixedNoiseWidth * 3; + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + // Adds a packet with an 8-bit RGB ImageFrame containing random noise. + void AddRandomRGBFrame(int width, int height, uint32 seed) { + RandomEngine random(seed); + std::uniform_int_distribution uniform_dist{ + 0, std::numeric_limits::max()}; + auto image_frame = + ::absl::make_unique(ImageFormat::SRGB, width, height); + + // Copy "noisy data" into the ImageFrame. + const int num_components_per_row = width * image_frame->NumberOfChannels(); + uint8* pixels = image_frame->MutablePixelData(); + for (int y = 0; y < kFixedNoiseHeight; ++y) { + uint8* p = pixels + y * image_frame->WidthStep(); + for (int i = 0; i < num_components_per_row; ++i) { + p[i] = uniform_dist(random); + } + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + } + + std::unique_ptr runner_; +}; + +TEST_F(ImageFrameToTensorCalculatorTest, SolidRedRGBFrame) { + // Check two widths to cover packed and padded ImageFrame. + const int num_widths = 2; + const int widths[num_widths] = {10, 24}; + const int height = 5; + for (int width_index = 0; width_index < num_widths; ++width_index) { + const int width = widths[width_index]; + const int num_pixels = width * height; + + // Run the calculator and verify that one output is generated. + runner_ = ::absl::make_unique( + "ImageFrameToTensorCalculator", "", 1, 1, 0); + AddRGBFrame(width, height); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT8, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(3, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const uint8* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kRed, pixels[0]); + ASSERT_EQ(kGreen, pixels[1]); + ASSERT_EQ(kBlue, pixels[2]); + pixels += 3; + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, SolidRedRGBAFrame) { + // Check two widths to cover packed and padded ImageFrame. + const int num_widths = 2; + const int widths[num_widths] = {10, 24}; + const int height = 5; + for (int width_index = 0; width_index < num_widths; ++width_index) { + const int width = widths[width_index]; + const int num_pixels = width * height; + + // Run the calculator and verify that one output is generated. + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddRGBAFrame(width, height); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT8, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(4, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const uint8* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kRed, pixels[0]); + ASSERT_EQ(kGreen, pixels[1]); + ASSERT_EQ(kBlue, pixels[2]); + ASSERT_EQ(kAlpha, pixels[3]); + pixels += 4; + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, SolidGray8Frame) { + // Check two widths to cover packed and padded ImageFrame. + const int num_widths = 2; + const int widths[num_widths] = {10, 24}; + const int height = 5; + for (int width_index = 0; width_index < num_widths; ++width_index) { + const int width = widths[width_index]; + const int num_pixels = width * height; + + // Run the calculator and verify that one output is generated. + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddGray8Frame(width, height); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT8, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(1, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const uint8* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kGray8, pixels[0]); + ++pixels; + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, SolidGray16Frame) { + // Check two widths to cover packed and padded ImageFrame. + const int num_widths = 2; + const int widths[num_widths] = {10, 24}; + const int height = 5; + for (int width_index = 0; width_index < num_widths; ++width_index) { + const int width = widths[width_index]; + const int num_pixels = width * height; + + // Run the calculator and verify that one output is generated. + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddGray16Frame(width, height); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT16, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(1, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const uint16* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kGray16, pixels[0]); + ++pixels; + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, SolidFloatFrame) { + // Check two widths to cover packed and padded ImageFrame. + const int num_widths = 2; + const int widths[num_widths] = {10, 24}; + const int height = 5; + for (int width_index = 0; width_index < num_widths; ++width_index) { + const int width = widths[width_index]; + const int num_pixels = width * height; + + // Run the calculator and verify that one output is generated. + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddFloatFrame(width, height); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_FLOAT, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(1, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const float* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kFloat, pixels[0]); + ++pixels; + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, FixedNoiseRGBFrame) { + // Run the calculator and verify that one output is generated. + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddFixedNoiseRGBFrame(); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT8, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(kFixedNoiseHeight, shape.dim_size(0)); + ASSERT_EQ(kFixedNoiseWidth, shape.dim_size(1)); + ASSERT_EQ(3, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + const int num_pixels = kFixedNoiseWidth * kFixedNoiseHeight; + const uint8* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + ASSERT_EQ(kFixedNoiseData[i], pixels[i]); + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, RandomRGBFrame) { + // Run the calculator and verify that one output is generated. + const uint32 seed = 1234; + const int height = 2; + for (int width = 1; width <= 33; ++width) { + runner_.reset( + new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0)); + AddRandomRGBFrame(width, height, seed); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that tensor is 3-dimensional + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_UINT8, tensor.dtype()); + + // Verify that each dimension has the correct size / number of channels. + const tf::TensorShape& shape = tensor.shape(); + ASSERT_EQ(height, shape.dim_size(0)); + ASSERT_EQ(width, shape.dim_size(1)); + ASSERT_EQ(3, shape.dim_size(2)); + + // Verify that the data in the tensor is correct. + RandomEngine random(seed); + std::uniform_int_distribution uniform_dist{ + 0, std::numeric_limits::max()}; + const int num_pixels = width * height; + const uint8* pixels = + reinterpret_cast(tensor.tensor_data().data()); + for (int i = 0; i < num_pixels; ++i) { + const uint8 expected = uniform_dist(random); + ASSERT_EQ(expected, pixels[i]); + } + } +} + +TEST_F(ImageFrameToTensorCalculatorTest, FixedRGBFrameWithMeanAndStddev) { + runner_ = ::absl::make_unique( + "ImageFrameToTensorCalculator", + "[mediapipe.ImageFrameToTensorCalculatorOptions.ext]" + "{data_type:DT_FLOAT mean:128.0 stddev:128.0}", + 1, 1, 0); + + // Create a single pixel image of fixed color #0080ff. + auto image_frame = ::absl::make_unique(ImageFormat::SRGB, 1, 1); + const uint8 color[] = {0, 128, 255}; + SetToColor(color, image_frame.get()); + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const auto& tensor = runner_->Outputs().Index(0).packets[0].Get(); + EXPECT_EQ(tensor.dtype(), tf::DT_FLOAT); + ASSERT_EQ(tensor.dims(), 3); + EXPECT_EQ(tensor.shape().dim_size(0), 1); + EXPECT_EQ(tensor.shape().dim_size(1), 1); + EXPECT_EQ(tensor.shape().dim_size(2), 3); + const float* actual = tensor.flat().data(); + EXPECT_EQ(actual[0], -1.0f); // ( 0 - 128) / 128 + EXPECT_EQ(actual[1], 0.0f); // (128 - 128) / 128 + EXPECT_EQ(actual[2], 127.0f / 128.0f); // (255 - 128) / 128 +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc new file mode 100644 index 000000000..78ee50871 --- /dev/null +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc @@ -0,0 +1,154 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/profiler/circular_buffer.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace mediapipe { + +namespace tf = tensorflow; + +// Given an input stream of tensors, concatenates the tensors over timesteps. +// The concatenated output tensors can be specified to have overlap between +// output timesteps. The tensors are concatenated along the first dimension, and +// a flag controls whether a new first dimension is inserted before +// concatenation. +// +// Currently, the number of tensors output will be buffer_size less than the +// number of input tensors because no padding is implemented and only full +// buffers are output. +// +// The timestamp of the output batch will match the timestamp of the first +// tensor in that batch by default. (e.g. when buffer_size frames are added, the +// output tensor will have the timestamp of the first input.). This behavior can +// be adjusted by the timestamp_offset option. +// +// Example config: +// node { +// calculator: "LappedTensorBufferCalculator" +// input_stream: "input_tensor" +// output_stream: "output_tensor" +// options { +// [mediapipe.LappedTensorBufferCalculatorOptions.ext] { +// buffer_size: 2 +// overlap: 1 +// add_batch_dim_to_tensors: false +// } +// } +// } +class LappedTensorBufferCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Adds a batch dimension to the input tensor if specified in the calculator + // options. + ::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor); + + int steps_until_output_; + std::unique_ptr> timestamp_buffer_; + std::unique_ptr> buffer_; + LappedTensorBufferCalculatorOptions options_; +}; +REGISTER_CALCULATOR(LappedTensorBufferCalculator); + +::mediapipe::Status LappedTensorBufferCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // tensorflow::Tensor stream. + ); + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output tensorflow::Tensor stream with possibly overlapping steps. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + RET_CHECK_LT(options_.overlap(), options_.buffer_size()); + RET_CHECK_GE(options_.timestamp_offset(), 0) + << "Negative timestamp_offset is not allowed."; + RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size()) + << "output_frame_num_offset has to be less than buffer_size."; + timestamp_buffer_ = + absl::make_unique>(options_.buffer_size()); + buffer_ = + absl::make_unique>(options_.buffer_size()); + steps_until_output_ = options_.buffer_size(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LappedTensorBufferCalculator::Process( + CalculatorContext* cc) { + // These are cheap, shallow copies. + tensorflow::Tensor input_tensor( + cc->Inputs().Index(0).Get()); + if (options_.add_batch_dim_to_tensors()) { + RET_CHECK_OK(AddBatchDimension(&input_tensor)); + } + buffer_->push_back(input_tensor); + timestamp_buffer_->push_back(cc->InputTimestamp()); + --steps_until_output_; + + if (steps_until_output_ <= 0) { + auto concatenated = ::absl::make_unique(); + + const tf::Status concat_status = tf::tensor::Concat( + std::vector(buffer_->begin(), buffer_->end()), + concatenated.get()); + RET_CHECK(concat_status.ok()) << concat_status.ToString(); + + cc->Outputs().Index(0).Add( + concatenated.release(), + timestamp_buffer_->Get(options_.timestamp_offset())); + + steps_until_output_ = options_.buffer_size() - options_.overlap(); + } + return ::mediapipe::OkStatus(); +} + +// Adds a batch dimension to the input tensor if specified in the calculator +// options. +::mediapipe::Status LappedTensorBufferCalculator::AddBatchDimension( + tf::Tensor* input_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(input_tensor->shape()); + new_shape.InsertDim(0, 1); + RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) + << "Could not add 0th dimension to tensor without changing its shape." + << " Current shape: " << input_tensor->shape().DebugString(); + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto new file mode 100644 index 000000000..543c65368 --- /dev/null +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.proto @@ -0,0 +1,48 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LappedTensorBufferCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional LappedTensorBufferCalculatorOptions ext = 175222228; + } + + // The output tensor will have this many tensors concatenated together along + // their first dimension. + optional int32 buffer_size = 1; + + // The overlap determines how many input tensors are shared between frames. + // Because the input tensors may have a non-singleton first dimension, this + // is not necessarily the number of overlapping entries in the first + // dimension. + optional int32 overlap = 2; + + // If true, inserts a singleton first dimension before concatenating the + // tensors together. + optional bool add_batch_dim_to_tensors = 3 [default = false]; + + // Timestamp offset for output batch. The valid range is [0, buffer_size). + // The timestamp of the output tensor will match the timestamp of the input + // correspeonding to the offset. For example, setting to 0 will output at the + // timestamp matching the first input tensor. Setting the timestamp_offset to + // int((N-1) / 2) output at the timestamp matching the middle input tensor. + // This is useful for aligning the timestamp to be centered on the input + // range. + optional int32 timestamp_offset = 4 [default = 0]; +} diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc new file mode 100644 index 000000000..71cc6d1da --- /dev/null +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc @@ -0,0 +1,240 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class LappedTensorBufferCalculatorTest : public ::testing::Test { + protected: + void SetUpCalculator(int buffer_size, int overlap, bool add_dim, + int timestamp_offset) { + CalculatorGraphConfig::Node config; + config.set_calculator("LappedTensorBufferCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + LappedTensorBufferCalculatorOptions::ext); + options->set_buffer_size(buffer_size); + options->set_overlap(overlap); + if (add_dim) { + options->set_add_batch_dim_to_tensors(true); + } + options->set_timestamp_offset(timestamp_offset); + runner_ = ::absl::make_unique(config); + } + std::unique_ptr runner_; +}; + +TEST_F(LappedTensorBufferCalculatorTest, OneToOne) { + SetUpCalculator(1, 0, false, 0); + int num_timesteps = 3; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps, output_packets.size()); + for (int i = 0; i < num_timesteps; ++i) { + float value = output_packets[i].Get().tensor()(0); + ASSERT_NEAR(i, value, 0.0001); + } +} + +TEST_F(LappedTensorBufferCalculatorTest, OneToTwo) { + int buffer_size = 2; + int overlap = 1; + bool add_dim = false; + SetUpCalculator(buffer_size, overlap, add_dim, 0); + int num_timesteps = 3; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size()); + for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) { + for (int j = 0; j < buffer_size; ++j) { + float value = output_packets[i].Get().tensor()(j); + ASSERT_NEAR(i + j, value, 0.0001); + } + } +} + +TEST_F(LappedTensorBufferCalculatorTest, OneToThree) { + int buffer_size = 3; + int overlap = 2; + bool add_dim = false; + SetUpCalculator(buffer_size, overlap, add_dim, 0); + int num_timesteps = 3; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size()); + for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) { + for (int j = 0; j < buffer_size; ++j) { + float value = output_packets[i].Get().tensor()(j); + ASSERT_NEAR(i + j, value, 0.0001); + } + } +} + +TEST_F(LappedTensorBufferCalculatorTest, OneToThreeSkip) { + int buffer_size = 3; + int overlap = 1; + bool add_dim = false; + SetUpCalculator(buffer_size, overlap, add_dim, 0); + int num_timesteps = 3; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size()); + for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) { + for (int j = 0; j < buffer_size; ++j) { + float value = output_packets[i].Get().tensor()(j); + ASSERT_NEAR((i * 2) + j, value, 0.0001); + } + } +} + +TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatch) { + int buffer_size = 3; + int overlap = 2; + bool add_dim = true; + SetUpCalculator(buffer_size, overlap, add_dim, 0); + int num_timesteps = 3; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size()); + for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) { + for (int j = 0; j < buffer_size; ++j) { + float value = + output_packets[i].Get().tensor()(j, 0); + ASSERT_NEAR(i + j, value, 0.0001); + } + } +} + +TEST_F(LappedTensorBufferCalculatorTest, NegativeTimestampOffsetFails) { + int buffer_size = 16; + int overlap = 15; + bool add_dim = true; + int timestamp_offset = -7; + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + int num_timesteps = 20; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_FALSE(runner_->Run().ok()); +} + +TEST_F(LappedTensorBufferCalculatorTest, OutOfRangeTimestampOffsetFails) { + int buffer_size = 16; + int overlap = 15; + bool add_dim = true; + int timestamp_offset = buffer_size; + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + int num_timesteps = 20; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_FALSE(runner_->Run().ok()); +} + +TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatchTimestampOffset) { + int buffer_size = 16; + int overlap = 15; + bool add_dim = true; + int timestamp_offset = 7; + SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset); + int num_timesteps = 20; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size()); + for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) { + for (int j = 0; j < buffer_size; ++j) { + int64 value = output_packets[i].Timestamp().Value(); + ASSERT_EQ(i + timestamp_offset, value); + } + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc new file mode 100644 index 000000000..31243e133 --- /dev/null +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc @@ -0,0 +1,157 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace { +::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { + CHECK(header); + if (header_packet.IsEmpty()) { + return ::mediapipe::UnknownError("No header found."); + } + if (!header_packet.ValidateAsType().ok()) { + return ::mediapipe::UnknownError( + "Packet does not contain TimeSeriesHeader."); + } + *header = header_packet.Get(); + if (header->has_sample_rate() && header->sample_rate() >= 0 && + header->has_num_channels() && header->num_channels() >= 0) { + return ::mediapipe::OkStatus(); + } else { + std::string error_message = + "TimeSeriesHeader is missing necessary fields: " + "sample_rate or num_channels, or one of their values is negative. "; +#ifndef MEDIAPIPE_MOBILE + absl::StrAppend(&error_message, "Got header:\n", + header->ShortDebugString()); +#endif + return ::mediapipe::InvalidArgumentError(error_message); + } +} +} // namespace + +namespace tf = tensorflow; + +typedef Eigen::Matrix + RowMajorMatrixXf; +typedef Eigen::Matrix + ColMajorMatrixXf; + +// Converts an input Matrix into a 2D or 3D tf::Tensor. +// +// The calculator expects one input (a packet containing a Matrix) and +// generates one output (a packet containing a tf::Tensor containing the same +// data). The output tensor will be 2D with dimensions corresponding to the +// input matrix, while it will be 3D if add_trailing_dimension is set to true. +// The option for making the tensor be 3D is useful for using audio and image +// features for training multimodal models, so that the number of tensor +// dimensions match up. It will hold DT_FLOAT values. +// +// Example config: +// node { +// calculator: "MatrixToTensorCalculator" +// input_stream: "matrix_features" +// output_stream: "tensor_features" +// } +class MatrixToTensorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + MatrixToTensorCalculatorOptions options_; +}; +REGISTER_CALCULATOR(MatrixToTensorCalculator); + +::mediapipe::Status MatrixToTensorCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // Input Matrix stream with optional TimeSeriesHeader. + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output stream with data as tf::Tensor and the same + // TimeSeriesHeader as the input (or no header if the input has no + // header). + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { + // If the input is part of a time series, then preserve the header so that + // downstream consumers can access the sample rate if needed. + options_ = cc->Options(); + auto input_header = ::absl::make_unique(); + const ::mediapipe::Status header_status = FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), input_header.get()); + if (header_status.ok()) { + cc->Outputs().Index(0).SetHeader(Adopt(input_header.release())); + } + + // Inform the framework that we always output at the same timestamp + // as we receive a packet at. + cc->SetOffset(mediapipe::TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { + const Matrix& matrix = cc->Inputs().Index(0).Get(); + tf::TensorShape tensor_shape; + if (options_.transpose()) { + tensor_shape = tf::TensorShape({matrix.cols(), matrix.rows()}); + } else { + tensor_shape = tf::TensorShape({matrix.rows(), matrix.cols()}); + } + auto tensor = ::absl::make_unique(tf::DT_FLOAT, tensor_shape); + + float* tensor_data = tensor->flat().data(); + if (options_.transpose()) { + auto matrix_map = + Eigen::Map(tensor_data, matrix.rows(), matrix.cols()); + matrix_map = matrix; + } else { + auto matrix_map = + Eigen::Map(tensor_data, matrix.rows(), matrix.cols()); + matrix_map = matrix; + } + + if (options_.add_trailing_dimension()) { + tf::TensorShape new_shape(tensor_shape); + new_shape.AddDim(1 /* size of dimension */); + RET_CHECK(tensor->CopyFrom(*tensor, new_shape)) + << "Could not add dimension to tensor without changing its shape." + << " Current shape: " << tensor->shape().DebugString(); + } + cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.proto b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.proto new file mode 100644 index 000000000..50a1775cc --- /dev/null +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message MatrixToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional MatrixToTensorCalculatorOptions ext = 130781699; + } + + optional bool transpose = 1 [default = false]; + // Adds a 3rd dimension of size 1 when this is set to true. + optional bool add_trailing_dimension = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_test.cc new file mode 100644 index 000000000..df9e7bc44 --- /dev/null +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_test.cc @@ -0,0 +1,165 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace { + +constexpr char kTransposeOptionsString[] = + "[mediapipe.MatrixToTensorCalculatorOptions.ext]: {" + "transpose: True}"; +constexpr char kAddDimensionOptionsString[] = + "[mediapipe.MatrixToTensorCalculatorOptions.ext]: {" + "add_trailing_dimension: True}"; + +} // namespace + +namespace tf = tensorflow; +using RandomEngine = std::mt19937_64; +const uint32 kSeed = 1234; +const int kNumSizes = 8; +const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, + {5, 3}, {7, 13}, {16, 32}, {101, 2}}; + +class MatrixToTensorCalculatorTest : public ::testing::Test { + protected: + // Adds a packet with a matrix filled with random values in [0,1]. + void AddRandomMatrix(int num_rows, int num_columns, uint32 seed) { + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + auto matrix = ::absl::make_unique(); + matrix->resize(num_rows, num_columns); + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + (*matrix)(y, x) = uniform_dist(random); + } + } + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(matrix.release()).At(Timestamp(0))); + } + + std::unique_ptr runner_; +}; + +TEST_F(MatrixToTensorCalculatorTest, RandomMatrix) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + runner_ = ::absl::make_unique("MatrixToTensorCalculator", + "", 1, 1, 0); + AddRandomMatrix(num_rows, num_columns, kSeed); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that the packet contains a 2D float tensor. + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(2, tensor.dims()); + ASSERT_EQ(tf::DT_FLOAT, tensor.dtype()); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + const auto matrix = tensor.matrix(); + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + const float expected = uniform_dist(random); + ASSERT_EQ(expected, matrix(y, x)); + } + } + } +} + +TEST_F(MatrixToTensorCalculatorTest, RandomMatrixTranspose) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + runner_ = ::absl::make_unique( + "MatrixToTensorCalculator", kTransposeOptionsString, 1, 1, 0); + AddRandomMatrix(num_rows, num_columns, kSeed); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that the packet contains a 2D float tensor. + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(2, tensor.dims()); + ASSERT_EQ(tf::DT_FLOAT, tensor.dtype()); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + const auto matrix = tensor.matrix(); + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + const float expected = uniform_dist(random); + ASSERT_EQ(expected, matrix(x, y)); + } + } + } +} + +TEST_F(MatrixToTensorCalculatorTest, RandomMatrixAddDimension) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + runner_ = ::absl::make_unique( + "MatrixToTensorCalculator", kAddDimensionOptionsString, 1, 1, 0); + AddRandomMatrix(num_rows, num_columns, kSeed); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(1, output_packets.size()); + + // Verify that the packet contains a 3D float tensor. + const tf::Tensor& tensor = output_packets[0].Get(); + ASSERT_EQ(3, tensor.dims()); + ASSERT_EQ(tf::DT_FLOAT, tensor.dtype()); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + // const auto matrix = tensor.matrix(); + const float* tensor_data = tensor.flat().data(); + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + const float expected = uniform_dist(random); + ASSERT_EQ(expected, tensor_data[y * num_columns + x]); + } + } + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc new file mode 100644 index 000000000..fa4fd1035 --- /dev/null +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc @@ -0,0 +1,239 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/util/tensor_to_detection.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace mediapipe { +class CalculatorOptions; +} // namespace mediapipe + +namespace mediapipe { + +namespace tf = ::tensorflow; + +namespace { +const char kNumDetections[] = "NUM_DETECTIONS"; +const char kBoxes[] = "BOXES"; +const char kScores[] = "SCORES"; +const char kClasses[] = "CLASSES"; +const char kDetections[] = "DETECTIONS"; +const char kKeypoints[] = "KEYPOINTS"; +const char kMasks[] = "MASKS"; +const char kLabelMap[] = "LABELMAP"; +const int kNumCoordsPerBox = 4; +} // namespace + +// Takes object detection results and converts them into MediaPipe Detections. +// +// Inputs are assumed to be tensors of the form: +// `num_detections` : float32 scalar tensor indicating the number of valid +// detections. +// `detection_boxes` : float32 tensor of the form [num_boxes, 4]. Format for +// coordinates is {ymin, xmin, ymax, xmax}. +// `detection_scores` : float32 tensor of the form [num_boxes]. +// `detection_classes` : float32 tensor of the form [num_boxes]. +// `detection_keypoints`: float32 tensor of the form +// [num_boxes, num_keypoints, 2]. +// `detection_masks` : float32 tensor of the form +// [num_boxes, height, width]. +// +// These are generated according to the Vale object detector model exporter, +// which may be found in +// image/understanding/object_detection/export_inference_graph.py +// +// By default, the output Detections store label ids (integers) for each +// detection. Optionally, a label map (of the form std::map +// mapping label ids to label names as strings) can be made available as an +// input side packet, in which case the output Detections store +// labels as their associated std::string provided by the label map. +// +// Usage example: +// node { +// calculator: "ObjectDetectionTensorsToDetectionsCalculator" +// input_stream: "BOXES:detection_boxes_tensor" +// input_stream: "SCORES:detection_scores_tensor" +// input_stream: "CLASSES:detection_classes_tensor" +// input_stream: "NUM_DETECTIONS:num_detections_tensor" +// output_stream: "DETECTIONS:detections" +// options: { +// [mediapipe.ObjectDetectionsTensorToDetectionsCalculatorOptions.ext]: { +// tensor_dim_to_squeeze: 0 +// } +// } +// } +class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { + public: + ObjectDetectionTensorsToDetectionsCalculator() = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag(kBoxes).Set(); + cc->Inputs().Tag(kScores).Set(); + + if (cc->Inputs().HasTag(kNumDetections)) { + cc->Inputs().Tag(kNumDetections).Set(); + } + if (cc->Inputs().HasTag(kClasses)) { + cc->Inputs().Tag(kClasses).Set(); + } + if (cc->Inputs().HasTag(kKeypoints)) { + cc->Inputs().Tag(kKeypoints).Set(); + } + + if (cc->Inputs().HasTag(kMasks)) { + cc->Inputs().Tag(kMasks).Set(); + + const auto& calculator_options = + cc->Options(); + float mask_threshold = calculator_options.mask_threshold(); + if (!(mask_threshold >= 0.0 && mask_threshold <= 1.0)) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "mask_threshold must be in range [0.0, 1.0]"; + } + } + + cc->Outputs().Tag(kDetections).Set>(); + + if (cc->InputSidePackets().HasTag(kLabelMap)) { + cc->InputSidePackets() + .Tag(kLabelMap) + .Set>>(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + if (cc->InputSidePackets().HasTag(kLabelMap)) { + label_map_ = GetFromUniquePtr>( + cc->InputSidePackets().Tag(kLabelMap)); + } + const auto& tensor_dim_to_squeeze_field = + cc->Options() + .tensor_dim_to_squeeze(); + tensor_dims_to_squeeze_ = std::vector( + tensor_dim_to_squeeze_field.begin(), tensor_dim_to_squeeze_field.end()); + std::sort(tensor_dims_to_squeeze_.rbegin(), tensor_dims_to_squeeze_.rend()); + cc->SetOffset(0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + const auto& options = + cc->Options(); + + tf::Tensor input_num_detections_tensor = + tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0})); + if (cc->Inputs().HasTag(kClasses)) { + ASSIGN_OR_RETURN( + input_num_detections_tensor, + MaybeSqueezeDims(kNumDetections, + cc->Inputs().Tag(kNumDetections).Get())); + } + if (input_num_detections_tensor.dtype() != tf::DT_INT32) { + RET_CHECK_EQ(input_num_detections_tensor.dtype(), tf::DT_FLOAT); + } + + ASSIGN_OR_RETURN( + auto input_boxes_tensor, + MaybeSqueezeDims(kBoxes, cc->Inputs().Tag(kBoxes).Get())); + RET_CHECK_EQ(input_boxes_tensor.dtype(), tf::DT_FLOAT); + + ASSIGN_OR_RETURN( + auto input_scores_tensor, + MaybeSqueezeDims(kScores, cc->Inputs().Tag(kScores).Get())); + RET_CHECK_EQ(input_scores_tensor.dtype(), tf::DT_FLOAT); + + tf::Tensor input_classes_tensor = + tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0})); + if (cc->Inputs().HasTag(kClasses)) { + ASSIGN_OR_RETURN( + input_classes_tensor, + MaybeSqueezeDims(kClasses, + cc->Inputs().Tag(kClasses).Get())); + } + RET_CHECK_EQ(input_classes_tensor.dtype(), tf::DT_FLOAT); + + auto output_detections = absl::make_unique>(); + + const tf::Tensor& input_keypoints_tensor = + cc->Inputs().HasTag(kKeypoints) + ? cc->Inputs().Tag(kKeypoints).Get() + : tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0})); + + const tf::Tensor& input_masks_tensor = + cc->Inputs().HasTag(kMasks) + ? cc->Inputs().Tag(kMasks).Get() + : tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0})); + RET_CHECK_EQ(input_masks_tensor.dtype(), tf::DT_FLOAT); + + const std::map label_map = + (label_map_ == nullptr) ? std::map{} : *label_map_; + + RET_CHECK_OK(TensorsToDetections( + input_num_detections_tensor, input_boxes_tensor, input_scores_tensor, + input_classes_tensor, input_keypoints_tensor, input_masks_tensor, + options.mask_threshold(), label_map, output_detections.get())); + + cc->Outputs() + .Tag(kDetections) + .Add(output_detections.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); + } + + private: + std::map* label_map_; + std::vector tensor_dims_to_squeeze_; + + ::mediapipe::StatusOr MaybeSqueezeDims( + const std::string& tensor_tag, const tf::Tensor& input_tensor) { + if (tensor_dims_to_squeeze_.empty()) { + return input_tensor; + } + tf::TensorShape tensor_shape = input_tensor.shape(); + for (const int dim : tensor_dims_to_squeeze_) { + RET_CHECK_GT(tensor_shape.dims(), dim) + << "Dimension " << dim + << " does not exist in input tensor with num dimensions " + << input_tensor.dims() << " dims"; + RET_CHECK_EQ(tensor_shape.dim_size(dim), 1) + << "Cannot remove dimension " << dim << " with size " + << tensor_shape.dim_size(dim); + tensor_shape.RemoveDim(dim); + } + tf::Tensor output_tensor; + RET_CHECK(output_tensor.CopyFrom(input_tensor, tensor_shape)); + return std::move(output_tensor); + } +}; + +REGISTER_CALCULATOR(ObjectDetectionTensorsToDetectionsCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.proto b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.proto new file mode 100644 index 000000000..149144684 --- /dev/null +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.proto @@ -0,0 +1,35 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The option proto for the ObjectDetectionsTensorToDetectionsCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message ObjectDetectionsTensorToDetectionsCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional ObjectDetectionsTensorToDetectionsCalculatorOptions ext = + 192676232; + } + + // The threshold used to compute the binary segmentation mask. + optional float mask_threshold = 1 [default = 0.0]; + + // The specific singleton dimensions to squeeze (remove). The calculator can + // only removes dimensions of size 1. + repeated int32 tensor_dim_to_squeeze = 2; +} diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator_test.cc b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator_test.cc new file mode 100644 index 000000000..adce27040 --- /dev/null +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator_test.cc @@ -0,0 +1,355 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; +namespace { +const char kNumDetections[] = "NUM_DETECTIONS"; +const char kBoxes[] = "BOXES"; +const char kScores[] = "SCORES"; +const char kClasses[] = "CLASSES"; +const char kKeypoints[] = "KEYPOINTS"; +const char kDetections[] = "DETECTIONS"; +const int kNumBoxes = 3; +const int kNumClasses = 4; +const int kNumCoordsPerBox = 4; +const int kNumKeypointsPerBox = 2; +const int kNumCoordsPerKeypoint = 2; + +class ObjectDetectionTensorsToDetectionsCalculatorTest + : public ::testing::Test { + protected: + void SetUp() override { SetUpInputs(); } + + void SetUpInputs() { + input_num_detections_ = tf::test::AsTensor({kNumBoxes}, {1}); + // {ymin, xmin, ymax, xmax} format. + input_boxes_ = + tf::test::AsTensor({0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.1f, 0.2f, 0.3f, 0.4f}, + {kNumBoxes, kNumCoordsPerBox}); + input_scores_ = tf::test::AsTensor({0.1f, 0.5f, 1.0f}, {kNumBoxes}); + input_scores_for_all_classes_ = + tf::test::AsTensor({0.0f, 0.1f, 0.05f, 0.02f, 0.0f, 0.1f, 0.5f, + 0.2f, 0.0f, 0.5f, 0.8f, 1.0f}, + {kNumBoxes, kNumClasses}); + input_classes_ = tf::test::AsTensor({1.0, 2.0, 3.0}, {kNumBoxes}); + input_keypoints_ = tf::test::AsTensor( + {0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f, + 0.5f}, + {kNumBoxes, kNumKeypointsPerBox, kNumCoordsPerKeypoint}); + } + + void CreateNodeConfig(CalculatorGraphConfig::Node* node_config) const { + ASSERT_NE(nullptr, node_config); + *node_config = ParseTextProtoOrDie(R"( + calculator: "ObjectDetectionTensorsToDetectionsCalculator" + input_stream: "NUM_DETECTIONS:num_detections" + input_stream: "BOXES:boxes" + input_stream: "SCORES:scores" + input_stream: "CLASSES:classes" + output_stream: "DETECTIONS:detections" + )"); + } + + void CreateNodeConfigRawTensors( + CalculatorGraphConfig::Node* node_config) const { + ASSERT_NE(nullptr, node_config); + *node_config = ParseTextProtoOrDie(R"( + calculator: "ObjectDetectionTensorsToDetectionsCalculator" + input_stream: "BOXES:raw_detection_boxes" + input_stream: "SCORES:raw_detection_scores" + output_stream: "DETECTIONS:detections" + )"); + } + + void CreateNodeConfigWithKeypoints( + CalculatorGraphConfig::Node* node_config) const { + ASSERT_NE(nullptr, node_config); + *node_config = ParseTextProtoOrDie(R"( + calculator: "ObjectDetectionTensorsToDetectionsCalculator" + input_stream: "NUM_DETECTIONS:num_detections" + input_stream: "BOXES:boxes" + input_stream: "SCORES:scores" + input_stream: "CLASSES:classes" + input_stream: "KEYPOINTS:keypoints" + output_stream: "DETECTIONS:detections" + )"); + } + + void SetUpCalculatorRunner() { + CalculatorGraphConfig::Node node_config; + CreateNodeConfig(&node_config); + runner_ = absl::make_unique(node_config); + } + + void SetUpCalculatorRunnerRawTensors() { + CalculatorGraphConfig::Node node_config; + CreateNodeConfigRawTensors(&node_config); + runner_ = absl::make_unique(node_config); + } + + void SetUpCalculatorRunnerWithKeypoints() { + CalculatorGraphConfig::Node node_config; + CreateNodeConfigWithKeypoints(&node_config); + runner_ = absl::make_unique(node_config); + } + + void RunCalculator() { + SetUpCalculatorRunner(); + runner_->MutableInputs() + ->Tag(kNumDetections) + .packets.push_back( + PointToForeign(&input_num_detections_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kBoxes).packets.push_back( + PointToForeign(&input_boxes_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kScores).packets.push_back( + PointToForeign(&input_scores_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kClasses).packets.push_back( + PointToForeign(&input_classes_).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size()); + } + + void RunCalculatorRawTensors() { + SetUpCalculatorRunnerRawTensors(); + runner_->MutableInputs()->Tag(kBoxes).packets.push_back( + PointToForeign(&input_boxes_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kScores).packets.push_back( + PointToForeign(&input_scores_for_all_classes_) + .At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size()); + } + + void RunCalculatorWithKeypoints() { + SetUpCalculatorRunnerWithKeypoints(); + runner_->MutableInputs() + ->Tag(kNumDetections) + .packets.push_back( + PointToForeign(&input_num_detections_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kBoxes).packets.push_back( + PointToForeign(&input_boxes_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kScores).packets.push_back( + PointToForeign(&input_scores_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kClasses).packets.push_back( + PointToForeign(&input_classes_).At(Timestamp::PostStream())); + runner_->MutableInputs() + ->Tag(kKeypoints) + .packets.push_back( + PointToForeign(&input_keypoints_).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size()); + } + + void RunCalculatorWithTensorDimensionSqueezing() { + InsertExtraSingltonDim(&input_num_detections_); + InsertExtraSingltonDim(&input_boxes_); + InsertExtraSingltonDim(&input_scores_); + InsertExtraSingltonDim(&input_classes_); + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "ObjectDetectionTensorsToDetectionsCalculator" + input_stream: "NUM_DETECTIONS:num_detections" + input_stream: "BOXES:boxes" + input_stream: "SCORES:scores" + input_stream: "CLASSES:classes" + output_stream: "DETECTIONS:detections" + options: { + [mediapipe.ObjectDetectionsTensorToDetectionsCalculatorOptions + .ext]: { tensor_dim_to_squeeze: 0 } + } + )"); + runner_ = absl::make_unique(node_config); + runner_->MutableInputs() + ->Tag(kNumDetections) + .packets.push_back( + PointToForeign(&input_num_detections_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kBoxes).packets.push_back( + PointToForeign(&input_boxes_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kScores).packets.push_back( + PointToForeign(&input_scores_).At(Timestamp::PostStream())); + runner_->MutableInputs()->Tag(kClasses).packets.push_back( + PointToForeign(&input_classes_).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size()); + } + + void InsertExtraSingltonDim(tf::Tensor* tensor) { + tf::TensorShape new_shape(tensor->shape()); + new_shape.InsertDim(0, 1); + ASSERT_TRUE(tensor->CopyFrom(*tensor, new_shape)); + } + + std::unique_ptr runner_; + + tf::Tensor input_num_detections_; + tf::Tensor input_boxes_; + tf::Tensor input_scores_; + tf::Tensor input_scores_for_all_classes_; + tf::Tensor input_classes_; + tf::Tensor input_keypoints_; +}; + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, OutputsDetections) { + RunCalculator(); + EXPECT_EQ(kNumBoxes, runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>() + .size()); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + OutputsDetectionsFromRawTensors) { + RunCalculatorRawTensors(); + EXPECT_EQ(kNumBoxes, runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>() + .size()); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + OutputsDetectionsWithKeypoints) { + RunCalculatorWithKeypoints(); + EXPECT_EQ(kNumBoxes, runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>() + .size()); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + OutputsDetectionsWithCorrectValues) { + RunCalculator(); + const std::vector detections = runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>(); + EXPECT_EQ(kNumBoxes, detections.size()); + for (const auto& detection : detections) { + LocationData::RelativeBoundingBox relative_bbox = + detection.location_data().relative_bounding_box(); + EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin()); + EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.width()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.height()); + } + EXPECT_FLOAT_EQ(0.1f, detections[0].score(0)); + EXPECT_FLOAT_EQ(0.5f, detections[1].score(0)); + EXPECT_FLOAT_EQ(1.0f, detections[2].score(0)); + EXPECT_EQ(1, detections[0].label_id(0)); + EXPECT_EQ(2, detections[1].label_id(0)); + EXPECT_EQ(3, detections[2].label_id(0)); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + OutputsDetectionsFromRawTensorsWithCorrectValues) { + RunCalculatorRawTensors(); + const std::vector detections = runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>(); + EXPECT_EQ(kNumBoxes, detections.size()); + for (const auto& detection : detections) { + LocationData::RelativeBoundingBox relative_bbox = + detection.location_data().relative_bounding_box(); + EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin()); + EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.width()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.height()); + } + EXPECT_FLOAT_EQ(0.1f, detections[0].score(0)); + EXPECT_FLOAT_EQ(0.5f, detections[1].score(0)); + EXPECT_FLOAT_EQ(1.0f, detections[2].score(0)); + EXPECT_EQ(1, detections[0].label_id(0)); + EXPECT_EQ(2, detections[1].label_id(0)); + EXPECT_EQ(3, detections[2].label_id(0)); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + OutputsDetectionsWithKeypointsAndCorrectValues) { + RunCalculatorWithKeypoints(); + const std::vector detections = runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>(); + EXPECT_EQ(kNumBoxes, detections.size()); + for (const auto& detection : detections) { + LocationData::RelativeBoundingBox relative_bbox = + detection.location_data().relative_bounding_box(); + EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin()); + EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.width()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.height()); + for (const auto& relative_keypoint : + detection.location_data().relative_keypoints()) { + EXPECT_FLOAT_EQ(0.5, relative_keypoint.x()); + EXPECT_FLOAT_EQ(0.6, relative_keypoint.y()); + } + } + EXPECT_FLOAT_EQ(0.1f, detections[0].score(0)); + EXPECT_FLOAT_EQ(0.5f, detections[1].score(0)); + EXPECT_FLOAT_EQ(1.0f, detections[2].score(0)); + EXPECT_EQ(1, detections[0].label_id(0)); + EXPECT_EQ(2, detections[1].label_id(0)); + EXPECT_EQ(3, detections[2].label_id(0)); +} + +TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, + SqueezesInputTensorDimensionAndOutputsDetectionsWithCorrectValues) { + RunCalculatorWithTensorDimensionSqueezing(); + const std::vector detections = runner_->Outputs() + .Tag(kDetections) + .packets[0] + .Get>(); + EXPECT_EQ(kNumBoxes, detections.size()); + for (const auto& detection : detections) { + LocationData::RelativeBoundingBox relative_bbox = + detection.location_data().relative_bounding_box(); + EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin()); + EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.width()); + EXPECT_FLOAT_EQ(0.2, relative_bbox.height()); + } + EXPECT_FLOAT_EQ(0.1f, detections[0].score(0)); + EXPECT_FLOAT_EQ(0.5f, detections[1].score(0)); + EXPECT_FLOAT_EQ(1.0f, detections[2].score(0)); + EXPECT_EQ(1, detections[0].label_id(0)); + EXPECT_EQ(2, detections[1].label_id(0)); + EXPECT_EQ(3, detections[2].label_id(0)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc new file mode 100644 index 000000000..9a6c7c97d --- /dev/null +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -0,0 +1,403 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/match.h" +#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/sequence/media_sequence.h" +#include "mediapipe/util/sequence/media_sequence_util.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" + +namespace mediapipe { + +const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; +const char kImageTag[] = "IMAGE"; +const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; +const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; +const char kBBoxTag[] = "BBOX"; +const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION"; + +namespace tf = ::tensorflow; +namespace mpms = ::mediapipe::mediasequence; + +// Sink calculator to package streams into tf.SequenceExamples. +// +// The calculator takes a tf.SequenceExample as a side input and then adds +// the data from inputs to the SequenceExample with timestamps. Additional +// context features can be supplied verbatim in the calculator's options. The +// SequenceExample will conform to the description in media_sequence.h. +// +// The supported input stream tags are "IMAGE", which stores the encoded +// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which +// stores the encoded optical flow from the same calculator, "BBOX" which stores +// bounding boxes from vector, and streams with the +// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's +// associated with the name ${NAME}. Audio streams (i.e. Matrix with a +// TimeSeriesHeader) are given extra packing and unpacking support and are named +// similar to floats with the pattern "AUDIO_${NAME}". "IMAGE_${NAME}" and +// "BBOX_${NAME}" will also store prefixed versions of each stream, which allows +// for multiple image streams to be included. However, the default names are +// suppored by more tools. "ENCODED_MEDIA" stores a video encoding for the clip +// directly. The last packet on this stream is stored, and can be unpacked with +// the metadata generator. Because the media decoder always starts with +// timestamp zero, the "ENCODED_MEDIA_START_TIMESTAMP" should be recorded as +// well. Use the FirstTimestampCalculator to determine this value. +// +// Example config: +// node { +// calculator: "PackMediaSequenceCalculator" +// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet" +// input_stream: "IMAGE:frames" +// input_stream: "FLOAT_FEATURE_FDENSE:fdense_vf" +// output_stream: "SEQUENCE_EXAMPLE:example_output_stream" +// options { +// [mediapipe.PackMediaSequenceCalculatorOptions.ext]: { +// context_feature_map { +// feature { +// key: "image/frames_per_second" +// value { +// float_list { +// value: 30.0 +// } +// } +// } +// } +// } +// } +// } +namespace { +uint8 ConvertFloatToByte(const float float_value) { + float clamped_value = MathUtil::Clamp(0.0f, 1.0f, float_value); + return static_cast(clamped_value * 255.0 + .5f); +} +} // namespace + +class PackMediaSequenceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); + cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); + + if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { + cc->Inputs() + .Tag(kForwardFlowEncodedTag) + .Set(); + } + if (cc->Inputs().HasTag(kSegmentationMaskTag)) { + cc->Inputs().Tag(kSegmentationMaskTag).Set>(); + } + + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kImageTag)) { + std::string key = ""; + if (tag != kImageTag) { + int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kImageTag)_?" + } + } + cc->Inputs().Tag(tag).Set(); + } + if (absl::StartsWith(tag, kBBoxTag)) { + std::string key = ""; + if (tag != kBBoxTag) { + int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kBBoxTag)_?" + } + } + cc->Inputs().Tag(tag).Set>(); + } + if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } + } + + CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || + cc->OutputSidePackets().HasTag(kSequenceExampleTag)) + << "Neither the output stream nor the output side packet is set to " + "output the sequence example."; + if (cc->Outputs().HasTag(kSequenceExampleTag)) { + cc->Outputs().Tag(kSequenceExampleTag).Set(); + } + if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) { + cc->OutputSidePackets() + .Tag(kSequenceExampleTag) + .Set(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + sequence_ = ::absl::make_unique( + cc->InputSidePackets() + .Tag(kSequenceExampleTag) + .Get()); + + const auto& context_features = + cc->Options().context_feature_map(); + for (const auto& feature : context_features.feature()) { + *mpms::MutableContext(feature.first, sequence_.get()) = feature.second; + } + for (const auto& tag : cc->Inputs().GetTags()) { + features_present_[tag] = false; + } + + if (cc->Options() + .GetExtension(PackMediaSequenceCalculatorOptions::ext) + .replace_data_instead_of_append()) { + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kImageTag)) { + std::string key = ""; + if (tag != kImageTag) { + int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kImageTag)_?" + } + } + mpms::ClearImageEncoded(key, sequence_.get()); + mpms::ClearImageTimestamp(key, sequence_.get()); + } + } + if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { + mpms::ClearForwardFlowEncoded(sequence_.get()); + mpms::ClearForwardFlowTimestamp(sequence_.get()); + } + + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { + std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / + sizeof(*kFloatFeaturePrefixTag) - + 1); + mpms::ClearFeatureFloats(key, sequence_.get()); + mpms::ClearFeatureTimestamp(key, sequence_.get()); + } + } + } + + if (cc->Outputs().HasTag(kSequenceExampleTag)) { + cc->Outputs() + .Tag(kSequenceExampleTag) + .SetNextTimestampBound(Timestamp::Max()); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status VerifySequence() { + std::string error_msg = "Missing features - "; + bool all_present = true; + for (auto iter : features_present_) { + if (!iter.second) { + all_present = false; + absl::StrAppend(&error_msg, iter.first, ", "); + } + } + if (all_present) { + return ::mediapipe::OkStatus(); + } else { + return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) << error_msg; + } + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + auto& options = + cc->Options().GetExtension(PackMediaSequenceCalculatorOptions::ext); + if (options.reconcile_metadata()) { + RET_CHECK_OK(mpms::ReconcileMetadata(options.reconcile_bbox_annotations(), + sequence_.get())); + } + + if (options.output_only_if_all_present()) { + ::mediapipe::Status status = VerifySequence(); + if (!status.ok()) { + cc->GetCounter(status.error_message())->Increment(); + return status; + } + } + + if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) { + cc->OutputSidePackets() + .Tag(kSequenceExampleTag) + .Set(MakePacket(*sequence_)); + } + if (cc->Outputs().HasTag(kSequenceExampleTag)) { + cc->Outputs() + .Tag(kSequenceExampleTag) + .Add(sequence_.release(), Timestamp::PostStream()); + } + sequence_.reset(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kImageTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = ""; + if (tag != kImageTag) { + int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kImageTag)_?" + } + } + const OpenCvImageEncoderCalculatorResults& image = + cc->Inputs().Tag(tag).Get(); + if (!image.has_encoded_image()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "No encoded image"; + } + mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get()); + } + } + if (cc->Inputs().HasTag(kForwardFlowEncodedTag) && + !cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) { + const OpenCvImageEncoderCalculatorResults& forward_flow = + cc->Inputs() + .Tag(kForwardFlowEncodedTag) + .Get(); + if (!forward_flow.has_encoded_image()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "No encoded forward flow"; + } + mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddForwardFlowEncoded(forward_flow.encoded_image(), + sequence_.get()); + } + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kFloatFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / + sizeof(*kFloatFeaturePrefixTag) - + 1); + mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddFeatureFloats(key, + cc->Inputs().Tag(tag).Get>(), + sequence_.get()); + } + } + for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = ""; + if (tag != kBBoxTag) { + int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kBBoxTag)_?" + } + } + std::vector predicted_locations; + std::vector predicted_class_strings; + std::vector predicted_label_ids; + for (auto& detection : + cc->Inputs().Tag(tag).Get>()) { + if (detection.location_data().format() == + LocationData::BOUNDING_BOX || + detection.location_data().format() == + LocationData::RELATIVE_BOUNDING_BOX) { + int height = mpms::GetImageHeight(*sequence_); + int width = mpms::GetImageWidth(*sequence_); + Location relative_bbox = Location::CreateRelativeBBoxLocation( + Location(detection.location_data()) + .ConvertToRelativeBBox(width, height)); + predicted_locations.push_back(relative_bbox); + if (detection.label_size() > 0) { + predicted_class_strings.push_back(detection.label(0)); + } + if (detection.label_id_size() > 0) { + predicted_label_ids.push_back(detection.label_id(0)); + } + } + } + if (!predicted_locations.empty()) { + mpms::AddBBox(key, predicted_locations, sequence_.get()); + mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + if (!predicted_class_strings.empty()) { + mpms::AddBBoxClassString(key, predicted_class_strings, + sequence_.get()); + } + if (!predicted_label_ids.empty()) { + mpms::AddBBoxClassIndex(key, predicted_label_ids, sequence_.get()); + } + } + } + } + if (cc->Inputs().HasTag(kSegmentationMaskTag) && + !cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) { + bool already_has_mask = false; + for (auto& detection : cc->Inputs() + .Tag(kSegmentationMaskTag) + .Get>()) { + if (detection.location_data().format() == LocationData::MASK) { + RET_CHECK(!already_has_mask) + << "We currently only support adding one mask per timestamp. " + << sequence_->DebugString(); + auto mask_mat_ptr = Location(detection.location_data()).GetCvMask(); + std::vector bytes; + RET_CHECK(cv::imencode(".png", *mask_mat_ptr, bytes, {})); + + std::string encoded_mask(bytes.begin(), bytes.end()); + mpms::AddClassSegmentationEncoded(encoded_mask, sequence_.get()); + mpms::AddClassSegmentationTimestamp(cc->InputTimestamp().Value(), + sequence_.get()); + // SegmentationClassLabelString is a context feature for the entire + // sequence. The values in the last detection will be saved. + mpms::SetClassSegmentationClassLabelString({detection.label(0)}, + sequence_.get()); + already_has_mask = true; + } else { + return ::mediapipe::UnimplementedError( + "Global detections and empty detections are not supported."); + } + } + } + for (const auto& tag : cc->Inputs().GetTags()) { + if (!cc->Inputs().Tag(tag).IsEmpty()) { + features_present_[tag] = true; + } + } + return ::mediapipe::OkStatus(); + } + + std::unique_ptr sequence_; + std::map features_present_; +}; +REGISTER_CALCULATOR(PackMediaSequenceCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto new file mode 100644 index 000000000..53a6f73c2 --- /dev/null +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto @@ -0,0 +1,53 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "tensorflow/core/example/feature.proto"; + +message PackMediaSequenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional PackMediaSequenceCalculatorOptions ext = 243132285; + } + + // Provide a map of strings to tf.Features to merge into the SequenceExample's + // context. Use this to add new metadata. + optional tensorflow.Features context_feature_map = 1; + + // If true, update the context for the SequenceExample features. + // (e.g. fills in the image height, width, and number of frames.) + optional bool reconcile_metadata = 2 [default = true]; + + // If true, updates the metadata for sequences with bounding boxes. This will + // align each bounding box annotation with the nearest frame and insert empty + // annotations as needed to satisfy the frame rate. + // NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS. + // If two or more annotations are closest to the same frame, then only + // the closest annotation is saved. This matches the behavior of + // downsampling images in time. + optional bool reconcile_bbox_annotations = 5 [default = true]; + + // If true, the SequenceExample is output only if all input streams are + // present. + optional bool output_only_if_all_present = 3 [default = false]; + + // If true, will remove all data from a sequence example for a corresponding + // input stream. E.g. if images are already present and an IMAGE stream is + // present, the previous images and timestamps will be removed before adding + // the new images. + optional bool replace_data_instead_of_append = 4 [default = true]; +} diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc new file mode 100644 index 000000000..d8931cfa8 --- /dev/null +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -0,0 +1,623 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/sequence/media_sequence.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" + +namespace mediapipe { +namespace { + +namespace tf = ::tensorflow; +namespace mpms = ::mediapipe::mediasequence; + +class PackMediaSequenceCalculatorTest : public ::testing::Test { + protected: + void SetUpCalculator(const std::vector& input_streams, + const tf::Features& features, + bool output_only_if_all_present, + bool replace_instead_of_append) { + CalculatorGraphConfig::Node config; + config.set_calculator("PackMediaSequenceCalculator"); + config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); + config.add_output_stream("SEQUENCE_EXAMPLE:output_sequence"); + for (const std::string& stream : input_streams) { + config.add_input_stream(stream); + } + auto options = config.mutable_options()->MutableExtension( + PackMediaSequenceCalculatorOptions::ext); + *options->mutable_context_feature_map() = features; + options->set_output_only_if_all_present(output_only_if_all_present); + options->set_replace_data_instead_of_append(replace_instead_of_append); + runner_ = ::absl::make_unique(config); + } + + std::unique_ptr runner_; +}; + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { + SetUpCalculator({"IMAGE:images"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_image_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(test_image_string); + encoded_image.set_width(2); + encoded_image.set_height(1); + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + Adopt(image_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(num_images, mpms::GetImageTimestampSize(output_sequence)); + ASSERT_EQ(num_images, mpms::GetImageEncodedSize(output_sequence)); + for (int i = 0; i < num_images; ++i) { + ASSERT_EQ(i, mpms::GetImageTimestampAt(output_sequence, i)); + ASSERT_EQ(test_image_string, mpms::GetImageEncodedAt(output_sequence, i)); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { + std::string prefix = "PREFIX"; + SetUpCalculator({"IMAGE_PREFIX:images"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_image_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(test_image_string); + encoded_image.set_width(2); + encoded_image.set_height(1); + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs() + ->Tag("IMAGE_PREFIX") + .packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(num_images, mpms::GetImageTimestampSize(prefix, output_sequence)); + ASSERT_EQ(num_images, mpms::GetImageEncodedSize(prefix, output_sequence)); + for (int i = 0; i < num_images; ++i) { + ASSERT_EQ(i, mpms::GetImageTimestampAt(prefix, output_sequence, i)); + ASSERT_EQ(test_image_string, + mpms::GetImageEncodedAt(prefix, output_sequence, i)); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) { + SetUpCalculator({"FLOAT_FEATURE_TEST:test", "FLOAT_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag("FLOAT_FEATURE_TEST") + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag("FLOAT_FEATURE_OTHER") + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureFloatsSize("OTHER", output_sequence)); + for (int i = 0; i < num_timesteps; ++i) { + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("TEST", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureFloatsAt("TEST", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 2 << i))); + ASSERT_EQ(i, mpms::GetFeatureTimestampAt("OTHER", output_sequence, i)); + ASSERT_THAT(mpms::GetFeatureFloatsAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 2 << i))); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { + tf::Features context; + (*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES"); + (*context.mutable_feature())["OTHER"].mutable_bytes_list()->add_value("NO"); + SetUpCalculator({"IMAGE:images"}, context, false, true); + + auto input_sequence = ::absl::make_unique(); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_image_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(test_image_string); + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + Adopt(image_ptr.release()).At(Timestamp(0))); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_TRUE(mpms::HasContext(output_sequence, "TEST")); + ASSERT_TRUE(mpms::HasContext(output_sequence, "OTHER")); + ASSERT_EQ(mpms::GetContext(output_sequence, "TEST").bytes_list().value(0), + "YES"); + ASSERT_EQ(mpms::GetContext(output_sequence, "OTHER").bytes_list().value(0), + "NO"); +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) { + SetUpCalculator({"FORWARD_FLOW_ENCODED:flow"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_flow_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_flow; + encoded_flow.set_encoded_image(test_flow_string); + encoded_flow.set_width(2); + encoded_flow.set_height(1); + + int num_flows = 2; + for (int i = 0; i < num_flows; ++i) { + auto flow_ptr = + ::absl::make_unique(encoded_flow); + runner_->MutableInputs() + ->Tag("FORWARD_FLOW_ENCODED") + .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(num_flows, mpms::GetForwardFlowTimestampSize(output_sequence)); + ASSERT_EQ(num_flows, mpms::GetForwardFlowEncodedSize(output_sequence)); + for (int i = 0; i < num_flows; ++i) { + ASSERT_EQ(i, mpms::GetForwardFlowTimestampAt(output_sequence, i)); + ASSERT_EQ(test_flow_string, + mpms::GetForwardFlowEncodedAt(output_sequence, i)); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { + SetUpCalculator({"BBOX_PREDICTED:detections"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + int height = 480; + int width = 640; + mpms::SetImageHeight(height, input_sequence.get()); + mpms::SetImageWidth(width, input_sequence.get()); + + int num_vectors = 2; + for (int i = 0; i < num_vectors; ++i) { + auto detections = ::absl::make_unique<::std::vector>(); + Detection detection; + detection.add_label("absolute bbox"); + detection.add_label_id(0); + detection.add_score(0.5); + Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + detection = Detection(); + detection.add_label("relative bbox"); + detection.add_label_id(1); + detection.add_score(0.75); + Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + // The mask detection should be ignored in the output. + detection = Detection(); + detection.add_label("mask"); + detection.add_score(1.0); + cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0)); + Location::CreateCvMaskLocation(image).ConvertToProto( + detection.mutable_location_data()); + detections->push_back(detection); + + runner_->MutableInputs() + ->Tag("BBOX_PREDICTED") + .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(height, mpms::GetImageHeight(output_sequence)); + ASSERT_EQ(width, mpms::GetImageWidth(output_sequence)); + ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxSize(output_sequence)); + ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxTimestampSize(output_sequence)); + ASSERT_EQ(0, mpms::GetClassSegmentationEncodedSize(output_sequence)); + ASSERT_EQ(0, mpms::GetClassSegmentationTimestampSize(output_sequence)); + for (int i = 0; i < num_vectors; ++i) { + ASSERT_EQ(i, mpms::GetPredictedBBoxTimestampAt(output_sequence, i)); + auto bboxes = mpms::GetPredictedBBoxAt(output_sequence, i); + ASSERT_EQ(2, bboxes.size()); + for (int j = 0; j < bboxes.size(); ++j) { + auto rect = bboxes[j].GetRelativeBBox(); + ASSERT_NEAR(0, rect.xmin(), 0.001); + ASSERT_NEAR(0.5, rect.ymin(), 0.001); + ASSERT_NEAR(0.5, rect.xmax(), 0.001); + ASSERT_NEAR(1.0, rect.ymax(), 0.001); + } + auto class_strings = + mpms::GetPredictedBBoxClassStringAt(output_sequence, i); + ASSERT_EQ("absolute bbox", class_strings[0]); + ASSERT_EQ("relative bbox", class_strings[1]); + auto class_indices = mpms::GetPredictedBBoxClassIndexAt(output_sequence, i); + ASSERT_EQ(0, class_indices[0]); + ASSERT_EQ(1, class_indices[1]); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { + SetUpCalculator({"CLASS_SEGMENTATION:detections"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + int height = 480; + int width = 640; + mpms::SetImageHeight(height, input_sequence.get()); + mpms::SetImageWidth(width, input_sequence.get()); + + int num_vectors = 2; + for (int i = 0; i < num_vectors; ++i) { + auto detections = ::absl::make_unique<::std::vector>(); + Detection detection; + detection = Detection(); + detection.add_label("mask"); + detection.add_score(1.0); + cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0)); + Location::CreateCvMaskLocation(image).ConvertToProto( + detection.mutable_location_data()); + + detections->push_back(detection); + + runner_->MutableInputs() + ->Tag("CLASS_SEGMENTATION") + .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + LOG(INFO) << "output_sequence: \n" << output_sequence.DebugString(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(height, mpms::GetImageHeight(output_sequence)); + ASSERT_EQ(width, mpms::GetImageWidth(output_sequence)); + ASSERT_EQ(2, mpms::GetClassSegmentationEncodedSize(output_sequence)); + ASSERT_EQ(2, mpms::GetClassSegmentationTimestampSize(output_sequence)); + for (int i = 0; i < num_vectors; ++i) { + ASSERT_EQ(i, mpms::GetClassSegmentationTimestampAt(output_sequence, i)); + } + ASSERT_THAT(mpms::GetClassSegmentationClassLabelString(output_sequence), + testing::ElementsAreArray(::std::vector({"mask"}))); +} + +TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { + SetUpCalculator( + {"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {}, + false, false); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_flow_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_flow; + encoded_flow.set_encoded_image(test_flow_string); + encoded_flow.set_width(2); + encoded_flow.set_height(1); + + int num_flows = 2; + for (int i = 0; i < num_flows; ++i) { + auto flow_ptr = + ::absl::make_unique(encoded_flow); + runner_->MutableInputs() + ->Tag("FORWARD_FLOW_ENCODED") + .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(num_flows, mpms::GetForwardFlowTimestampSize(output_sequence)); + ASSERT_EQ(num_flows, mpms::GetForwardFlowEncodedSize(output_sequence)); + for (int i = 0; i < num_flows; ++i) { + ASSERT_EQ(i, mpms::GetForwardFlowTimestampAt(output_sequence, i)); + ASSERT_EQ(test_flow_string, + mpms::GetForwardFlowEncodedAt(output_sequence, i)); + } +} + +TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { + SetUpCalculator( + {"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {}, true, + false); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_flow_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_flow; + encoded_flow.set_encoded_image(test_flow_string); + encoded_flow.set_width(2); + encoded_flow.set_height(1); + + int num_flows = 2; + for (int i = 0; i < num_flows; ++i) { + auto flow_ptr = + ::absl::make_unique(encoded_flow); + runner_->MutableInputs() + ->Tag("FORWARD_FLOW_ENCODED") + .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + ::mediapipe::Status status = runner_->Run(); + EXPECT_FALSE(status.ok()); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) { + SetUpCalculator({"IMAGE:images"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + mpms::AddImageEncoded("one", input_sequence.get()); + mpms::AddImageEncoded("two", input_sequence.get()); + mpms::AddImageTimestamp(1, input_sequence.get()); + mpms::AddImageTimestamp(2, input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(0, mpms::GetImageTimestampSize(output_sequence)); + ASSERT_EQ(0, mpms::GetImageEncodedSize(output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) { + SetUpCalculator({"FORWARD_FLOW_ENCODED:images"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + mpms::AddForwardFlowEncoded("one", input_sequence.get()); + mpms::AddForwardFlowEncoded("two", input_sequence.get()); + mpms::AddForwardFlowTimestamp(1, input_sequence.get()); + mpms::AddForwardFlowTimestamp(2, input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(0, mpms::GetForwardFlowTimestampSize(output_sequence)); + ASSERT_EQ(0, mpms::GetForwardFlowEncodedSize(output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) { + SetUpCalculator({"FLOAT_FEATURE_TEST:test", "FLOAT_FEATURE_OTHER:test2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + mpms::AddFeatureFloats("TEST", *vf_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("TEST", i, input_sequence.get()); + vf_ptr = ::absl::make_unique>(2, 2 << i); + mpms::AddFeatureFloats("OTHER", *vf_ptr, input_sequence.get()); + mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); + } + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("TEST", *input_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetFeatureFloatsSize("OTHER", *input_sequence)); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("TEST", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(0, mpms::GetFeatureFloatsSize("OTHER", output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { + SetUpCalculator({"IMAGE:images"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_image_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(test_image_string); + encoded_image.set_width(2); + encoded_image.set_height(1); + + int num_images = 5; // Timestamps: 10, 20, 30, 40, 50 + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10))); + } + + mpms::AddBBoxTimestamp(9, input_sequence.get()); + mpms::AddBBoxTimestamp(21, input_sequence.get()); + mpms::AddBBoxTimestamp(22, input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(mpms::GetBBoxTimestampSize(output_sequence), 5); + ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 0), 10); + ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 1), 20); + ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 2), 30); + ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 3), 40); + ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 4), 50); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc new file mode 100644 index 000000000..6693a0642 --- /dev/null +++ b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc @@ -0,0 +1,99 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "tensorflow/core/example/example.pb.h" + +// A calculator to serialize/deserialize tensorflow::SequenceExample protos +// to and from strings. +// +// Example converting to SequenceExample in Open(): +// node { +// calculator: "StringToSequenceExampleCalculator" +// input_side_packet: "STRING:serialized_sequence_example" +// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example" +// } +// +// Example converting to std::string in Close(): +// node { +// calculator: "StringToSequenceExampleCalculator" +// input_side_packet: "SEQUENCE_EXAMPLE:sequence_example" +// output_side_packet: "STRING:serialized_sequence_example" +// } + +namespace mediapipe { +namespace tf = ::tensorflow; +namespace { +constexpr char kString[] = "STRING"; +constexpr char kSequenceExample[] = "SEQUENCE_EXAMPLE"; +} // namespace + +class StringToSequenceExampleCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; +}; + +REGISTER_CALCULATOR(StringToSequenceExampleCalculator); + +::mediapipe::Status StringToSequenceExampleCalculator::GetContract( + CalculatorContract* cc) { + if (cc->InputSidePackets().HasTag(kString)) { + cc->InputSidePackets().Tag(kString).Set(); + cc->OutputSidePackets().Tag(kSequenceExample).Set(); + } + if (cc->InputSidePackets().HasTag(kSequenceExample)) { + cc->InputSidePackets().Tag(kSequenceExample).Set(); + cc->OutputSidePackets().Tag(kString).Set(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status StringToSequenceExampleCalculator::Open( + CalculatorContext* cc) { + if (cc->InputSidePackets().HasTag(kString)) { + auto string_value = cc->InputSidePackets().Tag(kString).Get(); + auto example = absl::make_unique(); + example->ParseFromString(string_value); + cc->OutputSidePackets() + .Tag(kSequenceExample) + .Set(::mediapipe::Adopt(example.release())); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status StringToSequenceExampleCalculator::Process( + CalculatorContext* cc) { + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status StringToSequenceExampleCalculator::Close( + CalculatorContext* cc) { + if (cc->InputSidePackets().HasTag(kSequenceExample)) { + const auto& example = + cc->InputSidePackets().Tag(kSequenceExample).Get(); + auto string_value = absl::make_unique(); + example.SerializeToString(string_value.get()); + cc->OutputSidePackets().Tag(kString).Set( + ::mediapipe::Adopt(string_value.release())); + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc new file mode 100644 index 000000000..b1e4f05f0 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc @@ -0,0 +1,111 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +// Given an input Tensor (example dimensions [1, 1024, 1, 5]), it squeezes all +// dimensions with size 1, or dimensions at specific indices, producing a tensor +// containing identical data (example output dimensions [1024, 5]). +class TensorSqueezeDimensionsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Need one input"; + cc->Inputs().Index(0).Set( + // Input Tensor + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Need one output"; + cc->Outputs().Index(0).Set( + // Output Tensor Reduced Dimensions + ); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + RET_CHECK(options_.squeeze_all_single_dims() ^ (options_.dim_size() > 0)) + << "Must specify dimensions to remove, or set squeeze_all_single_dims, " + "but not both. Received options: " + << options_.DebugString(); + if (options_.dim_size() > 0) { + remove_dims_ = + std::vector(options_.dim().begin(), options_.dim().end()); + std::sort(remove_dims_.rbegin(), remove_dims_.rend()); + remove_dims_initialized_ = true; + } + cc->SetOffset(0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + const tf::Tensor& input_tensor = cc->Inputs().Index(0).Get(); + tf::TensorShape tensor_shape = input_tensor.shape(); + if (!remove_dims_initialized_) { + // Happens iff options.squeeze_all_single_dims is set. + // Initialize remove_dims_ to all dimensions with size 1. + InitializeToRemoveAllSingletonDimensions(tensor_shape); + remove_dims_initialized_ = true; + } + for (const int dim : remove_dims_) { + RET_CHECK_GT(tensor_shape.dims(), dim) + << "Dimension " << dim + << " does not exist in input tensor with num dimensions " + << input_tensor.dims(); + RET_CHECK_EQ(tensor_shape.dim_size(dim), 1) + << "Cannot remove dimension " << dim << " with size " + << tensor_shape.dim_size(dim); + tensor_shape.RemoveDim(dim); + } + + std::unique_ptr output_tensor(new tf::Tensor); + RET_CHECK(output_tensor->CopyFrom(input_tensor, tensor_shape)); + cc->Outputs().Index(0).Add(output_tensor.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + + private: + TensorSqueezeDimensionsCalculatorOptions options_; + std::vector remove_dims_; + bool remove_dims_initialized_; + + void InitializeToRemoveAllSingletonDimensions( + const tf::TensorShape& tensor_shape) { + const int dims = tensor_shape.dims(); + for (int i = dims - 1; i >= 0; --i) { + if (tensor_shape.dim_size(i) == 1) { + remove_dims_.push_back(i); + } + } + if (remove_dims_.empty()) { + LOG(ERROR) << "TensorSqueezeDimensionsCalculator is squeezing input with " + "no single-dimensions. Calculator will be a no-op."; + LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape " + << tensor_shape.DebugString(); + } + } +}; +REGISTER_CALCULATOR(TensorSqueezeDimensionsCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.proto b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.proto new file mode 100644 index 000000000..87c59780d --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Specifies options to TensorSqueezeDimensionsCalculator. Use this proto to +// configure which dimensions to squeeze (remove). It is only possible to remove +// dimensions of size 1. +// The fields 'squeeze_all_single_dims' and 'dim' are mutually exclusive. +message TensorSqueezeDimensionsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorSqueezeDimensionsCalculatorOptions ext = 115619805; + } + // Remove all dimensions with size 1. + optional bool squeeze_all_single_dims = 1 [default = false]; + // Remove specific singleton dimensions. + repeated int32 dim = 2; +} diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator_test.cc new file mode 100644 index 000000000..e3b9f7233 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator_test.cc @@ -0,0 +1,120 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class TensorSqueezeDimensionsCalculatorTest : public ::testing::Test { + protected: + TensorSqueezeDimensionsCalculatorTest() { + // Initialize tensor_ with deterministic values. + tensor_shape_ = tf::TensorShape(std::vector({1, 3, 1, 3, 1})); + tensor_ = tf::Tensor(tf::DT_INT32, tensor_shape_); + auto tensor_values = tensor_.tensor(); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + tensor_values(0, i, 0, j, 0) = i * (j + 1); + } + } + } + + std::unique_ptr runner_; + tf::TensorShape tensor_shape_; + tf::Tensor tensor_; +}; + +TEST_F(TensorSqueezeDimensionsCalculatorTest, CanSqueezeAllSingleDimensions) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorSqueezeDimensionsCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + CalculatorOptions options; + TensorSqueezeDimensionsCalculatorOptions* squeeze_options = + options.MutableExtension(TensorSqueezeDimensionsCalculatorOptions::ext); + squeeze_options->set_squeeze_all_single_dims(true); + *config.mutable_options() = options; + + runner_.reset(new CalculatorRunner(config)); + std::unique_ptr tensor_copy(new tf::Tensor); + EXPECT_TRUE(tensor_copy->CopyFrom(tensor_, tensor_shape_)); + const tf::int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor_copy.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + const tf::TensorShape expected_shape(std::vector({3, 3})); + EXPECT_EQ(expected_shape.DebugString(), output_tensor.shape().DebugString()); + const auto tensor_values = output_tensor.tensor(); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + const int expected_value = i * (j + 1); + EXPECT_EQ(expected_value, tensor_values(i, j)); + } + } +} + +TEST_F(TensorSqueezeDimensionsCalculatorTest, CanSqueezeSpecifiedDimensions) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorSqueezeDimensionsCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + CalculatorOptions options; + TensorSqueezeDimensionsCalculatorOptions* squeeze_options = + options.MutableExtension(TensorSqueezeDimensionsCalculatorOptions::ext); + squeeze_options->add_dim(0); + squeeze_options->add_dim(4); + *config.mutable_options() = options; + + runner_.reset(new CalculatorRunner(config)); + std::unique_ptr tensor_copy(new tf::Tensor); + EXPECT_TRUE(tensor_copy->CopyFrom(tensor_, tensor_shape_)); + const tf::int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor_copy.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + const tf::TensorShape expected_shape(std::vector({3, 1, 3})); + EXPECT_EQ(expected_shape.DebugString(), output_tensor.shape().DebugString()); + const auto tensor_values = output_tensor.tensor(); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + const int expected_value = i * (j + 1); + EXPECT_EQ(expected_value, tensor_values(i, 0, j)); + } + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc new file mode 100644 index 000000000..f6e4354d3 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -0,0 +1,124 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; +namespace { + +constexpr char kImage[] = "IMAGE"; +constexpr char kTensor[] = "TENSOR"; + +} // namespace + +// Input: +// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The +// shape can be HxWx{3,1} or simply HxW. +// +// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255. +// +// Output: +// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8) +// +// Possible extensions: support other input ranges, maybe 4D tensors. +class TensorToImageFrameCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + float scale_factor_; +}; + +REGISTER_CALCULATOR(TensorToImageFrameCalculator); + +::mediapipe::Status TensorToImageFrameCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "One input stream must be provided."; + RET_CHECK(cc->Inputs().HasTag(kTensor)) + << "An input stream for tag: " << kTensor << " must be provided."; + cc->Inputs().Tag(kTensor).Set( + // Input Tensor. + ); + cc->Outputs().Tag(kImage).Set( + // Output ImageFrame. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { + scale_factor_ = + cc->Options().scale_factor(); + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToImageFrameCalculator::Process( + CalculatorContext* cc) { + const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); + int32 depth = 1; + if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors. + CHECK(3 == input_tensor.dims()) + << "Only 2 or 3-D Tensors can be converted to frames. Instead got: " + << input_tensor.dims(); + depth = input_tensor.dim_size(2); + if (depth != 1) { + RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1."; + } + } + const int32 total_size = + input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth; + std::unique_ptr buffer(new uint8[total_size]); + auto data = input_tensor.flat().data(); + for (int i = 0; i < total_size; ++i) { + float d = scale_factor_ * data[i]; + if (d < 0) d = 0; + if (d > 255) d = 255; + buffer[i] = d; + } + + ::std::unique_ptr output; + if (depth == 3) { + output = ::absl::make_unique( + ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0), + input_tensor.dim_size(1) * 3, buffer.release()); + } else if (depth == 1) { + output = ::absl::make_unique( + ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0), + input_tensor.dim_size(1), buffer.release()); + } else { + return ::mediapipe::InvalidArgumentError("Unrecognized image depth."); + } + cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto new file mode 100644 index 000000000..3410068d0 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorToImageFrameCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorToImageFrameCalculatorOptions ext = 142032475; + } + + // Multiples floating point tensor outputs by this value before converting to + // uint8. This is useful for converting from range [0, 1] to [0, 255] + optional float scale_factor = 1 [default = 1.0]; +} diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc new file mode 100644 index 000000000..54e989a20 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc @@ -0,0 +1,146 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; +namespace { + +constexpr char kTensor[] = "TENSOR"; +constexpr char kImage[] = "IMAGE"; + +} // namespace + +class TensorToImageFrameCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner() { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToImageFrameCalculator"); + config.add_input_stream("TENSOR:input_tensor"); + config.add_output_stream("IMAGE:output_image"); + runner_ = absl::make_unique(config); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { + SetUpRunner(); + constexpr int kWidth = 16; + constexpr int kHeight = 8; + const tf::TensorShape tensor_shape( + std::vector{kHeight, kWidth, 3}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->flat().data(); + + // Writing sequence of integers as floats which we want back (as they were + // written). + for (int i = 0; i < kWidth * kHeight * 3; ++i) { + tensor_vec[i] = i % 255; + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kImage).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(kWidth, output_image.Width()); + EXPECT_EQ(kHeight, output_image.Height()); + + for (int i = 0; i < kWidth * kHeight * 3; ++i) { + const uint8 pixel_value = output_image.PixelData()[i]; + EXPECT_EQ(i % 255, pixel_value); + } +} + +TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { + SetUpRunner(); + constexpr int kWidth = 16; + constexpr int kHeight = 8; + const tf::TensorShape tensor_shape( + std::vector{kHeight, kWidth, 1}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->flat().data(); + + // Writing sequence of integers as floats which we want back (as they were + // written). + for (int i = 0; i < kWidth * kHeight; ++i) { + tensor_vec[i] = i % 255; + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kImage).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(kWidth, output_image.Width()); + EXPECT_EQ(kHeight, output_image.Height()); + + for (int i = 0; i < kWidth * kHeight; ++i) { + const uint8 pixel_value = output_image.PixelData()[i]; + EXPECT_EQ(i % 255, pixel_value); + } +} + +TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) { + SetUpRunner(); + constexpr int kWidth = 16; + constexpr int kHeight = 8; + const tf::TensorShape tensor_shape(std::vector{kHeight, kWidth}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->flat().data(); + + // Writing sequence of integers as floats which we want back (as they were + // written). + for (int i = 0; i < kWidth * kHeight; ++i) { + tensor_vec[i] = i % 255; + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kImage).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(kWidth, output_image.Width()); + EXPECT_EQ(kHeight, output_image.Height()); + + for (int i = 0; i < kWidth * kHeight; ++i) { + const uint8 pixel_value = output_image.PixelData()[i]; + EXPECT_EQ(i % 255, pixel_value); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc new file mode 100644 index 000000000..b061fe7b3 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -0,0 +1,227 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Calculator converts from one-dimensional Tensor of DT_FLOAT to Matrix +// OR from (batched) two-dimensional Tensor of DT_FLOAT to Matrix. + +#include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; +namespace { + +constexpr char kMatrix[] = "MATRIX"; +constexpr char kTensor[] = "TENSOR"; +constexpr char kReference[] = "REFERENCE"; + +::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { + CHECK(header); + if (header_packet.IsEmpty()) { + return ::mediapipe::UnknownError("No header found."); + } + if (!header_packet.ValidateAsType().ok()) { + return ::mediapipe::UnknownError( + "Packet does not contain TimeSeriesHeader."); + } + *header = header_packet.Get(); + if (header->has_sample_rate() && header->sample_rate() >= 0 && + header->has_num_channels() && header->num_channels() >= 0) { + return ::mediapipe::OkStatus(); + } else { + std::string error_message = + "TimeSeriesHeader is missing necessary fields: " + "sample_rate or num_channels, or one of their values is negative. "; +#ifndef MEDIAPIPE_MOBILE + absl::StrAppend(&error_message, "Got header:\n", + header->ShortDebugString()); +#endif + return ::mediapipe::InvalidArgumentError(error_message); + } +} + +} // namespace + +// Converts a 1-D or a 2-D Tensor into a 2 dimensional Matrix. +// Input: +// -- 1-D or 2-D Tensor +// Output: +// -- Matrix with the same values as the Tensor +// If input tensor is 1 dimensional, the ouput Matrix is of (1xn) shape. +// If input tensor is 2 dimensional (batched), the ouput Matrix is (mxn) shape. +// +// Example Config +// node: { +// calculator: "TensorToMatrixCalculator" +// input_stream: "TENSOR:tensor" +// output_stream: "MATRIX:matrix" +// } +// +// +// This calculator produces a TimeSeriesHeader header on its output stream iff +// an input stream is supplied with the REFERENCE tag and that stream has a +// header of type TimeSeriesHeader. This header is modified in two ways: +// - the sample_rate is set to the packet rate of the REFERENCE stream (which +// must have a packet_rate defined in its header). This is under the +// assumption that the packets on the reference stream, input stream, and +// output stream are in a 1:1 correspondence, and that the output packets are +// 1-D column vectors that represent a single sample of output. +// - the TimeSeriesHeader overrides specified in the calculator options are +// then applied, which can override the sample_rate field. +// If the REFERENCE stream is supplied, then the TimeSeriesHeader is verified on +// the input data when it arrives in Process(). In particular, if the header +// states that we produce a 1xD column vector, the input tensor must also be 1xD +// +// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and +// http://g/daredevil-project/VB26tcseUy8. +// Example Config +// node: { +// calculator: "TensorToMatrixCalculator" +// input_stream: "TENSOR:tensor" +// input_stream: "REFERENCE:reference_matrix" +// output_stream: "MATRIX:matrix" +// options { +// [mediapipe.TensorToMatrixCalculatorOptions.ext] { +// time_series_header_overrides { +// num_channels: 128 +// } +// } +// } +// } +class TensorToMatrixCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + // Store header information so that we can verify the inputs in process(). + TimeSeriesHeader header_; +}; +REGISTER_CALCULATOR(TensorToMatrixCalculator); + +::mediapipe::Status TensorToMatrixCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_LE(cc->Inputs().NumEntries(), 2) + << "Only one or two input streams are supported."; + RET_CHECK_GT(cc->Inputs().NumEntries(), 0) + << "At least one input stream must be provided."; + RET_CHECK(cc->Inputs().HasTag(kTensor)) + << "An input stream for tag: " << kTensor << " must be provided."; + cc->Inputs().Tag(kTensor).Set( + // Input Tensor. + ); + if (cc->Inputs().NumEntries() == 2) { + RET_CHECK(cc->Inputs().HasTag(kReference)) + << "An input stream for tag: " << kReference + << " must be provided when" + " providing two inputs."; + cc->Inputs() + .Tag(kReference) + .Set( + // A reference stream for the header. + ); + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Tag(kMatrix).Set( + // Output Matrix. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { + auto input_header = absl::make_unique(); + ::mediapipe::Status header_status; + if (cc->Inputs().HasTag(kReference)) { + header_status = FillTimeSeriesHeaderIfValid( + cc->Inputs().Tag(kReference).Header(), input_header.get()); + } + if (header_status.ok()) { + if (cc->Options() + .has_time_series_header_overrides()) { + // From design discussions with Daredevil, we only want to support single + // sample per packet for now, so we hardcode the sample_rate based on the + // packet_rate of the REFERENCE and fail noisily if we cannot. An + // alternative would be to calculate the sample_rate from the reference + // sample_rate and the change in num_samples between the reference and + // override headers: + // sample_rate_output = sample_rate_reference / + // (num_samples_override / num_samples_reference) + const TimeSeriesHeader& override_header = + cc->Options() + .time_series_header_overrides(); + input_header->MergeFrom(override_header); + CHECK(input_header->has_packet_rate()) + << "The TimeSeriesHeader.packet_rate must be set."; + if (!override_header.has_sample_rate()) { + CHECK_EQ(input_header->num_samples(), 1) + << "Currently the time series can only output single samples."; + input_header->set_sample_rate(input_header->packet_rate()); + } + } + header_ = *input_header; + cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release())); + } + cc->SetOffset(mediapipe::TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { + // Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK + // failures. These are absolute conditions of the graph for the graph to be + // valid, and if it is violated by any input anywhere, the graph will be + // invalid for all inputs. A hard CHECK will enable faster debugging by + // immediately exiting and more prominently displaying error messages. + // Do not replace with RET_CHECKs. + + // Verify that each reference stream packet corresponds to a tensor packet + // otherwise the header information is invalid. If we don't have a reference + // stream, Process() is only called when we have an input tensor and this is + // always True. + CHECK(cc->Inputs().HasTag(kTensor)) + << "Tensor stream not available at same timestamp as the reference " + "stream."; + + const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); + CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) + << "Only 1-D or 2-D Tensors can be converted to matrices."; + const int32 length = input_tensor.dim_size(input_tensor.dims() - 1); + const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0); + if (header_.has_num_channels()) { + CHECK_EQ(length, header_.num_channels()) + << "The number of channels at runtime does not match the header."; + } + if (header_.has_num_samples()) { + CHECK_EQ(width, header_.num_samples()) + << "The number of samples at runtime does not match the header."; + ; + } + auto output = absl::make_unique(width, length); + *output = + Eigen::MatrixXf::Map(input_tensor.flat().data(), length, width); + cc->Outputs().Tag(kMatrix).Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.proto b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.proto new file mode 100644 index 000000000..e8647700c --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.proto @@ -0,0 +1,36 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/formats/time_series_header.proto"; + +message TensorToMatrixCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorToMatrixCalculatorOptions ext = 136654056; + } + + // Any values present in this TimeSeriesHeader override the values in the + // header from the reference stream if the reference stream is used. + // The most common fields to override are the num_channels field which + // typically correspond to the dimensionality of an output embedding and + // the num_samples field which is generally 1 for an embedding. + // If the sampling_rate is not specified in the time_series_header, then + // the packet_rate from the reference stream will be used as the sampling_rate + // which assumes that num_samples is 1. + optional TimeSeriesHeader time_series_header_overrides = 1; +} diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc new file mode 100644 index 000000000..fce24b9b9 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator_test.cc @@ -0,0 +1,227 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; +namespace { + +constexpr char kMatrix[] = "MATRIX"; +constexpr char kTensor[] = "TENSOR"; + +} // namespace + +class TensorToMatrixCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner() { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToMatrixCalculator"); + config.add_input_stream("TENSOR:input_tensor"); + config.add_output_stream("MATRIX:output_matrix"); + runner_ = absl::make_unique(config); + } + + // Creates a reference stream and sets num_channels and num_samples if > 0. + void SetUpRunnerWithReference(int channels, int samples, + int override_channels, bool include_rate) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToMatrixCalculator"); + config.add_input_stream("TENSOR:input_tensor"); + config.add_input_stream("REFERENCE:reference"); + config.add_output_stream("MATRIX:output_matrix"); + if (override_channels > 0) { + config.mutable_options() + ->MutableExtension(TensorToMatrixCalculatorOptions::ext) + ->mutable_time_series_header_overrides() + ->set_num_channels(override_channels); + } + runner_ = absl::make_unique(config); + + auto header = absl::make_unique(); + header->set_sample_rate(1.0); + if (channels > 0) { + header->set_num_channels(channels); + } + if (samples > 0) { + header->set_num_samples(samples); + } + if (include_rate) { + header->set_packet_rate(1.0); + } + runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release()); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorToMatrixCalculatorTest, Converts1DTensorToMatrix) { + // This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1. + SetUpRunner(); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + tensor_vec(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kMatrix).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const Matrix& output_matrix = output_packets[0].Get(); + EXPECT_EQ(5, output_matrix.rows()); + for (int i = 0; i < 5; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_matrix(i, 0)); + } +} + +TEST_F(TensorToMatrixCalculatorTest, Converts2DTensorofWidthOneToMatrix) { + // This test converts a 2 Dimensional Tensor of shape 1xM to a Matrix of Mx1. + SetUpRunner(); + const tf::TensorShape tensor_shape(std::vector({1, 4})); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto slice = tensor->Slice(0, 1).flat(); + for (int i = 0; i < 4; ++i) { + slice(i) = static_cast(1 << i); + } + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kMatrix).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const Matrix& output_matrix = output_packets[0].Get(); + ASSERT_EQ(1, output_matrix.cols()); + EXPECT_EQ(4, output_matrix.rows()); + for (int i = 0; i < 4; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_matrix(i, 0)); + } +} + +TEST_F(TensorToMatrixCalculatorTest, Converts2DTensorToMatrix) { + // This test converts a 2 Dimensional Tensor of shape NxM to a Matrix of MxN. + SetUpRunner(); + const tf::TensorShape tensor_shape(std::vector({3, 4})); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto slice = tensor->Slice(0, 1).flat(); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 4; ++j) { + slice(i * 4 + j) = static_cast(i * j); + } + } + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kMatrix).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const Matrix& output_matrix = output_packets[0].Get(); + ASSERT_EQ(3, output_matrix.cols()); + EXPECT_EQ(4, output_matrix.rows()); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 3; ++j) { + const float expected = static_cast(i * j); + EXPECT_EQ(expected, output_matrix(i, j)); + } + } +} + +TEST_F(TensorToMatrixCalculatorTest, ConvertsWithReferenceTimeSeriesHeader) { + // This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1. + SetUpRunnerWithReference(5, 1, -1, true); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + tensor_vec(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kMatrix).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const Matrix& output_matrix = output_packets[0].Get(); + EXPECT_EQ(5, output_matrix.rows()); + for (int i = 0; i < 5; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_matrix(i, 0)); + } + + const TimeSeriesHeader& output_header = + runner_->Outputs().Tag(kMatrix).header.Get(); + EXPECT_EQ(output_header.num_channels(), 5); +} + +TEST_F(TensorToMatrixCalculatorTest, TimeSeriesOverridesWork) { + // This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1. + SetUpRunnerWithReference(7, 1, 5, true); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + tensor_vec(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Tag(kMatrix).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const Matrix& output_matrix = output_packets[0].Get(); + EXPECT_EQ(5, output_matrix.rows()); + for (int i = 0; i < 5; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_matrix(i, 0)); + } + + const TimeSeriesHeader& output_header = + runner_->Outputs().Tag(kMatrix).header.Get(); + EXPECT_EQ(output_header.num_channels(), 5); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc new file mode 100644 index 000000000..7b447f4d5 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc @@ -0,0 +1,109 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Calculator converts from one-dimensional Tensor of DT_FLOAT to vector +// OR from (batched) two-dimensional Tensor of DT_FLOAT to vector. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +class TensorToVectorFloatCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + TensorToVectorFloatCalculatorOptions options_; +}; +REGISTER_CALCULATOR(TensorToVectorFloatCalculator); + +::mediapipe::Status TensorToVectorFloatCalculator::GetContract( + CalculatorContract* cc) { + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + cc->Inputs().Index(0).Set( + // Input Tensor + ); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + const auto& options = cc->Options(); + if (options.tensor_is_2d()) { + RET_CHECK(!options.flatten_nd()); + cc->Outputs().Index(0).Set>>( + /* "Output vector>." */); + } else { + cc->Outputs().Index(0).Set>( + // Output vector. + ); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TensorToVectorFloatCalculator::Process( + CalculatorContext* cc) { + const tf::Tensor& input_tensor = + cc->Inputs().Index(0).Value().Get(); + RET_CHECK(tf::DT_FLOAT == input_tensor.dtype()) + << "expected DT_FLOAT input but got " + << tensorflow::DataTypeString(input_tensor.dtype()); + + if (options_.tensor_is_2d()) { + RET_CHECK(2 == input_tensor.dims()) + << "Expected 2-dimensional Tensor, but the tensor shape is: " + << input_tensor.shape().DebugString(); + auto output = absl::make_unique>>( + input_tensor.dim_size(0), std::vector(input_tensor.dim_size(1))); + for (int i = 0; i < input_tensor.dim_size(0); ++i) { + auto& instance_output = output->at(i); + const auto& slice = input_tensor.Slice(i, i + 1).unaligned_flat(); + for (int j = 0; j < input_tensor.dim_size(1); ++j) { + instance_output.at(j) = slice(j); + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + if (!options_.flatten_nd()) { + RET_CHECK(1 == input_tensor.dims()) + << "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the " + << "tensor shape is: " << input_tensor.shape().DebugString(); + } + auto output = + absl::make_unique>(input_tensor.NumElements()); + const auto& tensor_values = input_tensor.flat(); + for (int i = 0; i < input_tensor.NumElements(); ++i) { + output->at(i) = tensor_values(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.proto new file mode 100644 index 000000000..c9aa67f52 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorToVectorFloatCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorToVectorFloatCalculatorOptions ext = 120862965; + } + + // If true, unpack a 2d tensor (matrix) into a vector>. If + // false, convert a 1d tensor (vector) into a vector. + optional bool tensor_is_2d = 1 [default = false]; + + // If true, an N-D tensor will be flattened to a vector. This is + // exclusive with tensor_is_2d. + optional bool flatten_nd = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc new file mode 100644 index 000000000..69d3af60a --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc @@ -0,0 +1,133 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class TensorToVectorFloatCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorToVectorFloatCalculator"); + config.add_input_stream("input_tensor"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + TensorToVectorFloatCalculatorOptions::ext); + options->set_tensor_is_2d(tensor_is_2d); + options->set_flatten_nd(flatten_nd); + runner_ = absl::make_unique(config); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorToVectorFloatCalculatorTest, ConvertsToVectorFloat) { + SetUpRunner(false, false); + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + tensor_vec(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorFloatCalculatorTest, ConvertsBatchedToVectorVectorFloat) { + SetUpRunner(true, false); + const tf::TensorShape tensor_shape(std::vector{1, 5}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto slice = tensor->Slice(0, 1).flat(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + slice(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector>& output_vectors = + output_packets[0].Get>>(); + ASSERT_EQ(1, output_vectors.size()); + const std::vector& output_vector = output_vectors[0]; + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) { + SetUpRunner(false, true); + const tf::TensorShape tensor_shape(std::vector{2, 2, 2}); + auto tensor = absl::make_unique(tf::DT_FLOAT, tensor_shape); + auto slice = tensor->flat(); + for (int i = 0; i < 2 * 2 * 2; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + slice(i) = static_cast(1 << i); + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const std::vector& output_vector = + output_packets[0].Get>(); + EXPECT_EQ(2 * 2 * 2, output_vector.size()); + for (int i = 0; i < 2 * 2 * 2; ++i) { + const float expected = static_cast(1 << i); + EXPECT_EQ(expected, output_vector[i]); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc new file mode 100644 index 000000000..47c4f21b7 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -0,0 +1,540 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/monotonic_clock.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" + +namespace tf = ::tensorflow; + +namespace mediapipe { + +namespace { +// This is a simple implementation of a semaphore using standard C++ libraries. +// It is supposed to be used only by TensorflowInferenceCalculator to throttle +// the concurrent calls of Tensorflow Session::Run. This is useful when multiple +// threads execute the graph (e.g. in a mapreduce type of job) but not to +// overload GPU/TPU/... +class SimpleSemaphore { + public: + explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {} + SimpleSemaphore(const SimpleSemaphore&) = delete; + SimpleSemaphore(SimpleSemaphore&&) = delete; + + // Acquires the semaphore by certain amount. + void Acquire(uint32 amount) { + mutex_.Lock(); + while (count_ < amount) { + cond_.Wait(&mutex_); + } + count_ -= amount; + mutex_.Unlock(); + } + + // Releases the semaphore by certain amount. + void Release(uint32 amount) { + mutex_.Lock(); + count_ += amount; + cond_.SignalAll(); + mutex_.Unlock(); + } + + private: + uint32 count_; + absl::Mutex mutex_; + absl::CondVar cond_; +}; +} // namespace + +// This calculator performs inference on a trained TensorFlow model. +// +// Additional documentation and examples at +// go/mediapipe/tensorflow_in_mediapipe. +// +// TensorFlow Sessions can be created from checkpoint paths, frozen models, or +// the SavedModel system (go/saved-model). See the TensorFlowSessionFrom* +// packet generators for details. Each of these methods defines a mapping +// between MediaPipe streams and TensorFlow tensors. All of this information is +// passed in as an input_side_packet. +// +// The input and output streams are TensorFlow tensors labeled by tags. The tags +// for the streams are matched to feeds and fetchs in a TensorFlow session using +// a named_signature.generic_signature in the ModelManifest. The +// generic_signature is used as key-value pairs between the MediaPipe tag and +// the TensorFlow tensor. The signature_name in the options proto determines +// which named_signature is used. The keys in the generic_signature must be +// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of +// the tensors corresponding to tags in the signature for input_streams are fed +// to the model and for output_streams the tensors are fetched from the model. +// +// Other calculators are used to convert data to and from tensors, this op only +// handles the TensorFlow session and batching. Batching occurs by concatenating +// input tensors along the 0th dimension across timestamps. If the 0th dimension +// is not a batch dimension, this calculator will add a 0th dimension by +// default. Setting add_batch_dim_to_tensors to false disables the dimension +// addition. Once batch_size inputs have been provided, the batch will be run +// and the output tensors sent out on the output streams with timestamps +// corresponding to the input stream packets. Setting the batch_size to 1 +// completely disables batching, but is indepdent of add_batch_dim_to_tensors. +// +// The TensorFlowInferenceCalculator also support feeding states recurrently for +// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the +// recurrent tensors. Initializing the recurrent state can be handled by the +// GraphTensorsPacketGenerator. +// +// The calculator updates two Counters to report timing information: +// ---TotalTimeUsecs = Total time spent running inference (in usecs), +// ---TotalProcessedTimestamps = # of instances processed +// (approximately batches processed * batch_size), +// where is replaced with CalculatorGraphConfig::Node::name() if it +// exists, or with TensorFlowInferenceCalculator if the name is not set. The +// name must be set for timing information to be instance-specific in graphs +// with multiple TensorFlowInferenceCalculators. +// +// Example config: +// packet_generator { +// packet_generator: "TensorFlowSessionFromSavedModelGenerator" +// output_side_packet: "tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: { +// saved_model_path: "/path/to/saved/model" +// signature_name: "mediapipe" +// } +// } +// } +// node { +// calculator: "TensorFlowInferenceCalculator" +// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag" +// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag" +// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag" +// input_side_packet: "SESSION:tensorflow_session" +// } +// +// Where the input and output streams are treated as Packet and +// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and +// "LABELS" and their respective tensors exported to /path/to/bundle. For an +// example of how this model was exported, see +// tensorflow_inference_test_graph_generator.py +// +// It is possible to use a GraphDef proto that was not exported by exporter (i.e +// without MetaGraph with bindings). Such GraphDef could contain all of its +// parameters in-lined (for example, it can be the output of freeze_graph.py). +// To instantiate a TensorFlow model from a GraphDef file, replace the +// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator: +// +// packet_generator { +// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" +// output_side_packet: "SESSION:tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { +// graph_proto_path: "[PATH]" +// tag_to_tensor_names { +// key: "JPG_STRING" +// value: "input:0" +// } +// tag_to_tensor_names { +// key: "SOFTMAX" +// value: "softmax:0" +// } +// } +// } +// } +// +// It is also possible to use a GraphDef proto and checkpoint file that have not +// been frozen. This can be used to load graphs directly as they have been +// written from training. However, it is more brittle and you are encouraged to +// use a one of the more perminent formats described above. To instantiate a +// TensorFlow model from a GraphDef file and checkpoint, replace the +// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator: +// +// packet_generator { +// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator" +// output_side_packet: "SESSION:tensorflow_session" +// options { +// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: { +// graph_proto_path: "[PATH]" +// model_options { +// checkpoint_path: "[PATH2]" +// } +// tag_to_tensor_names { +// key: "JPG_STRING" +// value: "input:0" +// } +// tag_to_tensor_names { +// key: "SOFTMAX" +// value: "softmax:0" +// } +// } +// } +// } +class TensorFlowInferenceCalculator : public CalculatorBase { + public: + // Counters for recording timing information. The actual names have the value + // of CalculatorGraphConfig::Node::name() prepended. + static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs"; + static constexpr char kTotalProcessedTimestampsCounterSuffix[] = + "TotalProcessedTimestamps"; + static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] = + "TotalSessionRunsTimeUsecs"; + static constexpr char kTotalNumSessionRunsCounterSuffix[] = + "TotalNumSessionRuns"; + + TensorFlowInferenceCalculator() : session_(nullptr) { + clock_ = std::unique_ptr( + mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); + } + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + for (const std::string& tag : cc->Inputs().GetTags()) { + // The tensorflow::Tensor with the tag equal to the graph node. May + // have a TimeSeriesHeader if all present TimeSeriesHeaders match. + cc->Inputs().Tag(tag).Set(); + } + RET_CHECK(!cc->Outputs().GetTags().empty()); + for (const std::string& tag : cc->Outputs().GetTags()) { + // The tensorflow::Tensor with tag equal to the graph node to + // output. Any TimeSeriesHeader from the inputs will be forwarded + // with channels set to 0. + cc->Outputs().Tag(tag).Set(); + } + // A mediapipe::TensorFlowSession with a model loaded and ready for use. + // For this calculator it must include a tag_to_tensor_map. + cc->InputSidePackets().Tag("SESSION").Set(); + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets() + .Tag("RECURRENT_INIT_TENSORS") + .Set>>(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + session_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .session.get(); + tag_to_tensor_map_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .tag_to_tensor_map; + + // Validate and store the recurrent tags + RET_CHECK(options_.has_batch_size()); + RET_CHECK(options_.batch_size() == 1 || + options_.recurrent_tag_pair().empty()) + << "To use recurrent_tag_pairs, batch_size must be 1."; + for (const auto& tag_pair : options_.recurrent_tag_pair()) { + const std::vector tags = absl::StrSplit(tag_pair, ':'); + RET_CHECK_EQ(tags.size(), 2) + << "recurrent_tag_pair must be a colon " + "separated std::string with two components: " + << tag_pair; + RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) + << "Can't find tag '" << tags[0] << "' in signature " + << options_.signature_name(); + RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) + << "Can't find tag '" << tags[1] << "' in signature " + << options_.signature_name(); + recurrent_feed_tags_.insert(tags[0]); + recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; + } + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && + !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + std::map* init_tensor_map; + init_tensor_map = GetFromUniquePtr>( + cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + for (const auto& p : *init_tensor_map) { + input_tensor_batches_[p.first].emplace_back(p.second); + } + } + + // Check that all tags are present in this signature bound to tensors. + for (const std::string& tag : cc->Inputs().GetTags()) { + RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + for (const std::string& tag : cc->Outputs().GetTags()) { + RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + + if (options_.batch_size() == 1) { + cc->SetOffset(0); + } + return ::mediapipe::OkStatus(); + } + + // Adds a batch dimension to the input tensor if specified in the calculator + // options. + ::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(input_tensor->shape()); + new_shape.InsertDim(0, 1); + RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) + << "Could not add 0th dimension to tensor without changing its shape." + << " Current shape: " << input_tensor->shape().DebugString(); + } + return ::mediapipe::OkStatus(); + } + + // Removes the batch dimension of the output tensor if specified in the + // calculator options. + ::mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(output_tensor->shape()); + new_shape.RemoveDim(0); + RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) + << "Could not remove 0th dimension from tensor without changing its " + << "shape. Current shape: " << output_tensor->shape().DebugString() + << " (The expected first dimension is 1 for a batch element.)"; + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + std::map input_tensors_by_tag; + for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { + if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { + // Recurrent tensors can be empty. + if (!::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { + if (options_.skip_on_missing_features()) { + return ::mediapipe::OkStatus(); + } else { + return ::mediapipe::InvalidArgumentError(absl::StrCat( + "Tag ", tag_as_node_name, + " not present at timestamp: ", cc->InputTimestamp().Value())); + } + } + } else { + tf::Tensor input_tensor( + cc->Inputs().Tag(tag_as_node_name).Get()); + RET_CHECK_OK(AddBatchDimension(&input_tensor)); + if (::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { + // If we receive an input on a recurrent tag, override the state. + // It's OK to override the global state because there is just one + // input stream allowed for recurrent tensors. + input_tensor_batches_[tag_as_node_name].clear(); + } + input_tensors_by_tag.insert( + std::make_pair(tag_as_node_name, input_tensor)); + } + } + batch_timestamps_.emplace_back(cc->InputTimestamp()); + for (const auto& input_tensor_and_tag : input_tensors_by_tag) { + input_tensor_batches_[input_tensor_and_tag.first].emplace_back( + input_tensor_and_tag.second); + } + + if (batch_timestamps_.size() == options_.batch_size()) { + RETURN_IF_ERROR(OutputBatch(cc)); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + if (!batch_timestamps_.empty()) { + RETURN_IF_ERROR(OutputBatch(cc)); + } + return ::mediapipe::OkStatus(); + } + + // When a batch of input tensors is ready to be run, runs TensorFlow and + // outputs the output tensors. The output tensors have timestamps matching + // the input tensor that formed that batch element. Any requested + // batch_dimension is added and removed. This code takes advantage of the fact + // that copying a tensor shares the same reference-counted, heap allocated + // memory buffer. Therefore, copies are cheap and should not cause the memory + // buffer to fall out of scope. In contrast, concat is only used where + // necessary. + ::mediapipe::Status OutputBatch(CalculatorContext* cc) { + const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + std::vector> input_tensors; + for (auto& keyed_tensors : input_tensor_batches_) { + if (options_.batch_size() == 1) { + // Short circuit to avoid the cost of deep copying tensors in concat. + if (!keyed_tensors.second.empty()) { + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + keyed_tensors.second[0]); + } else { + // The input buffer can be empty for recurrent tensors. + RET_CHECK(::mediapipe::ContainsKey(recurrent_feed_tags_, + keyed_tensors.first)) + << "A non-recurrent tensor does not have an input: " + << keyed_tensors.first; + } + } else { + // Pad by replicating the first tens or, then ignore the values. + keyed_tensors.second.resize(options_.batch_size()); + std::fill(keyed_tensors.second.begin() + batch_timestamps_.size(), + keyed_tensors.second.end(), keyed_tensors.second[0]); + tf::Tensor concated; + const tf::Status concat_status = + tf::tensor::Concat(keyed_tensors.second, &concated); + CHECK(concat_status.ok()) << concat_status.ToString(); + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + concated); + } + } + input_tensor_batches_.clear(); + std::vector output_tensor_names; + std::vector output_name_in_signature; + for (const std::string& tag : cc->Outputs().GetTags()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); + output_name_in_signature.emplace_back(tag); + } + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + // Ensure that we always fetch the recurrent state tensors. + if (std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), + tag_pair.first) == output_name_in_signature.end()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); + output_name_in_signature.emplace_back(tag_pair.first); + } + } + std::vector outputs; + + SimpleSemaphore* session_run_throttle = nullptr; + if (options_.max_concurrent_session_runs() > 0) { + session_run_throttle = + get_session_run_throttle(options_.max_concurrent_session_runs()); + session_run_throttle->Acquire(1); + } + const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + const tf::Status tf_status = + session_->Run(input_tensors, output_tensor_names, + {} /* target_node_names */, &outputs); + if (session_run_throttle != nullptr) { + session_run_throttle->Release(1); + } + + // RET_CHECK on the tf::Status object itself in order to print an + // informative error message. + RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.error_message(); + + const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) + ->IncrementBy(run_end_time - run_start_time); + cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); + + // Feed back the recurrent state. + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + int pos = std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), tag_pair.first) - + output_name_in_signature.begin(); + input_tensor_batches_[tag_pair.second].emplace_back(outputs[pos]); + } + + // Set that we want to split on each index of the 0th dimension. + std::vector split_vector(options_.batch_size(), 1); + for (int i = 0; i < output_tensor_names.size(); ++i) { + if (options_.batch_size() == 1) { + if (cc->Outputs().HasTag(output_name_in_signature[i])) { + tf::Tensor output_tensor(outputs[i]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), batch_timestamps_[0]); + } + } else { + std::vector split_tensors; + const tf::Status split_status = + tf::tensor::Split(outputs[i], split_vector, &split_tensors); + CHECK(split_status.ok()) << split_status.ToString(); + // Loop over timestamps so that we don't copy the padding. + for (int j = 0; j < batch_timestamps_.size(); ++j) { + tf::Tensor output_tensor(split_tensors[j]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), batch_timestamps_[j]); + } + } + } + // Get end time and report. + const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalUsecsCounterSuffix) + ->IncrementBy(end_time - start_time); + cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) + ->IncrementBy(batch_timestamps_.size()); + batch_timestamps_.clear(); + return ::mediapipe::OkStatus(); + } + + private: + // The Session object is provided by a packet factory and is owned by the + // MediaPipe framework. Individual calls are thread-safe, but session state + // may be shared across threads. + tf::Session* session_; + + // A mapping between stream tags and the tensor names they are bound to. + std::map tag_to_tensor_map_; + + // A mapping between stream tags and the tensors we are collecting as a batch. + std::map> input_tensor_batches_; + + // The timestamps that go into a batch. + std::vector batch_timestamps_; + + // The options for the calculator. + TensorFlowInferenceCalculatorOptions options_; + + // Store the feed and fetch tags for feed/fetch recurrent networks. + std::set recurrent_feed_tags_; + std::map recurrent_fetch_tags_to_feed_tags_; + + // Clock used to measure the computation time in OutputBatch(). + std::unique_ptr clock_; + + // The static singleton semaphore to throttle concurrent session runs. + static SimpleSemaphore* get_session_run_throttle( + int32 max_concurrent_session_runs) { + static SimpleSemaphore* session_run_throttle = + new SimpleSemaphore(max_concurrent_session_runs); + return session_run_throttle; + } +}; +REGISTER_CALCULATOR(TensorFlowInferenceCalculator); + +constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[]; +constexpr char + TensorFlowInferenceCalculator::kTotalProcessedTimestampsCounterSuffix[]; +constexpr char + TensorFlowInferenceCalculator::kTotalSessionRunsTimeUsecsCounterSuffix[]; +constexpr char + TensorFlowInferenceCalculator::kTotalNumSessionRunsCounterSuffix[]; +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto new file mode 100644 index 000000000..a353d2f55 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto @@ -0,0 +1,79 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorFlowInferenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorFlowInferenceCalculatorOptions ext = 113766539; + } + + // The signature_name specifies the mapping between stream tags and TensorFlow + // session tensors. The mapping between tags and tensors is encoded as a + // ModelManifest signature. The signatures are keyed by name and the + // named_signature matching signature_name is used by the calculator to + // match stream tags to tensors. The named_signature must be a + // ModelManifest.generic_signature with map keys that are valid tags (i.e. + // [A-Z0-9]*). + optional string signature_name = 1; + + // How many elements to batch together and feed into the graph. + // Setting the batch_size to 1 disables batching entirely. You still may or + // may not need to add the batch dimension via the option below depending on + // the input data shape and the model's expectations. + optional int32 batch_size = 2; + + // Whether to add a 0th dimension to the input tensors for batching. + // If the 0th dimension is the batch dimension, then the tensors are + // concatenated on that dimension. If the 0th is a data dimension, then a 0th + // dimension is added before concatenating. If added, the extra dimension is + // removed before outputing the tensor. Examples of each case: If you want + // to batch spectra of audio over time for an LSTM, a time-frequency + // representation has a 0th dimension as the batch dimension. If you want to + // batch frames of video that are [width, height, channels], the batch + // dimension needs to be added. + optional bool add_batch_dim_to_tensors = 3 [default = true]; + + // These pairs represent feed and fetch tensors for handling recurrent state. + // Each entry is a colon separated pair of strings. The first half of each + // string is the signature tag for the feed tensor for recurrent state. The + // second half of the string is the signature tag for the fetch tensor for the + // recurrent state. More than two colon separated strings is an error. During + // inference, The fetch tensor is fetched at every timestep and will be output + // if there is a corresponding output stream. The behavior of the feed tensor + // is determined by the following conditions in order: If the MediaPipe input + // stream with the matching tag has a packet available, then the input + // packet's tensor is passed in. If no input packet is available and we have + // fetched a tensor from the previous time step, we will feed the tensor from + // the previous timestep back in. If neither tensor is available, no tensor + // will be fed into the model. + // If this flag is set, batch_size must be 1. Do not list recurrent_tag_pair + // tags as initial_state_tags because those are only fed once. + repeated string recurrent_tag_pair = 4; + + // If set to true, skips input for which any of the features are missing. + // If set to false, requires that all input features to be available. If not, + // it will report an error for the calculator. + optional bool skip_on_missing_features = 5 [default = false]; + + // Maximum allowed concurrent Tensorflow session run calls in the calculator + // to avoid overloading local compute hardware such as TPU. Note that this + // only works in the local process, not "globally" across multiple processes + // or replicas (if any). Default to 0, i.e. no limit. + optional int32 max_concurrent_session_runs = 6 [default = 0]; +} diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc new file mode 100644 index 000000000..f0b8ea5e1 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -0,0 +1,512 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" + +#ifdef __APPLE__ +#include +#endif // defined(__APPLE__) + +namespace mediapipe { + +namespace tf = ::tensorflow; + +namespace { +std::string GetGraphDefPath() { +#ifdef __APPLE__ + char path[1024]; + CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle()); + CFURLGetFileSystemRepresentation( + bundle_url, true, reinterpret_cast(path), sizeof(path)); + CFRelease(bundle_url); + return ::mediapipe::file::JoinPath(path, "testdata/frozen_graph_def.pb"); +#elif defined(__ANDROID__) + char path[1024]; + getcwd(path, sizeof(path)); + return ::mediapipe::file::JoinPath(path, + "mediapipe/calculators/tensorflow/" + "testdata/frozen_graph_def.pb"); +#else + return ::mediapipe::file::JoinPath( + "./", + // This should match the path of the output files + // of the genrule() that generates test model files. + "mediapipe/calculators/tensorflow/testdata/", "frozen_graph_def.pb"); +#endif // defined(__APPLE__) +} +} // namespace + +class TensorflowInferenceCalculatorTest : public ::testing::Test { + protected: + // Add the input side packet. + void AddSessionInputSidePacket() { + PacketGeneratorOptions extendable_options; + TensorFlowSessionFromFrozenGraphGeneratorOptions* generator_options; + generator_options = extendable_options.MutableExtension( + TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); + generator_options->set_graph_proto_path(GetGraphDefPath()); + (*generator_options->mutable_tag_to_tensor_names())["MULTIPLIED"] = + "multiplied:0"; + (*generator_options->mutable_tag_to_tensor_names())["A"] = "a:0"; + (*generator_options->mutable_tag_to_tensor_names())["B"] = "b:0"; + (*generator_options->mutable_tag_to_tensor_names())["EXPENSIVE"] = + "expensive:0"; + + PacketSet input_side_packets({}); + PacketSet output_side_packets({"SESSION"}); + MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options, + input_side_packets, &output_side_packets)); + runner_->MutableSidePackets()->Tag("SESSION") = + output_side_packets.Tag("SESSION"); + } + + // Create tensor from Vector and add as a Packet to the provided tag as input. + void AddVectorToInputsAsTensor(const std::vector& input, + const std::string& tag, int64 time) { + tf::TensorShape tensor_shape; + tensor_shape.AddDim(input.size()); + auto tensor = absl::make_unique(tf::DT_INT32, tensor_shape); + for (int i = 0; i < input.size(); ++i) { + tensor->vec()(i) = input[i]; + } + runner_->MutableInputs()->Tag(tag).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + std::unique_ptr runner_; +}; + +TEST_F(TensorflowInferenceCalculatorTest, GetConstants) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_in"); + config.add_output_stream("B:tensor_out"); + config.add_output_stream("MULTIPLIED:tensor_multiplied"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(false); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({0, 0, 0}, "A", 0); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_b = + runner_->Outputs().Tag("B").packets; + ASSERT_EQ(output_packets_b.size(), 1); + const tf::Tensor& tensor_b = output_packets_b[0].Get(); + tf::TensorShape expected_shape({1, 3}); + auto expected_tensor = tf::test::AsTensor({3, 2, 1}, expected_shape); + tf::test::ExpectTensorEqual(expected_tensor, tensor_b); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(1, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + expected_tensor = tf::test::AsTensor({0, 0, 0}, expected_shape); + tf::test::ExpectTensorEqual(expected_tensor, tensor_mult); + + EXPECT_EQ(1, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +TEST_F(TensorflowInferenceCalculatorTest, GetComputed) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(false); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(1, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + tf::TensorShape expected_shape({3}); + auto expected_tensor = tf::test::AsTensor({6, 8, 10}, expected_shape); + tf::test::ExpectTensorEqual(expected_tensor, tensor_mult); + + // Add only one of the two expected tensors at the next timestamp, expect + // useful failure message. + AddVectorToInputsAsTensor({1, 2, 3}, "A", 1); + auto run_status = runner_->Run(); + ASSERT_FALSE(run_status.ok()); + EXPECT_THAT(run_status.ToString(), + testing::HasSubstr("TensorFlowInferenceCalculator")); + EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); +} + +TEST_F(TensorflowInferenceCalculatorTest, BadTag) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("BAD:tensor_in"); // This one is bad. + config.add_output_stream("B:tensor_out"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + ASSERT_FALSE(runner_->Run().ok()); +} + +TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + AddVectorToInputsAsTensor({3, 3, 3}, "A", 1); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(2, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + auto expected_tensor = tf::test::AsTensor({6, 8, 10}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); + auto expected_tensor1 = tf::test::AsTensor({9, 12, 15}); + tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); + + EXPECT_EQ(2, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(2); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + AddVectorToInputsAsTensor({3, 3, 3}, "A", 1); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(2, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + auto expected_tensor = tf::test::AsTensor({6, 8, 10}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); + auto expected_tensor1 = tf::test::AsTensor({9, 12, 15}); + tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); + + EXPECT_EQ(2, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(3); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + AddVectorToInputsAsTensor({3, 3, 3}, "A", 1); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(2, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + auto expected_tensor = tf::test::AsTensor({6, 8, 10}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); + auto expected_tensor1 = tf::test::AsTensor({9, 12, 15}); + tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); + + EXPECT_EQ(2, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->add_recurrent_tag_pair("A:MULTIPLIED"); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(2, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + LOG(INFO) << "timestamp: " << 0; + auto expected_tensor = tf::test::AsTensor({3, 8, 15}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); + auto expected_tensor1 = tf::test::AsTensor({9, 32, 75}); + LOG(INFO) << "timestamp: " << 1; + tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); + + EXPECT_EQ(2, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} +TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->add_recurrent_tag_pair("A:MULTIPLIED"); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({1, 1, 1}, "A", 0); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 0); + AddVectorToInputsAsTensor({1, 1, 1}, "A", 1); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(2, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + LOG(INFO) << "timestamp: " << 0; + auto expected_tensor = tf::test::AsTensor({3, 4, 5}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); + auto expected_tensor1 = tf::test::AsTensor({3, 4, 5}); + LOG(INFO) << "timestamp: " << 1; + tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); + + EXPECT_EQ(2, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +// TODO: Investigate this test failure. +TEST_F(TensorflowInferenceCalculatorTest, DISABLED_CheckTiming) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_in"); + config.add_output_stream("EXPENSIVE:tensor_expensive"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(1); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(false); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({0, 0, 0}, "A", 0); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + EXPECT_EQ(1, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); + // We only test the timing counter here because we are requesting an + // expensive tensor output. Because the precision on android is + // sometimes closer to milliseconds, we need to request a large tensor + // to be sure this will be greater than zero. + EXPECT_GT(runner_->GetCounter("TensorFlowInferenceCalculator-TotalTimeUsecs") + ->Get(), + 0); +} + +TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(2); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_skip_on_missing_features(false); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + ASSERT_FALSE(runner_->Run().ok()); +} + +TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(2); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_skip_on_missing_features(true); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(0, output_packets_mult.size()); +} + +TEST_F(TensorflowInferenceCalculatorTest, + MissingInputFeature_SkipCheckInternalState) { + CalculatorGraphConfig::Node config; + config.set_calculator("TensorFlowInferenceCalculator"); + config.add_input_stream("A:tensor_a"); + config.add_input_stream("B:tensor_b"); + config.add_output_stream("MULTIPLIED:tensor_o1"); + config.add_input_side_packet("SESSION:session"); + CalculatorOptions options; + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_batch_size(2); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_add_batch_dim_to_tensors(true); + options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext) + ->set_skip_on_missing_features(true); + *config.mutable_options() = options; + + runner_ = absl::make_unique(config); + AddSessionInputSidePacket(); + AddVectorToInputsAsTensor({2, 2, 2}, "A", 0); + AddVectorToInputsAsTensor({3, 3, 3}, "A", 1); + AddVectorToInputsAsTensor({3, 4, 5}, "B", 1); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets_mult = + runner_->Outputs().Tag("MULTIPLIED").packets; + ASSERT_EQ(1, output_packets_mult.size()); + const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); + auto expected_tensor = tf::test::AsTensor({9, 12, 15}); + tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); + + EXPECT_EQ(1, runner_ + ->GetCounter( + "TensorFlowInferenceCalculator-TotalProcessedTimestamps") + ->Get()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session.h b/mediapipe/calculators/tensorflow/tensorflow_session.h new file mode 100644 index 000000000..250e2a572 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session.h @@ -0,0 +1,35 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_ +#define MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_ + +#include + +#include "tensorflow/core/public/session.h" + +namespace mediapipe { +struct TensorFlowSession { + // TensorFlow session wrapper to get around the RTTI issue. + std::unique_ptr session; + + // Store an optional mapping to the between MediaPipe tags and TensorFlow + // tensor names. Creating this mapping when the session is loaded allows more + // flexible definition of mapping tags to tensors across platforms. + std::map tag_to_tensor_map; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_ diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc new file mode 100644 index 000000000..d0b1c8ffd --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -0,0 +1,131 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Reads serialized GraphDef proto. There are three ways to load a model: +// 1. Specify the path to a graph.pb in the calculator options. +// 2. Specify the path to the graph.pb through the +// input_side_packet:STRING_MODEL_FILE_PATH +// 3. Provide a serialized GraphDef through input_side_packet:STRING_MODEL, +// typically provided by EmbeddingFilePacketFactory. +// +// See tensorflow_session_bundle_from_graph_generator.proto for options. +// Produces a SessionBundle that TensorFlowInferenceCalculator can use. + +#include + +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/core/public/session_options.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + RET_CHECK(extendable_options.HasExtension( + TensorFlowSessionFromFrozenGraphGeneratorOptions::ext)); + const auto& options = extendable_options.GetExtension( // NOLINT + TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); + bool has_exactly_one_model = + !options.graph_proto_path().empty() + ? !(input_side_packets->HasTag("STRING_MODEL") | + input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) + : (input_side_packets->HasTag("STRING_MODEL") ^ + input_side_packets->HasTag("STRING_MODEL_FILE_PATH")); + RET_CHECK(has_exactly_one_model) + << "Must have exactly one of graph_proto_path in options or " + "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; + if (input_side_packets->HasTag("STRING_MODEL")) { + input_side_packets->Tag("STRING_MODEL") + .Set( + // String model from embedded path + ); + } else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) { + input_side_packets->Tag("STRING_MODEL_FILE_PATH") + .Set( + // Filename of std::string model. + ); + } + output_side_packets->Tag("SESSION").Set( + // A TensorFlow model loaded and ready for use along with + // a map from tags to tensor names. + ); + RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& packet_generator_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + const TensorFlowSessionFromFrozenGraphGeneratorOptions& options = + packet_generator_options.GetExtension( + TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); + // Output bundle packet. + auto session = ::absl::make_unique(); + + tf::SessionOptions session_options; + session_options.config.CopyFrom(options.config()); + std::vector initialization_op_names; + initialization_op_names.reserve(options.initialization_op_names_size()); + for (int i = 0; i < options.initialization_op_names_size(); ++i) { + initialization_op_names.emplace_back(options.initialization_op_names(i)); + } + session->session.reset(tf::NewSession(session_options)); + + std::string graph_def_serialized; + if (input_side_packets.HasTag("STRING_MODEL")) { + graph_def_serialized = + input_side_packets.Tag("STRING_MODEL").Get(); + } else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) { + const std::string& frozen_graph = + input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get(); + RET_CHECK_OK( + mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); + } else { + RET_CHECK_OK(mediapipe::file::GetContents(options.graph_proto_path(), + &graph_def_serialized)); + } + tensorflow::GraphDef graph_def; + + RET_CHECK(graph_def.ParseFromString(graph_def_serialized)); + const tf::Status tf_status = session->session->Create(graph_def); + RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.error_message(); + + for (const auto& key_value : options.tag_to_tensor_names()) { + session->tag_to_tensor_map[key_value.first] = key_value.second; + } + if (!initialization_op_names.empty()) { + const tf::Status tf_status = + session->session->Run({}, {}, initialization_op_names, {}); + // RET_CHECK on the tf::Status object itself in order to print an + // informative error message. + RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.error_message(); + } + + output_side_packets->Tag("SESSION") = Adopt(session.release()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(TensorFlowSessionFromFrozenGraphGenerator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto new file mode 100644 index 000000000..183b5a5a5 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.proto @@ -0,0 +1,72 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; +import "tensorflow/core/protobuf/config.proto"; + +message TensorFlowSessionFromFrozenGraphGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional TensorFlowSessionFromFrozenGraphGeneratorOptions ext = 160666123; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // To run inference with MediaPipe inputs MediaPipe streams need to be mapped + // to TensorFlow tensors. This map defines the which streams are fed into + // which tensors in the model. The MediaPipe tag of the stream is the map key. + // Tags must be capitalized, matching regex [A-Z0-9_]+. Examples: "JPG_STRING" + // and "SOFTMAX". Then, those tags can be used as the MediaPipe tags of + // input_stream or output_stream of the TensorflowInferenceCalculator + // consuming the packet produced by this generator. The tensor names must + // match the tensor names in the graph that you want to feed or fetch into or + // out of. Examples: "DecodeJpeg/contents:0" or "softmax:0". For example, a + // mediapipe graph can include the nodes: + // + // packet_generator { + // packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" + // output_side_packet: "SESSION:session" + // options { + // [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { + // graph_proto_path: "[PATH]" + // tag_to_tensor_names { + // key: "JPG_STRING" + // value: "input:0" + // } + // tag_to_tensor_names { + // key: "SOFTMAX" + // value: "softmax:0" + // } + // } + // } + // } + // node { + // calculator: "TensorflowInferenceCalculator" + // input_side_packet: "SESSION:graph_with_bindings" + // input_stream: "JPG_STRING:jpg_string_tensor" + // output_stream: "SOFTMAX:softmax_tensor" + // } + map tag_to_tensor_names = 2; + + // Tensorflow session config options. + optional tensorflow.ConfigProto config = 3; + + // Graph nodes to run to initialize the model. Any output of these ops is + // ignored. + repeated string initialization_op_names = 4; +} diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc new file mode 100644 index 000000000..ca9cc8141 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -0,0 +1,288 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/tag_map_helper.h" +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +std::string GetGraphDefPath() { + return mediapipe::file::JoinPath("./", + "mediapipe/calculators/tensorflow/" + "testdata/frozen_graph_def.pb"); +} + +// Helper function that creates Tensor INT32 matrix with size 1x3. +tf::Tensor TensorMatrix1x3(const int v1, const int v2, const int v3) { + tf::Tensor tensor(tf::DT_INT32, + tf::TensorShape(std::vector({1, 3}))); + auto matrix = tensor.matrix(); + matrix(0, 0) = v1; + matrix(0, 1) = v2; + matrix(0, 2) = v3; + return tensor; +} + +class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { + protected: + void SetUp() override { + extendable_options_.Clear(); + generator_options_ = extendable_options_.MutableExtension( + TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); + generator_options_->set_graph_proto_path(GetGraphDefPath()); + (*generator_options_->mutable_tag_to_tensor_names())["MULTIPLIED"] = + "multiplied:0"; + (*generator_options_->mutable_tag_to_tensor_names())["A"] = "a:0"; + (*generator_options_->mutable_tag_to_tensor_names())["B"] = "b:0"; + generator_options_->mutable_config()->set_intra_op_parallelism_threads(1); + generator_options_->mutable_config()->set_inter_op_parallelism_threads(2); + } + + void VerifySignatureMap(PacketSet* output_side_packets) { + const TensorFlowSession& session = + output_side_packets->Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); + + // Bindings are inserted. + EXPECT_EQ(session.tag_to_tensor_map.size(), 3); + + // For some reason, EXPECT_EQ and EXPECT_NE are not working with iterators. + EXPECT_FALSE(session.tag_to_tensor_map.find("A") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("B") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("MULTIPLIED") == + session.tag_to_tensor_map.end()); + // Sanity: find() actually returns a reference to end() if element not + // found. + EXPECT_TRUE(session.tag_to_tensor_map.find("Z") == + session.tag_to_tensor_map.end()); + + EXPECT_EQ(session.tag_to_tensor_map.at("A"), "a:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("B"), "b:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("MULTIPLIED"), "multiplied:0"); + } + + PacketGeneratorOptions extendable_options_; + TensorFlowSessionFromFrozenGraphGeneratorOptions* generator_options_; +}; + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CreatesPacketWithGraphAndBindings) { + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + VerifySignatureMap(&output_side_packets); +} + +// Integration test. Verifies that TensorFlowInferenceCalculator correctly +// consumes the Packet emitted by this generator. +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + ProducesPacketUsableByTensorFlowInferenceCalculator) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + node { + calculator: "TensorFlowInferenceCalculator" + input_side_packet: "SESSION:tf_model" + input_stream: "A:a_tensor" + output_stream: "MULTIPLIED:multiplied_tensor" + options { + [mediapipe.TensorFlowInferenceCalculatorOptions.ext] { + batch_size: 5 + add_batch_dim_to_tensors: false + } + } + } + + packet_generator { + packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { + $0 + } + } + } + input_stream: "a_tensor" + )", + generator_options_->DebugString())); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("multiplied_tensor"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "a_tensor", + Adopt(new auto(TensorMatrix1x3(1, -1, 10))).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("a_tensor")); + + Packet packet; + ASSERT_TRUE(poller.Next(&packet)); + // input tensor gets multiplied by [[3, 2, 1]]. Expected output: + tf::Tensor expected_multiplication = TensorMatrix1x3(3, -2, 10); + EXPECT_EQ(expected_multiplication.DebugString(), + packet.Get().DebugString()); + + ASSERT_FALSE(poller.Next(&packet)); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CreatesPacketWithGraphAndBindingsFromInputSidePacket) { + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_MODEL:model"}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + std::string serialized_graph_contents; + MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), + &serialized_graph_contents)); + generator_options_->clear_graph_proto_path(); + input_side_packets.Tag("STRING_MODEL") = + Adopt(new std::string(serialized_graph_contents)); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + VerifySignatureMap(&output_side_packets); +} + +TEST_F( + TensorFlowSessionFromFrozenGraphGeneratorTest, + CreatesPacketWithGraphAndBindingsFromInputSidePacketStringModelFilePath) { + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + generator_options_->clear_graph_proto_path(); + input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + Adopt(new std::string(GetGraphDefPath())); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + VerifySignatureMap(&output_side_packets); +} + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CheckFailureForOptionsAndInputsProvideGraphDefProto) { + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + Adopt(new std::string(GetGraphDefPath())); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_THAT( + run_status.message(), + ::testing::HasSubstr("Must have exactly one of graph_proto_path")); +} + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CheckFailureForAllInputsProvideGraphDefProto) { + PacketSet input_side_packets( + tool::CreateTagMap( + {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) + .ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + std::string serialized_graph_contents; + MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), + &serialized_graph_contents)); + input_side_packets.Tag("STRING_MODEL") = + Adopt(new std::string(serialized_graph_contents)); + input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + Adopt(new std::string(GetGraphDefPath())); + + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_THAT( + run_status.message(), + ::testing::HasSubstr("Must have exactly one of graph_proto_path")); +} + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CheckFailureForOnlyBothInputSidePacketsProvideGraphDefProto) { + PacketSet input_side_packets( + tool::CreateTagMap( + {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) + .ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + std::string serialized_graph_contents; + EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), + &serialized_graph_contents)); + input_side_packets.Tag("STRING_MODEL") = + Adopt(new std::string(serialized_graph_contents)); + input_side_packets.Tag("STRING_MODEL_FILE_PATH") = + Adopt(new std::string(GetGraphDefPath())); + generator_options_->clear_graph_proto_path(); + + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_THAT( + run_status.message(), + ::testing::HasSubstr("Must have exactly one of graph_proto_path")); +} + +TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, + CheckInitializationOpName) { + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + generator_options_->add_initialization_op_names("multiplied:0"); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status); + VerifySignatureMap(&output_side_packets); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc new file mode 100644 index 000000000..525eb4237 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -0,0 +1,176 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#if defined(MEDIAPIPE_TPU_SUPPORT) +#include "learning/brain/google/xla/global_tpu_init.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" +#endif +#if !defined(__ANDROID__) +#include "mediapipe/framework/port/file_helpers.h" +#endif +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/tag_constants.h" + +namespace mediapipe { + +namespace { +static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; + +// Given the path to a directory containing multiple tensorflow saved models +// in subdirectories, replaces path with the alphabetically last subdirectory. +::mediapipe::Status GetLatestDirectory(std::string* path) { +#if defined(__ANDROID__) + return ::mediapipe::UnimplementedError( + "GetLatestDirectory is not implemented on Android"); +#else + std::vector saved_models; + RET_CHECK_OK(file::MatchInTopSubdirectories( + *path, tensorflow::kSavedModelFilenamePb, &saved_models)); + RET_CHECK_GT(saved_models.size(), 0) + << "No exported bundles found in " << path; + ::std::sort(saved_models.begin(), saved_models.end()); + *path = std::string(file::Dirname(saved_models.back())); + return ::mediapipe::OkStatus(); +#endif +} + +// If options.convert_signature_to_tags() will convert letters to uppercase +// and replace /'s with _'s. If set, this enables the standard SavedModel +// classification, regression, and prediction signatures to be used as +// uppercase INPUTS and OUTPUTS tags for streams. +const std::string MaybeConvertSignatureToTag( + const std::string& name, + const TensorFlowSessionFromSavedModelCalculatorOptions& options) { + if (options.convert_signature_to_tags()) { + std::string output; + output.resize(name.length()); + std::transform(name.begin(), name.end(), output.begin(), + [](unsigned char c) { return std::toupper(c); }); + output = absl::Substitute(output, "/", "_"); + return output; + } else { + return name; + } +} + +} // namespace + +// TensorFlowSessionFromSavedModelCalculator is a MediaPipe packet calculator +// that loads a trained TensorFlow model exported via SavedModel's exporter (see +// go/savedmodel) and returns a Packet containing a unique_ptr to a +// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session +// ready for execution and a map between tags and tensor names. +// +// Example usage: +// node { +// calculator: "TensorFlowSessionFromSavedModelCalculator" +// output_side_packet: "SESSION:vod_session" +// options { +// [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { +// signature_name: "serving_default" +// saved_model_path: "path/to/model" +// } +// } +// } +class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const auto& options = + cc->Options(); + const bool has_exactly_one_model = + options.saved_model_path().empty() == + cc->InputSidePackets().HasTag(kStringSavedModelPath); + RET_CHECK(has_exactly_one_model) + << "Must have exactly one of saved model filepath in options or " + "input_side_packets STRING_MODEL_FILE_PATH"; + // Path of savedmodel. + if (cc->InputSidePackets().HasTag(kStringSavedModelPath)) { + cc->InputSidePackets().Tag(kStringSavedModelPath).Set(); + } + // A TensorFlow model loaded and ready for use along with tensor + cc->OutputSidePackets().Tag("SESSION").Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const auto& options = + cc->Options(); + std::string path = cc->InputSidePackets().HasTag(kStringSavedModelPath) + ? cc->InputSidePackets() + .Tag(kStringSavedModelPath) + .Get() + : options.saved_model_path(); + if (options.load_latest_model()) { + RET_CHECK_OK(GetLatestDirectory(&path)); + } + + // Set user specified tags properly. + // If no tags specified will use tensorflow::kSavedModelTagServe by default. + std::unordered_set tags_set; + for (std::string tag : options.saved_model_tag()) { + tags_set.insert(tag); + } + if (tags_set.empty()) { + tags_set.insert(tensorflow::kSavedModelTagServe); + } + + tensorflow::RunOptions run_options; + // In the future, could construct session options from the options proto. + tensorflow::SessionOptions session_options; + auto saved_model = absl::make_unique(); + ::tensorflow::Status status = tensorflow::LoadSavedModel( + session_options, run_options, path, tags_set, saved_model.get()); + if (!status.ok()) { + return ::mediapipe::Status( + static_cast<::mediapipe::StatusCode>(status.code()), + status.error_message()); + } + + auto session = absl::make_unique(); + session->session = std::move(saved_model->session); + + RET_CHECK(!options.signature_name().empty()); + const auto& signature_def_map = saved_model->meta_graph_def.signature_def(); + const auto& signature_def = signature_def_map.at(options.signature_name()); + for (const auto& input_signature : signature_def.inputs()) { + session->tag_to_tensor_map[MaybeConvertSignatureToTag( + input_signature.first, options)] = input_signature.second.name(); + } + for (const auto& output_signature : signature_def.outputs()) { + session->tag_to_tensor_map[MaybeConvertSignatureToTag( + output_signature.first, options)] = output_signature.second.name(); + } + + cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(TensorFlowSessionFromSavedModelCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto new file mode 100644 index 000000000..da5edb91c --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -0,0 +1,58 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorFlowSessionFromSavedModelCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorFlowSessionFromSavedModelCalculatorOptions ext = 244429915; + } + // TODO: SessionBundles provided global step versioning of models + // that let you load the latest model. If there's a similar solution for + // SavedModels, include a flag to load the most recent model. + + // Path to a directory containing a trained TensorFlow model as prepared + // by SavedModel (go/saved-model). + optional string saved_model_path = 1; + // The name of the generic signature to load into the mapping from tags to + // tensor names. + optional string signature_name = 2 [default = "serving_default"]; + // Whether to convert the signature keys to uppercase and switch /'s to + // _'s, which enables standard signatures to be used as Tags. + optional bool convert_signature_to_tags = 3 [default = true]; + // If true, saved_model_path can have multiple exported models in + // subdirectories saved_model_path/%08d and the alphabetically last (i.e., + // latest checkpoint) model is loaded. Note that saved models are not exported + // in numbered directories by default. If you want to use this feature, you + // need to arrange your directories by global_step or some other order when + // you save your models. + optional bool load_latest_model = 4; + // [DEPRECATED] If true, this calculator will try to initialize local Tensor + // Processing Unit (TPU) hardware so that the Tensorflow session loaded from + // this saved model may benefit from TPU speedups. If you want to use this + // feature, you need to make sure that the calculator runs on a machine that + // has TPU hardware installed. The saved model should have correct device + // placements in the graph (have the ops already placed on TPU), typically if + // the saved model was exported through TPUEstimator then device placement is + // automatically taken care of. + optional bool use_tpu = 5 [deprecated = true]; + // User specified tags in a saved model. + // If no tag is specified, then use "serve" as the default. Note that in order + // to use TPU accelerator hardware, the tag "tpu" needs to be specified. + repeated string saved_model_tag = 6; +} diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc new file mode 100644 index 000000000..1a6902dc8 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -0,0 +1,208 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/tag_map_helper.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +std::string GetSavedModelDir() { + std::string out_path = + file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", + "tensorflow_saved_model/00000000"); + return out_path; +} + +// Helper function that creates Tensor INT32 matrix with size 1x3. +tf::Tensor TensorMatrix1x3(const int v1, const int v2, const int v3) { + tf::Tensor tensor(tf::DT_INT32, + tf::TensorShape(std::vector({1, 3}))); + auto matrix = tensor.matrix(); + matrix(0, 0) = v1; + matrix(0, 1) = v2; + matrix(0, 2) = v3; + return tensor; +} + +class TensorFlowSessionFromSavedModelCalculatorTest : public ::testing::Test { + protected: + void SetUp() override { + extendable_options_.Clear(); + options_ = extendable_options_.MutableExtension( + TensorFlowSessionFromSavedModelCalculatorOptions::ext); + options_->set_saved_model_path(GetSavedModelDir()); + } + + CalculatorOptions extendable_options_; + TensorFlowSessionFromSavedModelCalculatorOptions* options_; +}; + +TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, + CreatesPacketWithGraphAndBindings) { + CalculatorRunner runner(absl::Substitute(R"( + calculator: "TensorFlowSessionFromSavedModelCalculator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { + $0 + } + })", + options_->DebugString())); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const TensorFlowSession& session = + runner.OutputSidePackets().Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); + + // Bindings are inserted. + EXPECT_EQ(session.tag_to_tensor_map.size(), 4); + + // For some reason, EXPECT_EQ and EXPECT_NE are not working with iterators. + EXPECT_FALSE(session.tag_to_tensor_map.find("A") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("B") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("MULTIPLIED") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("EXPENSIVE") == + session.tag_to_tensor_map.end()); + // Sanity: find() actually returns a reference to end() if element not + // found. + EXPECT_TRUE(session.tag_to_tensor_map.find("Z") == + session.tag_to_tensor_map.end()); + + EXPECT_EQ(session.tag_to_tensor_map.at("A"), "a:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("B"), "b:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("MULTIPLIED"), "multiplied:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("EXPENSIVE"), "expensive:0"); +} + +TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, + CreateSessionFromSidePacket) { + options_->clear_saved_model_path(); + CalculatorRunner runner(absl::Substitute(R"( + calculator: "TensorFlowSessionFromSavedModelCalculator" + input_side_packet: "STRING_SAVED_MODEL_PATH:saved_model_dir" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { + $0 + } + })", + options_->DebugString())); + runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") = + MakePacket(GetSavedModelDir()); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const TensorFlowSession& session = + runner.OutputSidePackets().Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); +} + +// Integration test. Verifies that TensorFlowInferenceCalculator correctly +// consumes the Packet emitted by this factory. +TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, + ProducesPacketUsableByTensorFlowInferenceCalculator) { + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + node { + calculator: "TensorFlowInferenceCalculator" + input_side_packet: "SESSION:tf_model" + input_stream: "A:a_tensor" + output_stream: "MULTIPLIED:multiplied_tensor" + options { + [mediapipe.TensorFlowInferenceCalculatorOptions.ext] { + batch_size: 5 + add_batch_dim_to_tensors: false + } + } + } + node { + calculator: "TensorFlowSessionFromSavedModelCalculator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { + $0 + } + } + } + input_stream: "a_tensor" + )", + options_->DebugString())); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("multiplied_tensor"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "a_tensor", + Adopt(new auto(TensorMatrix1x3(1, -1, 10))).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("a_tensor")); + + Packet packet; + ASSERT_TRUE(poller.Next(&packet)); + // input tensor gets multiplied by [[3, 2, 1]]. Expected output: + tf::Tensor expected_multiplication = TensorMatrix1x3(3, -2, 10); + EXPECT_EQ(expected_multiplication.DebugString(), + packet.Get().DebugString()); + + ASSERT_FALSE(poller.Next(&packet)); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, + GetsBundleGivenParentDirectory) { + options_->set_saved_model_path( + std::string(file::SplitPath(GetSavedModelDir()).first)); + options_->set_load_latest_model(true); + + CalculatorRunner runner(absl::Substitute(R"( + calculator: "TensorFlowSessionFromSavedModelCalculator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: { + $0 + } + })", + options_->DebugString())); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const TensorFlowSession& session = + runner.OutputSidePackets().Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc new file mode 100644 index 000000000..d9959c5b7 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -0,0 +1,166 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#if defined(MEDIAPIPE_TPU_SUPPORT) +#include "learning/brain/google/xla/global_tpu_init.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" +#endif +#if !defined(__ANDROID__) +#include "mediapipe/framework/port/file_helpers.h" +#endif +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/packet_generator.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/status_util.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/tag_constants.h" + +namespace mediapipe { + +namespace { +static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; + +// Given the path to a directory containing multiple tensorflow saved models +// in subdirectories, replaces path with the alphabetically last subdirectory. +::mediapipe::Status GetLatestDirectory(std::string* path) { +#if defined(__ANDROID__) + return ::mediapipe::UnimplementedError( + "GetLatestDirectory is not implemented on Android"); +#else + std::vector saved_models; + RET_CHECK_OK(file::MatchInTopSubdirectories( + *path, tensorflow::kSavedModelFilenamePb, &saved_models)); + RET_CHECK_GT(saved_models.size(), 0) + << "No exported bundles found in " << path; + ::std::sort(saved_models.begin(), saved_models.end()); + *path = std::string(file::Dirname(saved_models.back())); + return ::mediapipe::OkStatus(); +#endif +} + +// If options.convert_signature_to_tags() will convert letters to uppercase +// and replace /'s with _'s. If set, this enables the standard SavedModel +// classification, regression, and prediction signatures to be used as +// uppercase INPUTS and OUTPUTS tags for streams. +const std::string MaybeConvertSignatureToTag( + const std::string& name, + const TensorFlowSessionFromSavedModelGeneratorOptions& options) { + if (options.convert_signature_to_tags()) { + std::string output; + output.resize(name.length()); + std::transform(name.begin(), name.end(), output.begin(), + [](unsigned char c) { return std::toupper(c); }); + output = absl::Substitute(output, "/", "_"); + return output; + } else { + return name; + } +} + +} // namespace + +// TensorFlowSessionFromSavedModelGenerator is a MediaPipe packet generator +// that loads a trained TensorFlow model exported via SavedModel's exporter (see +// go/savedmodel) and returns a Packet containing a unique_ptr to a +// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session +// ready for execution and a map between tags and tensor names. +class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + const TensorFlowSessionFromSavedModelGeneratorOptions& options = + extendable_options.GetExtension( + TensorFlowSessionFromSavedModelGeneratorOptions::ext); + const bool has_exactly_one_model = + options.saved_model_path().empty() == + input_side_packets->HasTag(kStringSavedModelPath); + RET_CHECK(has_exactly_one_model) + << "Must have exactly one of saved model filepath in options or " + "input_side_packets STRING_MODEL_FILE_PATH"; + // Path of savedmodel. + if (input_side_packets->HasTag(kStringSavedModelPath)) { + input_side_packets->Tag(kStringSavedModelPath).Set(); + } + // A TensorFlow model loaded and ready for use along with tensor + output_side_packets->Tag("SESSION").Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + const TensorFlowSessionFromSavedModelGeneratorOptions& options = + extendable_options.GetExtension( + TensorFlowSessionFromSavedModelGeneratorOptions::ext); + std::string path = + input_side_packets.HasTag(kStringSavedModelPath) + ? input_side_packets.Tag(kStringSavedModelPath).Get() + : options.saved_model_path(); + if (options.load_latest_model()) { + RET_CHECK_OK(GetLatestDirectory(&path)); + } + + // Set user specified tags properly. + // If no tags specified will use tensorflow::kSavedModelTagServe by default. + std::unordered_set tags_set; + for (std::string tag : options.saved_model_tag()) { + tags_set.insert(tag); + } + if (tags_set.empty()) { + tags_set.insert(tensorflow::kSavedModelTagServe); + } + + tensorflow::RunOptions run_options; + // In the future, could construct session options from the options proto. + tensorflow::SessionOptions session_options; + auto saved_model = absl::make_unique(); + ::tensorflow::Status status = tensorflow::LoadSavedModel( + session_options, run_options, path, tags_set, saved_model.get()); + if (!status.ok()) { + return ::mediapipe::Status( + static_cast<::mediapipe::StatusCode>(status.code()), + status.error_message()); + } + + auto session = absl::make_unique(); + session->session = std::move(saved_model->session); + + RET_CHECK(!options.signature_name().empty()); + const auto& signature_def_map = saved_model->meta_graph_def.signature_def(); + const auto& signature_def = signature_def_map.at(options.signature_name()); + for (const auto& input_signature : signature_def.inputs()) { + session->tag_to_tensor_map[MaybeConvertSignatureToTag( + input_signature.first, options)] = input_signature.second.name(); + } + for (const auto& output_signature : signature_def.outputs()) { + session->tag_to_tensor_map[MaybeConvertSignatureToTag( + output_signature.first, options)] = output_signature.second.name(); + } + + output_side_packets->Tag("SESSION") = Adopt(session.release()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(TensorFlowSessionFromSavedModelGenerator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto new file mode 100644 index 000000000..7e33f7518 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -0,0 +1,58 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message TensorFlowSessionFromSavedModelGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional TensorFlowSessionFromSavedModelGeneratorOptions ext = 151486368; + } + // TODO: SessionBundles provided global step versioning of models + // that let you load the latest model. If there's a similar solution for + // SavedModels, include a flag to load the most recent model. + + // Path to a directory containing a trained TensorFlow model as prepared + // by SavedModel (go/saved-model). + optional string saved_model_path = 1; + // The name of the generic signature to load into the mapping from tags to + // tensor names. + optional string signature_name = 2 [default = "serving_default"]; + // Whether to convert the signature keys to uppercase and switch /'s to + // _'s, which enables standard signatures to be used as Tags. + optional bool convert_signature_to_tags = 3 [default = true]; + // If true, saved_model_path can have multiple exported models in + // subdirectories saved_model_path/%08d and the alphabetically last (i.e., + // latest checkpoint) model is loaded. Note that saved models are not exported + // in numbered directories by default. If you want to use this feature, you + // need to arrange your directories by global_step or some other order when + // you save your models. + optional bool load_latest_model = 4; + // [DEPRECATED] If true, this calculator will try to initialize local Tensor + // Processing Unit (TPU) hardware so that the Tensorflow session loaded from + // this saved model may benefit from TPU speedups. If you want to use this + // feature, you need to make sure that the calculator runs on a machine that + // has TPU hardware installed. The saved model should have correct device + // placements in the graph (have the ops already placed on TPU), typically if + // the saved model was exported through TPUEstimator then device placement is + // automatically taken care of. + optional bool use_tpu = 5 [deprecated = true]; + // User specified tags in a saved model. + // If no tag is specified, then use "serve" as the default. Note that in order + // to use TPU accelerator hardware, the tag "tpu" needs to be specified. + repeated string saved_model_tag = 6; +} diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc new file mode 100644 index 000000000..268158f86 --- /dev/null +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -0,0 +1,200 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session.h" +#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/tag_map_helper.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +std::string GetSavedModelDir() { + std::string out_path = + file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", + "tensorflow_saved_model/00000000"); + return out_path; +} + +// Helper function that creates Tensor INT32 matrix with size 1x3. +tf::Tensor TensorMatrix1x3(const int v1, const int v2, const int v3) { + tf::Tensor tensor(tf::DT_INT32, + tf::TensorShape(std::vector({1, 3}))); + auto matrix = tensor.matrix(); + matrix(0, 0) = v1; + matrix(0, 1) = v2; + matrix(0, 2) = v3; + return tensor; +} + +class TensorFlowSessionFromSavedModelGeneratorTest : public ::testing::Test { + protected: + void SetUp() override { + extendable_options_.Clear(); + generator_options_ = extendable_options_.MutableExtension( + TensorFlowSessionFromSavedModelGeneratorOptions::ext); + generator_options_->set_saved_model_path(GetSavedModelDir()); + } + + PacketGeneratorOptions extendable_options_; + TensorFlowSessionFromSavedModelGeneratorOptions* generator_options_; +}; + +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + CreatesPacketWithGraphAndBindings) { + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + const TensorFlowSession& session = + output_side_packets.Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); + + // Bindings are inserted. + EXPECT_EQ(session.tag_to_tensor_map.size(), 4); + + // For some reason, EXPECT_EQ and EXPECT_NE are not working with iterators. + EXPECT_FALSE(session.tag_to_tensor_map.find("A") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("B") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("MULTIPLIED") == + session.tag_to_tensor_map.end()); + EXPECT_FALSE(session.tag_to_tensor_map.find("EXPENSIVE") == + session.tag_to_tensor_map.end()); + // Sanity: find() actually returns a reference to end() if element not + // found. + EXPECT_TRUE(session.tag_to_tensor_map.find("Z") == + session.tag_to_tensor_map.end()); + + EXPECT_EQ(session.tag_to_tensor_map.at("A"), "a:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("B"), "b:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("MULTIPLIED"), "multiplied:0"); + EXPECT_EQ(session.tag_to_tensor_map.at("EXPENSIVE"), "expensive:0"); +} + +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + CreateSessionFromSidePacket) { + generator_options_->clear_saved_model_path(); + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}) + .ValueOrDie()); + input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = + Adopt(new std::string(GetSavedModelDir())); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + const TensorFlowSession& session = + output_side_packets.Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); +} + +// Integration test. Verifies that TensorFlowInferenceCalculator correctly +// consumes the Packet emitted by this factory. +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + ProducesPacketUsableByTensorFlowInferenceCalculator) { + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + absl::Substitute(R"( + node { + calculator: "TensorFlowInferenceCalculator" + input_side_packet: "SESSION:tf_model" + input_stream: "A:a_tensor" + output_stream: "MULTIPLIED:multiplied_tensor" + options { + [mediapipe.TensorFlowInferenceCalculatorOptions.ext] { + batch_size: 5 + add_batch_dim_to_tensors: false + } + } + } + + packet_generator { + packet_generator: "TensorFlowSessionFromSavedModelGenerator" + output_side_packet: "SESSION:tf_model" + options { + [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: { + $0 + } + } + } + input_stream: "a_tensor" + )", + generator_options_->DebugString())); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("multiplied_tensor"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "a_tensor", + Adopt(new auto(TensorMatrix1x3(1, -1, 10))).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("a_tensor")); + + Packet packet; + ASSERT_TRUE(poller.Next(&packet)); + // input tensor gets multiplied by [[3, 2, 1]]. Expected output: + tf::Tensor expected_multiplication = TensorMatrix1x3(3, -2, 10); + EXPECT_EQ(expected_multiplication.DebugString(), + packet.Get().DebugString()); + + ASSERT_FALSE(poller.Next(&packet)); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + GetsBundleGivenParentDirectory) { + generator_options_->set_saved_model_path( + std::string(file::SplitPath(GetSavedModelDir()).first)); + generator_options_->set_load_latest_model(true); + + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + ::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MEDIAPIPE_EXPECT_OK(run_status) << run_status.message(); + const TensorFlowSession& session = + output_side_packets.Tag("SESSION").Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/testdata/bundle/00000000/checkpoint b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/checkpoint new file mode 100755 index 000000000..055afb948 --- /dev/null +++ b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "export-?????-of-00001" +all_model_checkpoint_paths: "export-?????-of-00001" diff --git a/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export-00000-of-00001 b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export-00000-of-00001 new file mode 100755 index 000000000..82e808c48 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export-00000-of-00001 differ diff --git a/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export.meta b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export.meta new file mode 100755 index 000000000..1a5b318d4 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/bundle/00000000/export.meta differ diff --git a/mediapipe/calculators/tensorflow/testdata/frozen_graph_def.pb b/mediapipe/calculators/tensorflow/testdata/frozen_graph_def.pb new file mode 100755 index 000000000..966790d52 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/frozen_graph_def.pb differ diff --git a/mediapipe/calculators/tensorflow/testdata/model.chkpt-0 b/mediapipe/calculators/tensorflow/testdata/model.chkpt-0 new file mode 100755 index 000000000..82e808c48 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/model.chkpt-0 differ diff --git a/mediapipe/calculators/tensorflow/testdata/model.chkpt-0.meta b/mediapipe/calculators/tensorflow/testdata/model.chkpt-0.meta new file mode 100755 index 000000000..c21a47f8e Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/model.chkpt-0.meta differ diff --git a/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/saved_model.pb b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/saved_model.pb new file mode 100755 index 000000000..b66128177 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/saved_model.pb differ diff --git a/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.data-00000-of-00001 b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.data-00000-of-00001 new file mode 100755 index 000000000..b3be9e288 Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.data-00000-of-00001 differ diff --git a/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.index b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.index new file mode 100755 index 000000000..a2e8d586f Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/tensorflow_saved_model/00000000/variables/variables.index differ diff --git a/mediapipe/calculators/tensorflow/testdata/tf_graph_def.pb b/mediapipe/calculators/tensorflow/testdata/tf_graph_def.pb new file mode 100755 index 000000000..bb046494a Binary files /dev/null and b/mediapipe/calculators/tensorflow/testdata/tf_graph_def.pb differ diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc new file mode 100644 index 000000000..196f91e2e --- /dev/null +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -0,0 +1,431 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/match.h" +#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/sequence/media_sequence.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" + +namespace mediapipe { + +// Streams: +const char kBBoxTag[] = "BBOX"; +const char kImageTag[] = "IMAGE"; +const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; +const char kForwardFlowImageTag[] = "FORWARD_FLOW_ENCODED"; + +// Side Packets: +const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; +const char kDatasetRootDirTag[] = "DATASET_ROOT"; +const char kDataPath[] = "DATA_PATH"; +const char kPacketResamplerOptions[] = "RESAMPLER_OPTIONS"; +const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE"; + +namespace tf = ::tensorflow; +namespace mpms = ::mediapipe::mediasequence; + +// Source calculator to unpack side_packets and streams from tf.SequenceExamples +// +// Often, only side_packets or streams need to be output, but both can be output +// if needed. A tf.SequenceExample always needs to be supplied as an +// input_side_packet. The SequenceExample must be in the format described in +// media_sequence.h. This documentation will first describe the side_packets +// the calculator can output, and then describe the streams. +// +// Side_packets are commonly used to specify which clip to extract data from. +// Seeking into a video does not necessarily provide consistent timestamps when +// resampling to a known rate. To enable consistent timestamps, we unpack the +// metadata into options for the MediaDecoderCalculator and the +// PacketResamplerCalculator. To ensure consistent timestamps, the MediaDecoder +// needs to seek to slightly before the clip starts, so it sees at least one +// packet before the first packet we want to keep. The PacketResamplerCalculator +// then trims down the timestamps. Furthermore, we should always specify that we +// want timestamps from a base timestamp of 0, so we have the same resampled +// frames after a seek that we would have from the start of a video. In summary, +// when decoding image frames, output both the DECODER_OPTIONS and +// RESAMPLER_OPTIONS. In the base_media_decoder_options, specify which streams +// you want. In the base_packet_resampler_options, specify the frame_rate you +// want and base_timestamp = 0. In the options for this calculator, specify +// padding extra_padding_from_media_decoder such that at least one frame arrives +// before the first frame the PacketResamplerCalculator should output. +// +// Optional output_side_packets include (referenced by tag): +// DATA_PATH: The data_path context feature joined onto the +// options.dataset_root_directory or input_side_packet of DATASET_ROOT. +// RESAMPLER_OPTIONS: CalculatorOptions to pass to the +// PacketResamplerCalculator. The most accurate procedure for sampling a +// range of frames is to request a padded time range from the +// MediaDecoderCalculator and then trim it down to the proper time range with +// the PacketResamplerCalculator. +// IMAGES_FRAME_RATE: The frame rate of the images in the original video as a +// double. +// +// Example config: +// node { +// calculator: "UnpackMediaSequenceCalculator" +// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet" +// input_side_packet: "ROOT_DIRECTORY:path_to_dataset_root_directory" +// output_side_packet: "DATA_PATH:full_path_to_data_element" +// output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" +// options { +// [mediapipe.UnpackMediaSequenceCalculatorOptions.ext]: { +// base_packet_resampler_options { +// frame_rate: 1.0 # PARAM_FRAME_RATE +// base_timestamp: 0 +// } +// } +// } +// } +// +// The calculator also takes a tf.SequenceExample as a side input and outputs +// the data in streams from the SequenceExample at the proper timestamps. The +// SequenceExample must conform to the description in media_sequence.h. +// Timestamps in the SequenceExample must be in sequential order. +// +// The following output stream tags are supported: +// IMAGE: encoded images as strings. (IMAGE_${NAME} is supported.) +// FORWARD_FLOW_ENCODED: encoded FORWARD_FLOW prefix images as strings. +// FLOAT_FEATURE_${NAME}: the feature named ${NAME} as vector. +// BBOX: bounding boxes as vectors. (BBOX_${NAME} is supported.) +// +// Example config: +// node { +// calculator: "UnpackMediaSequenceCalculator" +// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet" +// output_stream: "IMAGE:frames" +// output_stream: "FLOAT_FEATURE_FDENSE:fdense_vf" +// output_stream: "BBOX:faces" +// } +class UnpackMediaSequenceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); + cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); + // Optional side inputs. + if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) { + cc->InputSidePackets().Tag(kDatasetRootDirTag).Set(); + } + if (cc->OutputSidePackets().HasTag(kDataPath)) { + cc->OutputSidePackets().Tag(kDataPath).Set(); + } + if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) { + cc->OutputSidePackets().Tag(kImagesFrameRateTag).Set(); + } + if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { + cc->OutputSidePackets() + .Tag(kPacketResamplerOptions) + .Set(); + } + if ((options.has_padding_before_label() || + options.has_padding_after_label()) && + !(cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "If specifying padding, must output " + << kPacketResamplerOptions; + } + + // Optional streams. + if (cc->Outputs().HasTag(kForwardFlowImageTag)) { + cc->Outputs().Tag(kForwardFlowImageTag).Set(); + } + for (const auto& tag : cc->Outputs().GetTags()) { + if (absl::StartsWith(tag, kImageTag)) { + std::string key = ""; + if (tag != kImageTag) { + int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kImageTag)_?" + } + } + cc->Outputs().Tag(tag).Set(); + } + if (absl::StartsWith(tag, kBBoxTag)) { + std::string key = ""; + if (tag != kBBoxTag) { + int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kBBoxTag)_?" + } + } + cc->Outputs().Tag(tag).Set>(); + } + if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { + cc->Outputs().Tag(tag).Set>(); + } + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + // Copy the packet to copy the otherwise inaccessible shared ptr. + example_packet_holder_ = cc->InputSidePackets().Tag(kSequenceExampleTag); + sequence_ = &example_packet_holder_.Get(); + + // Collect the timestamps for all streams keyed by the timestamp feature's + // key. While creating this data structure we also identify the last + // timestamp and the associated feature. This information is used in process + // to output batches of packets in order. + timestamps_.clear(); + int64 last_timestamp_seen = Timestamp::PreStream().Value(); + first_timestamp_seen_ = Timestamp::OneOverPostStream().Value(); + for (const auto& map_kv : sequence_->feature_lists().feature_list()) { + if (absl::StrContains(map_kv.first, "/timestamp")) { + LOG(INFO) << "Found feature timestamps: " << map_kv.first + << " with size: " << map_kv.second.feature_size(); + int64 recent_timestamp = Timestamp::PreStream().Value(); + for (int i = 0; i < map_kv.second.feature_size(); ++i) { + int64 next_timestamp = + mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0); + RET_CHECK_GT(next_timestamp, recent_timestamp) + << "Timestamps must be sequential. If you're seeing this message " + << "you may have added images to the same SequenceExample twice. " + << "Key: " << map_kv.first; + timestamps_[map_kv.first].push_back(next_timestamp); + recent_timestamp = next_timestamp; + if (recent_timestamp < first_timestamp_seen_) { + first_timestamp_seen_ = recent_timestamp; + } + } + if (recent_timestamp > last_timestamp_seen) { + last_timestamp_key_ = map_kv.first; + last_timestamp_seen = recent_timestamp; + } + } + } + if (!timestamps_.empty()) { + RET_CHECK(!last_timestamp_key_.empty()) + << "Something went wrong because the timestamp key is unset. " + "Example: " + << sequence_->DebugString(); + RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value()) + << "Something went wrong because the last timestamp is unset. " + "Example: " + << sequence_->DebugString(); + RET_CHECK_LT(first_timestamp_seen_, + Timestamp::OneOverPostStream().Value()) + << "Something went wrong because the first timestamp is unset. " + "Example: " + << sequence_->DebugString(); + } + current_timestamp_index_ = 0; + + // Determine the data path and output it. + const auto& options = + cc->Options().GetExtension(UnpackMediaSequenceCalculatorOptions::ext); + const auto& sequence = cc->InputSidePackets() + .Tag(kSequenceExampleTag) + .Get(); + if (cc->OutputSidePackets().HasTag(kDataPath)) { + std::string root_directory = ""; + if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) { + root_directory = + cc->InputSidePackets().Tag(kDatasetRootDirTag).Get(); + } else if (options.has_dataset_root_directory()) { + root_directory = options.dataset_root_directory(); + } + + std::string data_path = mpms::GetClipDataPath(sequence); + if (!root_directory.empty()) { + if (root_directory[root_directory.size() - 1] == '/') { + data_path = root_directory + data_path; + } else { + data_path = root_directory + "/" + data_path; + } + } + cc->OutputSidePackets().Tag(kDataPath).Set( + MakePacket(data_path)); + } + + // Set the start and end of the clip in the appropriate options protos. + double start_time = 0; + double end_time = 0; + if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { + if (mpms::HasClipStartTimestamp(sequence)) { + start_time = + Timestamp(mpms::GetClipStartTimestamp(sequence)).Seconds() - + options.padding_before_label(); + } + if (mpms::HasClipEndTimestamp(sequence)) { + end_time = Timestamp(mpms::GetClipEndTimestamp(sequence)).Seconds() + + options.padding_after_label(); + } + } + if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { + auto resampler_options = absl::make_unique(); + *(resampler_options->MutableExtension( + PacketResamplerCalculatorOptions::ext)) = + options.base_packet_resampler_options(); + if (mpms::HasClipStartTimestamp(sequence)) { + resampler_options + ->MutableExtension(PacketResamplerCalculatorOptions::ext) + ->set_start_time(Timestamp::FromSeconds(start_time).Value()); + } + if (mpms::HasClipEndTimestamp(sequence)) { + resampler_options + ->MutableExtension(PacketResamplerCalculatorOptions::ext) + ->set_end_time(Timestamp::FromSeconds(end_time).Value()); + } + + LOG(INFO) << "Created PacketResamplerOptions:\n" + << resampler_options->DebugString(); + cc->OutputSidePackets() + .Tag(kPacketResamplerOptions) + .Set(Adopt(resampler_options.release())); + } + + // Output the remaining side outputs. + if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) { + cc->OutputSidePackets() + .Tag(kImagesFrameRateTag) + .Set(MakePacket(mpms::GetImageFrameRate(sequence))); + } + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (timestamps_.empty()) { + // This occurs when we only have metadata to unpack. + LOG(INFO) << "only unpacking metadata because there are no timestamps."; + return tool::StatusStop(); + } + // In Process(), we loop through timestamps on a reference stream and emit + // all packets on all streams that have a timestamp between the current + // reference timestep and the previous reference timestep. This ensures that + // we emit all timestamps in order, but also only emit a limited number in + // any particular call to Process(). + int64 start_timestamp = + timestamps_[last_timestamp_key_][current_timestamp_index_]; + if (current_timestamp_index_ == 0) { + start_timestamp = first_timestamp_seen_; + } + + int64 end_timestamp = start_timestamp + 1; // Base case at end of sequence. + if (current_timestamp_index_ < + timestamps_[last_timestamp_key_].size() - 1) { + end_timestamp = + timestamps_[last_timestamp_key_][current_timestamp_index_ + 1]; + } + + for (const auto& map_kv : timestamps_) { + for (int i = 0; i < map_kv.second.size(); ++i) { + if (map_kv.second[i] >= start_timestamp && + map_kv.second[i] < end_timestamp) { + const Timestamp current_timestamp = + map_kv.second[i] == Timestamp::PostStream().Value() + ? Timestamp::PostStream() + : Timestamp(map_kv.second[i]); + + LOG(INFO) << "key: " << map_kv.first; + if (absl::StrContains(map_kv.first, mpms::GetImageTimestampKey())) { + std::vector pieces = absl::StrSplit(map_kv.first, '/'); + std::string feature_key = ""; + std::string possible_tag = kImageTag; + if (pieces[0] != "image") { + feature_key = pieces[0]; + possible_tag = absl::StrCat(kImageTag, "_", feature_key); + } + if (cc->Outputs().HasTag(possible_tag)) { + cc->Outputs() + .Tag(possible_tag) + .Add(new std::string( + mpms::GetImageEncodedAt(feature_key, *sequence_, i)), + current_timestamp); + } + } + + if (cc->Outputs().HasTag(kForwardFlowImageTag) && + map_kv.first == mpms::GetForwardFlowTimestampKey()) { + cc->Outputs() + .Tag(kForwardFlowImageTag) + .Add(new std::string( + mpms::GetForwardFlowEncodedAt(*sequence_, i)), + current_timestamp); + } + if (absl::StrContains(map_kv.first, mpms::GetBBoxTimestampKey())) { + std::vector pieces = absl::StrSplit(map_kv.first, '/'); + std::string feature_key = ""; + std::string possible_tag = kBBoxTag; + if (pieces[0] != "region") { + feature_key = pieces[0]; + possible_tag = absl::StrCat(kBBoxTag, "_", feature_key); + } + if (cc->Outputs().HasTag(possible_tag)) { + const auto& bboxes = mpms::GetBBoxAt(feature_key, *sequence_, i); + cc->Outputs() + .Tag(possible_tag) + .Add(new std::vector(bboxes.begin(), bboxes.end()), + current_timestamp); + } + } + + if (absl::StrContains(map_kv.first, "feature")) { + std::vector pieces = absl::StrSplit(map_kv.first, '/'); + RET_CHECK_GT(pieces.size(), 1) + << "Failed to parse the feature substring before / from key " + << map_kv.first; + std::string feature_key = pieces[0]; + std::string possible_tag = kFloatFeaturePrefixTag + feature_key; + if (cc->Outputs().HasTag(possible_tag)) { + const auto& float_list = + mpms::GetFeatureFloatsAt(feature_key, *sequence_, i); + cc->Outputs() + .Tag(possible_tag) + .Add(new std::vector(float_list.begin(), + float_list.end()), + current_timestamp); + } + } + } + } + } + + ++current_timestamp_index_; + if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) { + return ::mediapipe::OkStatus(); + } else { + return tool::StatusStop(); + } + } + + // Hold a copy of the packet to prevent the shared_ptr from dying and then + // access the SequenceExample with a handy pointer. + const tf::SequenceExample* sequence_; + Packet example_packet_holder_; + + // Store a map from the keys for each stream to the timestamps for each + // key. This allows us to identify which packets to output for each stream + // for timestamps within a given time window. + std::map> timestamps_; + // Store the stream with the latest timestamp in the SequenceExample. + std::string last_timestamp_key_; + // Store the index of the current timestamp. Will be less than + // timestamps_[last_timestamp_key_].size(). + int current_timestamp_index_; + // Store the very first timestamp, so we output everything on the first frame. + int64 first_timestamp_seen_; +}; +REGISTER_CALCULATOR(UnpackMediaSequenceCalculator); +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto new file mode 100644 index 000000000..7088ff076 --- /dev/null +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto @@ -0,0 +1,52 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/calculators/core/packet_resampler_calculator.proto"; +import "mediapipe/framework/calculator.proto"; + +message UnpackMediaSequenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional UnpackMediaSequenceCalculatorOptions ext = 244411537; + } + + // Path to the root directory of the data set that SequenceExample directory + // paths are relative from. If present, the input_side_packet overrides this + // value. + optional string dataset_root_directory = 1; + + // Time in seconds to pad before (or after) timestamps in context's + // clip/timestamp/start and clip/timestamp/end. These settings modify the + // clip's time range in the base_media_decoder_options. + optional float padding_before_label = 3; + optional float padding_after_label = 4; + + // Time in seconds to apply as additional padding to the media decoder, but + // not to the packet resampler. + optional float extra_padding_from_media_decoder = 5 [default = 0.0]; + + // Stores the packet resampler settings for the graph. The most accurate + // proceedure for sampling a range of frames is to request a padded time range + // from the MediaDecoderCalculator and then trim it down to the proper time + // range with the PacketResamplerCalculator. + optional PacketResamplerCalculatorOptions base_packet_resampler_options = 6; + + // Decode media from time zero. This setting overrides other padding + // parameters for the MediaDecoderCalculator. End time parameters are still + // respected. + optional bool force_decoding_from_start_of_media = 7; +} diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc new file mode 100644 index 000000000..0d582fb2e --- /dev/null +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -0,0 +1,512 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" +#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/rectangle.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/sequence/media_sequence.h" +#include "tensorflow/core/example/example.pb.h" + +namespace mediapipe { +namespace { + +namespace tf = ::tensorflow; +namespace mpms = ::mediapipe::mediasequence; + +class UnpackMediaSequenceCalculatorTest : public ::testing::Test { + protected: + void SetUpCalculator(const std::vector& output_streams, + const std::vector& output_side_packets, + const std::vector& input_side_packets = {}, + const CalculatorOptions* options = nullptr) { + CalculatorGraphConfig::Node config; + config.set_calculator("UnpackMediaSequenceCalculator"); + config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); + for (const std::string& stream : output_streams) { + config.add_output_stream(stream); + } + for (const std::string& side_packet : output_side_packets) { + config.add_output_side_packet(side_packet); + } + for (const std::string& side_packet : input_side_packets) { + config.add_input_side_packet(side_packet); + } + if (options != nullptr) { + *config.mutable_options() = *options; + } + LOG(INFO) << config.DebugString(); + runner_ = absl::make_unique(config); + } + + void SetUp() override { + sequence_ = absl::make_unique(); + mpms::SetClipMediaId(video_id_, sequence_.get()); + mpms::SetClipDataPath(data_path_, sequence_.get()); + mpms::SetClipStartTimestamp(start_time_, sequence_.get()); + mpms::SetClipEndTimestamp(end_time_, sequence_.get()); + mpms::SetClipEncodedMediaBytes(encoded_video_data_, sequence_.get()); + mpms::SetClipEncodedMediaStartTimestamp(encoded_video_start_timestamp_, + sequence_.get()); + mpms::SetImageFrameRate(image_frame_rate_, sequence_.get()); + } + + std::unique_ptr sequence_; + std::unique_ptr runner_; + const std::string video_id_ = "test_video_id"; + const std::string data_path_ = "test_directory"; + const int64 start_time_ = 3000000; + const int64 end_time_ = 5000000; + const std::string encoded_video_data_ = "encoded_video_data"; + const int64 encoded_video_start_timestamp_ = 1000000; + const double image_frame_rate_ = 1.0; +}; + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) { + SetUpCalculator({"IMAGE:images"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + + int num_images = 1; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(i, input_sequence.get()); + mpms::AddImageEncoded(test_image_string, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("IMAGE").packets; + ASSERT_EQ(num_images, output_packets.size()); + + for (int i = 0; i < num_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) { + SetUpCalculator({"IMAGE:images"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(i, input_sequence.get()); + mpms::AddImageEncoded(test_image_string, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("IMAGE").packets; + ASSERT_EQ(num_images, output_packets.size()); + + for (int i = 0; i < num_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) { + std::string prefix = "PREFIX"; + SetUpCalculator({"IMAGE_PREFIX:images"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(prefix, i, input_sequence.get()); + mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("IMAGE_PREFIX").packets; + ASSERT_EQ(num_images, output_packets.size()); + + for (int i = 0; i < num_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) { + SetUpCalculator({"FORWARD_FLOW_ENCODED:flow_images"}, {}); + auto input_sequence = absl::make_unique(); + const std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + const std::string test_image_string = "test_image_string"; + const int num_forward_flow_images = 1; + for (int i = 0; i < num_forward_flow_images; ++i) { + mpms::AddForwardFlowTimestamp(i, input_sequence.get()); + mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + ASSERT_EQ(num_forward_flow_images, output_packets.size()); + + for (int i = 0; i < num_forward_flow_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + ASSERT_EQ(output_packets[i].Timestamp().Value(), static_cast(i)); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) { + SetUpCalculator({"FORWARD_FLOW_ENCODED:flow_images"}, {}); + auto input_sequence = absl::make_unique(); + const std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + const std::string test_image_strings[2] = {"test_image_string0", + "test_image_string1"}; + const int num_forward_flow_images = 2; + for (int i = 0; i < num_forward_flow_images; ++i) { + mpms::AddForwardFlowTimestamp(i, input_sequence.get()); + mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; + ASSERT_EQ(num_forward_flow_images, output_packets.size()); + + for (int i = 0; i < num_forward_flow_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_strings[i]); + ASSERT_EQ(output_packets[i].Timestamp().Value(), static_cast(i)); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) { + SetUpCalculator({"BBOX:test", "FLOAT_FEATURE_OTHER:other"}, {}); + auto input_sequence = absl::make_unique(); + + std::vector> bboxes = { + {Location::CreateRelativeBBoxLocation(0.1, 0.2, 0.7, 0.7), + Location::CreateRelativeBBoxLocation(0.3, 0.4, 0.2, 0.1)}, + {Location::CreateRelativeBBoxLocation(0.2, 0.3, 0.4, 0.5)}}; + + for (int i = 0; i < bboxes.size(); ++i) { + mpms::AddBBox(bboxes[i], input_sequence.get()); + mpms::AddBBoxTimestamp(i, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("BBOX").packets; + ASSERT_EQ(bboxes.size(), output_packets.size()); + + for (int i = 0; i < bboxes.size(); ++i) { + const auto& output_vector = + output_packets[i].Get>(); + for (int j = 0; j < bboxes[i].size(); ++j) { + ASSERT_EQ(output_vector[j].GetRelativeBBox(), + bboxes[i][j].GetRelativeBBox()); + } + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) { + std::string prefix = "PREFIX"; + SetUpCalculator({"BBOX_PREFIX:test", "FLOAT_FEATURE_OTHER:other"}, {}); + auto input_sequence = absl::make_unique(); + + std::vector> bboxes = { + {Location::CreateRelativeBBoxLocation(0.1, 0.2, 0.7, 0.7), + Location::CreateRelativeBBoxLocation(0.3, 0.4, 0.2, 0.1)}, + {Location::CreateRelativeBBoxLocation(0.2, 0.3, 0.4, 0.5)}}; + + for (int i = 0; i < bboxes.size(); ++i) { + mpms::AddBBox(prefix, bboxes[i], input_sequence.get()); + mpms::AddBBoxTimestamp(prefix, i, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("BBOX_PREFIX").packets; + ASSERT_EQ(bboxes.size(), output_packets.size()); + + for (int i = 0; i < bboxes.size(); ++i) { + const auto& output_vector = + output_packets[i].Get>(); + for (int j = 0; j < bboxes[i].size(); ++j) { + ASSERT_EQ(output_vector[j].GetRelativeBBox(), + bboxes[i][j].GetRelativeBBox()); + } + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) { + SetUpCalculator({"FLOAT_FEATURE_TEST:test", "FLOAT_FEATURE_OTHER:other"}, {}); + auto input_sequence = absl::make_unique(); + + int num_float_lists = 2; + for (int i = 0; i < num_float_lists; ++i) { + std::vector data(2, 2 << i); + mpms::AddFeatureFloats("TEST", data, input_sequence.get()); + mpms::AddFeatureFloats("OTHER", data, input_sequence.get()); + mpms::AddFeatureTimestamp("TEST", i, input_sequence.get()); + mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets; + ASSERT_EQ(num_float_lists, output_packets.size()); + + for (int i = 0; i < num_float_lists; ++i) { + const auto& output_vector = output_packets[i].Get>(); + ASSERT_THAT(output_vector, + ::testing::ElementsAreArray(std::vector(2, 2 << i))); + } + + const std::vector& output_packets_other = + runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + ASSERT_EQ(num_float_lists, output_packets_other.size()); + + for (int i = 0; i < num_float_lists; ++i) { + const auto& output_vector = + output_packets_other[i].Get>(); + ASSERT_THAT(output_vector, + ::testing::ElementsAreArray(std::vector(2, 2 << i))); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) { + SetUpCalculator({"IMAGE:images", "FLOAT_FEATURE_OTHER:other"}, {}); + auto input_sequence = absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::string test_image_string = "test_image_string"; + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + mpms::AddImageTimestamp(i, input_sequence.get()); + mpms::AddImageEncoded(test_image_string, input_sequence.get()); + } + int num_float_lists = 2; + for (int i = 0; i < num_float_lists; ++i) { + std::vector data(2, 2 << i); + mpms::AddFeatureFloats("OTHER", data, input_sequence.get()); + mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get()); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("IMAGE").packets; + ASSERT_EQ(num_images, output_packets.size()); + + for (int i = 0; i < num_images; ++i) { + const std::string& output_image = output_packets[i].Get(); + ASSERT_EQ(output_image, test_image_string); + } + + const std::vector& output_packets_other = + runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; + ASSERT_EQ(num_float_lists, output_packets_other.size()); + + for (int i = 0; i < num_float_lists; ++i) { + const auto& output_vector = + output_packets_other[i].Get>(); + ASSERT_THAT(output_vector, + ::testing::ElementsAreArray(std::vector(2, 2 << i))); + } +} + +TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) { + SetUpCalculator( + {"FLOAT_FEATURE_FDENSE_AVG:avg", "FLOAT_FEATURE_FDENSE_MAX:max"}, {}); + auto input_sequence = absl::make_unique(); + mpms::AddFeatureFloats("FDENSE_AVG", {1.0f, 2.0f}, input_sequence.get()); + mpms::AddFeatureTimestamp("FDENSE_AVG", Timestamp::PostStream().Value(), + input_sequence.get()); + + mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get()); + mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), + input_sequence.get()); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& fdense_avg_packets = + runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets; + ASSERT_EQ(fdense_avg_packets.size(), 1); + const auto& fdense_avg_vector = + fdense_avg_packets[0].Get>(); + ASSERT_THAT(fdense_avg_vector, ::testing::ElementsAreArray({1.0f, 2.0f})); + ASSERT_THAT(fdense_avg_packets[0].Timestamp(), + ::testing::Eq(Timestamp::PostStream())); + + const std::vector& fdense_max_packets = + runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; + ASSERT_EQ(fdense_max_packets.size(), 1); + const auto& fdense_max_vector = + fdense_max_packets[0].Get>(); + ASSERT_THAT(fdense_max_vector, ::testing::ElementsAreArray({3.0f, 4.0f})); + ASSERT_THAT(fdense_max_packets[0].Timestamp(), + ::testing::Eq(Timestamp::PostStream())); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { + SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + + std::string root = "test_root"; + runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + MEDIAPIPE_ASSERT_OK(runner_->OutputSidePackets() + .Tag("DATA_PATH") + .ValidateAsType()); + ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + root + "/" + data_path_); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) { + CalculatorOptions options; + std::string root = "test_root"; + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_dataset_root_directory(root); + SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + MEDIAPIPE_ASSERT_OK(runner_->OutputSidePackets() + .Tag("DATA_PATH") + .ValidateAsType()); + ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + root + "/" + data_path_); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { + SetUpCalculator({}, {"DATA_PATH:data_path"}); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + MEDIAPIPE_ASSERT_OK(runner_->OutputSidePackets() + .Tag("DATA_PATH") + .ValidateAsType()); + ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get(), + data_path_); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { + CalculatorOptions options; + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_before_label(1); + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_after_label(2); + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->mutable_base_packet_resampler_options() + ->set_frame_rate(1.0); + SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + MEDIAPIPE_EXPECT_OK(runner_->OutputSidePackets() + .Tag("RESAMPLER_OPTIONS") + .ValidateAsType()); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("RESAMPLER_OPTIONS") + .Get() + .GetExtension(PacketResamplerCalculatorOptions::ext) + .start_time(), + 2000000, 1); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("RESAMPLER_OPTIONS") + .Get() + .GetExtension(PacketResamplerCalculatorOptions::ext) + .end_time(), + 7000000, 1); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("RESAMPLER_OPTIONS") + .Get() + .GetExtension(PacketResamplerCalculatorOptions::ext) + .frame_rate(), + 1.0, 1e-5); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) { + SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"}); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + MEDIAPIPE_ASSERT_OK(runner_->Run()); + MEDIAPIPE_EXPECT_OK(runner_->OutputSidePackets() + .Tag("IMAGE_FRAME_RATE") + .ValidateAsType()); + EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get(), + image_frame_rate_); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc new file mode 100644 index 000000000..a96e39918 --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc @@ -0,0 +1,131 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Converts vector (or vector>) to 1D (or 2D) tf::Tensor. + +#include "mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace tf = ::tensorflow; + +auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D; +auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D; + +// The calculator expects one input (a packet containing a vector or +// vector>) and generates one output (a packet containing a +// tf::Tensor containing the same data). The output tensor will be either +// 1D or 2D with dimensions corresponding to the input vector float. +// It will hold DT_FLOAT values. +// +// Example config: +// node { +// calculator: "VectorFloatToTensorCalculator" +// input_stream: "vector_float_features" +// output_stream: "tensor_features" +// } +class VectorFloatToTensorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + VectorFloatToTensorCalculatorOptions options_; +}; +REGISTER_CALCULATOR(VectorFloatToTensorCalculator); + +::mediapipe::Status VectorFloatToTensorCalculator::GetContract( + CalculatorContract* cc) { + const auto& options = cc->Options(); + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + if (options.input_size() == INPUT_2D) { + cc->Inputs().Index(0).Set>>( + /* "Input vector>." */); + } else if (options.input_size() == INPUT_1D) { + cc->Inputs().Index(0).Set>( + // Output vector. + ); + } else { + LOG(FATAL) << "input size not supported"; + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output stream with data as tf::Tensor and the same TimeSeriesHeader. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status VectorFloatToTensorCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status VectorFloatToTensorCalculator::Process( + CalculatorContext* cc) { + tf::TensorShape tensor_shape; + if (options_.input_size() == INPUT_2D) { + const std::vector>& input = + cc->Inputs().Index(0).Value().Get>>(); + + const int32 rows = input.size(); + CHECK_GE(rows, 1); + const int32 cols = input[0].size(); + CHECK_GE(cols, 1); + for (int i = 1; i < rows; ++i) { + CHECK_EQ(input[i].size(), cols); + } + if (options_.transpose()) { + tensor_shape = tf::TensorShape({cols, rows}); + } else { + tensor_shape = tf::TensorShape({rows, cols}); + } + auto output = ::absl::make_unique(tf::DT_FLOAT, tensor_shape); + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + if (options_.transpose()) { + output->tensor()(c, r) = input[r][c]; + } else { + output->tensor()(r, c) = input[r][c]; + } + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else if (options_.input_size() == INPUT_1D) { + const std::vector& input = + cc->Inputs().Index(0).Value().Get>(); + CHECK_GE(input.size(), 1); + const int32 length = input.size(); + tensor_shape = tf::TensorShape({length}); + auto output = ::absl::make_unique(tf::DT_FLOAT, tensor_shape); + for (int i = 0; i < length; ++i) { + output->tensor()(i) = input.at(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + LOG(FATAL) << "input size not supported"; + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.proto b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.proto new file mode 100644 index 000000000..01be3b72c --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message VectorFloatToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional VectorFloatToTensorCalculatorOptions ext = 136399889; + } + enum InputSize { + UNKNOWN = 0; + INPUT_1D = 1; + INPUT_2D = 2; + } + + // If input_size is INPUT_2D, unpack a vector> to a + // 2d tensor (matrix). If INPUT_1D, + // convert a vector into a 1d tensor (vector). + optional InputSize input_size = 1 [default = INPUT_1D]; + + // If true, the output tensor is transposed. + // Otherwise, the output tensor is not transposed. + // It will be ignored if tensor_is_2d is INPUT_1D. + optional bool transpose = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc new file mode 100644 index 000000000..aadce3615 --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc @@ -0,0 +1,122 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class VectorToTensorFloatCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner( + const VectorFloatToTensorCalculatorOptions::InputSize input_size, + const bool transpose) { + CalculatorGraphConfig::Node config; + config.set_calculator("VectorFloatToTensorCalculator"); + config.add_input_stream("input_float"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + VectorFloatToTensorCalculatorOptions::ext); + options->set_input_size(input_size); + options->set_transpose(transpose); + runner_ = ::absl::make_unique(config); + } + + void TestConvertFromVectoVectorFloat(const bool transpose) { + SetUpRunner(VectorFloatToTensorCalculatorOptions::INPUT_2D, transpose); + auto input = ::absl::make_unique>>( + 2, std::vector(2)); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + // 2^i can be represented exactly in floating point numbers + // if 'i' is small. + input->at(i).at(j) = static_cast(1 << (i * 2 + j)); + } + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(2, output_tensor.dims()); + EXPECT_EQ(tf::DT_FLOAT, output_tensor.dtype()); + const auto matrix = output_tensor.matrix(); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + if (!transpose) { + EXPECT_EQ(1 << (i * 2 + j), matrix(i, j)); + } else { + EXPECT_EQ(1 << (j * 2 + i), matrix(i, j)); + } + } + } + } + + std::unique_ptr runner_; +}; + +TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) { + SetUpRunner(VectorFloatToTensorCalculatorOptions::INPUT_1D, false); + auto input = ::absl::make_unique>(5); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is small. + input->at(i) = static_cast(1 << i); + } + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(1, output_tensor.dims()); + EXPECT_EQ(tf::DT_FLOAT, output_tensor.dtype()); + const auto vec = output_tensor.vec(); + + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(1 << i, vec(i)); + } +} + +TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorVectorFloat) { + for (bool transpose : {false, true}) { + TestConvertFromVectoVectorFloat(transpose); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD new file mode 100644 index 000000000..65a955cad --- /dev/null +++ b/mediapipe/calculators/tflite/BUILD @@ -0,0 +1,304 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "ssd_anchors_calculator_proto", + srcs = ["ssd_anchors_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tflite_custom_op_resolver_calculator_proto", + srcs = ["tflite_custom_op_resolver_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tflite_inference_calculator_proto", + srcs = ["tflite_inference_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tflite_converter_calculator_proto", + srcs = ["tflite_converter_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tflite_tensors_to_segmentation_calculator_proto", + srcs = ["tflite_tensors_to_segmentation_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "tflite_tensors_to_detections_calculator_proto", + srcs = ["tflite_tensors_to_detections_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "ssd_anchors_calculator_cc_proto", + srcs = ["ssd_anchors_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":ssd_anchors_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tflite_custom_op_resolver_calculator_cc_proto", + srcs = ["tflite_custom_op_resolver_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_custom_op_resolver_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tflite_converter_calculator_cc_proto", + srcs = ["tflite_converter_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_converter_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tflite_tensors_to_segmentation_calculator_cc_proto", + srcs = ["tflite_tensors_to_segmentation_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_tensors_to_segmentation_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tflite_inference_calculator_cc_proto", + srcs = ["tflite_inference_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_inference_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "tflite_tensors_to_detections_calculator_cc_proto", + srcs = ["tflite_tensors_to_detections_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_tensors_to_detections_calculator_proto"], +) + +cc_library( + name = "ssd_anchors_calculator", + srcs = ["ssd_anchors_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":ssd_anchors_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "tflite_custom_op_resolver_calculator", + srcs = ["tflite_custom_op_resolver_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":tflite_custom_op_resolver_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util/tflite:cpu_op_resolver", + "//mediapipe/util/tflite:op_resolver", + ], + alwayslink = 1, +) + +filegroup( + name = "anchor_golden_files", + srcs = [ + "testdata/anchor_golden_file_0.txt", + "testdata/anchor_golden_file_1.txt", + ], +) + +cc_test( + name = "ssd_anchors_calculator_test", + srcs = ["ssd_anchors_calculator_test.cc"], + data = [":anchor_golden_files"], + deps = [ + ":ssd_anchors_calculator", + ":ssd_anchors_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + ], +) + +cc_library( + name = "tflite_inference_calculator", + srcs = ["tflite_inference_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tflite_inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/util:resource_util", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/framework/port:ret_check", + ] + select({ + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "tflite_converter_calculator", + srcs = ["tflite_converter_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tflite_converter_calculator_cc_proto", + "//mediapipe/util:resource_util", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_calculator_helper", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "tflite_tensors_to_segmentation_calculator", + srcs = ["tflite_tensors_to_segmentation_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tflite_tensors_to_segmentation_calculator_cc_proto", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/util:resource_util", + "@org_tensorflow//tensorflow/lite:framework", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "tflite_tensors_to_detections_calculator", + srcs = ["tflite_tensors_to_detections_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tflite_tensors_to_detections_calculator_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/lite:framework", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_test( + name = "tflite_inference_calculator_test", + srcs = ["tflite_inference_calculator_test.cc"], + data = ["testdata/add.bin"], + linkstatic = 1, + deps = [ + ":tflite_inference_calculator", + ":tflite_inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc new file mode 100644 index 000000000..2dc60b990 --- /dev/null +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -0,0 +1,206 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +float CalculateScale(float min_scale, float max_scale, int stride_index, + int num_strides) { + return min_scale + + (max_scale - min_scale) * 1.0 * stride_index / (num_strides - 1.0f); +} + +} // namespace + +// Generate anchors for SSD object detection model. +// Output: +// ANCHORS: A list of anchors. Model generates predictions based on the +// offsets of these anchors. +// +// Usage example: +// node { +// calculator: "SsdAnchorsCalculator" +// output_side_packet: "anchors" +// options { +// [mediapipe.SsdAnchorsCalculatorOptions.ext] { +// num_layers: 6 +// min_scale: 0.2 +// max_scale: 0.95 +// input_size_height: 300 +// input_size_width: 300 +// anchor_offset_x: 0.5 +// anchor_offset_y: 0.5 +// strides: 16 +// strides: 32 +// strides: 64 +// strides: 128 +// strides: 256 +// strides: 512 +// aspect_ratios: 1.0 +// aspect_ratios: 2.0 +// aspect_ratios: 0.5 +// aspect_ratios: 3.0 +// aspect_ratios: 0.3333 +// reduce_boxes_in_lowest_layer: true +// } +// } +// } +class SsdAnchorsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->OutputSidePackets().Index(0).Set>(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const SsdAnchorsCalculatorOptions& options = + cc->Options(); + + auto anchors = absl::make_unique>(); + RETURN_IF_ERROR(GenerateAnchors(anchors.get(), options)); + cc->OutputSidePackets().Index(0).Set(Adopt(anchors.release())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + + private: + static ::mediapipe::Status GenerateAnchors( + std::vector* anchors, const SsdAnchorsCalculatorOptions& options); +}; +REGISTER_CALCULATOR(SsdAnchorsCalculator); + +::mediapipe::Status SsdAnchorsCalculator::GenerateAnchors( + std::vector* anchors, const SsdAnchorsCalculatorOptions& options) { + // Verify the options. + if (!options.feature_map_height_size() && !options.strides_size()) { + return ::mediapipe::InvalidArgumentError( + "Both feature map shape and strides are missing. Must provide either " + "one."); + } + if (options.feature_map_height_size()) { + if (options.strides_size()) { + LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; + } + CHECK_EQ(options.feature_map_height_size(), options.num_layers()); + CHECK_EQ(options.feature_map_height_size(), + options.feature_map_width_size()); + } else { + CHECK_EQ(options.strides_size(), options.num_layers()); + } + + int layer_id = 0; + while (layer_id < options.strides_size()) { + std::vector anchor_height; + std::vector anchor_width; + std::vector aspect_ratios; + std::vector scales; + + // For same strides, we merge the anchors in the same order. + int last_same_stride_layer = layer_id; + while (last_same_stride_layer < options.strides_size() && + options.strides(last_same_stride_layer) == + options.strides(layer_id)) { + const float scale = + CalculateScale(options.min_scale(), options.max_scale(), + last_same_stride_layer, options.strides_size()); + if (last_same_stride_layer == 0 && + options.reduce_boxes_in_lowest_layer()) { + // For first layer, it can be specified to use predefined anchors. + aspect_ratios.push_back(1.0); + aspect_ratios.push_back(2.0); + aspect_ratios.push_back(0.5); + scales.push_back(0.1); + scales.push_back(scale); + scales.push_back(scale); + } else { + for (int aspect_ratio_id = 0; + aspect_ratio_id < options.aspect_ratios_size(); + ++aspect_ratio_id) { + aspect_ratios.push_back(options.aspect_ratios(aspect_ratio_id)); + scales.push_back(scale); + } + if (options.interpolated_scale_aspect_ratio() > 0.0) { + const float scale_next = + last_same_stride_layer == options.strides_size() - 1 + ? 1.0f + : CalculateScale(options.min_scale(), options.max_scale(), + last_same_stride_layer + 1, + options.strides_size()); + scales.push_back(std::sqrt(scale * scale_next)); + aspect_ratios.push_back(options.interpolated_scale_aspect_ratio()); + } + } + last_same_stride_layer++; + } + + for (int i = 0; i < aspect_ratios.size(); ++i) { + const float ratio_sqrts = std::sqrt(aspect_ratios[i]); + anchor_height.push_back(scales[i] / ratio_sqrts); + anchor_width.push_back(scales[i] * ratio_sqrts); + } + + int feature_map_height = 0; + int feature_map_width = 0; + if (options.feature_map_height_size()) { + feature_map_height = options.feature_map_height(layer_id); + feature_map_width = options.feature_map_width(layer_id); + } else { + const int stride = options.strides(layer_id); + feature_map_height = + std::ceil(1.0f * options.input_size_height() / stride); + feature_map_width = std::ceil(1.0f * options.input_size_width() / stride); + } + + for (int y = 0; y < feature_map_height; ++y) { + for (int x = 0; x < feature_map_width; ++x) { + for (int anchor_id = 0; anchor_id < anchor_height.size(); ++anchor_id) { + // TODO: Support specifying anchor_offset_x, anchor_offset_y. + const float x_center = + (x + options.anchor_offset_x()) * 1.0f / feature_map_width; + const float y_center = + (y + options.anchor_offset_y()) * 1.0f / feature_map_height; + + Anchor new_anchor; + new_anchor.set_x_center(x_center); + new_anchor.set_y_center(y_center); + + if (options.fixed_anchor_size()) { + new_anchor.set_w(1.0f); + new_anchor.set_h(1.0f); + } else { + new_anchor.set_w(anchor_width[anchor_id]); + new_anchor.set_h(anchor_height[anchor_id]); + } + anchors->push_back(new_anchor); + } + } + } + layer_id = last_same_stride_layer; + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.proto b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto new file mode 100644 index 000000000..c89248822 --- /dev/null +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto @@ -0,0 +1,63 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Options to generate anchors for SSD object detection models. +message SsdAnchorsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional SsdAnchorsCalculatorOptions ext = 247258239; + } + // Size of input images. + required int32 input_size_width = 1; + required int32 input_size_height = 2; + + // Min and max scales for generating anchor boxes on feature maps. + required float min_scale = 3; + required float max_scale = 4; + + // The offset for the center of anchors. The value is in the scale of stride. + // E.g. 0.5 meaning 0.5 * |current_stride| in pixels. + required float anchor_offset_x = 5 [default = 0.5]; + required float anchor_offset_y = 6 [default = 0.5]; + + // Number of output feature maps to generate the anchors on. + required int32 num_layers = 7; + // Sizes of output feature maps to create anchors. Either feature_map size or + // stride should be provided. + repeated int32 feature_map_width = 8; + repeated int32 feature_map_height = 9; + // Strides of each output feature maps. + repeated int32 strides = 10; + + // List of different aspect ratio to generate anchors. + repeated float aspect_ratios = 11; + + // A boolean to indicate whether the fixed 3 boxes per location is used in the + // lowest layer. + optional bool reduce_boxes_in_lowest_layer = 12 [default = false]; + // An additional anchor is added with this aspect ratio and a scale + // interpolated between the scale for a layer and the scale for the next layer + // (1.0 for the last layer). This anchor is not included if this value is 0. + optional float interpolated_scale_aspect_ratio = 13 [default = 1.0]; + + // Whether use fixed width and height (e.g. both 1.0f) for each anchor. + // This option can be used when the predicted anchor width and height are in + // pixels. + optional bool fixed_anchor_size = 14 [default = false]; +} diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc new file mode 100644 index 000000000..df0814e8f --- /dev/null +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc @@ -0,0 +1,150 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +std::string GetGoldenFilePath(const std::string& filename) { + return mediapipe::file::JoinPath( + "./", "mediapipe/calculators/tflite/testdata/" + filename); +} + +void ParseAnchorsFromText(const std::string& text, + std::vector* anchors) { + const std::string line_delimiter = "\n"; + const std::string number_delimiter = ","; + + std::istringstream stream(text); + std::string line; + while (std::getline(stream, line)) { + Anchor anchor; + float values[4]; + std::string::size_type pos; + for (int i = 0; i < 4; ++i) { + values[i] = std::stof(line, &pos); + line = line.substr(pos); + } + anchor.set_x_center(values[0]); + anchor.set_y_center(values[1]); + anchor.set_w(values[2]); + anchor.set_h(values[3]); + anchors->push_back(anchor); + } +} + +void CompareAnchors(const std::vector& anchors_0, + const std::vector& anchors_1) { + EXPECT_EQ(anchors_0.size(), anchors_1.size()); + for (int i = 0; i < anchors_0.size(); ++i) { + const auto& anchor_0 = anchors_0[i]; + const auto& anchor_1 = anchors_1[i]; + EXPECT_THAT(anchor_0.x_center(), + testing::FloatNear(anchor_1.x_center(), 1e-5)); + EXPECT_THAT(anchor_0.y_center(), + testing::FloatNear(anchor_1.y_center(), 1e-5)); + } +} + +TEST(SsdAnchorCalculatorTest, FaceDetectionConfig) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + options { + [mediapipe.SsdAnchorsCalculatorOptions.ext] { + num_layers: 5 + min_scale: 0.1171875 + max_scale: 0.75 + input_size_height: 256 + input_size_width: 256 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 32 + strides: 32 + strides: 32 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } + )")); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + + const auto& anchors = + runner.OutputSidePackets().Index(0).Get>(); + std::string anchors_string; + MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents( + GetGoldenFilePath("anchor_golden_file_0.txt"), &anchors_string)); + + std::vector anchors_golden; + ParseAnchorsFromText(anchors_string, &anchors_golden); + + CompareAnchors(anchors, anchors_golden); +} + +TEST(SsdAnchorCalculatorTest, MobileSSDConfig) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + options { + [mediapipe.SsdAnchorsCalculatorOptions.ext] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 300 + input_size_width: 300 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } + )")); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const auto& anchors = + runner.OutputSidePackets().Index(0).Get>(); + + std::string anchors_string; + MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents( + GetGoldenFilePath("anchor_golden_file_1.txt"), &anchors_string)); + + std::vector anchors_golden; + ParseAnchorsFromText(anchors_string, &anchors_golden); + + CompareAnchors(anchors, anchors_golden); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/testdata/add.bin b/mediapipe/calculators/tflite/testdata/add.bin new file mode 100644 index 000000000..b4c02350c Binary files /dev/null and b/mediapipe/calculators/tflite/testdata/add.bin differ diff --git a/mediapipe/calculators/tflite/testdata/add.json b/mediapipe/calculators/tflite/testdata/add.json new file mode 100644 index 000000000..f589bebfb --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/add.json @@ -0,0 +1,79 @@ +{ + version: 3, + operator_codes: [ + { + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "add" + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "input" + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "output" + } + ], + inputs: [ + 1 + ], + outputs: [ + 2 + ], + operators: [ + { + inputs: [ + 1, + 1 + ], + outputs: [ + 0 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + }, + { + inputs: [ + 0, + 1 + ], + outputs: [ + 2 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + } + ] + } + ], + buffers: [ + { + data: [ + + ] + } + ] +} diff --git a/mediapipe/calculators/tflite/testdata/anchor_golden_file_0.txt b/mediapipe/calculators/tflite/testdata/anchor_golden_file_0.txt new file mode 100644 index 000000000..3d936c86b --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/anchor_golden_file_0.txt @@ -0,0 +1,2944 @@ +0.015625 0.015625 1 1 +0.015625 0.015625 1 1 +0.046875 0.015625 1 1 +0.046875 0.015625 1 1 +0.078125 0.015625 1 1 +0.078125 0.015625 1 1 +0.109375 0.015625 1 1 +0.109375 0.015625 1 1 +0.140625 0.015625 1 1 +0.140625 0.015625 1 1 +0.171875 0.015625 1 1 +0.171875 0.015625 1 1 +0.203125 0.015625 1 1 +0.203125 0.015625 1 1 +0.234375 0.015625 1 1 +0.234375 0.015625 1 1 +0.265625 0.015625 1 1 +0.265625 0.015625 1 1 +0.296875 0.015625 1 1 +0.296875 0.015625 1 1 +0.328125 0.015625 1 1 +0.328125 0.015625 1 1 +0.359375 0.015625 1 1 +0.359375 0.015625 1 1 +0.390625 0.015625 1 1 +0.390625 0.015625 1 1 +0.421875 0.015625 1 1 +0.421875 0.015625 1 1 +0.453125 0.015625 1 1 +0.453125 0.015625 1 1 +0.484375 0.015625 1 1 +0.484375 0.015625 1 1 +0.515625 0.015625 1 1 +0.515625 0.015625 1 1 +0.546875 0.015625 1 1 +0.546875 0.015625 1 1 +0.578125 0.015625 1 1 +0.578125 0.015625 1 1 +0.609375 0.015625 1 1 +0.609375 0.015625 1 1 +0.640625 0.015625 1 1 +0.640625 0.015625 1 1 +0.671875 0.015625 1 1 +0.671875 0.015625 1 1 +0.703125 0.015625 1 1 +0.703125 0.015625 1 1 +0.734375 0.015625 1 1 +0.734375 0.015625 1 1 +0.765625 0.015625 1 1 +0.765625 0.015625 1 1 +0.796875 0.015625 1 1 +0.796875 0.015625 1 1 +0.828125 0.015625 1 1 +0.828125 0.015625 1 1 +0.859375 0.015625 1 1 +0.859375 0.015625 1 1 +0.890625 0.015625 1 1 +0.890625 0.015625 1 1 +0.921875 0.015625 1 1 +0.921875 0.015625 1 1 +0.953125 0.015625 1 1 +0.953125 0.015625 1 1 +0.984375 0.015625 1 1 +0.984375 0.015625 1 1 +0.015625 0.046875 1 1 +0.015625 0.046875 1 1 +0.046875 0.046875 1 1 +0.046875 0.046875 1 1 +0.078125 0.046875 1 1 +0.078125 0.046875 1 1 +0.109375 0.046875 1 1 +0.109375 0.046875 1 1 +0.140625 0.046875 1 1 +0.140625 0.046875 1 1 +0.171875 0.046875 1 1 +0.171875 0.046875 1 1 +0.203125 0.046875 1 1 +0.203125 0.046875 1 1 +0.234375 0.046875 1 1 +0.234375 0.046875 1 1 +0.265625 0.046875 1 1 +0.265625 0.046875 1 1 +0.296875 0.046875 1 1 +0.296875 0.046875 1 1 +0.328125 0.046875 1 1 +0.328125 0.046875 1 1 +0.359375 0.046875 1 1 +0.359375 0.046875 1 1 +0.390625 0.046875 1 1 +0.390625 0.046875 1 1 +0.421875 0.046875 1 1 +0.421875 0.046875 1 1 +0.453125 0.046875 1 1 +0.453125 0.046875 1 1 +0.484375 0.046875 1 1 +0.484375 0.046875 1 1 +0.515625 0.046875 1 1 +0.515625 0.046875 1 1 +0.546875 0.046875 1 1 +0.546875 0.046875 1 1 +0.578125 0.046875 1 1 +0.578125 0.046875 1 1 +0.609375 0.046875 1 1 +0.609375 0.046875 1 1 +0.640625 0.046875 1 1 +0.640625 0.046875 1 1 +0.671875 0.046875 1 1 +0.671875 0.046875 1 1 +0.703125 0.046875 1 1 +0.703125 0.046875 1 1 +0.734375 0.046875 1 1 +0.734375 0.046875 1 1 +0.765625 0.046875 1 1 +0.765625 0.046875 1 1 +0.796875 0.046875 1 1 +0.796875 0.046875 1 1 +0.828125 0.046875 1 1 +0.828125 0.046875 1 1 +0.859375 0.046875 1 1 +0.859375 0.046875 1 1 +0.890625 0.046875 1 1 +0.890625 0.046875 1 1 +0.921875 0.046875 1 1 +0.921875 0.046875 1 1 +0.953125 0.046875 1 1 +0.953125 0.046875 1 1 +0.984375 0.046875 1 1 +0.984375 0.046875 1 1 +0.015625 0.078125 1 1 +0.015625 0.078125 1 1 +0.046875 0.078125 1 1 +0.046875 0.078125 1 1 +0.078125 0.078125 1 1 +0.078125 0.078125 1 1 +0.109375 0.078125 1 1 +0.109375 0.078125 1 1 +0.140625 0.078125 1 1 +0.140625 0.078125 1 1 +0.171875 0.078125 1 1 +0.171875 0.078125 1 1 +0.203125 0.078125 1 1 +0.203125 0.078125 1 1 +0.234375 0.078125 1 1 +0.234375 0.078125 1 1 +0.265625 0.078125 1 1 +0.265625 0.078125 1 1 +0.296875 0.078125 1 1 +0.296875 0.078125 1 1 +0.328125 0.078125 1 1 +0.328125 0.078125 1 1 +0.359375 0.078125 1 1 +0.359375 0.078125 1 1 +0.390625 0.078125 1 1 +0.390625 0.078125 1 1 +0.421875 0.078125 1 1 +0.421875 0.078125 1 1 +0.453125 0.078125 1 1 +0.453125 0.078125 1 1 +0.484375 0.078125 1 1 +0.484375 0.078125 1 1 +0.515625 0.078125 1 1 +0.515625 0.078125 1 1 +0.546875 0.078125 1 1 +0.546875 0.078125 1 1 +0.578125 0.078125 1 1 +0.578125 0.078125 1 1 +0.609375 0.078125 1 1 +0.609375 0.078125 1 1 +0.640625 0.078125 1 1 +0.640625 0.078125 1 1 +0.671875 0.078125 1 1 +0.671875 0.078125 1 1 +0.703125 0.078125 1 1 +0.703125 0.078125 1 1 +0.734375 0.078125 1 1 +0.734375 0.078125 1 1 +0.765625 0.078125 1 1 +0.765625 0.078125 1 1 +0.796875 0.078125 1 1 +0.796875 0.078125 1 1 +0.828125 0.078125 1 1 +0.828125 0.078125 1 1 +0.859375 0.078125 1 1 +0.859375 0.078125 1 1 +0.890625 0.078125 1 1 +0.890625 0.078125 1 1 +0.921875 0.078125 1 1 +0.921875 0.078125 1 1 +0.953125 0.078125 1 1 +0.953125 0.078125 1 1 +0.984375 0.078125 1 1 +0.984375 0.078125 1 1 +0.015625 0.109375 1 1 +0.015625 0.109375 1 1 +0.046875 0.109375 1 1 +0.046875 0.109375 1 1 +0.078125 0.109375 1 1 +0.078125 0.109375 1 1 +0.109375 0.109375 1 1 +0.109375 0.109375 1 1 +0.140625 0.109375 1 1 +0.140625 0.109375 1 1 +0.171875 0.109375 1 1 +0.171875 0.109375 1 1 +0.203125 0.109375 1 1 +0.203125 0.109375 1 1 +0.234375 0.109375 1 1 +0.234375 0.109375 1 1 +0.265625 0.109375 1 1 +0.265625 0.109375 1 1 +0.296875 0.109375 1 1 +0.296875 0.109375 1 1 +0.328125 0.109375 1 1 +0.328125 0.109375 1 1 +0.359375 0.109375 1 1 +0.359375 0.109375 1 1 +0.390625 0.109375 1 1 +0.390625 0.109375 1 1 +0.421875 0.109375 1 1 +0.421875 0.109375 1 1 +0.453125 0.109375 1 1 +0.453125 0.109375 1 1 +0.484375 0.109375 1 1 +0.484375 0.109375 1 1 +0.515625 0.109375 1 1 +0.515625 0.109375 1 1 +0.546875 0.109375 1 1 +0.546875 0.109375 1 1 +0.578125 0.109375 1 1 +0.578125 0.109375 1 1 +0.609375 0.109375 1 1 +0.609375 0.109375 1 1 +0.640625 0.109375 1 1 +0.640625 0.109375 1 1 +0.671875 0.109375 1 1 +0.671875 0.109375 1 1 +0.703125 0.109375 1 1 +0.703125 0.109375 1 1 +0.734375 0.109375 1 1 +0.734375 0.109375 1 1 +0.765625 0.109375 1 1 +0.765625 0.109375 1 1 +0.796875 0.109375 1 1 +0.796875 0.109375 1 1 +0.828125 0.109375 1 1 +0.828125 0.109375 1 1 +0.859375 0.109375 1 1 +0.859375 0.109375 1 1 +0.890625 0.109375 1 1 +0.890625 0.109375 1 1 +0.921875 0.109375 1 1 +0.921875 0.109375 1 1 +0.953125 0.109375 1 1 +0.953125 0.109375 1 1 +0.984375 0.109375 1 1 +0.984375 0.109375 1 1 +0.015625 0.140625 1 1 +0.015625 0.140625 1 1 +0.046875 0.140625 1 1 +0.046875 0.140625 1 1 +0.078125 0.140625 1 1 +0.078125 0.140625 1 1 +0.109375 0.140625 1 1 +0.109375 0.140625 1 1 +0.140625 0.140625 1 1 +0.140625 0.140625 1 1 +0.171875 0.140625 1 1 +0.171875 0.140625 1 1 +0.203125 0.140625 1 1 +0.203125 0.140625 1 1 +0.234375 0.140625 1 1 +0.234375 0.140625 1 1 +0.265625 0.140625 1 1 +0.265625 0.140625 1 1 +0.296875 0.140625 1 1 +0.296875 0.140625 1 1 +0.328125 0.140625 1 1 +0.328125 0.140625 1 1 +0.359375 0.140625 1 1 +0.359375 0.140625 1 1 +0.390625 0.140625 1 1 +0.390625 0.140625 1 1 +0.421875 0.140625 1 1 +0.421875 0.140625 1 1 +0.453125 0.140625 1 1 +0.453125 0.140625 1 1 +0.484375 0.140625 1 1 +0.484375 0.140625 1 1 +0.515625 0.140625 1 1 +0.515625 0.140625 1 1 +0.546875 0.140625 1 1 +0.546875 0.140625 1 1 +0.578125 0.140625 1 1 +0.578125 0.140625 1 1 +0.609375 0.140625 1 1 +0.609375 0.140625 1 1 +0.640625 0.140625 1 1 +0.640625 0.140625 1 1 +0.671875 0.140625 1 1 +0.671875 0.140625 1 1 +0.703125 0.140625 1 1 +0.703125 0.140625 1 1 +0.734375 0.140625 1 1 +0.734375 0.140625 1 1 +0.765625 0.140625 1 1 +0.765625 0.140625 1 1 +0.796875 0.140625 1 1 +0.796875 0.140625 1 1 +0.828125 0.140625 1 1 +0.828125 0.140625 1 1 +0.859375 0.140625 1 1 +0.859375 0.140625 1 1 +0.890625 0.140625 1 1 +0.890625 0.140625 1 1 +0.921875 0.140625 1 1 +0.921875 0.140625 1 1 +0.953125 0.140625 1 1 +0.953125 0.140625 1 1 +0.984375 0.140625 1 1 +0.984375 0.140625 1 1 +0.015625 0.171875 1 1 +0.015625 0.171875 1 1 +0.046875 0.171875 1 1 +0.046875 0.171875 1 1 +0.078125 0.171875 1 1 +0.078125 0.171875 1 1 +0.109375 0.171875 1 1 +0.109375 0.171875 1 1 +0.140625 0.171875 1 1 +0.140625 0.171875 1 1 +0.171875 0.171875 1 1 +0.171875 0.171875 1 1 +0.203125 0.171875 1 1 +0.203125 0.171875 1 1 +0.234375 0.171875 1 1 +0.234375 0.171875 1 1 +0.265625 0.171875 1 1 +0.265625 0.171875 1 1 +0.296875 0.171875 1 1 +0.296875 0.171875 1 1 +0.328125 0.171875 1 1 +0.328125 0.171875 1 1 +0.359375 0.171875 1 1 +0.359375 0.171875 1 1 +0.390625 0.171875 1 1 +0.390625 0.171875 1 1 +0.421875 0.171875 1 1 +0.421875 0.171875 1 1 +0.453125 0.171875 1 1 +0.453125 0.171875 1 1 +0.484375 0.171875 1 1 +0.484375 0.171875 1 1 +0.515625 0.171875 1 1 +0.515625 0.171875 1 1 +0.546875 0.171875 1 1 +0.546875 0.171875 1 1 +0.578125 0.171875 1 1 +0.578125 0.171875 1 1 +0.609375 0.171875 1 1 +0.609375 0.171875 1 1 +0.640625 0.171875 1 1 +0.640625 0.171875 1 1 +0.671875 0.171875 1 1 +0.671875 0.171875 1 1 +0.703125 0.171875 1 1 +0.703125 0.171875 1 1 +0.734375 0.171875 1 1 +0.734375 0.171875 1 1 +0.765625 0.171875 1 1 +0.765625 0.171875 1 1 +0.796875 0.171875 1 1 +0.796875 0.171875 1 1 +0.828125 0.171875 1 1 +0.828125 0.171875 1 1 +0.859375 0.171875 1 1 +0.859375 0.171875 1 1 +0.890625 0.171875 1 1 +0.890625 0.171875 1 1 +0.921875 0.171875 1 1 +0.921875 0.171875 1 1 +0.953125 0.171875 1 1 +0.953125 0.171875 1 1 +0.984375 0.171875 1 1 +0.984375 0.171875 1 1 +0.015625 0.203125 1 1 +0.015625 0.203125 1 1 +0.046875 0.203125 1 1 +0.046875 0.203125 1 1 +0.078125 0.203125 1 1 +0.078125 0.203125 1 1 +0.109375 0.203125 1 1 +0.109375 0.203125 1 1 +0.140625 0.203125 1 1 +0.140625 0.203125 1 1 +0.171875 0.203125 1 1 +0.171875 0.203125 1 1 +0.203125 0.203125 1 1 +0.203125 0.203125 1 1 +0.234375 0.203125 1 1 +0.234375 0.203125 1 1 +0.265625 0.203125 1 1 +0.265625 0.203125 1 1 +0.296875 0.203125 1 1 +0.296875 0.203125 1 1 +0.328125 0.203125 1 1 +0.328125 0.203125 1 1 +0.359375 0.203125 1 1 +0.359375 0.203125 1 1 +0.390625 0.203125 1 1 +0.390625 0.203125 1 1 +0.421875 0.203125 1 1 +0.421875 0.203125 1 1 +0.453125 0.203125 1 1 +0.453125 0.203125 1 1 +0.484375 0.203125 1 1 +0.484375 0.203125 1 1 +0.515625 0.203125 1 1 +0.515625 0.203125 1 1 +0.546875 0.203125 1 1 +0.546875 0.203125 1 1 +0.578125 0.203125 1 1 +0.578125 0.203125 1 1 +0.609375 0.203125 1 1 +0.609375 0.203125 1 1 +0.640625 0.203125 1 1 +0.640625 0.203125 1 1 +0.671875 0.203125 1 1 +0.671875 0.203125 1 1 +0.703125 0.203125 1 1 +0.703125 0.203125 1 1 +0.734375 0.203125 1 1 +0.734375 0.203125 1 1 +0.765625 0.203125 1 1 +0.765625 0.203125 1 1 +0.796875 0.203125 1 1 +0.796875 0.203125 1 1 +0.828125 0.203125 1 1 +0.828125 0.203125 1 1 +0.859375 0.203125 1 1 +0.859375 0.203125 1 1 +0.890625 0.203125 1 1 +0.890625 0.203125 1 1 +0.921875 0.203125 1 1 +0.921875 0.203125 1 1 +0.953125 0.203125 1 1 +0.953125 0.203125 1 1 +0.984375 0.203125 1 1 +0.984375 0.203125 1 1 +0.015625 0.234375 1 1 +0.015625 0.234375 1 1 +0.046875 0.234375 1 1 +0.046875 0.234375 1 1 +0.078125 0.234375 1 1 +0.078125 0.234375 1 1 +0.109375 0.234375 1 1 +0.109375 0.234375 1 1 +0.140625 0.234375 1 1 +0.140625 0.234375 1 1 +0.171875 0.234375 1 1 +0.171875 0.234375 1 1 +0.203125 0.234375 1 1 +0.203125 0.234375 1 1 +0.234375 0.234375 1 1 +0.234375 0.234375 1 1 +0.265625 0.234375 1 1 +0.265625 0.234375 1 1 +0.296875 0.234375 1 1 +0.296875 0.234375 1 1 +0.328125 0.234375 1 1 +0.328125 0.234375 1 1 +0.359375 0.234375 1 1 +0.359375 0.234375 1 1 +0.390625 0.234375 1 1 +0.390625 0.234375 1 1 +0.421875 0.234375 1 1 +0.421875 0.234375 1 1 +0.453125 0.234375 1 1 +0.453125 0.234375 1 1 +0.484375 0.234375 1 1 +0.484375 0.234375 1 1 +0.515625 0.234375 1 1 +0.515625 0.234375 1 1 +0.546875 0.234375 1 1 +0.546875 0.234375 1 1 +0.578125 0.234375 1 1 +0.578125 0.234375 1 1 +0.609375 0.234375 1 1 +0.609375 0.234375 1 1 +0.640625 0.234375 1 1 +0.640625 0.234375 1 1 +0.671875 0.234375 1 1 +0.671875 0.234375 1 1 +0.703125 0.234375 1 1 +0.703125 0.234375 1 1 +0.734375 0.234375 1 1 +0.734375 0.234375 1 1 +0.765625 0.234375 1 1 +0.765625 0.234375 1 1 +0.796875 0.234375 1 1 +0.796875 0.234375 1 1 +0.828125 0.234375 1 1 +0.828125 0.234375 1 1 +0.859375 0.234375 1 1 +0.859375 0.234375 1 1 +0.890625 0.234375 1 1 +0.890625 0.234375 1 1 +0.921875 0.234375 1 1 +0.921875 0.234375 1 1 +0.953125 0.234375 1 1 +0.953125 0.234375 1 1 +0.984375 0.234375 1 1 +0.984375 0.234375 1 1 +0.015625 0.265625 1 1 +0.015625 0.265625 1 1 +0.046875 0.265625 1 1 +0.046875 0.265625 1 1 +0.078125 0.265625 1 1 +0.078125 0.265625 1 1 +0.109375 0.265625 1 1 +0.109375 0.265625 1 1 +0.140625 0.265625 1 1 +0.140625 0.265625 1 1 +0.171875 0.265625 1 1 +0.171875 0.265625 1 1 +0.203125 0.265625 1 1 +0.203125 0.265625 1 1 +0.234375 0.265625 1 1 +0.234375 0.265625 1 1 +0.265625 0.265625 1 1 +0.265625 0.265625 1 1 +0.296875 0.265625 1 1 +0.296875 0.265625 1 1 +0.328125 0.265625 1 1 +0.328125 0.265625 1 1 +0.359375 0.265625 1 1 +0.359375 0.265625 1 1 +0.390625 0.265625 1 1 +0.390625 0.265625 1 1 +0.421875 0.265625 1 1 +0.421875 0.265625 1 1 +0.453125 0.265625 1 1 +0.453125 0.265625 1 1 +0.484375 0.265625 1 1 +0.484375 0.265625 1 1 +0.515625 0.265625 1 1 +0.515625 0.265625 1 1 +0.546875 0.265625 1 1 +0.546875 0.265625 1 1 +0.578125 0.265625 1 1 +0.578125 0.265625 1 1 +0.609375 0.265625 1 1 +0.609375 0.265625 1 1 +0.640625 0.265625 1 1 +0.640625 0.265625 1 1 +0.671875 0.265625 1 1 +0.671875 0.265625 1 1 +0.703125 0.265625 1 1 +0.703125 0.265625 1 1 +0.734375 0.265625 1 1 +0.734375 0.265625 1 1 +0.765625 0.265625 1 1 +0.765625 0.265625 1 1 +0.796875 0.265625 1 1 +0.796875 0.265625 1 1 +0.828125 0.265625 1 1 +0.828125 0.265625 1 1 +0.859375 0.265625 1 1 +0.859375 0.265625 1 1 +0.890625 0.265625 1 1 +0.890625 0.265625 1 1 +0.921875 0.265625 1 1 +0.921875 0.265625 1 1 +0.953125 0.265625 1 1 +0.953125 0.265625 1 1 +0.984375 0.265625 1 1 +0.984375 0.265625 1 1 +0.015625 0.296875 1 1 +0.015625 0.296875 1 1 +0.046875 0.296875 1 1 +0.046875 0.296875 1 1 +0.078125 0.296875 1 1 +0.078125 0.296875 1 1 +0.109375 0.296875 1 1 +0.109375 0.296875 1 1 +0.140625 0.296875 1 1 +0.140625 0.296875 1 1 +0.171875 0.296875 1 1 +0.171875 0.296875 1 1 +0.203125 0.296875 1 1 +0.203125 0.296875 1 1 +0.234375 0.296875 1 1 +0.234375 0.296875 1 1 +0.265625 0.296875 1 1 +0.265625 0.296875 1 1 +0.296875 0.296875 1 1 +0.296875 0.296875 1 1 +0.328125 0.296875 1 1 +0.328125 0.296875 1 1 +0.359375 0.296875 1 1 +0.359375 0.296875 1 1 +0.390625 0.296875 1 1 +0.390625 0.296875 1 1 +0.421875 0.296875 1 1 +0.421875 0.296875 1 1 +0.453125 0.296875 1 1 +0.453125 0.296875 1 1 +0.484375 0.296875 1 1 +0.484375 0.296875 1 1 +0.515625 0.296875 1 1 +0.515625 0.296875 1 1 +0.546875 0.296875 1 1 +0.546875 0.296875 1 1 +0.578125 0.296875 1 1 +0.578125 0.296875 1 1 +0.609375 0.296875 1 1 +0.609375 0.296875 1 1 +0.640625 0.296875 1 1 +0.640625 0.296875 1 1 +0.671875 0.296875 1 1 +0.671875 0.296875 1 1 +0.703125 0.296875 1 1 +0.703125 0.296875 1 1 +0.734375 0.296875 1 1 +0.734375 0.296875 1 1 +0.765625 0.296875 1 1 +0.765625 0.296875 1 1 +0.796875 0.296875 1 1 +0.796875 0.296875 1 1 +0.828125 0.296875 1 1 +0.828125 0.296875 1 1 +0.859375 0.296875 1 1 +0.859375 0.296875 1 1 +0.890625 0.296875 1 1 +0.890625 0.296875 1 1 +0.921875 0.296875 1 1 +0.921875 0.296875 1 1 +0.953125 0.296875 1 1 +0.953125 0.296875 1 1 +0.984375 0.296875 1 1 +0.984375 0.296875 1 1 +0.015625 0.328125 1 1 +0.015625 0.328125 1 1 +0.046875 0.328125 1 1 +0.046875 0.328125 1 1 +0.078125 0.328125 1 1 +0.078125 0.328125 1 1 +0.109375 0.328125 1 1 +0.109375 0.328125 1 1 +0.140625 0.328125 1 1 +0.140625 0.328125 1 1 +0.171875 0.328125 1 1 +0.171875 0.328125 1 1 +0.203125 0.328125 1 1 +0.203125 0.328125 1 1 +0.234375 0.328125 1 1 +0.234375 0.328125 1 1 +0.265625 0.328125 1 1 +0.265625 0.328125 1 1 +0.296875 0.328125 1 1 +0.296875 0.328125 1 1 +0.328125 0.328125 1 1 +0.328125 0.328125 1 1 +0.359375 0.328125 1 1 +0.359375 0.328125 1 1 +0.390625 0.328125 1 1 +0.390625 0.328125 1 1 +0.421875 0.328125 1 1 +0.421875 0.328125 1 1 +0.453125 0.328125 1 1 +0.453125 0.328125 1 1 +0.484375 0.328125 1 1 +0.484375 0.328125 1 1 +0.515625 0.328125 1 1 +0.515625 0.328125 1 1 +0.546875 0.328125 1 1 +0.546875 0.328125 1 1 +0.578125 0.328125 1 1 +0.578125 0.328125 1 1 +0.609375 0.328125 1 1 +0.609375 0.328125 1 1 +0.640625 0.328125 1 1 +0.640625 0.328125 1 1 +0.671875 0.328125 1 1 +0.671875 0.328125 1 1 +0.703125 0.328125 1 1 +0.703125 0.328125 1 1 +0.734375 0.328125 1 1 +0.734375 0.328125 1 1 +0.765625 0.328125 1 1 +0.765625 0.328125 1 1 +0.796875 0.328125 1 1 +0.796875 0.328125 1 1 +0.828125 0.328125 1 1 +0.828125 0.328125 1 1 +0.859375 0.328125 1 1 +0.859375 0.328125 1 1 +0.890625 0.328125 1 1 +0.890625 0.328125 1 1 +0.921875 0.328125 1 1 +0.921875 0.328125 1 1 +0.953125 0.328125 1 1 +0.953125 0.328125 1 1 +0.984375 0.328125 1 1 +0.984375 0.328125 1 1 +0.015625 0.359375 1 1 +0.015625 0.359375 1 1 +0.046875 0.359375 1 1 +0.046875 0.359375 1 1 +0.078125 0.359375 1 1 +0.078125 0.359375 1 1 +0.109375 0.359375 1 1 +0.109375 0.359375 1 1 +0.140625 0.359375 1 1 +0.140625 0.359375 1 1 +0.171875 0.359375 1 1 +0.171875 0.359375 1 1 +0.203125 0.359375 1 1 +0.203125 0.359375 1 1 +0.234375 0.359375 1 1 +0.234375 0.359375 1 1 +0.265625 0.359375 1 1 +0.265625 0.359375 1 1 +0.296875 0.359375 1 1 +0.296875 0.359375 1 1 +0.328125 0.359375 1 1 +0.328125 0.359375 1 1 +0.359375 0.359375 1 1 +0.359375 0.359375 1 1 +0.390625 0.359375 1 1 +0.390625 0.359375 1 1 +0.421875 0.359375 1 1 +0.421875 0.359375 1 1 +0.453125 0.359375 1 1 +0.453125 0.359375 1 1 +0.484375 0.359375 1 1 +0.484375 0.359375 1 1 +0.515625 0.359375 1 1 +0.515625 0.359375 1 1 +0.546875 0.359375 1 1 +0.546875 0.359375 1 1 +0.578125 0.359375 1 1 +0.578125 0.359375 1 1 +0.609375 0.359375 1 1 +0.609375 0.359375 1 1 +0.640625 0.359375 1 1 +0.640625 0.359375 1 1 +0.671875 0.359375 1 1 +0.671875 0.359375 1 1 +0.703125 0.359375 1 1 +0.703125 0.359375 1 1 +0.734375 0.359375 1 1 +0.734375 0.359375 1 1 +0.765625 0.359375 1 1 +0.765625 0.359375 1 1 +0.796875 0.359375 1 1 +0.796875 0.359375 1 1 +0.828125 0.359375 1 1 +0.828125 0.359375 1 1 +0.859375 0.359375 1 1 +0.859375 0.359375 1 1 +0.890625 0.359375 1 1 +0.890625 0.359375 1 1 +0.921875 0.359375 1 1 +0.921875 0.359375 1 1 +0.953125 0.359375 1 1 +0.953125 0.359375 1 1 +0.984375 0.359375 1 1 +0.984375 0.359375 1 1 +0.015625 0.390625 1 1 +0.015625 0.390625 1 1 +0.046875 0.390625 1 1 +0.046875 0.390625 1 1 +0.078125 0.390625 1 1 +0.078125 0.390625 1 1 +0.109375 0.390625 1 1 +0.109375 0.390625 1 1 +0.140625 0.390625 1 1 +0.140625 0.390625 1 1 +0.171875 0.390625 1 1 +0.171875 0.390625 1 1 +0.203125 0.390625 1 1 +0.203125 0.390625 1 1 +0.234375 0.390625 1 1 +0.234375 0.390625 1 1 +0.265625 0.390625 1 1 +0.265625 0.390625 1 1 +0.296875 0.390625 1 1 +0.296875 0.390625 1 1 +0.328125 0.390625 1 1 +0.328125 0.390625 1 1 +0.359375 0.390625 1 1 +0.359375 0.390625 1 1 +0.390625 0.390625 1 1 +0.390625 0.390625 1 1 +0.421875 0.390625 1 1 +0.421875 0.390625 1 1 +0.453125 0.390625 1 1 +0.453125 0.390625 1 1 +0.484375 0.390625 1 1 +0.484375 0.390625 1 1 +0.515625 0.390625 1 1 +0.515625 0.390625 1 1 +0.546875 0.390625 1 1 +0.546875 0.390625 1 1 +0.578125 0.390625 1 1 +0.578125 0.390625 1 1 +0.609375 0.390625 1 1 +0.609375 0.390625 1 1 +0.640625 0.390625 1 1 +0.640625 0.390625 1 1 +0.671875 0.390625 1 1 +0.671875 0.390625 1 1 +0.703125 0.390625 1 1 +0.703125 0.390625 1 1 +0.734375 0.390625 1 1 +0.734375 0.390625 1 1 +0.765625 0.390625 1 1 +0.765625 0.390625 1 1 +0.796875 0.390625 1 1 +0.796875 0.390625 1 1 +0.828125 0.390625 1 1 +0.828125 0.390625 1 1 +0.859375 0.390625 1 1 +0.859375 0.390625 1 1 +0.890625 0.390625 1 1 +0.890625 0.390625 1 1 +0.921875 0.390625 1 1 +0.921875 0.390625 1 1 +0.953125 0.390625 1 1 +0.953125 0.390625 1 1 +0.984375 0.390625 1 1 +0.984375 0.390625 1 1 +0.015625 0.421875 1 1 +0.015625 0.421875 1 1 +0.046875 0.421875 1 1 +0.046875 0.421875 1 1 +0.078125 0.421875 1 1 +0.078125 0.421875 1 1 +0.109375 0.421875 1 1 +0.109375 0.421875 1 1 +0.140625 0.421875 1 1 +0.140625 0.421875 1 1 +0.171875 0.421875 1 1 +0.171875 0.421875 1 1 +0.203125 0.421875 1 1 +0.203125 0.421875 1 1 +0.234375 0.421875 1 1 +0.234375 0.421875 1 1 +0.265625 0.421875 1 1 +0.265625 0.421875 1 1 +0.296875 0.421875 1 1 +0.296875 0.421875 1 1 +0.328125 0.421875 1 1 +0.328125 0.421875 1 1 +0.359375 0.421875 1 1 +0.359375 0.421875 1 1 +0.390625 0.421875 1 1 +0.390625 0.421875 1 1 +0.421875 0.421875 1 1 +0.421875 0.421875 1 1 +0.453125 0.421875 1 1 +0.453125 0.421875 1 1 +0.484375 0.421875 1 1 +0.484375 0.421875 1 1 +0.515625 0.421875 1 1 +0.515625 0.421875 1 1 +0.546875 0.421875 1 1 +0.546875 0.421875 1 1 +0.578125 0.421875 1 1 +0.578125 0.421875 1 1 +0.609375 0.421875 1 1 +0.609375 0.421875 1 1 +0.640625 0.421875 1 1 +0.640625 0.421875 1 1 +0.671875 0.421875 1 1 +0.671875 0.421875 1 1 +0.703125 0.421875 1 1 +0.703125 0.421875 1 1 +0.734375 0.421875 1 1 +0.734375 0.421875 1 1 +0.765625 0.421875 1 1 +0.765625 0.421875 1 1 +0.796875 0.421875 1 1 +0.796875 0.421875 1 1 +0.828125 0.421875 1 1 +0.828125 0.421875 1 1 +0.859375 0.421875 1 1 +0.859375 0.421875 1 1 +0.890625 0.421875 1 1 +0.890625 0.421875 1 1 +0.921875 0.421875 1 1 +0.921875 0.421875 1 1 +0.953125 0.421875 1 1 +0.953125 0.421875 1 1 +0.984375 0.421875 1 1 +0.984375 0.421875 1 1 +0.015625 0.453125 1 1 +0.015625 0.453125 1 1 +0.046875 0.453125 1 1 +0.046875 0.453125 1 1 +0.078125 0.453125 1 1 +0.078125 0.453125 1 1 +0.109375 0.453125 1 1 +0.109375 0.453125 1 1 +0.140625 0.453125 1 1 +0.140625 0.453125 1 1 +0.171875 0.453125 1 1 +0.171875 0.453125 1 1 +0.203125 0.453125 1 1 +0.203125 0.453125 1 1 +0.234375 0.453125 1 1 +0.234375 0.453125 1 1 +0.265625 0.453125 1 1 +0.265625 0.453125 1 1 +0.296875 0.453125 1 1 +0.296875 0.453125 1 1 +0.328125 0.453125 1 1 +0.328125 0.453125 1 1 +0.359375 0.453125 1 1 +0.359375 0.453125 1 1 +0.390625 0.453125 1 1 +0.390625 0.453125 1 1 +0.421875 0.453125 1 1 +0.421875 0.453125 1 1 +0.453125 0.453125 1 1 +0.453125 0.453125 1 1 +0.484375 0.453125 1 1 +0.484375 0.453125 1 1 +0.515625 0.453125 1 1 +0.515625 0.453125 1 1 +0.546875 0.453125 1 1 +0.546875 0.453125 1 1 +0.578125 0.453125 1 1 +0.578125 0.453125 1 1 +0.609375 0.453125 1 1 +0.609375 0.453125 1 1 +0.640625 0.453125 1 1 +0.640625 0.453125 1 1 +0.671875 0.453125 1 1 +0.671875 0.453125 1 1 +0.703125 0.453125 1 1 +0.703125 0.453125 1 1 +0.734375 0.453125 1 1 +0.734375 0.453125 1 1 +0.765625 0.453125 1 1 +0.765625 0.453125 1 1 +0.796875 0.453125 1 1 +0.796875 0.453125 1 1 +0.828125 0.453125 1 1 +0.828125 0.453125 1 1 +0.859375 0.453125 1 1 +0.859375 0.453125 1 1 +0.890625 0.453125 1 1 +0.890625 0.453125 1 1 +0.921875 0.453125 1 1 +0.921875 0.453125 1 1 +0.953125 0.453125 1 1 +0.953125 0.453125 1 1 +0.984375 0.453125 1 1 +0.984375 0.453125 1 1 +0.015625 0.484375 1 1 +0.015625 0.484375 1 1 +0.046875 0.484375 1 1 +0.046875 0.484375 1 1 +0.078125 0.484375 1 1 +0.078125 0.484375 1 1 +0.109375 0.484375 1 1 +0.109375 0.484375 1 1 +0.140625 0.484375 1 1 +0.140625 0.484375 1 1 +0.171875 0.484375 1 1 +0.171875 0.484375 1 1 +0.203125 0.484375 1 1 +0.203125 0.484375 1 1 +0.234375 0.484375 1 1 +0.234375 0.484375 1 1 +0.265625 0.484375 1 1 +0.265625 0.484375 1 1 +0.296875 0.484375 1 1 +0.296875 0.484375 1 1 +0.328125 0.484375 1 1 +0.328125 0.484375 1 1 +0.359375 0.484375 1 1 +0.359375 0.484375 1 1 +0.390625 0.484375 1 1 +0.390625 0.484375 1 1 +0.421875 0.484375 1 1 +0.421875 0.484375 1 1 +0.453125 0.484375 1 1 +0.453125 0.484375 1 1 +0.484375 0.484375 1 1 +0.484375 0.484375 1 1 +0.515625 0.484375 1 1 +0.515625 0.484375 1 1 +0.546875 0.484375 1 1 +0.546875 0.484375 1 1 +0.578125 0.484375 1 1 +0.578125 0.484375 1 1 +0.609375 0.484375 1 1 +0.609375 0.484375 1 1 +0.640625 0.484375 1 1 +0.640625 0.484375 1 1 +0.671875 0.484375 1 1 +0.671875 0.484375 1 1 +0.703125 0.484375 1 1 +0.703125 0.484375 1 1 +0.734375 0.484375 1 1 +0.734375 0.484375 1 1 +0.765625 0.484375 1 1 +0.765625 0.484375 1 1 +0.796875 0.484375 1 1 +0.796875 0.484375 1 1 +0.828125 0.484375 1 1 +0.828125 0.484375 1 1 +0.859375 0.484375 1 1 +0.859375 0.484375 1 1 +0.890625 0.484375 1 1 +0.890625 0.484375 1 1 +0.921875 0.484375 1 1 +0.921875 0.484375 1 1 +0.953125 0.484375 1 1 +0.953125 0.484375 1 1 +0.984375 0.484375 1 1 +0.984375 0.484375 1 1 +0.015625 0.515625 1 1 +0.015625 0.515625 1 1 +0.046875 0.515625 1 1 +0.046875 0.515625 1 1 +0.078125 0.515625 1 1 +0.078125 0.515625 1 1 +0.109375 0.515625 1 1 +0.109375 0.515625 1 1 +0.140625 0.515625 1 1 +0.140625 0.515625 1 1 +0.171875 0.515625 1 1 +0.171875 0.515625 1 1 +0.203125 0.515625 1 1 +0.203125 0.515625 1 1 +0.234375 0.515625 1 1 +0.234375 0.515625 1 1 +0.265625 0.515625 1 1 +0.265625 0.515625 1 1 +0.296875 0.515625 1 1 +0.296875 0.515625 1 1 +0.328125 0.515625 1 1 +0.328125 0.515625 1 1 +0.359375 0.515625 1 1 +0.359375 0.515625 1 1 +0.390625 0.515625 1 1 +0.390625 0.515625 1 1 +0.421875 0.515625 1 1 +0.421875 0.515625 1 1 +0.453125 0.515625 1 1 +0.453125 0.515625 1 1 +0.484375 0.515625 1 1 +0.484375 0.515625 1 1 +0.515625 0.515625 1 1 +0.515625 0.515625 1 1 +0.546875 0.515625 1 1 +0.546875 0.515625 1 1 +0.578125 0.515625 1 1 +0.578125 0.515625 1 1 +0.609375 0.515625 1 1 +0.609375 0.515625 1 1 +0.640625 0.515625 1 1 +0.640625 0.515625 1 1 +0.671875 0.515625 1 1 +0.671875 0.515625 1 1 +0.703125 0.515625 1 1 +0.703125 0.515625 1 1 +0.734375 0.515625 1 1 +0.734375 0.515625 1 1 +0.765625 0.515625 1 1 +0.765625 0.515625 1 1 +0.796875 0.515625 1 1 +0.796875 0.515625 1 1 +0.828125 0.515625 1 1 +0.828125 0.515625 1 1 +0.859375 0.515625 1 1 +0.859375 0.515625 1 1 +0.890625 0.515625 1 1 +0.890625 0.515625 1 1 +0.921875 0.515625 1 1 +0.921875 0.515625 1 1 +0.953125 0.515625 1 1 +0.953125 0.515625 1 1 +0.984375 0.515625 1 1 +0.984375 0.515625 1 1 +0.015625 0.546875 1 1 +0.015625 0.546875 1 1 +0.046875 0.546875 1 1 +0.046875 0.546875 1 1 +0.078125 0.546875 1 1 +0.078125 0.546875 1 1 +0.109375 0.546875 1 1 +0.109375 0.546875 1 1 +0.140625 0.546875 1 1 +0.140625 0.546875 1 1 +0.171875 0.546875 1 1 +0.171875 0.546875 1 1 +0.203125 0.546875 1 1 +0.203125 0.546875 1 1 +0.234375 0.546875 1 1 +0.234375 0.546875 1 1 +0.265625 0.546875 1 1 +0.265625 0.546875 1 1 +0.296875 0.546875 1 1 +0.296875 0.546875 1 1 +0.328125 0.546875 1 1 +0.328125 0.546875 1 1 +0.359375 0.546875 1 1 +0.359375 0.546875 1 1 +0.390625 0.546875 1 1 +0.390625 0.546875 1 1 +0.421875 0.546875 1 1 +0.421875 0.546875 1 1 +0.453125 0.546875 1 1 +0.453125 0.546875 1 1 +0.484375 0.546875 1 1 +0.484375 0.546875 1 1 +0.515625 0.546875 1 1 +0.515625 0.546875 1 1 +0.546875 0.546875 1 1 +0.546875 0.546875 1 1 +0.578125 0.546875 1 1 +0.578125 0.546875 1 1 +0.609375 0.546875 1 1 +0.609375 0.546875 1 1 +0.640625 0.546875 1 1 +0.640625 0.546875 1 1 +0.671875 0.546875 1 1 +0.671875 0.546875 1 1 +0.703125 0.546875 1 1 +0.703125 0.546875 1 1 +0.734375 0.546875 1 1 +0.734375 0.546875 1 1 +0.765625 0.546875 1 1 +0.765625 0.546875 1 1 +0.796875 0.546875 1 1 +0.796875 0.546875 1 1 +0.828125 0.546875 1 1 +0.828125 0.546875 1 1 +0.859375 0.546875 1 1 +0.859375 0.546875 1 1 +0.890625 0.546875 1 1 +0.890625 0.546875 1 1 +0.921875 0.546875 1 1 +0.921875 0.546875 1 1 +0.953125 0.546875 1 1 +0.953125 0.546875 1 1 +0.984375 0.546875 1 1 +0.984375 0.546875 1 1 +0.015625 0.578125 1 1 +0.015625 0.578125 1 1 +0.046875 0.578125 1 1 +0.046875 0.578125 1 1 +0.078125 0.578125 1 1 +0.078125 0.578125 1 1 +0.109375 0.578125 1 1 +0.109375 0.578125 1 1 +0.140625 0.578125 1 1 +0.140625 0.578125 1 1 +0.171875 0.578125 1 1 +0.171875 0.578125 1 1 +0.203125 0.578125 1 1 +0.203125 0.578125 1 1 +0.234375 0.578125 1 1 +0.234375 0.578125 1 1 +0.265625 0.578125 1 1 +0.265625 0.578125 1 1 +0.296875 0.578125 1 1 +0.296875 0.578125 1 1 +0.328125 0.578125 1 1 +0.328125 0.578125 1 1 +0.359375 0.578125 1 1 +0.359375 0.578125 1 1 +0.390625 0.578125 1 1 +0.390625 0.578125 1 1 +0.421875 0.578125 1 1 +0.421875 0.578125 1 1 +0.453125 0.578125 1 1 +0.453125 0.578125 1 1 +0.484375 0.578125 1 1 +0.484375 0.578125 1 1 +0.515625 0.578125 1 1 +0.515625 0.578125 1 1 +0.546875 0.578125 1 1 +0.546875 0.578125 1 1 +0.578125 0.578125 1 1 +0.578125 0.578125 1 1 +0.609375 0.578125 1 1 +0.609375 0.578125 1 1 +0.640625 0.578125 1 1 +0.640625 0.578125 1 1 +0.671875 0.578125 1 1 +0.671875 0.578125 1 1 +0.703125 0.578125 1 1 +0.703125 0.578125 1 1 +0.734375 0.578125 1 1 +0.734375 0.578125 1 1 +0.765625 0.578125 1 1 +0.765625 0.578125 1 1 +0.796875 0.578125 1 1 +0.796875 0.578125 1 1 +0.828125 0.578125 1 1 +0.828125 0.578125 1 1 +0.859375 0.578125 1 1 +0.859375 0.578125 1 1 +0.890625 0.578125 1 1 +0.890625 0.578125 1 1 +0.921875 0.578125 1 1 +0.921875 0.578125 1 1 +0.953125 0.578125 1 1 +0.953125 0.578125 1 1 +0.984375 0.578125 1 1 +0.984375 0.578125 1 1 +0.015625 0.609375 1 1 +0.015625 0.609375 1 1 +0.046875 0.609375 1 1 +0.046875 0.609375 1 1 +0.078125 0.609375 1 1 +0.078125 0.609375 1 1 +0.109375 0.609375 1 1 +0.109375 0.609375 1 1 +0.140625 0.609375 1 1 +0.140625 0.609375 1 1 +0.171875 0.609375 1 1 +0.171875 0.609375 1 1 +0.203125 0.609375 1 1 +0.203125 0.609375 1 1 +0.234375 0.609375 1 1 +0.234375 0.609375 1 1 +0.265625 0.609375 1 1 +0.265625 0.609375 1 1 +0.296875 0.609375 1 1 +0.296875 0.609375 1 1 +0.328125 0.609375 1 1 +0.328125 0.609375 1 1 +0.359375 0.609375 1 1 +0.359375 0.609375 1 1 +0.390625 0.609375 1 1 +0.390625 0.609375 1 1 +0.421875 0.609375 1 1 +0.421875 0.609375 1 1 +0.453125 0.609375 1 1 +0.453125 0.609375 1 1 +0.484375 0.609375 1 1 +0.484375 0.609375 1 1 +0.515625 0.609375 1 1 +0.515625 0.609375 1 1 +0.546875 0.609375 1 1 +0.546875 0.609375 1 1 +0.578125 0.609375 1 1 +0.578125 0.609375 1 1 +0.609375 0.609375 1 1 +0.609375 0.609375 1 1 +0.640625 0.609375 1 1 +0.640625 0.609375 1 1 +0.671875 0.609375 1 1 +0.671875 0.609375 1 1 +0.703125 0.609375 1 1 +0.703125 0.609375 1 1 +0.734375 0.609375 1 1 +0.734375 0.609375 1 1 +0.765625 0.609375 1 1 +0.765625 0.609375 1 1 +0.796875 0.609375 1 1 +0.796875 0.609375 1 1 +0.828125 0.609375 1 1 +0.828125 0.609375 1 1 +0.859375 0.609375 1 1 +0.859375 0.609375 1 1 +0.890625 0.609375 1 1 +0.890625 0.609375 1 1 +0.921875 0.609375 1 1 +0.921875 0.609375 1 1 +0.953125 0.609375 1 1 +0.953125 0.609375 1 1 +0.984375 0.609375 1 1 +0.984375 0.609375 1 1 +0.015625 0.640625 1 1 +0.015625 0.640625 1 1 +0.046875 0.640625 1 1 +0.046875 0.640625 1 1 +0.078125 0.640625 1 1 +0.078125 0.640625 1 1 +0.109375 0.640625 1 1 +0.109375 0.640625 1 1 +0.140625 0.640625 1 1 +0.140625 0.640625 1 1 +0.171875 0.640625 1 1 +0.171875 0.640625 1 1 +0.203125 0.640625 1 1 +0.203125 0.640625 1 1 +0.234375 0.640625 1 1 +0.234375 0.640625 1 1 +0.265625 0.640625 1 1 +0.265625 0.640625 1 1 +0.296875 0.640625 1 1 +0.296875 0.640625 1 1 +0.328125 0.640625 1 1 +0.328125 0.640625 1 1 +0.359375 0.640625 1 1 +0.359375 0.640625 1 1 +0.390625 0.640625 1 1 +0.390625 0.640625 1 1 +0.421875 0.640625 1 1 +0.421875 0.640625 1 1 +0.453125 0.640625 1 1 +0.453125 0.640625 1 1 +0.484375 0.640625 1 1 +0.484375 0.640625 1 1 +0.515625 0.640625 1 1 +0.515625 0.640625 1 1 +0.546875 0.640625 1 1 +0.546875 0.640625 1 1 +0.578125 0.640625 1 1 +0.578125 0.640625 1 1 +0.609375 0.640625 1 1 +0.609375 0.640625 1 1 +0.640625 0.640625 1 1 +0.640625 0.640625 1 1 +0.671875 0.640625 1 1 +0.671875 0.640625 1 1 +0.703125 0.640625 1 1 +0.703125 0.640625 1 1 +0.734375 0.640625 1 1 +0.734375 0.640625 1 1 +0.765625 0.640625 1 1 +0.765625 0.640625 1 1 +0.796875 0.640625 1 1 +0.796875 0.640625 1 1 +0.828125 0.640625 1 1 +0.828125 0.640625 1 1 +0.859375 0.640625 1 1 +0.859375 0.640625 1 1 +0.890625 0.640625 1 1 +0.890625 0.640625 1 1 +0.921875 0.640625 1 1 +0.921875 0.640625 1 1 +0.953125 0.640625 1 1 +0.953125 0.640625 1 1 +0.984375 0.640625 1 1 +0.984375 0.640625 1 1 +0.015625 0.671875 1 1 +0.015625 0.671875 1 1 +0.046875 0.671875 1 1 +0.046875 0.671875 1 1 +0.078125 0.671875 1 1 +0.078125 0.671875 1 1 +0.109375 0.671875 1 1 +0.109375 0.671875 1 1 +0.140625 0.671875 1 1 +0.140625 0.671875 1 1 +0.171875 0.671875 1 1 +0.171875 0.671875 1 1 +0.203125 0.671875 1 1 +0.203125 0.671875 1 1 +0.234375 0.671875 1 1 +0.234375 0.671875 1 1 +0.265625 0.671875 1 1 +0.265625 0.671875 1 1 +0.296875 0.671875 1 1 +0.296875 0.671875 1 1 +0.328125 0.671875 1 1 +0.328125 0.671875 1 1 +0.359375 0.671875 1 1 +0.359375 0.671875 1 1 +0.390625 0.671875 1 1 +0.390625 0.671875 1 1 +0.421875 0.671875 1 1 +0.421875 0.671875 1 1 +0.453125 0.671875 1 1 +0.453125 0.671875 1 1 +0.484375 0.671875 1 1 +0.484375 0.671875 1 1 +0.515625 0.671875 1 1 +0.515625 0.671875 1 1 +0.546875 0.671875 1 1 +0.546875 0.671875 1 1 +0.578125 0.671875 1 1 +0.578125 0.671875 1 1 +0.609375 0.671875 1 1 +0.609375 0.671875 1 1 +0.640625 0.671875 1 1 +0.640625 0.671875 1 1 +0.671875 0.671875 1 1 +0.671875 0.671875 1 1 +0.703125 0.671875 1 1 +0.703125 0.671875 1 1 +0.734375 0.671875 1 1 +0.734375 0.671875 1 1 +0.765625 0.671875 1 1 +0.765625 0.671875 1 1 +0.796875 0.671875 1 1 +0.796875 0.671875 1 1 +0.828125 0.671875 1 1 +0.828125 0.671875 1 1 +0.859375 0.671875 1 1 +0.859375 0.671875 1 1 +0.890625 0.671875 1 1 +0.890625 0.671875 1 1 +0.921875 0.671875 1 1 +0.921875 0.671875 1 1 +0.953125 0.671875 1 1 +0.953125 0.671875 1 1 +0.984375 0.671875 1 1 +0.984375 0.671875 1 1 +0.015625 0.703125 1 1 +0.015625 0.703125 1 1 +0.046875 0.703125 1 1 +0.046875 0.703125 1 1 +0.078125 0.703125 1 1 +0.078125 0.703125 1 1 +0.109375 0.703125 1 1 +0.109375 0.703125 1 1 +0.140625 0.703125 1 1 +0.140625 0.703125 1 1 +0.171875 0.703125 1 1 +0.171875 0.703125 1 1 +0.203125 0.703125 1 1 +0.203125 0.703125 1 1 +0.234375 0.703125 1 1 +0.234375 0.703125 1 1 +0.265625 0.703125 1 1 +0.265625 0.703125 1 1 +0.296875 0.703125 1 1 +0.296875 0.703125 1 1 +0.328125 0.703125 1 1 +0.328125 0.703125 1 1 +0.359375 0.703125 1 1 +0.359375 0.703125 1 1 +0.390625 0.703125 1 1 +0.390625 0.703125 1 1 +0.421875 0.703125 1 1 +0.421875 0.703125 1 1 +0.453125 0.703125 1 1 +0.453125 0.703125 1 1 +0.484375 0.703125 1 1 +0.484375 0.703125 1 1 +0.515625 0.703125 1 1 +0.515625 0.703125 1 1 +0.546875 0.703125 1 1 +0.546875 0.703125 1 1 +0.578125 0.703125 1 1 +0.578125 0.703125 1 1 +0.609375 0.703125 1 1 +0.609375 0.703125 1 1 +0.640625 0.703125 1 1 +0.640625 0.703125 1 1 +0.671875 0.703125 1 1 +0.671875 0.703125 1 1 +0.703125 0.703125 1 1 +0.703125 0.703125 1 1 +0.734375 0.703125 1 1 +0.734375 0.703125 1 1 +0.765625 0.703125 1 1 +0.765625 0.703125 1 1 +0.796875 0.703125 1 1 +0.796875 0.703125 1 1 +0.828125 0.703125 1 1 +0.828125 0.703125 1 1 +0.859375 0.703125 1 1 +0.859375 0.703125 1 1 +0.890625 0.703125 1 1 +0.890625 0.703125 1 1 +0.921875 0.703125 1 1 +0.921875 0.703125 1 1 +0.953125 0.703125 1 1 +0.953125 0.703125 1 1 +0.984375 0.703125 1 1 +0.984375 0.703125 1 1 +0.015625 0.734375 1 1 +0.015625 0.734375 1 1 +0.046875 0.734375 1 1 +0.046875 0.734375 1 1 +0.078125 0.734375 1 1 +0.078125 0.734375 1 1 +0.109375 0.734375 1 1 +0.109375 0.734375 1 1 +0.140625 0.734375 1 1 +0.140625 0.734375 1 1 +0.171875 0.734375 1 1 +0.171875 0.734375 1 1 +0.203125 0.734375 1 1 +0.203125 0.734375 1 1 +0.234375 0.734375 1 1 +0.234375 0.734375 1 1 +0.265625 0.734375 1 1 +0.265625 0.734375 1 1 +0.296875 0.734375 1 1 +0.296875 0.734375 1 1 +0.328125 0.734375 1 1 +0.328125 0.734375 1 1 +0.359375 0.734375 1 1 +0.359375 0.734375 1 1 +0.390625 0.734375 1 1 +0.390625 0.734375 1 1 +0.421875 0.734375 1 1 +0.421875 0.734375 1 1 +0.453125 0.734375 1 1 +0.453125 0.734375 1 1 +0.484375 0.734375 1 1 +0.484375 0.734375 1 1 +0.515625 0.734375 1 1 +0.515625 0.734375 1 1 +0.546875 0.734375 1 1 +0.546875 0.734375 1 1 +0.578125 0.734375 1 1 +0.578125 0.734375 1 1 +0.609375 0.734375 1 1 +0.609375 0.734375 1 1 +0.640625 0.734375 1 1 +0.640625 0.734375 1 1 +0.671875 0.734375 1 1 +0.671875 0.734375 1 1 +0.703125 0.734375 1 1 +0.703125 0.734375 1 1 +0.734375 0.734375 1 1 +0.734375 0.734375 1 1 +0.765625 0.734375 1 1 +0.765625 0.734375 1 1 +0.796875 0.734375 1 1 +0.796875 0.734375 1 1 +0.828125 0.734375 1 1 +0.828125 0.734375 1 1 +0.859375 0.734375 1 1 +0.859375 0.734375 1 1 +0.890625 0.734375 1 1 +0.890625 0.734375 1 1 +0.921875 0.734375 1 1 +0.921875 0.734375 1 1 +0.953125 0.734375 1 1 +0.953125 0.734375 1 1 +0.984375 0.734375 1 1 +0.984375 0.734375 1 1 +0.015625 0.765625 1 1 +0.015625 0.765625 1 1 +0.046875 0.765625 1 1 +0.046875 0.765625 1 1 +0.078125 0.765625 1 1 +0.078125 0.765625 1 1 +0.109375 0.765625 1 1 +0.109375 0.765625 1 1 +0.140625 0.765625 1 1 +0.140625 0.765625 1 1 +0.171875 0.765625 1 1 +0.171875 0.765625 1 1 +0.203125 0.765625 1 1 +0.203125 0.765625 1 1 +0.234375 0.765625 1 1 +0.234375 0.765625 1 1 +0.265625 0.765625 1 1 +0.265625 0.765625 1 1 +0.296875 0.765625 1 1 +0.296875 0.765625 1 1 +0.328125 0.765625 1 1 +0.328125 0.765625 1 1 +0.359375 0.765625 1 1 +0.359375 0.765625 1 1 +0.390625 0.765625 1 1 +0.390625 0.765625 1 1 +0.421875 0.765625 1 1 +0.421875 0.765625 1 1 +0.453125 0.765625 1 1 +0.453125 0.765625 1 1 +0.484375 0.765625 1 1 +0.484375 0.765625 1 1 +0.515625 0.765625 1 1 +0.515625 0.765625 1 1 +0.546875 0.765625 1 1 +0.546875 0.765625 1 1 +0.578125 0.765625 1 1 +0.578125 0.765625 1 1 +0.609375 0.765625 1 1 +0.609375 0.765625 1 1 +0.640625 0.765625 1 1 +0.640625 0.765625 1 1 +0.671875 0.765625 1 1 +0.671875 0.765625 1 1 +0.703125 0.765625 1 1 +0.703125 0.765625 1 1 +0.734375 0.765625 1 1 +0.734375 0.765625 1 1 +0.765625 0.765625 1 1 +0.765625 0.765625 1 1 +0.796875 0.765625 1 1 +0.796875 0.765625 1 1 +0.828125 0.765625 1 1 +0.828125 0.765625 1 1 +0.859375 0.765625 1 1 +0.859375 0.765625 1 1 +0.890625 0.765625 1 1 +0.890625 0.765625 1 1 +0.921875 0.765625 1 1 +0.921875 0.765625 1 1 +0.953125 0.765625 1 1 +0.953125 0.765625 1 1 +0.984375 0.765625 1 1 +0.984375 0.765625 1 1 +0.015625 0.796875 1 1 +0.015625 0.796875 1 1 +0.046875 0.796875 1 1 +0.046875 0.796875 1 1 +0.078125 0.796875 1 1 +0.078125 0.796875 1 1 +0.109375 0.796875 1 1 +0.109375 0.796875 1 1 +0.140625 0.796875 1 1 +0.140625 0.796875 1 1 +0.171875 0.796875 1 1 +0.171875 0.796875 1 1 +0.203125 0.796875 1 1 +0.203125 0.796875 1 1 +0.234375 0.796875 1 1 +0.234375 0.796875 1 1 +0.265625 0.796875 1 1 +0.265625 0.796875 1 1 +0.296875 0.796875 1 1 +0.296875 0.796875 1 1 +0.328125 0.796875 1 1 +0.328125 0.796875 1 1 +0.359375 0.796875 1 1 +0.359375 0.796875 1 1 +0.390625 0.796875 1 1 +0.390625 0.796875 1 1 +0.421875 0.796875 1 1 +0.421875 0.796875 1 1 +0.453125 0.796875 1 1 +0.453125 0.796875 1 1 +0.484375 0.796875 1 1 +0.484375 0.796875 1 1 +0.515625 0.796875 1 1 +0.515625 0.796875 1 1 +0.546875 0.796875 1 1 +0.546875 0.796875 1 1 +0.578125 0.796875 1 1 +0.578125 0.796875 1 1 +0.609375 0.796875 1 1 +0.609375 0.796875 1 1 +0.640625 0.796875 1 1 +0.640625 0.796875 1 1 +0.671875 0.796875 1 1 +0.671875 0.796875 1 1 +0.703125 0.796875 1 1 +0.703125 0.796875 1 1 +0.734375 0.796875 1 1 +0.734375 0.796875 1 1 +0.765625 0.796875 1 1 +0.765625 0.796875 1 1 +0.796875 0.796875 1 1 +0.796875 0.796875 1 1 +0.828125 0.796875 1 1 +0.828125 0.796875 1 1 +0.859375 0.796875 1 1 +0.859375 0.796875 1 1 +0.890625 0.796875 1 1 +0.890625 0.796875 1 1 +0.921875 0.796875 1 1 +0.921875 0.796875 1 1 +0.953125 0.796875 1 1 +0.953125 0.796875 1 1 +0.984375 0.796875 1 1 +0.984375 0.796875 1 1 +0.015625 0.828125 1 1 +0.015625 0.828125 1 1 +0.046875 0.828125 1 1 +0.046875 0.828125 1 1 +0.078125 0.828125 1 1 +0.078125 0.828125 1 1 +0.109375 0.828125 1 1 +0.109375 0.828125 1 1 +0.140625 0.828125 1 1 +0.140625 0.828125 1 1 +0.171875 0.828125 1 1 +0.171875 0.828125 1 1 +0.203125 0.828125 1 1 +0.203125 0.828125 1 1 +0.234375 0.828125 1 1 +0.234375 0.828125 1 1 +0.265625 0.828125 1 1 +0.265625 0.828125 1 1 +0.296875 0.828125 1 1 +0.296875 0.828125 1 1 +0.328125 0.828125 1 1 +0.328125 0.828125 1 1 +0.359375 0.828125 1 1 +0.359375 0.828125 1 1 +0.390625 0.828125 1 1 +0.390625 0.828125 1 1 +0.421875 0.828125 1 1 +0.421875 0.828125 1 1 +0.453125 0.828125 1 1 +0.453125 0.828125 1 1 +0.484375 0.828125 1 1 +0.484375 0.828125 1 1 +0.515625 0.828125 1 1 +0.515625 0.828125 1 1 +0.546875 0.828125 1 1 +0.546875 0.828125 1 1 +0.578125 0.828125 1 1 +0.578125 0.828125 1 1 +0.609375 0.828125 1 1 +0.609375 0.828125 1 1 +0.640625 0.828125 1 1 +0.640625 0.828125 1 1 +0.671875 0.828125 1 1 +0.671875 0.828125 1 1 +0.703125 0.828125 1 1 +0.703125 0.828125 1 1 +0.734375 0.828125 1 1 +0.734375 0.828125 1 1 +0.765625 0.828125 1 1 +0.765625 0.828125 1 1 +0.796875 0.828125 1 1 +0.796875 0.828125 1 1 +0.828125 0.828125 1 1 +0.828125 0.828125 1 1 +0.859375 0.828125 1 1 +0.859375 0.828125 1 1 +0.890625 0.828125 1 1 +0.890625 0.828125 1 1 +0.921875 0.828125 1 1 +0.921875 0.828125 1 1 +0.953125 0.828125 1 1 +0.953125 0.828125 1 1 +0.984375 0.828125 1 1 +0.984375 0.828125 1 1 +0.015625 0.859375 1 1 +0.015625 0.859375 1 1 +0.046875 0.859375 1 1 +0.046875 0.859375 1 1 +0.078125 0.859375 1 1 +0.078125 0.859375 1 1 +0.109375 0.859375 1 1 +0.109375 0.859375 1 1 +0.140625 0.859375 1 1 +0.140625 0.859375 1 1 +0.171875 0.859375 1 1 +0.171875 0.859375 1 1 +0.203125 0.859375 1 1 +0.203125 0.859375 1 1 +0.234375 0.859375 1 1 +0.234375 0.859375 1 1 +0.265625 0.859375 1 1 +0.265625 0.859375 1 1 +0.296875 0.859375 1 1 +0.296875 0.859375 1 1 +0.328125 0.859375 1 1 +0.328125 0.859375 1 1 +0.359375 0.859375 1 1 +0.359375 0.859375 1 1 +0.390625 0.859375 1 1 +0.390625 0.859375 1 1 +0.421875 0.859375 1 1 +0.421875 0.859375 1 1 +0.453125 0.859375 1 1 +0.453125 0.859375 1 1 +0.484375 0.859375 1 1 +0.484375 0.859375 1 1 +0.515625 0.859375 1 1 +0.515625 0.859375 1 1 +0.546875 0.859375 1 1 +0.546875 0.859375 1 1 +0.578125 0.859375 1 1 +0.578125 0.859375 1 1 +0.609375 0.859375 1 1 +0.609375 0.859375 1 1 +0.640625 0.859375 1 1 +0.640625 0.859375 1 1 +0.671875 0.859375 1 1 +0.671875 0.859375 1 1 +0.703125 0.859375 1 1 +0.703125 0.859375 1 1 +0.734375 0.859375 1 1 +0.734375 0.859375 1 1 +0.765625 0.859375 1 1 +0.765625 0.859375 1 1 +0.796875 0.859375 1 1 +0.796875 0.859375 1 1 +0.828125 0.859375 1 1 +0.828125 0.859375 1 1 +0.859375 0.859375 1 1 +0.859375 0.859375 1 1 +0.890625 0.859375 1 1 +0.890625 0.859375 1 1 +0.921875 0.859375 1 1 +0.921875 0.859375 1 1 +0.953125 0.859375 1 1 +0.953125 0.859375 1 1 +0.984375 0.859375 1 1 +0.984375 0.859375 1 1 +0.015625 0.890625 1 1 +0.015625 0.890625 1 1 +0.046875 0.890625 1 1 +0.046875 0.890625 1 1 +0.078125 0.890625 1 1 +0.078125 0.890625 1 1 +0.109375 0.890625 1 1 +0.109375 0.890625 1 1 +0.140625 0.890625 1 1 +0.140625 0.890625 1 1 +0.171875 0.890625 1 1 +0.171875 0.890625 1 1 +0.203125 0.890625 1 1 +0.203125 0.890625 1 1 +0.234375 0.890625 1 1 +0.234375 0.890625 1 1 +0.265625 0.890625 1 1 +0.265625 0.890625 1 1 +0.296875 0.890625 1 1 +0.296875 0.890625 1 1 +0.328125 0.890625 1 1 +0.328125 0.890625 1 1 +0.359375 0.890625 1 1 +0.359375 0.890625 1 1 +0.390625 0.890625 1 1 +0.390625 0.890625 1 1 +0.421875 0.890625 1 1 +0.421875 0.890625 1 1 +0.453125 0.890625 1 1 +0.453125 0.890625 1 1 +0.484375 0.890625 1 1 +0.484375 0.890625 1 1 +0.515625 0.890625 1 1 +0.515625 0.890625 1 1 +0.546875 0.890625 1 1 +0.546875 0.890625 1 1 +0.578125 0.890625 1 1 +0.578125 0.890625 1 1 +0.609375 0.890625 1 1 +0.609375 0.890625 1 1 +0.640625 0.890625 1 1 +0.640625 0.890625 1 1 +0.671875 0.890625 1 1 +0.671875 0.890625 1 1 +0.703125 0.890625 1 1 +0.703125 0.890625 1 1 +0.734375 0.890625 1 1 +0.734375 0.890625 1 1 +0.765625 0.890625 1 1 +0.765625 0.890625 1 1 +0.796875 0.890625 1 1 +0.796875 0.890625 1 1 +0.828125 0.890625 1 1 +0.828125 0.890625 1 1 +0.859375 0.890625 1 1 +0.859375 0.890625 1 1 +0.890625 0.890625 1 1 +0.890625 0.890625 1 1 +0.921875 0.890625 1 1 +0.921875 0.890625 1 1 +0.953125 0.890625 1 1 +0.953125 0.890625 1 1 +0.984375 0.890625 1 1 +0.984375 0.890625 1 1 +0.015625 0.921875 1 1 +0.015625 0.921875 1 1 +0.046875 0.921875 1 1 +0.046875 0.921875 1 1 +0.078125 0.921875 1 1 +0.078125 0.921875 1 1 +0.109375 0.921875 1 1 +0.109375 0.921875 1 1 +0.140625 0.921875 1 1 +0.140625 0.921875 1 1 +0.171875 0.921875 1 1 +0.171875 0.921875 1 1 +0.203125 0.921875 1 1 +0.203125 0.921875 1 1 +0.234375 0.921875 1 1 +0.234375 0.921875 1 1 +0.265625 0.921875 1 1 +0.265625 0.921875 1 1 +0.296875 0.921875 1 1 +0.296875 0.921875 1 1 +0.328125 0.921875 1 1 +0.328125 0.921875 1 1 +0.359375 0.921875 1 1 +0.359375 0.921875 1 1 +0.390625 0.921875 1 1 +0.390625 0.921875 1 1 +0.421875 0.921875 1 1 +0.421875 0.921875 1 1 +0.453125 0.921875 1 1 +0.453125 0.921875 1 1 +0.484375 0.921875 1 1 +0.484375 0.921875 1 1 +0.515625 0.921875 1 1 +0.515625 0.921875 1 1 +0.546875 0.921875 1 1 +0.546875 0.921875 1 1 +0.578125 0.921875 1 1 +0.578125 0.921875 1 1 +0.609375 0.921875 1 1 +0.609375 0.921875 1 1 +0.640625 0.921875 1 1 +0.640625 0.921875 1 1 +0.671875 0.921875 1 1 +0.671875 0.921875 1 1 +0.703125 0.921875 1 1 +0.703125 0.921875 1 1 +0.734375 0.921875 1 1 +0.734375 0.921875 1 1 +0.765625 0.921875 1 1 +0.765625 0.921875 1 1 +0.796875 0.921875 1 1 +0.796875 0.921875 1 1 +0.828125 0.921875 1 1 +0.828125 0.921875 1 1 +0.859375 0.921875 1 1 +0.859375 0.921875 1 1 +0.890625 0.921875 1 1 +0.890625 0.921875 1 1 +0.921875 0.921875 1 1 +0.921875 0.921875 1 1 +0.953125 0.921875 1 1 +0.953125 0.921875 1 1 +0.984375 0.921875 1 1 +0.984375 0.921875 1 1 +0.015625 0.953125 1 1 +0.015625 0.953125 1 1 +0.046875 0.953125 1 1 +0.046875 0.953125 1 1 +0.078125 0.953125 1 1 +0.078125 0.953125 1 1 +0.109375 0.953125 1 1 +0.109375 0.953125 1 1 +0.140625 0.953125 1 1 +0.140625 0.953125 1 1 +0.171875 0.953125 1 1 +0.171875 0.953125 1 1 +0.203125 0.953125 1 1 +0.203125 0.953125 1 1 +0.234375 0.953125 1 1 +0.234375 0.953125 1 1 +0.265625 0.953125 1 1 +0.265625 0.953125 1 1 +0.296875 0.953125 1 1 +0.296875 0.953125 1 1 +0.328125 0.953125 1 1 +0.328125 0.953125 1 1 +0.359375 0.953125 1 1 +0.359375 0.953125 1 1 +0.390625 0.953125 1 1 +0.390625 0.953125 1 1 +0.421875 0.953125 1 1 +0.421875 0.953125 1 1 +0.453125 0.953125 1 1 +0.453125 0.953125 1 1 +0.484375 0.953125 1 1 +0.484375 0.953125 1 1 +0.515625 0.953125 1 1 +0.515625 0.953125 1 1 +0.546875 0.953125 1 1 +0.546875 0.953125 1 1 +0.578125 0.953125 1 1 +0.578125 0.953125 1 1 +0.609375 0.953125 1 1 +0.609375 0.953125 1 1 +0.640625 0.953125 1 1 +0.640625 0.953125 1 1 +0.671875 0.953125 1 1 +0.671875 0.953125 1 1 +0.703125 0.953125 1 1 +0.703125 0.953125 1 1 +0.734375 0.953125 1 1 +0.734375 0.953125 1 1 +0.765625 0.953125 1 1 +0.765625 0.953125 1 1 +0.796875 0.953125 1 1 +0.796875 0.953125 1 1 +0.828125 0.953125 1 1 +0.828125 0.953125 1 1 +0.859375 0.953125 1 1 +0.859375 0.953125 1 1 +0.890625 0.953125 1 1 +0.890625 0.953125 1 1 +0.921875 0.953125 1 1 +0.921875 0.953125 1 1 +0.953125 0.953125 1 1 +0.953125 0.953125 1 1 +0.984375 0.953125 1 1 +0.984375 0.953125 1 1 +0.015625 0.984375 1 1 +0.015625 0.984375 1 1 +0.046875 0.984375 1 1 +0.046875 0.984375 1 1 +0.078125 0.984375 1 1 +0.078125 0.984375 1 1 +0.109375 0.984375 1 1 +0.109375 0.984375 1 1 +0.140625 0.984375 1 1 +0.140625 0.984375 1 1 +0.171875 0.984375 1 1 +0.171875 0.984375 1 1 +0.203125 0.984375 1 1 +0.203125 0.984375 1 1 +0.234375 0.984375 1 1 +0.234375 0.984375 1 1 +0.265625 0.984375 1 1 +0.265625 0.984375 1 1 +0.296875 0.984375 1 1 +0.296875 0.984375 1 1 +0.328125 0.984375 1 1 +0.328125 0.984375 1 1 +0.359375 0.984375 1 1 +0.359375 0.984375 1 1 +0.390625 0.984375 1 1 +0.390625 0.984375 1 1 +0.421875 0.984375 1 1 +0.421875 0.984375 1 1 +0.453125 0.984375 1 1 +0.453125 0.984375 1 1 +0.484375 0.984375 1 1 +0.484375 0.984375 1 1 +0.515625 0.984375 1 1 +0.515625 0.984375 1 1 +0.546875 0.984375 1 1 +0.546875 0.984375 1 1 +0.578125 0.984375 1 1 +0.578125 0.984375 1 1 +0.609375 0.984375 1 1 +0.609375 0.984375 1 1 +0.640625 0.984375 1 1 +0.640625 0.984375 1 1 +0.671875 0.984375 1 1 +0.671875 0.984375 1 1 +0.703125 0.984375 1 1 +0.703125 0.984375 1 1 +0.734375 0.984375 1 1 +0.734375 0.984375 1 1 +0.765625 0.984375 1 1 +0.765625 0.984375 1 1 +0.796875 0.984375 1 1 +0.796875 0.984375 1 1 +0.828125 0.984375 1 1 +0.828125 0.984375 1 1 +0.859375 0.984375 1 1 +0.859375 0.984375 1 1 +0.890625 0.984375 1 1 +0.890625 0.984375 1 1 +0.921875 0.984375 1 1 +0.921875 0.984375 1 1 +0.953125 0.984375 1 1 +0.953125 0.984375 1 1 +0.984375 0.984375 1 1 +0.984375 0.984375 1 1 +0.03125 0.03125 1 1 +0.03125 0.03125 1 1 +0.09375 0.03125 1 1 +0.09375 0.03125 1 1 +0.15625 0.03125 1 1 +0.15625 0.03125 1 1 +0.21875 0.03125 1 1 +0.21875 0.03125 1 1 +0.28125 0.03125 1 1 +0.28125 0.03125 1 1 +0.34375 0.03125 1 1 +0.34375 0.03125 1 1 +0.40625 0.03125 1 1 +0.40625 0.03125 1 1 +0.46875 0.03125 1 1 +0.46875 0.03125 1 1 +0.53125 0.03125 1 1 +0.53125 0.03125 1 1 +0.59375 0.03125 1 1 +0.59375 0.03125 1 1 +0.65625 0.03125 1 1 +0.65625 0.03125 1 1 +0.71875 0.03125 1 1 +0.71875 0.03125 1 1 +0.78125 0.03125 1 1 +0.78125 0.03125 1 1 +0.84375 0.03125 1 1 +0.84375 0.03125 1 1 +0.90625 0.03125 1 1 +0.90625 0.03125 1 1 +0.96875 0.03125 1 1 +0.96875 0.03125 1 1 +0.03125 0.09375 1 1 +0.03125 0.09375 1 1 +0.09375 0.09375 1 1 +0.09375 0.09375 1 1 +0.15625 0.09375 1 1 +0.15625 0.09375 1 1 +0.21875 0.09375 1 1 +0.21875 0.09375 1 1 +0.28125 0.09375 1 1 +0.28125 0.09375 1 1 +0.34375 0.09375 1 1 +0.34375 0.09375 1 1 +0.40625 0.09375 1 1 +0.40625 0.09375 1 1 +0.46875 0.09375 1 1 +0.46875 0.09375 1 1 +0.53125 0.09375 1 1 +0.53125 0.09375 1 1 +0.59375 0.09375 1 1 +0.59375 0.09375 1 1 +0.65625 0.09375 1 1 +0.65625 0.09375 1 1 +0.71875 0.09375 1 1 +0.71875 0.09375 1 1 +0.78125 0.09375 1 1 +0.78125 0.09375 1 1 +0.84375 0.09375 1 1 +0.84375 0.09375 1 1 +0.90625 0.09375 1 1 +0.90625 0.09375 1 1 +0.96875 0.09375 1 1 +0.96875 0.09375 1 1 +0.03125 0.15625 1 1 +0.03125 0.15625 1 1 +0.09375 0.15625 1 1 +0.09375 0.15625 1 1 +0.15625 0.15625 1 1 +0.15625 0.15625 1 1 +0.21875 0.15625 1 1 +0.21875 0.15625 1 1 +0.28125 0.15625 1 1 +0.28125 0.15625 1 1 +0.34375 0.15625 1 1 +0.34375 0.15625 1 1 +0.40625 0.15625 1 1 +0.40625 0.15625 1 1 +0.46875 0.15625 1 1 +0.46875 0.15625 1 1 +0.53125 0.15625 1 1 +0.53125 0.15625 1 1 +0.59375 0.15625 1 1 +0.59375 0.15625 1 1 +0.65625 0.15625 1 1 +0.65625 0.15625 1 1 +0.71875 0.15625 1 1 +0.71875 0.15625 1 1 +0.78125 0.15625 1 1 +0.78125 0.15625 1 1 +0.84375 0.15625 1 1 +0.84375 0.15625 1 1 +0.90625 0.15625 1 1 +0.90625 0.15625 1 1 +0.96875 0.15625 1 1 +0.96875 0.15625 1 1 +0.03125 0.21875 1 1 +0.03125 0.21875 1 1 +0.09375 0.21875 1 1 +0.09375 0.21875 1 1 +0.15625 0.21875 1 1 +0.15625 0.21875 1 1 +0.21875 0.21875 1 1 +0.21875 0.21875 1 1 +0.28125 0.21875 1 1 +0.28125 0.21875 1 1 +0.34375 0.21875 1 1 +0.34375 0.21875 1 1 +0.40625 0.21875 1 1 +0.40625 0.21875 1 1 +0.46875 0.21875 1 1 +0.46875 0.21875 1 1 +0.53125 0.21875 1 1 +0.53125 0.21875 1 1 +0.59375 0.21875 1 1 +0.59375 0.21875 1 1 +0.65625 0.21875 1 1 +0.65625 0.21875 1 1 +0.71875 0.21875 1 1 +0.71875 0.21875 1 1 +0.78125 0.21875 1 1 +0.78125 0.21875 1 1 +0.84375 0.21875 1 1 +0.84375 0.21875 1 1 +0.90625 0.21875 1 1 +0.90625 0.21875 1 1 +0.96875 0.21875 1 1 +0.96875 0.21875 1 1 +0.03125 0.28125 1 1 +0.03125 0.28125 1 1 +0.09375 0.28125 1 1 +0.09375 0.28125 1 1 +0.15625 0.28125 1 1 +0.15625 0.28125 1 1 +0.21875 0.28125 1 1 +0.21875 0.28125 1 1 +0.28125 0.28125 1 1 +0.28125 0.28125 1 1 +0.34375 0.28125 1 1 +0.34375 0.28125 1 1 +0.40625 0.28125 1 1 +0.40625 0.28125 1 1 +0.46875 0.28125 1 1 +0.46875 0.28125 1 1 +0.53125 0.28125 1 1 +0.53125 0.28125 1 1 +0.59375 0.28125 1 1 +0.59375 0.28125 1 1 +0.65625 0.28125 1 1 +0.65625 0.28125 1 1 +0.71875 0.28125 1 1 +0.71875 0.28125 1 1 +0.78125 0.28125 1 1 +0.78125 0.28125 1 1 +0.84375 0.28125 1 1 +0.84375 0.28125 1 1 +0.90625 0.28125 1 1 +0.90625 0.28125 1 1 +0.96875 0.28125 1 1 +0.96875 0.28125 1 1 +0.03125 0.34375 1 1 +0.03125 0.34375 1 1 +0.09375 0.34375 1 1 +0.09375 0.34375 1 1 +0.15625 0.34375 1 1 +0.15625 0.34375 1 1 +0.21875 0.34375 1 1 +0.21875 0.34375 1 1 +0.28125 0.34375 1 1 +0.28125 0.34375 1 1 +0.34375 0.34375 1 1 +0.34375 0.34375 1 1 +0.40625 0.34375 1 1 +0.40625 0.34375 1 1 +0.46875 0.34375 1 1 +0.46875 0.34375 1 1 +0.53125 0.34375 1 1 +0.53125 0.34375 1 1 +0.59375 0.34375 1 1 +0.59375 0.34375 1 1 +0.65625 0.34375 1 1 +0.65625 0.34375 1 1 +0.71875 0.34375 1 1 +0.71875 0.34375 1 1 +0.78125 0.34375 1 1 +0.78125 0.34375 1 1 +0.84375 0.34375 1 1 +0.84375 0.34375 1 1 +0.90625 0.34375 1 1 +0.90625 0.34375 1 1 +0.96875 0.34375 1 1 +0.96875 0.34375 1 1 +0.03125 0.40625 1 1 +0.03125 0.40625 1 1 +0.09375 0.40625 1 1 +0.09375 0.40625 1 1 +0.15625 0.40625 1 1 +0.15625 0.40625 1 1 +0.21875 0.40625 1 1 +0.21875 0.40625 1 1 +0.28125 0.40625 1 1 +0.28125 0.40625 1 1 +0.34375 0.40625 1 1 +0.34375 0.40625 1 1 +0.40625 0.40625 1 1 +0.40625 0.40625 1 1 +0.46875 0.40625 1 1 +0.46875 0.40625 1 1 +0.53125 0.40625 1 1 +0.53125 0.40625 1 1 +0.59375 0.40625 1 1 +0.59375 0.40625 1 1 +0.65625 0.40625 1 1 +0.65625 0.40625 1 1 +0.71875 0.40625 1 1 +0.71875 0.40625 1 1 +0.78125 0.40625 1 1 +0.78125 0.40625 1 1 +0.84375 0.40625 1 1 +0.84375 0.40625 1 1 +0.90625 0.40625 1 1 +0.90625 0.40625 1 1 +0.96875 0.40625 1 1 +0.96875 0.40625 1 1 +0.03125 0.46875 1 1 +0.03125 0.46875 1 1 +0.09375 0.46875 1 1 +0.09375 0.46875 1 1 +0.15625 0.46875 1 1 +0.15625 0.46875 1 1 +0.21875 0.46875 1 1 +0.21875 0.46875 1 1 +0.28125 0.46875 1 1 +0.28125 0.46875 1 1 +0.34375 0.46875 1 1 +0.34375 0.46875 1 1 +0.40625 0.46875 1 1 +0.40625 0.46875 1 1 +0.46875 0.46875 1 1 +0.46875 0.46875 1 1 +0.53125 0.46875 1 1 +0.53125 0.46875 1 1 +0.59375 0.46875 1 1 +0.59375 0.46875 1 1 +0.65625 0.46875 1 1 +0.65625 0.46875 1 1 +0.71875 0.46875 1 1 +0.71875 0.46875 1 1 +0.78125 0.46875 1 1 +0.78125 0.46875 1 1 +0.84375 0.46875 1 1 +0.84375 0.46875 1 1 +0.90625 0.46875 1 1 +0.90625 0.46875 1 1 +0.96875 0.46875 1 1 +0.96875 0.46875 1 1 +0.03125 0.53125 1 1 +0.03125 0.53125 1 1 +0.09375 0.53125 1 1 +0.09375 0.53125 1 1 +0.15625 0.53125 1 1 +0.15625 0.53125 1 1 +0.21875 0.53125 1 1 +0.21875 0.53125 1 1 +0.28125 0.53125 1 1 +0.28125 0.53125 1 1 +0.34375 0.53125 1 1 +0.34375 0.53125 1 1 +0.40625 0.53125 1 1 +0.40625 0.53125 1 1 +0.46875 0.53125 1 1 +0.46875 0.53125 1 1 +0.53125 0.53125 1 1 +0.53125 0.53125 1 1 +0.59375 0.53125 1 1 +0.59375 0.53125 1 1 +0.65625 0.53125 1 1 +0.65625 0.53125 1 1 +0.71875 0.53125 1 1 +0.71875 0.53125 1 1 +0.78125 0.53125 1 1 +0.78125 0.53125 1 1 +0.84375 0.53125 1 1 +0.84375 0.53125 1 1 +0.90625 0.53125 1 1 +0.90625 0.53125 1 1 +0.96875 0.53125 1 1 +0.96875 0.53125 1 1 +0.03125 0.59375 1 1 +0.03125 0.59375 1 1 +0.09375 0.59375 1 1 +0.09375 0.59375 1 1 +0.15625 0.59375 1 1 +0.15625 0.59375 1 1 +0.21875 0.59375 1 1 +0.21875 0.59375 1 1 +0.28125 0.59375 1 1 +0.28125 0.59375 1 1 +0.34375 0.59375 1 1 +0.34375 0.59375 1 1 +0.40625 0.59375 1 1 +0.40625 0.59375 1 1 +0.46875 0.59375 1 1 +0.46875 0.59375 1 1 +0.53125 0.59375 1 1 +0.53125 0.59375 1 1 +0.59375 0.59375 1 1 +0.59375 0.59375 1 1 +0.65625 0.59375 1 1 +0.65625 0.59375 1 1 +0.71875 0.59375 1 1 +0.71875 0.59375 1 1 +0.78125 0.59375 1 1 +0.78125 0.59375 1 1 +0.84375 0.59375 1 1 +0.84375 0.59375 1 1 +0.90625 0.59375 1 1 +0.90625 0.59375 1 1 +0.96875 0.59375 1 1 +0.96875 0.59375 1 1 +0.03125 0.65625 1 1 +0.03125 0.65625 1 1 +0.09375 0.65625 1 1 +0.09375 0.65625 1 1 +0.15625 0.65625 1 1 +0.15625 0.65625 1 1 +0.21875 0.65625 1 1 +0.21875 0.65625 1 1 +0.28125 0.65625 1 1 +0.28125 0.65625 1 1 +0.34375 0.65625 1 1 +0.34375 0.65625 1 1 +0.40625 0.65625 1 1 +0.40625 0.65625 1 1 +0.46875 0.65625 1 1 +0.46875 0.65625 1 1 +0.53125 0.65625 1 1 +0.53125 0.65625 1 1 +0.59375 0.65625 1 1 +0.59375 0.65625 1 1 +0.65625 0.65625 1 1 +0.65625 0.65625 1 1 +0.71875 0.65625 1 1 +0.71875 0.65625 1 1 +0.78125 0.65625 1 1 +0.78125 0.65625 1 1 +0.84375 0.65625 1 1 +0.84375 0.65625 1 1 +0.90625 0.65625 1 1 +0.90625 0.65625 1 1 +0.96875 0.65625 1 1 +0.96875 0.65625 1 1 +0.03125 0.71875 1 1 +0.03125 0.71875 1 1 +0.09375 0.71875 1 1 +0.09375 0.71875 1 1 +0.15625 0.71875 1 1 +0.15625 0.71875 1 1 +0.21875 0.71875 1 1 +0.21875 0.71875 1 1 +0.28125 0.71875 1 1 +0.28125 0.71875 1 1 +0.34375 0.71875 1 1 +0.34375 0.71875 1 1 +0.40625 0.71875 1 1 +0.40625 0.71875 1 1 +0.46875 0.71875 1 1 +0.46875 0.71875 1 1 +0.53125 0.71875 1 1 +0.53125 0.71875 1 1 +0.59375 0.71875 1 1 +0.59375 0.71875 1 1 +0.65625 0.71875 1 1 +0.65625 0.71875 1 1 +0.71875 0.71875 1 1 +0.71875 0.71875 1 1 +0.78125 0.71875 1 1 +0.78125 0.71875 1 1 +0.84375 0.71875 1 1 +0.84375 0.71875 1 1 +0.90625 0.71875 1 1 +0.90625 0.71875 1 1 +0.96875 0.71875 1 1 +0.96875 0.71875 1 1 +0.03125 0.78125 1 1 +0.03125 0.78125 1 1 +0.09375 0.78125 1 1 +0.09375 0.78125 1 1 +0.15625 0.78125 1 1 +0.15625 0.78125 1 1 +0.21875 0.78125 1 1 +0.21875 0.78125 1 1 +0.28125 0.78125 1 1 +0.28125 0.78125 1 1 +0.34375 0.78125 1 1 +0.34375 0.78125 1 1 +0.40625 0.78125 1 1 +0.40625 0.78125 1 1 +0.46875 0.78125 1 1 +0.46875 0.78125 1 1 +0.53125 0.78125 1 1 +0.53125 0.78125 1 1 +0.59375 0.78125 1 1 +0.59375 0.78125 1 1 +0.65625 0.78125 1 1 +0.65625 0.78125 1 1 +0.71875 0.78125 1 1 +0.71875 0.78125 1 1 +0.78125 0.78125 1 1 +0.78125 0.78125 1 1 +0.84375 0.78125 1 1 +0.84375 0.78125 1 1 +0.90625 0.78125 1 1 +0.90625 0.78125 1 1 +0.96875 0.78125 1 1 +0.96875 0.78125 1 1 +0.03125 0.84375 1 1 +0.03125 0.84375 1 1 +0.09375 0.84375 1 1 +0.09375 0.84375 1 1 +0.15625 0.84375 1 1 +0.15625 0.84375 1 1 +0.21875 0.84375 1 1 +0.21875 0.84375 1 1 +0.28125 0.84375 1 1 +0.28125 0.84375 1 1 +0.34375 0.84375 1 1 +0.34375 0.84375 1 1 +0.40625 0.84375 1 1 +0.40625 0.84375 1 1 +0.46875 0.84375 1 1 +0.46875 0.84375 1 1 +0.53125 0.84375 1 1 +0.53125 0.84375 1 1 +0.59375 0.84375 1 1 +0.59375 0.84375 1 1 +0.65625 0.84375 1 1 +0.65625 0.84375 1 1 +0.71875 0.84375 1 1 +0.71875 0.84375 1 1 +0.78125 0.84375 1 1 +0.78125 0.84375 1 1 +0.84375 0.84375 1 1 +0.84375 0.84375 1 1 +0.90625 0.84375 1 1 +0.90625 0.84375 1 1 +0.96875 0.84375 1 1 +0.96875 0.84375 1 1 +0.03125 0.90625 1 1 +0.03125 0.90625 1 1 +0.09375 0.90625 1 1 +0.09375 0.90625 1 1 +0.15625 0.90625 1 1 +0.15625 0.90625 1 1 +0.21875 0.90625 1 1 +0.21875 0.90625 1 1 +0.28125 0.90625 1 1 +0.28125 0.90625 1 1 +0.34375 0.90625 1 1 +0.34375 0.90625 1 1 +0.40625 0.90625 1 1 +0.40625 0.90625 1 1 +0.46875 0.90625 1 1 +0.46875 0.90625 1 1 +0.53125 0.90625 1 1 +0.53125 0.90625 1 1 +0.59375 0.90625 1 1 +0.59375 0.90625 1 1 +0.65625 0.90625 1 1 +0.65625 0.90625 1 1 +0.71875 0.90625 1 1 +0.71875 0.90625 1 1 +0.78125 0.90625 1 1 +0.78125 0.90625 1 1 +0.84375 0.90625 1 1 +0.84375 0.90625 1 1 +0.90625 0.90625 1 1 +0.90625 0.90625 1 1 +0.96875 0.90625 1 1 +0.96875 0.90625 1 1 +0.03125 0.96875 1 1 +0.03125 0.96875 1 1 +0.09375 0.96875 1 1 +0.09375 0.96875 1 1 +0.15625 0.96875 1 1 +0.15625 0.96875 1 1 +0.21875 0.96875 1 1 +0.21875 0.96875 1 1 +0.28125 0.96875 1 1 +0.28125 0.96875 1 1 +0.34375 0.96875 1 1 +0.34375 0.96875 1 1 +0.40625 0.96875 1 1 +0.40625 0.96875 1 1 +0.46875 0.96875 1 1 +0.46875 0.96875 1 1 +0.53125 0.96875 1 1 +0.53125 0.96875 1 1 +0.59375 0.96875 1 1 +0.59375 0.96875 1 1 +0.65625 0.96875 1 1 +0.65625 0.96875 1 1 +0.71875 0.96875 1 1 +0.71875 0.96875 1 1 +0.78125 0.96875 1 1 +0.78125 0.96875 1 1 +0.84375 0.96875 1 1 +0.84375 0.96875 1 1 +0.90625 0.96875 1 1 +0.90625 0.96875 1 1 +0.96875 0.96875 1 1 +0.96875 0.96875 1 1 +0.0625 0.0625 1 1 +0.0625 0.0625 1 1 +0.0625 0.0625 1 1 +0.0625 0.0625 1 1 +0.0625 0.0625 1 1 +0.0625 0.0625 1 1 +0.1875 0.0625 1 1 +0.1875 0.0625 1 1 +0.1875 0.0625 1 1 +0.1875 0.0625 1 1 +0.1875 0.0625 1 1 +0.1875 0.0625 1 1 +0.3125 0.0625 1 1 +0.3125 0.0625 1 1 +0.3125 0.0625 1 1 +0.3125 0.0625 1 1 +0.3125 0.0625 1 1 +0.3125 0.0625 1 1 +0.4375 0.0625 1 1 +0.4375 0.0625 1 1 +0.4375 0.0625 1 1 +0.4375 0.0625 1 1 +0.4375 0.0625 1 1 +0.4375 0.0625 1 1 +0.5625 0.0625 1 1 +0.5625 0.0625 1 1 +0.5625 0.0625 1 1 +0.5625 0.0625 1 1 +0.5625 0.0625 1 1 +0.5625 0.0625 1 1 +0.6875 0.0625 1 1 +0.6875 0.0625 1 1 +0.6875 0.0625 1 1 +0.6875 0.0625 1 1 +0.6875 0.0625 1 1 +0.6875 0.0625 1 1 +0.8125 0.0625 1 1 +0.8125 0.0625 1 1 +0.8125 0.0625 1 1 +0.8125 0.0625 1 1 +0.8125 0.0625 1 1 +0.8125 0.0625 1 1 +0.9375 0.0625 1 1 +0.9375 0.0625 1 1 +0.9375 0.0625 1 1 +0.9375 0.0625 1 1 +0.9375 0.0625 1 1 +0.9375 0.0625 1 1 +0.0625 0.1875 1 1 +0.0625 0.1875 1 1 +0.0625 0.1875 1 1 +0.0625 0.1875 1 1 +0.0625 0.1875 1 1 +0.0625 0.1875 1 1 +0.1875 0.1875 1 1 +0.1875 0.1875 1 1 +0.1875 0.1875 1 1 +0.1875 0.1875 1 1 +0.1875 0.1875 1 1 +0.1875 0.1875 1 1 +0.3125 0.1875 1 1 +0.3125 0.1875 1 1 +0.3125 0.1875 1 1 +0.3125 0.1875 1 1 +0.3125 0.1875 1 1 +0.3125 0.1875 1 1 +0.4375 0.1875 1 1 +0.4375 0.1875 1 1 +0.4375 0.1875 1 1 +0.4375 0.1875 1 1 +0.4375 0.1875 1 1 +0.4375 0.1875 1 1 +0.5625 0.1875 1 1 +0.5625 0.1875 1 1 +0.5625 0.1875 1 1 +0.5625 0.1875 1 1 +0.5625 0.1875 1 1 +0.5625 0.1875 1 1 +0.6875 0.1875 1 1 +0.6875 0.1875 1 1 +0.6875 0.1875 1 1 +0.6875 0.1875 1 1 +0.6875 0.1875 1 1 +0.6875 0.1875 1 1 +0.8125 0.1875 1 1 +0.8125 0.1875 1 1 +0.8125 0.1875 1 1 +0.8125 0.1875 1 1 +0.8125 0.1875 1 1 +0.8125 0.1875 1 1 +0.9375 0.1875 1 1 +0.9375 0.1875 1 1 +0.9375 0.1875 1 1 +0.9375 0.1875 1 1 +0.9375 0.1875 1 1 +0.9375 0.1875 1 1 +0.0625 0.3125 1 1 +0.0625 0.3125 1 1 +0.0625 0.3125 1 1 +0.0625 0.3125 1 1 +0.0625 0.3125 1 1 +0.0625 0.3125 1 1 +0.1875 0.3125 1 1 +0.1875 0.3125 1 1 +0.1875 0.3125 1 1 +0.1875 0.3125 1 1 +0.1875 0.3125 1 1 +0.1875 0.3125 1 1 +0.3125 0.3125 1 1 +0.3125 0.3125 1 1 +0.3125 0.3125 1 1 +0.3125 0.3125 1 1 +0.3125 0.3125 1 1 +0.3125 0.3125 1 1 +0.4375 0.3125 1 1 +0.4375 0.3125 1 1 +0.4375 0.3125 1 1 +0.4375 0.3125 1 1 +0.4375 0.3125 1 1 +0.4375 0.3125 1 1 +0.5625 0.3125 1 1 +0.5625 0.3125 1 1 +0.5625 0.3125 1 1 +0.5625 0.3125 1 1 +0.5625 0.3125 1 1 +0.5625 0.3125 1 1 +0.6875 0.3125 1 1 +0.6875 0.3125 1 1 +0.6875 0.3125 1 1 +0.6875 0.3125 1 1 +0.6875 0.3125 1 1 +0.6875 0.3125 1 1 +0.8125 0.3125 1 1 +0.8125 0.3125 1 1 +0.8125 0.3125 1 1 +0.8125 0.3125 1 1 +0.8125 0.3125 1 1 +0.8125 0.3125 1 1 +0.9375 0.3125 1 1 +0.9375 0.3125 1 1 +0.9375 0.3125 1 1 +0.9375 0.3125 1 1 +0.9375 0.3125 1 1 +0.9375 0.3125 1 1 +0.0625 0.4375 1 1 +0.0625 0.4375 1 1 +0.0625 0.4375 1 1 +0.0625 0.4375 1 1 +0.0625 0.4375 1 1 +0.0625 0.4375 1 1 +0.1875 0.4375 1 1 +0.1875 0.4375 1 1 +0.1875 0.4375 1 1 +0.1875 0.4375 1 1 +0.1875 0.4375 1 1 +0.1875 0.4375 1 1 +0.3125 0.4375 1 1 +0.3125 0.4375 1 1 +0.3125 0.4375 1 1 +0.3125 0.4375 1 1 +0.3125 0.4375 1 1 +0.3125 0.4375 1 1 +0.4375 0.4375 1 1 +0.4375 0.4375 1 1 +0.4375 0.4375 1 1 +0.4375 0.4375 1 1 +0.4375 0.4375 1 1 +0.4375 0.4375 1 1 +0.5625 0.4375 1 1 +0.5625 0.4375 1 1 +0.5625 0.4375 1 1 +0.5625 0.4375 1 1 +0.5625 0.4375 1 1 +0.5625 0.4375 1 1 +0.6875 0.4375 1 1 +0.6875 0.4375 1 1 +0.6875 0.4375 1 1 +0.6875 0.4375 1 1 +0.6875 0.4375 1 1 +0.6875 0.4375 1 1 +0.8125 0.4375 1 1 +0.8125 0.4375 1 1 +0.8125 0.4375 1 1 +0.8125 0.4375 1 1 +0.8125 0.4375 1 1 +0.8125 0.4375 1 1 +0.9375 0.4375 1 1 +0.9375 0.4375 1 1 +0.9375 0.4375 1 1 +0.9375 0.4375 1 1 +0.9375 0.4375 1 1 +0.9375 0.4375 1 1 +0.0625 0.5625 1 1 +0.0625 0.5625 1 1 +0.0625 0.5625 1 1 +0.0625 0.5625 1 1 +0.0625 0.5625 1 1 +0.0625 0.5625 1 1 +0.1875 0.5625 1 1 +0.1875 0.5625 1 1 +0.1875 0.5625 1 1 +0.1875 0.5625 1 1 +0.1875 0.5625 1 1 +0.1875 0.5625 1 1 +0.3125 0.5625 1 1 +0.3125 0.5625 1 1 +0.3125 0.5625 1 1 +0.3125 0.5625 1 1 +0.3125 0.5625 1 1 +0.3125 0.5625 1 1 +0.4375 0.5625 1 1 +0.4375 0.5625 1 1 +0.4375 0.5625 1 1 +0.4375 0.5625 1 1 +0.4375 0.5625 1 1 +0.4375 0.5625 1 1 +0.5625 0.5625 1 1 +0.5625 0.5625 1 1 +0.5625 0.5625 1 1 +0.5625 0.5625 1 1 +0.5625 0.5625 1 1 +0.5625 0.5625 1 1 +0.6875 0.5625 1 1 +0.6875 0.5625 1 1 +0.6875 0.5625 1 1 +0.6875 0.5625 1 1 +0.6875 0.5625 1 1 +0.6875 0.5625 1 1 +0.8125 0.5625 1 1 +0.8125 0.5625 1 1 +0.8125 0.5625 1 1 +0.8125 0.5625 1 1 +0.8125 0.5625 1 1 +0.8125 0.5625 1 1 +0.9375 0.5625 1 1 +0.9375 0.5625 1 1 +0.9375 0.5625 1 1 +0.9375 0.5625 1 1 +0.9375 0.5625 1 1 +0.9375 0.5625 1 1 +0.0625 0.6875 1 1 +0.0625 0.6875 1 1 +0.0625 0.6875 1 1 +0.0625 0.6875 1 1 +0.0625 0.6875 1 1 +0.0625 0.6875 1 1 +0.1875 0.6875 1 1 +0.1875 0.6875 1 1 +0.1875 0.6875 1 1 +0.1875 0.6875 1 1 +0.1875 0.6875 1 1 +0.1875 0.6875 1 1 +0.3125 0.6875 1 1 +0.3125 0.6875 1 1 +0.3125 0.6875 1 1 +0.3125 0.6875 1 1 +0.3125 0.6875 1 1 +0.3125 0.6875 1 1 +0.4375 0.6875 1 1 +0.4375 0.6875 1 1 +0.4375 0.6875 1 1 +0.4375 0.6875 1 1 +0.4375 0.6875 1 1 +0.4375 0.6875 1 1 +0.5625 0.6875 1 1 +0.5625 0.6875 1 1 +0.5625 0.6875 1 1 +0.5625 0.6875 1 1 +0.5625 0.6875 1 1 +0.5625 0.6875 1 1 +0.6875 0.6875 1 1 +0.6875 0.6875 1 1 +0.6875 0.6875 1 1 +0.6875 0.6875 1 1 +0.6875 0.6875 1 1 +0.6875 0.6875 1 1 +0.8125 0.6875 1 1 +0.8125 0.6875 1 1 +0.8125 0.6875 1 1 +0.8125 0.6875 1 1 +0.8125 0.6875 1 1 +0.8125 0.6875 1 1 +0.9375 0.6875 1 1 +0.9375 0.6875 1 1 +0.9375 0.6875 1 1 +0.9375 0.6875 1 1 +0.9375 0.6875 1 1 +0.9375 0.6875 1 1 +0.0625 0.8125 1 1 +0.0625 0.8125 1 1 +0.0625 0.8125 1 1 +0.0625 0.8125 1 1 +0.0625 0.8125 1 1 +0.0625 0.8125 1 1 +0.1875 0.8125 1 1 +0.1875 0.8125 1 1 +0.1875 0.8125 1 1 +0.1875 0.8125 1 1 +0.1875 0.8125 1 1 +0.1875 0.8125 1 1 +0.3125 0.8125 1 1 +0.3125 0.8125 1 1 +0.3125 0.8125 1 1 +0.3125 0.8125 1 1 +0.3125 0.8125 1 1 +0.3125 0.8125 1 1 +0.4375 0.8125 1 1 +0.4375 0.8125 1 1 +0.4375 0.8125 1 1 +0.4375 0.8125 1 1 +0.4375 0.8125 1 1 +0.4375 0.8125 1 1 +0.5625 0.8125 1 1 +0.5625 0.8125 1 1 +0.5625 0.8125 1 1 +0.5625 0.8125 1 1 +0.5625 0.8125 1 1 +0.5625 0.8125 1 1 +0.6875 0.8125 1 1 +0.6875 0.8125 1 1 +0.6875 0.8125 1 1 +0.6875 0.8125 1 1 +0.6875 0.8125 1 1 +0.6875 0.8125 1 1 +0.8125 0.8125 1 1 +0.8125 0.8125 1 1 +0.8125 0.8125 1 1 +0.8125 0.8125 1 1 +0.8125 0.8125 1 1 +0.8125 0.8125 1 1 +0.9375 0.8125 1 1 +0.9375 0.8125 1 1 +0.9375 0.8125 1 1 +0.9375 0.8125 1 1 +0.9375 0.8125 1 1 +0.9375 0.8125 1 1 +0.0625 0.9375 1 1 +0.0625 0.9375 1 1 +0.0625 0.9375 1 1 +0.0625 0.9375 1 1 +0.0625 0.9375 1 1 +0.0625 0.9375 1 1 +0.1875 0.9375 1 1 +0.1875 0.9375 1 1 +0.1875 0.9375 1 1 +0.1875 0.9375 1 1 +0.1875 0.9375 1 1 +0.1875 0.9375 1 1 +0.3125 0.9375 1 1 +0.3125 0.9375 1 1 +0.3125 0.9375 1 1 +0.3125 0.9375 1 1 +0.3125 0.9375 1 1 +0.3125 0.9375 1 1 +0.4375 0.9375 1 1 +0.4375 0.9375 1 1 +0.4375 0.9375 1 1 +0.4375 0.9375 1 1 +0.4375 0.9375 1 1 +0.4375 0.9375 1 1 +0.5625 0.9375 1 1 +0.5625 0.9375 1 1 +0.5625 0.9375 1 1 +0.5625 0.9375 1 1 +0.5625 0.9375 1 1 +0.5625 0.9375 1 1 +0.6875 0.9375 1 1 +0.6875 0.9375 1 1 +0.6875 0.9375 1 1 +0.6875 0.9375 1 1 +0.6875 0.9375 1 1 +0.6875 0.9375 1 1 +0.8125 0.9375 1 1 +0.8125 0.9375 1 1 +0.8125 0.9375 1 1 +0.8125 0.9375 1 1 +0.8125 0.9375 1 1 +0.8125 0.9375 1 1 +0.9375 0.9375 1 1 +0.9375 0.9375 1 1 +0.9375 0.9375 1 1 +0.9375 0.9375 1 1 +0.9375 0.9375 1 1 +0.9375 0.9375 1 1 diff --git a/mediapipe/calculators/tflite/testdata/anchor_golden_file_1.txt b/mediapipe/calculators/tflite/testdata/anchor_golden_file_1.txt new file mode 100644 index 000000000..c894e0f8d --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/anchor_golden_file_1.txt @@ -0,0 +1,1917 @@ +0.0263158 0.0263158 0.1 0.1 +0.0263158 0.0263158 0.282843 0.141421 +0.0263158 0.0263158 0.141421 0.282843 +0.0789474 0.0263158 0.1 0.1 +0.0789474 0.0263158 0.282843 0.141421 +0.0789474 0.0263158 0.141421 0.282843 +0.131579 0.0263158 0.1 0.1 +0.131579 0.0263158 0.282843 0.141421 +0.131579 0.0263158 0.141421 0.282843 +0.184211 0.0263158 0.1 0.1 +0.184211 0.0263158 0.282843 0.141421 +0.184211 0.0263158 0.141421 0.282843 +0.236842 0.0263158 0.1 0.1 +0.236842 0.0263158 0.282843 0.141421 +0.236842 0.0263158 0.141421 0.282843 +0.289474 0.0263158 0.1 0.1 +0.289474 0.0263158 0.282843 0.141421 +0.289474 0.0263158 0.141421 0.282843 +0.342105 0.0263158 0.1 0.1 +0.342105 0.0263158 0.282843 0.141421 +0.342105 0.0263158 0.141421 0.282843 +0.394737 0.0263158 0.1 0.1 +0.394737 0.0263158 0.282843 0.141421 +0.394737 0.0263158 0.141421 0.282843 +0.447368 0.0263158 0.1 0.1 +0.447368 0.0263158 0.282843 0.141421 +0.447368 0.0263158 0.141421 0.282843 +0.5 0.0263158 0.1 0.1 +0.5 0.0263158 0.282843 0.141421 +0.5 0.0263158 0.141421 0.282843 +0.552632 0.0263158 0.1 0.1 +0.552632 0.0263158 0.282843 0.141421 +0.552632 0.0263158 0.141421 0.282843 +0.605263 0.0263158 0.1 0.1 +0.605263 0.0263158 0.282843 0.141421 +0.605263 0.0263158 0.141421 0.282843 +0.657895 0.0263158 0.1 0.1 +0.657895 0.0263158 0.282843 0.141421 +0.657895 0.0263158 0.141421 0.282843 +0.710526 0.0263158 0.1 0.1 +0.710526 0.0263158 0.282843 0.141421 +0.710526 0.0263158 0.141421 0.282843 +0.763158 0.0263158 0.1 0.1 +0.763158 0.0263158 0.282843 0.141421 +0.763158 0.0263158 0.141421 0.282843 +0.81579 0.0263158 0.1 0.1 +0.81579 0.0263158 0.282843 0.141421 +0.81579 0.0263158 0.141421 0.282843 +0.868421 0.0263158 0.1 0.1 +0.868421 0.0263158 0.282843 0.141421 +0.868421 0.0263158 0.141421 0.282843 +0.921053 0.0263158 0.1 0.1 +0.921053 0.0263158 0.282843 0.141421 +0.921053 0.0263158 0.141421 0.282843 +0.973684 0.0263158 0.1 0.1 +0.973684 0.0263158 0.282843 0.141421 +0.973684 0.0263158 0.141421 0.282843 +0.0263158 0.0789474 0.1 0.1 +0.0263158 0.0789474 0.282843 0.141421 +0.0263158 0.0789474 0.141421 0.282843 +0.0789474 0.0789474 0.1 0.1 +0.0789474 0.0789474 0.282843 0.141421 +0.0789474 0.0789474 0.141421 0.282843 +0.131579 0.0789474 0.1 0.1 +0.131579 0.0789474 0.282843 0.141421 +0.131579 0.0789474 0.141421 0.282843 +0.184211 0.0789474 0.1 0.1 +0.184211 0.0789474 0.282843 0.141421 +0.184211 0.0789474 0.141421 0.282843 +0.236842 0.0789474 0.1 0.1 +0.236842 0.0789474 0.282843 0.141421 +0.236842 0.0789474 0.141421 0.282843 +0.289474 0.0789474 0.1 0.1 +0.289474 0.0789474 0.282843 0.141421 +0.289474 0.0789474 0.141421 0.282843 +0.342105 0.0789474 0.1 0.1 +0.342105 0.0789474 0.282843 0.141421 +0.342105 0.0789474 0.141421 0.282843 +0.394737 0.0789474 0.1 0.1 +0.394737 0.0789474 0.282843 0.141421 +0.394737 0.0789474 0.141421 0.282843 +0.447368 0.0789474 0.1 0.1 +0.447368 0.0789474 0.282843 0.141421 +0.447368 0.0789474 0.141421 0.282843 +0.5 0.0789474 0.1 0.1 +0.5 0.0789474 0.282843 0.141421 +0.5 0.0789474 0.141421 0.282843 +0.552632 0.0789474 0.1 0.1 +0.552632 0.0789474 0.282843 0.141421 +0.552632 0.0789474 0.141421 0.282843 +0.605263 0.0789474 0.1 0.1 +0.605263 0.0789474 0.282843 0.141421 +0.605263 0.0789474 0.141421 0.282843 +0.657895 0.0789474 0.1 0.1 +0.657895 0.0789474 0.282843 0.141421 +0.657895 0.0789474 0.141421 0.282843 +0.710526 0.0789474 0.1 0.1 +0.710526 0.0789474 0.282843 0.141421 +0.710526 0.0789474 0.141421 0.282843 +0.763158 0.0789474 0.1 0.1 +0.763158 0.0789474 0.282843 0.141421 +0.763158 0.0789474 0.141421 0.282843 +0.81579 0.0789474 0.1 0.1 +0.81579 0.0789474 0.282843 0.141421 +0.81579 0.0789474 0.141421 0.282843 +0.868421 0.0789474 0.1 0.1 +0.868421 0.0789474 0.282843 0.141421 +0.868421 0.0789474 0.141421 0.282843 +0.921053 0.0789474 0.1 0.1 +0.921053 0.0789474 0.282843 0.141421 +0.921053 0.0789474 0.141421 0.282843 +0.973684 0.0789474 0.1 0.1 +0.973684 0.0789474 0.282843 0.141421 +0.973684 0.0789474 0.141421 0.282843 +0.0263158 0.131579 0.1 0.1 +0.0263158 0.131579 0.282843 0.141421 +0.0263158 0.131579 0.141421 0.282843 +0.0789474 0.131579 0.1 0.1 +0.0789474 0.131579 0.282843 0.141421 +0.0789474 0.131579 0.141421 0.282843 +0.131579 0.131579 0.1 0.1 +0.131579 0.131579 0.282843 0.141421 +0.131579 0.131579 0.141421 0.282843 +0.184211 0.131579 0.1 0.1 +0.184211 0.131579 0.282843 0.141421 +0.184211 0.131579 0.141421 0.282843 +0.236842 0.131579 0.1 0.1 +0.236842 0.131579 0.282843 0.141421 +0.236842 0.131579 0.141421 0.282843 +0.289474 0.131579 0.1 0.1 +0.289474 0.131579 0.282843 0.141421 +0.289474 0.131579 0.141421 0.282843 +0.342105 0.131579 0.1 0.1 +0.342105 0.131579 0.282843 0.141421 +0.342105 0.131579 0.141421 0.282843 +0.394737 0.131579 0.1 0.1 +0.394737 0.131579 0.282843 0.141421 +0.394737 0.131579 0.141421 0.282843 +0.447368 0.131579 0.1 0.1 +0.447368 0.131579 0.282843 0.141421 +0.447368 0.131579 0.141421 0.282843 +0.5 0.131579 0.1 0.1 +0.5 0.131579 0.282843 0.141421 +0.5 0.131579 0.141421 0.282843 +0.552632 0.131579 0.1 0.1 +0.552632 0.131579 0.282843 0.141421 +0.552632 0.131579 0.141421 0.282843 +0.605263 0.131579 0.1 0.1 +0.605263 0.131579 0.282843 0.141421 +0.605263 0.131579 0.141421 0.282843 +0.657895 0.131579 0.1 0.1 +0.657895 0.131579 0.282843 0.141421 +0.657895 0.131579 0.141421 0.282843 +0.710526 0.131579 0.1 0.1 +0.710526 0.131579 0.282843 0.141421 +0.710526 0.131579 0.141421 0.282843 +0.763158 0.131579 0.1 0.1 +0.763158 0.131579 0.282843 0.141421 +0.763158 0.131579 0.141421 0.282843 +0.81579 0.131579 0.1 0.1 +0.81579 0.131579 0.282843 0.141421 +0.81579 0.131579 0.141421 0.282843 +0.868421 0.131579 0.1 0.1 +0.868421 0.131579 0.282843 0.141421 +0.868421 0.131579 0.141421 0.282843 +0.921053 0.131579 0.1 0.1 +0.921053 0.131579 0.282843 0.141421 +0.921053 0.131579 0.141421 0.282843 +0.973684 0.131579 0.1 0.1 +0.973684 0.131579 0.282843 0.141421 +0.973684 0.131579 0.141421 0.282843 +0.0263158 0.184211 0.1 0.1 +0.0263158 0.184211 0.282843 0.141421 +0.0263158 0.184211 0.141421 0.282843 +0.0789474 0.184211 0.1 0.1 +0.0789474 0.184211 0.282843 0.141421 +0.0789474 0.184211 0.141421 0.282843 +0.131579 0.184211 0.1 0.1 +0.131579 0.184211 0.282843 0.141421 +0.131579 0.184211 0.141421 0.282843 +0.184211 0.184211 0.1 0.1 +0.184211 0.184211 0.282843 0.141421 +0.184211 0.184211 0.141421 0.282843 +0.236842 0.184211 0.1 0.1 +0.236842 0.184211 0.282843 0.141421 +0.236842 0.184211 0.141421 0.282843 +0.289474 0.184211 0.1 0.1 +0.289474 0.184211 0.282843 0.141421 +0.289474 0.184211 0.141421 0.282843 +0.342105 0.184211 0.1 0.1 +0.342105 0.184211 0.282843 0.141421 +0.342105 0.184211 0.141421 0.282843 +0.394737 0.184211 0.1 0.1 +0.394737 0.184211 0.282843 0.141421 +0.394737 0.184211 0.141421 0.282843 +0.447368 0.184211 0.1 0.1 +0.447368 0.184211 0.282843 0.141421 +0.447368 0.184211 0.141421 0.282843 +0.5 0.184211 0.1 0.1 +0.5 0.184211 0.282843 0.141421 +0.5 0.184211 0.141421 0.282843 +0.552632 0.184211 0.1 0.1 +0.552632 0.184211 0.282843 0.141421 +0.552632 0.184211 0.141421 0.282843 +0.605263 0.184211 0.1 0.1 +0.605263 0.184211 0.282843 0.141421 +0.605263 0.184211 0.141421 0.282843 +0.657895 0.184211 0.1 0.1 +0.657895 0.184211 0.282843 0.141421 +0.657895 0.184211 0.141421 0.282843 +0.710526 0.184211 0.1 0.1 +0.710526 0.184211 0.282843 0.141421 +0.710526 0.184211 0.141421 0.282843 +0.763158 0.184211 0.1 0.1 +0.763158 0.184211 0.282843 0.141421 +0.763158 0.184211 0.141421 0.282843 +0.81579 0.184211 0.1 0.1 +0.81579 0.184211 0.282843 0.141421 +0.81579 0.184211 0.141421 0.282843 +0.868421 0.184211 0.1 0.1 +0.868421 0.184211 0.282843 0.141421 +0.868421 0.184211 0.141421 0.282843 +0.921053 0.184211 0.1 0.1 +0.921053 0.184211 0.282843 0.141421 +0.921053 0.184211 0.141421 0.282843 +0.973684 0.184211 0.1 0.1 +0.973684 0.184211 0.282843 0.141421 +0.973684 0.184211 0.141421 0.282843 +0.0263158 0.236842 0.1 0.1 +0.0263158 0.236842 0.282843 0.141421 +0.0263158 0.236842 0.141421 0.282843 +0.0789474 0.236842 0.1 0.1 +0.0789474 0.236842 0.282843 0.141421 +0.0789474 0.236842 0.141421 0.282843 +0.131579 0.236842 0.1 0.1 +0.131579 0.236842 0.282843 0.141421 +0.131579 0.236842 0.141421 0.282843 +0.184211 0.236842 0.1 0.1 +0.184211 0.236842 0.282843 0.141421 +0.184211 0.236842 0.141421 0.282843 +0.236842 0.236842 0.1 0.1 +0.236842 0.236842 0.282843 0.141421 +0.236842 0.236842 0.141421 0.282843 +0.289474 0.236842 0.1 0.1 +0.289474 0.236842 0.282843 0.141421 +0.289474 0.236842 0.141421 0.282843 +0.342105 0.236842 0.1 0.1 +0.342105 0.236842 0.282843 0.141421 +0.342105 0.236842 0.141421 0.282843 +0.394737 0.236842 0.1 0.1 +0.394737 0.236842 0.282843 0.141421 +0.394737 0.236842 0.141421 0.282843 +0.447368 0.236842 0.1 0.1 +0.447368 0.236842 0.282843 0.141421 +0.447368 0.236842 0.141421 0.282843 +0.5 0.236842 0.1 0.1 +0.5 0.236842 0.282843 0.141421 +0.5 0.236842 0.141421 0.282843 +0.552632 0.236842 0.1 0.1 +0.552632 0.236842 0.282843 0.141421 +0.552632 0.236842 0.141421 0.282843 +0.605263 0.236842 0.1 0.1 +0.605263 0.236842 0.282843 0.141421 +0.605263 0.236842 0.141421 0.282843 +0.657895 0.236842 0.1 0.1 +0.657895 0.236842 0.282843 0.141421 +0.657895 0.236842 0.141421 0.282843 +0.710526 0.236842 0.1 0.1 +0.710526 0.236842 0.282843 0.141421 +0.710526 0.236842 0.141421 0.282843 +0.763158 0.236842 0.1 0.1 +0.763158 0.236842 0.282843 0.141421 +0.763158 0.236842 0.141421 0.282843 +0.81579 0.236842 0.1 0.1 +0.81579 0.236842 0.282843 0.141421 +0.81579 0.236842 0.141421 0.282843 +0.868421 0.236842 0.1 0.1 +0.868421 0.236842 0.282843 0.141421 +0.868421 0.236842 0.141421 0.282843 +0.921053 0.236842 0.1 0.1 +0.921053 0.236842 0.282843 0.141421 +0.921053 0.236842 0.141421 0.282843 +0.973684 0.236842 0.1 0.1 +0.973684 0.236842 0.282843 0.141421 +0.973684 0.236842 0.141421 0.282843 +0.0263158 0.289474 0.1 0.1 +0.0263158 0.289474 0.282843 0.141421 +0.0263158 0.289474 0.141421 0.282843 +0.0789474 0.289474 0.1 0.1 +0.0789474 0.289474 0.282843 0.141421 +0.0789474 0.289474 0.141421 0.282843 +0.131579 0.289474 0.1 0.1 +0.131579 0.289474 0.282843 0.141421 +0.131579 0.289474 0.141421 0.282843 +0.184211 0.289474 0.1 0.1 +0.184211 0.289474 0.282843 0.141421 +0.184211 0.289474 0.141421 0.282843 +0.236842 0.289474 0.1 0.1 +0.236842 0.289474 0.282843 0.141421 +0.236842 0.289474 0.141421 0.282843 +0.289474 0.289474 0.1 0.1 +0.289474 0.289474 0.282843 0.141421 +0.289474 0.289474 0.141421 0.282843 +0.342105 0.289474 0.1 0.1 +0.342105 0.289474 0.282843 0.141421 +0.342105 0.289474 0.141421 0.282843 +0.394737 0.289474 0.1 0.1 +0.394737 0.289474 0.282843 0.141421 +0.394737 0.289474 0.141421 0.282843 +0.447368 0.289474 0.1 0.1 +0.447368 0.289474 0.282843 0.141421 +0.447368 0.289474 0.141421 0.282843 +0.5 0.289474 0.1 0.1 +0.5 0.289474 0.282843 0.141421 +0.5 0.289474 0.141421 0.282843 +0.552632 0.289474 0.1 0.1 +0.552632 0.289474 0.282843 0.141421 +0.552632 0.289474 0.141421 0.282843 +0.605263 0.289474 0.1 0.1 +0.605263 0.289474 0.282843 0.141421 +0.605263 0.289474 0.141421 0.282843 +0.657895 0.289474 0.1 0.1 +0.657895 0.289474 0.282843 0.141421 +0.657895 0.289474 0.141421 0.282843 +0.710526 0.289474 0.1 0.1 +0.710526 0.289474 0.282843 0.141421 +0.710526 0.289474 0.141421 0.282843 +0.763158 0.289474 0.1 0.1 +0.763158 0.289474 0.282843 0.141421 +0.763158 0.289474 0.141421 0.282843 +0.81579 0.289474 0.1 0.1 +0.81579 0.289474 0.282843 0.141421 +0.81579 0.289474 0.141421 0.282843 +0.868421 0.289474 0.1 0.1 +0.868421 0.289474 0.282843 0.141421 +0.868421 0.289474 0.141421 0.282843 +0.921053 0.289474 0.1 0.1 +0.921053 0.289474 0.282843 0.141421 +0.921053 0.289474 0.141421 0.282843 +0.973684 0.289474 0.1 0.1 +0.973684 0.289474 0.282843 0.141421 +0.973684 0.289474 0.141421 0.282843 +0.0263158 0.342105 0.1 0.1 +0.0263158 0.342105 0.282843 0.141421 +0.0263158 0.342105 0.141421 0.282843 +0.0789474 0.342105 0.1 0.1 +0.0789474 0.342105 0.282843 0.141421 +0.0789474 0.342105 0.141421 0.282843 +0.131579 0.342105 0.1 0.1 +0.131579 0.342105 0.282843 0.141421 +0.131579 0.342105 0.141421 0.282843 +0.184211 0.342105 0.1 0.1 +0.184211 0.342105 0.282843 0.141421 +0.184211 0.342105 0.141421 0.282843 +0.236842 0.342105 0.1 0.1 +0.236842 0.342105 0.282843 0.141421 +0.236842 0.342105 0.141421 0.282843 +0.289474 0.342105 0.1 0.1 +0.289474 0.342105 0.282843 0.141421 +0.289474 0.342105 0.141421 0.282843 +0.342105 0.342105 0.1 0.1 +0.342105 0.342105 0.282843 0.141421 +0.342105 0.342105 0.141421 0.282843 +0.394737 0.342105 0.1 0.1 +0.394737 0.342105 0.282843 0.141421 +0.394737 0.342105 0.141421 0.282843 +0.447368 0.342105 0.1 0.1 +0.447368 0.342105 0.282843 0.141421 +0.447368 0.342105 0.141421 0.282843 +0.5 0.342105 0.1 0.1 +0.5 0.342105 0.282843 0.141421 +0.5 0.342105 0.141421 0.282843 +0.552632 0.342105 0.1 0.1 +0.552632 0.342105 0.282843 0.141421 +0.552632 0.342105 0.141421 0.282843 +0.605263 0.342105 0.1 0.1 +0.605263 0.342105 0.282843 0.141421 +0.605263 0.342105 0.141421 0.282843 +0.657895 0.342105 0.1 0.1 +0.657895 0.342105 0.282843 0.141421 +0.657895 0.342105 0.141421 0.282843 +0.710526 0.342105 0.1 0.1 +0.710526 0.342105 0.282843 0.141421 +0.710526 0.342105 0.141421 0.282843 +0.763158 0.342105 0.1 0.1 +0.763158 0.342105 0.282843 0.141421 +0.763158 0.342105 0.141421 0.282843 +0.81579 0.342105 0.1 0.1 +0.81579 0.342105 0.282843 0.141421 +0.81579 0.342105 0.141421 0.282843 +0.868421 0.342105 0.1 0.1 +0.868421 0.342105 0.282843 0.141421 +0.868421 0.342105 0.141421 0.282843 +0.921053 0.342105 0.1 0.1 +0.921053 0.342105 0.282843 0.141421 +0.921053 0.342105 0.141421 0.282843 +0.973684 0.342105 0.1 0.1 +0.973684 0.342105 0.282843 0.141421 +0.973684 0.342105 0.141421 0.282843 +0.0263158 0.394737 0.1 0.1 +0.0263158 0.394737 0.282843 0.141421 +0.0263158 0.394737 0.141421 0.282843 +0.0789474 0.394737 0.1 0.1 +0.0789474 0.394737 0.282843 0.141421 +0.0789474 0.394737 0.141421 0.282843 +0.131579 0.394737 0.1 0.1 +0.131579 0.394737 0.282843 0.141421 +0.131579 0.394737 0.141421 0.282843 +0.184211 0.394737 0.1 0.1 +0.184211 0.394737 0.282843 0.141421 +0.184211 0.394737 0.141421 0.282843 +0.236842 0.394737 0.1 0.1 +0.236842 0.394737 0.282843 0.141421 +0.236842 0.394737 0.141421 0.282843 +0.289474 0.394737 0.1 0.1 +0.289474 0.394737 0.282843 0.141421 +0.289474 0.394737 0.141421 0.282843 +0.342105 0.394737 0.1 0.1 +0.342105 0.394737 0.282843 0.141421 +0.342105 0.394737 0.141421 0.282843 +0.394737 0.394737 0.1 0.1 +0.394737 0.394737 0.282843 0.141421 +0.394737 0.394737 0.141421 0.282843 +0.447368 0.394737 0.1 0.1 +0.447368 0.394737 0.282843 0.141421 +0.447368 0.394737 0.141421 0.282843 +0.5 0.394737 0.1 0.1 +0.5 0.394737 0.282843 0.141421 +0.5 0.394737 0.141421 0.282843 +0.552632 0.394737 0.1 0.1 +0.552632 0.394737 0.282843 0.141421 +0.552632 0.394737 0.141421 0.282843 +0.605263 0.394737 0.1 0.1 +0.605263 0.394737 0.282843 0.141421 +0.605263 0.394737 0.141421 0.282843 +0.657895 0.394737 0.1 0.1 +0.657895 0.394737 0.282843 0.141421 +0.657895 0.394737 0.141421 0.282843 +0.710526 0.394737 0.1 0.1 +0.710526 0.394737 0.282843 0.141421 +0.710526 0.394737 0.141421 0.282843 +0.763158 0.394737 0.1 0.1 +0.763158 0.394737 0.282843 0.141421 +0.763158 0.394737 0.141421 0.282843 +0.81579 0.394737 0.1 0.1 +0.81579 0.394737 0.282843 0.141421 +0.81579 0.394737 0.141421 0.282843 +0.868421 0.394737 0.1 0.1 +0.868421 0.394737 0.282843 0.141421 +0.868421 0.394737 0.141421 0.282843 +0.921053 0.394737 0.1 0.1 +0.921053 0.394737 0.282843 0.141421 +0.921053 0.394737 0.141421 0.282843 +0.973684 0.394737 0.1 0.1 +0.973684 0.394737 0.282843 0.141421 +0.973684 0.394737 0.141421 0.282843 +0.0263158 0.447368 0.1 0.1 +0.0263158 0.447368 0.282843 0.141421 +0.0263158 0.447368 0.141421 0.282843 +0.0789474 0.447368 0.1 0.1 +0.0789474 0.447368 0.282843 0.141421 +0.0789474 0.447368 0.141421 0.282843 +0.131579 0.447368 0.1 0.1 +0.131579 0.447368 0.282843 0.141421 +0.131579 0.447368 0.141421 0.282843 +0.184211 0.447368 0.1 0.1 +0.184211 0.447368 0.282843 0.141421 +0.184211 0.447368 0.141421 0.282843 +0.236842 0.447368 0.1 0.1 +0.236842 0.447368 0.282843 0.141421 +0.236842 0.447368 0.141421 0.282843 +0.289474 0.447368 0.1 0.1 +0.289474 0.447368 0.282843 0.141421 +0.289474 0.447368 0.141421 0.282843 +0.342105 0.447368 0.1 0.1 +0.342105 0.447368 0.282843 0.141421 +0.342105 0.447368 0.141421 0.282843 +0.394737 0.447368 0.1 0.1 +0.394737 0.447368 0.282843 0.141421 +0.394737 0.447368 0.141421 0.282843 +0.447368 0.447368 0.1 0.1 +0.447368 0.447368 0.282843 0.141421 +0.447368 0.447368 0.141421 0.282843 +0.5 0.447368 0.1 0.1 +0.5 0.447368 0.282843 0.141421 +0.5 0.447368 0.141421 0.282843 +0.552632 0.447368 0.1 0.1 +0.552632 0.447368 0.282843 0.141421 +0.552632 0.447368 0.141421 0.282843 +0.605263 0.447368 0.1 0.1 +0.605263 0.447368 0.282843 0.141421 +0.605263 0.447368 0.141421 0.282843 +0.657895 0.447368 0.1 0.1 +0.657895 0.447368 0.282843 0.141421 +0.657895 0.447368 0.141421 0.282843 +0.710526 0.447368 0.1 0.1 +0.710526 0.447368 0.282843 0.141421 +0.710526 0.447368 0.141421 0.282843 +0.763158 0.447368 0.1 0.1 +0.763158 0.447368 0.282843 0.141421 +0.763158 0.447368 0.141421 0.282843 +0.81579 0.447368 0.1 0.1 +0.81579 0.447368 0.282843 0.141421 +0.81579 0.447368 0.141421 0.282843 +0.868421 0.447368 0.1 0.1 +0.868421 0.447368 0.282843 0.141421 +0.868421 0.447368 0.141421 0.282843 +0.921053 0.447368 0.1 0.1 +0.921053 0.447368 0.282843 0.141421 +0.921053 0.447368 0.141421 0.282843 +0.973684 0.447368 0.1 0.1 +0.973684 0.447368 0.282843 0.141421 +0.973684 0.447368 0.141421 0.282843 +0.0263158 0.5 0.1 0.1 +0.0263158 0.5 0.282843 0.141421 +0.0263158 0.5 0.141421 0.282843 +0.0789474 0.5 0.1 0.1 +0.0789474 0.5 0.282843 0.141421 +0.0789474 0.5 0.141421 0.282843 +0.131579 0.5 0.1 0.1 +0.131579 0.5 0.282843 0.141421 +0.131579 0.5 0.141421 0.282843 +0.184211 0.5 0.1 0.1 +0.184211 0.5 0.282843 0.141421 +0.184211 0.5 0.141421 0.282843 +0.236842 0.5 0.1 0.1 +0.236842 0.5 0.282843 0.141421 +0.236842 0.5 0.141421 0.282843 +0.289474 0.5 0.1 0.1 +0.289474 0.5 0.282843 0.141421 +0.289474 0.5 0.141421 0.282843 +0.342105 0.5 0.1 0.1 +0.342105 0.5 0.282843 0.141421 +0.342105 0.5 0.141421 0.282843 +0.394737 0.5 0.1 0.1 +0.394737 0.5 0.282843 0.141421 +0.394737 0.5 0.141421 0.282843 +0.447368 0.5 0.1 0.1 +0.447368 0.5 0.282843 0.141421 +0.447368 0.5 0.141421 0.282843 +0.5 0.5 0.1 0.1 +0.5 0.5 0.282843 0.141421 +0.5 0.5 0.141421 0.282843 +0.552632 0.5 0.1 0.1 +0.552632 0.5 0.282843 0.141421 +0.552632 0.5 0.141421 0.282843 +0.605263 0.5 0.1 0.1 +0.605263 0.5 0.282843 0.141421 +0.605263 0.5 0.141421 0.282843 +0.657895 0.5 0.1 0.1 +0.657895 0.5 0.282843 0.141421 +0.657895 0.5 0.141421 0.282843 +0.710526 0.5 0.1 0.1 +0.710526 0.5 0.282843 0.141421 +0.710526 0.5 0.141421 0.282843 +0.763158 0.5 0.1 0.1 +0.763158 0.5 0.282843 0.141421 +0.763158 0.5 0.141421 0.282843 +0.81579 0.5 0.1 0.1 +0.81579 0.5 0.282843 0.141421 +0.81579 0.5 0.141421 0.282843 +0.868421 0.5 0.1 0.1 +0.868421 0.5 0.282843 0.141421 +0.868421 0.5 0.141421 0.282843 +0.921053 0.5 0.1 0.1 +0.921053 0.5 0.282843 0.141421 +0.921053 0.5 0.141421 0.282843 +0.973684 0.5 0.1 0.1 +0.973684 0.5 0.282843 0.141421 +0.973684 0.5 0.141421 0.282843 +0.0263158 0.552632 0.1 0.1 +0.0263158 0.552632 0.282843 0.141421 +0.0263158 0.552632 0.141421 0.282843 +0.0789474 0.552632 0.1 0.1 +0.0789474 0.552632 0.282843 0.141421 +0.0789474 0.552632 0.141421 0.282843 +0.131579 0.552632 0.1 0.1 +0.131579 0.552632 0.282843 0.141421 +0.131579 0.552632 0.141421 0.282843 +0.184211 0.552632 0.1 0.1 +0.184211 0.552632 0.282843 0.141421 +0.184211 0.552632 0.141421 0.282843 +0.236842 0.552632 0.1 0.1 +0.236842 0.552632 0.282843 0.141421 +0.236842 0.552632 0.141421 0.282843 +0.289474 0.552632 0.1 0.1 +0.289474 0.552632 0.282843 0.141421 +0.289474 0.552632 0.141421 0.282843 +0.342105 0.552632 0.1 0.1 +0.342105 0.552632 0.282843 0.141421 +0.342105 0.552632 0.141421 0.282843 +0.394737 0.552632 0.1 0.1 +0.394737 0.552632 0.282843 0.141421 +0.394737 0.552632 0.141421 0.282843 +0.447368 0.552632 0.1 0.1 +0.447368 0.552632 0.282843 0.141421 +0.447368 0.552632 0.141421 0.282843 +0.5 0.552632 0.1 0.1 +0.5 0.552632 0.282843 0.141421 +0.5 0.552632 0.141421 0.282843 +0.552632 0.552632 0.1 0.1 +0.552632 0.552632 0.282843 0.141421 +0.552632 0.552632 0.141421 0.282843 +0.605263 0.552632 0.1 0.1 +0.605263 0.552632 0.282843 0.141421 +0.605263 0.552632 0.141421 0.282843 +0.657895 0.552632 0.1 0.1 +0.657895 0.552632 0.282843 0.141421 +0.657895 0.552632 0.141421 0.282843 +0.710526 0.552632 0.1 0.1 +0.710526 0.552632 0.282843 0.141421 +0.710526 0.552632 0.141421 0.282843 +0.763158 0.552632 0.1 0.1 +0.763158 0.552632 0.282843 0.141421 +0.763158 0.552632 0.141421 0.282843 +0.81579 0.552632 0.1 0.1 +0.81579 0.552632 0.282843 0.141421 +0.81579 0.552632 0.141421 0.282843 +0.868421 0.552632 0.1 0.1 +0.868421 0.552632 0.282843 0.141421 +0.868421 0.552632 0.141421 0.282843 +0.921053 0.552632 0.1 0.1 +0.921053 0.552632 0.282843 0.141421 +0.921053 0.552632 0.141421 0.282843 +0.973684 0.552632 0.1 0.1 +0.973684 0.552632 0.282843 0.141421 +0.973684 0.552632 0.141421 0.282843 +0.0263158 0.605263 0.1 0.1 +0.0263158 0.605263 0.282843 0.141421 +0.0263158 0.605263 0.141421 0.282843 +0.0789474 0.605263 0.1 0.1 +0.0789474 0.605263 0.282843 0.141421 +0.0789474 0.605263 0.141421 0.282843 +0.131579 0.605263 0.1 0.1 +0.131579 0.605263 0.282843 0.141421 +0.131579 0.605263 0.141421 0.282843 +0.184211 0.605263 0.1 0.1 +0.184211 0.605263 0.282843 0.141421 +0.184211 0.605263 0.141421 0.282843 +0.236842 0.605263 0.1 0.1 +0.236842 0.605263 0.282843 0.141421 +0.236842 0.605263 0.141421 0.282843 +0.289474 0.605263 0.1 0.1 +0.289474 0.605263 0.282843 0.141421 +0.289474 0.605263 0.141421 0.282843 +0.342105 0.605263 0.1 0.1 +0.342105 0.605263 0.282843 0.141421 +0.342105 0.605263 0.141421 0.282843 +0.394737 0.605263 0.1 0.1 +0.394737 0.605263 0.282843 0.141421 +0.394737 0.605263 0.141421 0.282843 +0.447368 0.605263 0.1 0.1 +0.447368 0.605263 0.282843 0.141421 +0.447368 0.605263 0.141421 0.282843 +0.5 0.605263 0.1 0.1 +0.5 0.605263 0.282843 0.141421 +0.5 0.605263 0.141421 0.282843 +0.552632 0.605263 0.1 0.1 +0.552632 0.605263 0.282843 0.141421 +0.552632 0.605263 0.141421 0.282843 +0.605263 0.605263 0.1 0.1 +0.605263 0.605263 0.282843 0.141421 +0.605263 0.605263 0.141421 0.282843 +0.657895 0.605263 0.1 0.1 +0.657895 0.605263 0.282843 0.141421 +0.657895 0.605263 0.141421 0.282843 +0.710526 0.605263 0.1 0.1 +0.710526 0.605263 0.282843 0.141421 +0.710526 0.605263 0.141421 0.282843 +0.763158 0.605263 0.1 0.1 +0.763158 0.605263 0.282843 0.141421 +0.763158 0.605263 0.141421 0.282843 +0.81579 0.605263 0.1 0.1 +0.81579 0.605263 0.282843 0.141421 +0.81579 0.605263 0.141421 0.282843 +0.868421 0.605263 0.1 0.1 +0.868421 0.605263 0.282843 0.141421 +0.868421 0.605263 0.141421 0.282843 +0.921053 0.605263 0.1 0.1 +0.921053 0.605263 0.282843 0.141421 +0.921053 0.605263 0.141421 0.282843 +0.973684 0.605263 0.1 0.1 +0.973684 0.605263 0.282843 0.141421 +0.973684 0.605263 0.141421 0.282843 +0.0263158 0.657895 0.1 0.1 +0.0263158 0.657895 0.282843 0.141421 +0.0263158 0.657895 0.141421 0.282843 +0.0789474 0.657895 0.1 0.1 +0.0789474 0.657895 0.282843 0.141421 +0.0789474 0.657895 0.141421 0.282843 +0.131579 0.657895 0.1 0.1 +0.131579 0.657895 0.282843 0.141421 +0.131579 0.657895 0.141421 0.282843 +0.184211 0.657895 0.1 0.1 +0.184211 0.657895 0.282843 0.141421 +0.184211 0.657895 0.141421 0.282843 +0.236842 0.657895 0.1 0.1 +0.236842 0.657895 0.282843 0.141421 +0.236842 0.657895 0.141421 0.282843 +0.289474 0.657895 0.1 0.1 +0.289474 0.657895 0.282843 0.141421 +0.289474 0.657895 0.141421 0.282843 +0.342105 0.657895 0.1 0.1 +0.342105 0.657895 0.282843 0.141421 +0.342105 0.657895 0.141421 0.282843 +0.394737 0.657895 0.1 0.1 +0.394737 0.657895 0.282843 0.141421 +0.394737 0.657895 0.141421 0.282843 +0.447368 0.657895 0.1 0.1 +0.447368 0.657895 0.282843 0.141421 +0.447368 0.657895 0.141421 0.282843 +0.5 0.657895 0.1 0.1 +0.5 0.657895 0.282843 0.141421 +0.5 0.657895 0.141421 0.282843 +0.552632 0.657895 0.1 0.1 +0.552632 0.657895 0.282843 0.141421 +0.552632 0.657895 0.141421 0.282843 +0.605263 0.657895 0.1 0.1 +0.605263 0.657895 0.282843 0.141421 +0.605263 0.657895 0.141421 0.282843 +0.657895 0.657895 0.1 0.1 +0.657895 0.657895 0.282843 0.141421 +0.657895 0.657895 0.141421 0.282843 +0.710526 0.657895 0.1 0.1 +0.710526 0.657895 0.282843 0.141421 +0.710526 0.657895 0.141421 0.282843 +0.763158 0.657895 0.1 0.1 +0.763158 0.657895 0.282843 0.141421 +0.763158 0.657895 0.141421 0.282843 +0.81579 0.657895 0.1 0.1 +0.81579 0.657895 0.282843 0.141421 +0.81579 0.657895 0.141421 0.282843 +0.868421 0.657895 0.1 0.1 +0.868421 0.657895 0.282843 0.141421 +0.868421 0.657895 0.141421 0.282843 +0.921053 0.657895 0.1 0.1 +0.921053 0.657895 0.282843 0.141421 +0.921053 0.657895 0.141421 0.282843 +0.973684 0.657895 0.1 0.1 +0.973684 0.657895 0.282843 0.141421 +0.973684 0.657895 0.141421 0.282843 +0.0263158 0.710526 0.1 0.1 +0.0263158 0.710526 0.282843 0.141421 +0.0263158 0.710526 0.141421 0.282843 +0.0789474 0.710526 0.1 0.1 +0.0789474 0.710526 0.282843 0.141421 +0.0789474 0.710526 0.141421 0.282843 +0.131579 0.710526 0.1 0.1 +0.131579 0.710526 0.282843 0.141421 +0.131579 0.710526 0.141421 0.282843 +0.184211 0.710526 0.1 0.1 +0.184211 0.710526 0.282843 0.141421 +0.184211 0.710526 0.141421 0.282843 +0.236842 0.710526 0.1 0.1 +0.236842 0.710526 0.282843 0.141421 +0.236842 0.710526 0.141421 0.282843 +0.289474 0.710526 0.1 0.1 +0.289474 0.710526 0.282843 0.141421 +0.289474 0.710526 0.141421 0.282843 +0.342105 0.710526 0.1 0.1 +0.342105 0.710526 0.282843 0.141421 +0.342105 0.710526 0.141421 0.282843 +0.394737 0.710526 0.1 0.1 +0.394737 0.710526 0.282843 0.141421 +0.394737 0.710526 0.141421 0.282843 +0.447368 0.710526 0.1 0.1 +0.447368 0.710526 0.282843 0.141421 +0.447368 0.710526 0.141421 0.282843 +0.5 0.710526 0.1 0.1 +0.5 0.710526 0.282843 0.141421 +0.5 0.710526 0.141421 0.282843 +0.552632 0.710526 0.1 0.1 +0.552632 0.710526 0.282843 0.141421 +0.552632 0.710526 0.141421 0.282843 +0.605263 0.710526 0.1 0.1 +0.605263 0.710526 0.282843 0.141421 +0.605263 0.710526 0.141421 0.282843 +0.657895 0.710526 0.1 0.1 +0.657895 0.710526 0.282843 0.141421 +0.657895 0.710526 0.141421 0.282843 +0.710526 0.710526 0.1 0.1 +0.710526 0.710526 0.282843 0.141421 +0.710526 0.710526 0.141421 0.282843 +0.763158 0.710526 0.1 0.1 +0.763158 0.710526 0.282843 0.141421 +0.763158 0.710526 0.141421 0.282843 +0.81579 0.710526 0.1 0.1 +0.81579 0.710526 0.282843 0.141421 +0.81579 0.710526 0.141421 0.282843 +0.868421 0.710526 0.1 0.1 +0.868421 0.710526 0.282843 0.141421 +0.868421 0.710526 0.141421 0.282843 +0.921053 0.710526 0.1 0.1 +0.921053 0.710526 0.282843 0.141421 +0.921053 0.710526 0.141421 0.282843 +0.973684 0.710526 0.1 0.1 +0.973684 0.710526 0.282843 0.141421 +0.973684 0.710526 0.141421 0.282843 +0.0263158 0.763158 0.1 0.1 +0.0263158 0.763158 0.282843 0.141421 +0.0263158 0.763158 0.141421 0.282843 +0.0789474 0.763158 0.1 0.1 +0.0789474 0.763158 0.282843 0.141421 +0.0789474 0.763158 0.141421 0.282843 +0.131579 0.763158 0.1 0.1 +0.131579 0.763158 0.282843 0.141421 +0.131579 0.763158 0.141421 0.282843 +0.184211 0.763158 0.1 0.1 +0.184211 0.763158 0.282843 0.141421 +0.184211 0.763158 0.141421 0.282843 +0.236842 0.763158 0.1 0.1 +0.236842 0.763158 0.282843 0.141421 +0.236842 0.763158 0.141421 0.282843 +0.289474 0.763158 0.1 0.1 +0.289474 0.763158 0.282843 0.141421 +0.289474 0.763158 0.141421 0.282843 +0.342105 0.763158 0.1 0.1 +0.342105 0.763158 0.282843 0.141421 +0.342105 0.763158 0.141421 0.282843 +0.394737 0.763158 0.1 0.1 +0.394737 0.763158 0.282843 0.141421 +0.394737 0.763158 0.141421 0.282843 +0.447368 0.763158 0.1 0.1 +0.447368 0.763158 0.282843 0.141421 +0.447368 0.763158 0.141421 0.282843 +0.5 0.763158 0.1 0.1 +0.5 0.763158 0.282843 0.141421 +0.5 0.763158 0.141421 0.282843 +0.552632 0.763158 0.1 0.1 +0.552632 0.763158 0.282843 0.141421 +0.552632 0.763158 0.141421 0.282843 +0.605263 0.763158 0.1 0.1 +0.605263 0.763158 0.282843 0.141421 +0.605263 0.763158 0.141421 0.282843 +0.657895 0.763158 0.1 0.1 +0.657895 0.763158 0.282843 0.141421 +0.657895 0.763158 0.141421 0.282843 +0.710526 0.763158 0.1 0.1 +0.710526 0.763158 0.282843 0.141421 +0.710526 0.763158 0.141421 0.282843 +0.763158 0.763158 0.1 0.1 +0.763158 0.763158 0.282843 0.141421 +0.763158 0.763158 0.141421 0.282843 +0.81579 0.763158 0.1 0.1 +0.81579 0.763158 0.282843 0.141421 +0.81579 0.763158 0.141421 0.282843 +0.868421 0.763158 0.1 0.1 +0.868421 0.763158 0.282843 0.141421 +0.868421 0.763158 0.141421 0.282843 +0.921053 0.763158 0.1 0.1 +0.921053 0.763158 0.282843 0.141421 +0.921053 0.763158 0.141421 0.282843 +0.973684 0.763158 0.1 0.1 +0.973684 0.763158 0.282843 0.141421 +0.973684 0.763158 0.141421 0.282843 +0.0263158 0.81579 0.1 0.1 +0.0263158 0.81579 0.282843 0.141421 +0.0263158 0.81579 0.141421 0.282843 +0.0789474 0.81579 0.1 0.1 +0.0789474 0.81579 0.282843 0.141421 +0.0789474 0.81579 0.141421 0.282843 +0.131579 0.81579 0.1 0.1 +0.131579 0.81579 0.282843 0.141421 +0.131579 0.81579 0.141421 0.282843 +0.184211 0.81579 0.1 0.1 +0.184211 0.81579 0.282843 0.141421 +0.184211 0.81579 0.141421 0.282843 +0.236842 0.81579 0.1 0.1 +0.236842 0.81579 0.282843 0.141421 +0.236842 0.81579 0.141421 0.282843 +0.289474 0.81579 0.1 0.1 +0.289474 0.81579 0.282843 0.141421 +0.289474 0.81579 0.141421 0.282843 +0.342105 0.81579 0.1 0.1 +0.342105 0.81579 0.282843 0.141421 +0.342105 0.81579 0.141421 0.282843 +0.394737 0.81579 0.1 0.1 +0.394737 0.81579 0.282843 0.141421 +0.394737 0.81579 0.141421 0.282843 +0.447368 0.81579 0.1 0.1 +0.447368 0.81579 0.282843 0.141421 +0.447368 0.81579 0.141421 0.282843 +0.5 0.81579 0.1 0.1 +0.5 0.81579 0.282843 0.141421 +0.5 0.81579 0.141421 0.282843 +0.552632 0.81579 0.1 0.1 +0.552632 0.81579 0.282843 0.141421 +0.552632 0.81579 0.141421 0.282843 +0.605263 0.81579 0.1 0.1 +0.605263 0.81579 0.282843 0.141421 +0.605263 0.81579 0.141421 0.282843 +0.657895 0.81579 0.1 0.1 +0.657895 0.81579 0.282843 0.141421 +0.657895 0.81579 0.141421 0.282843 +0.710526 0.81579 0.1 0.1 +0.710526 0.81579 0.282843 0.141421 +0.710526 0.81579 0.141421 0.282843 +0.763158 0.81579 0.1 0.1 +0.763158 0.81579 0.282843 0.141421 +0.763158 0.81579 0.141421 0.282843 +0.81579 0.81579 0.1 0.1 +0.81579 0.81579 0.282843 0.141421 +0.81579 0.81579 0.141421 0.282843 +0.868421 0.81579 0.1 0.1 +0.868421 0.81579 0.282843 0.141421 +0.868421 0.81579 0.141421 0.282843 +0.921053 0.81579 0.1 0.1 +0.921053 0.81579 0.282843 0.141421 +0.921053 0.81579 0.141421 0.282843 +0.973684 0.81579 0.1 0.1 +0.973684 0.81579 0.282843 0.141421 +0.973684 0.81579 0.141421 0.282843 +0.0263158 0.868421 0.1 0.1 +0.0263158 0.868421 0.282843 0.141421 +0.0263158 0.868421 0.141421 0.282843 +0.0789474 0.868421 0.1 0.1 +0.0789474 0.868421 0.282843 0.141421 +0.0789474 0.868421 0.141421 0.282843 +0.131579 0.868421 0.1 0.1 +0.131579 0.868421 0.282843 0.141421 +0.131579 0.868421 0.141421 0.282843 +0.184211 0.868421 0.1 0.1 +0.184211 0.868421 0.282843 0.141421 +0.184211 0.868421 0.141421 0.282843 +0.236842 0.868421 0.1 0.1 +0.236842 0.868421 0.282843 0.141421 +0.236842 0.868421 0.141421 0.282843 +0.289474 0.868421 0.1 0.1 +0.289474 0.868421 0.282843 0.141421 +0.289474 0.868421 0.141421 0.282843 +0.342105 0.868421 0.1 0.1 +0.342105 0.868421 0.282843 0.141421 +0.342105 0.868421 0.141421 0.282843 +0.394737 0.868421 0.1 0.1 +0.394737 0.868421 0.282843 0.141421 +0.394737 0.868421 0.141421 0.282843 +0.447368 0.868421 0.1 0.1 +0.447368 0.868421 0.282843 0.141421 +0.447368 0.868421 0.141421 0.282843 +0.5 0.868421 0.1 0.1 +0.5 0.868421 0.282843 0.141421 +0.5 0.868421 0.141421 0.282843 +0.552632 0.868421 0.1 0.1 +0.552632 0.868421 0.282843 0.141421 +0.552632 0.868421 0.141421 0.282843 +0.605263 0.868421 0.1 0.1 +0.605263 0.868421 0.282843 0.141421 +0.605263 0.868421 0.141421 0.282843 +0.657895 0.868421 0.1 0.1 +0.657895 0.868421 0.282843 0.141421 +0.657895 0.868421 0.141421 0.282843 +0.710526 0.868421 0.1 0.1 +0.710526 0.868421 0.282843 0.141421 +0.710526 0.868421 0.141421 0.282843 +0.763158 0.868421 0.1 0.1 +0.763158 0.868421 0.282843 0.141421 +0.763158 0.868421 0.141421 0.282843 +0.81579 0.868421 0.1 0.1 +0.81579 0.868421 0.282843 0.141421 +0.81579 0.868421 0.141421 0.282843 +0.868421 0.868421 0.1 0.1 +0.868421 0.868421 0.282843 0.141421 +0.868421 0.868421 0.141421 0.282843 +0.921053 0.868421 0.1 0.1 +0.921053 0.868421 0.282843 0.141421 +0.921053 0.868421 0.141421 0.282843 +0.973684 0.868421 0.1 0.1 +0.973684 0.868421 0.282843 0.141421 +0.973684 0.868421 0.141421 0.282843 +0.0263158 0.921053 0.1 0.1 +0.0263158 0.921053 0.282843 0.141421 +0.0263158 0.921053 0.141421 0.282843 +0.0789474 0.921053 0.1 0.1 +0.0789474 0.921053 0.282843 0.141421 +0.0789474 0.921053 0.141421 0.282843 +0.131579 0.921053 0.1 0.1 +0.131579 0.921053 0.282843 0.141421 +0.131579 0.921053 0.141421 0.282843 +0.184211 0.921053 0.1 0.1 +0.184211 0.921053 0.282843 0.141421 +0.184211 0.921053 0.141421 0.282843 +0.236842 0.921053 0.1 0.1 +0.236842 0.921053 0.282843 0.141421 +0.236842 0.921053 0.141421 0.282843 +0.289474 0.921053 0.1 0.1 +0.289474 0.921053 0.282843 0.141421 +0.289474 0.921053 0.141421 0.282843 +0.342105 0.921053 0.1 0.1 +0.342105 0.921053 0.282843 0.141421 +0.342105 0.921053 0.141421 0.282843 +0.394737 0.921053 0.1 0.1 +0.394737 0.921053 0.282843 0.141421 +0.394737 0.921053 0.141421 0.282843 +0.447368 0.921053 0.1 0.1 +0.447368 0.921053 0.282843 0.141421 +0.447368 0.921053 0.141421 0.282843 +0.5 0.921053 0.1 0.1 +0.5 0.921053 0.282843 0.141421 +0.5 0.921053 0.141421 0.282843 +0.552632 0.921053 0.1 0.1 +0.552632 0.921053 0.282843 0.141421 +0.552632 0.921053 0.141421 0.282843 +0.605263 0.921053 0.1 0.1 +0.605263 0.921053 0.282843 0.141421 +0.605263 0.921053 0.141421 0.282843 +0.657895 0.921053 0.1 0.1 +0.657895 0.921053 0.282843 0.141421 +0.657895 0.921053 0.141421 0.282843 +0.710526 0.921053 0.1 0.1 +0.710526 0.921053 0.282843 0.141421 +0.710526 0.921053 0.141421 0.282843 +0.763158 0.921053 0.1 0.1 +0.763158 0.921053 0.282843 0.141421 +0.763158 0.921053 0.141421 0.282843 +0.81579 0.921053 0.1 0.1 +0.81579 0.921053 0.282843 0.141421 +0.81579 0.921053 0.141421 0.282843 +0.868421 0.921053 0.1 0.1 +0.868421 0.921053 0.282843 0.141421 +0.868421 0.921053 0.141421 0.282843 +0.921053 0.921053 0.1 0.1 +0.921053 0.921053 0.282843 0.141421 +0.921053 0.921053 0.141421 0.282843 +0.973684 0.921053 0.1 0.1 +0.973684 0.921053 0.282843 0.141421 +0.973684 0.921053 0.141421 0.282843 +0.0263158 0.973684 0.1 0.1 +0.0263158 0.973684 0.282843 0.141421 +0.0263158 0.973684 0.141421 0.282843 +0.0789474 0.973684 0.1 0.1 +0.0789474 0.973684 0.282843 0.141421 +0.0789474 0.973684 0.141421 0.282843 +0.131579 0.973684 0.1 0.1 +0.131579 0.973684 0.282843 0.141421 +0.131579 0.973684 0.141421 0.282843 +0.184211 0.973684 0.1 0.1 +0.184211 0.973684 0.282843 0.141421 +0.184211 0.973684 0.141421 0.282843 +0.236842 0.973684 0.1 0.1 +0.236842 0.973684 0.282843 0.141421 +0.236842 0.973684 0.141421 0.282843 +0.289474 0.973684 0.1 0.1 +0.289474 0.973684 0.282843 0.141421 +0.289474 0.973684 0.141421 0.282843 +0.342105 0.973684 0.1 0.1 +0.342105 0.973684 0.282843 0.141421 +0.342105 0.973684 0.141421 0.282843 +0.394737 0.973684 0.1 0.1 +0.394737 0.973684 0.282843 0.141421 +0.394737 0.973684 0.141421 0.282843 +0.447368 0.973684 0.1 0.1 +0.447368 0.973684 0.282843 0.141421 +0.447368 0.973684 0.141421 0.282843 +0.5 0.973684 0.1 0.1 +0.5 0.973684 0.282843 0.141421 +0.5 0.973684 0.141421 0.282843 +0.552632 0.973684 0.1 0.1 +0.552632 0.973684 0.282843 0.141421 +0.552632 0.973684 0.141421 0.282843 +0.605263 0.973684 0.1 0.1 +0.605263 0.973684 0.282843 0.141421 +0.605263 0.973684 0.141421 0.282843 +0.657895 0.973684 0.1 0.1 +0.657895 0.973684 0.282843 0.141421 +0.657895 0.973684 0.141421 0.282843 +0.710526 0.973684 0.1 0.1 +0.710526 0.973684 0.282843 0.141421 +0.710526 0.973684 0.141421 0.282843 +0.763158 0.973684 0.1 0.1 +0.763158 0.973684 0.282843 0.141421 +0.763158 0.973684 0.141421 0.282843 +0.81579 0.973684 0.1 0.1 +0.81579 0.973684 0.282843 0.141421 +0.81579 0.973684 0.141421 0.282843 +0.868421 0.973684 0.1 0.1 +0.868421 0.973684 0.282843 0.141421 +0.868421 0.973684 0.141421 0.282843 +0.921053 0.973684 0.1 0.1 +0.921053 0.973684 0.282843 0.141421 +0.921053 0.973684 0.141421 0.282843 +0.973684 0.973684 0.1 0.1 +0.973684 0.973684 0.282843 0.141421 +0.973684 0.973684 0.141421 0.282843 +0.05 0.05 0.35 0.35 +0.05 0.05 0.494975 0.247487 +0.05 0.05 0.247487 0.494975 +0.05 0.05 0.606218 0.202073 +0.05 0.05 0.202062 0.606248 +0.05 0.05 0.41833 0.41833 +0.15 0.05 0.35 0.35 +0.15 0.05 0.494975 0.247487 +0.15 0.05 0.247487 0.494975 +0.15 0.05 0.606218 0.202073 +0.15 0.05 0.202062 0.606248 +0.15 0.05 0.41833 0.41833 +0.25 0.05 0.35 0.35 +0.25 0.05 0.494975 0.247487 +0.25 0.05 0.247487 0.494975 +0.25 0.05 0.606218 0.202073 +0.25 0.05 0.202062 0.606248 +0.25 0.05 0.41833 0.41833 +0.35 0.05 0.35 0.35 +0.35 0.05 0.494975 0.247487 +0.35 0.05 0.247487 0.494975 +0.35 0.05 0.606218 0.202073 +0.35 0.05 0.202062 0.606248 +0.35 0.05 0.41833 0.41833 +0.45 0.05 0.35 0.35 +0.45 0.05 0.494975 0.247487 +0.45 0.05 0.247487 0.494975 +0.45 0.05 0.606218 0.202073 +0.45 0.05 0.202062 0.606248 +0.45 0.05 0.41833 0.41833 +0.55 0.05 0.35 0.35 +0.55 0.05 0.494975 0.247487 +0.55 0.05 0.247487 0.494975 +0.55 0.05 0.606218 0.202073 +0.55 0.05 0.202062 0.606248 +0.55 0.05 0.41833 0.41833 +0.65 0.05 0.35 0.35 +0.65 0.05 0.494975 0.247487 +0.65 0.05 0.247487 0.494975 +0.65 0.05 0.606218 0.202073 +0.65 0.05 0.202062 0.606248 +0.65 0.05 0.41833 0.41833 +0.75 0.05 0.35 0.35 +0.75 0.05 0.494975 0.247487 +0.75 0.05 0.247487 0.494975 +0.75 0.05 0.606218 0.202073 +0.75 0.05 0.202062 0.606248 +0.75 0.05 0.41833 0.41833 +0.85 0.05 0.35 0.35 +0.85 0.05 0.494975 0.247487 +0.85 0.05 0.247487 0.494975 +0.85 0.05 0.606218 0.202073 +0.85 0.05 0.202062 0.606248 +0.85 0.05 0.41833 0.41833 +0.95 0.05 0.35 0.35 +0.95 0.05 0.494975 0.247487 +0.95 0.05 0.247487 0.494975 +0.95 0.05 0.606218 0.202073 +0.95 0.05 0.202063 0.606248 +0.95 0.05 0.41833 0.41833 +0.05 0.15 0.35 0.35 +0.05 0.15 0.494975 0.247487 +0.05 0.15 0.247487 0.494975 +0.05 0.15 0.606218 0.202073 +0.05 0.15 0.202062 0.606248 +0.05 0.15 0.41833 0.41833 +0.15 0.15 0.35 0.35 +0.15 0.15 0.494975 0.247487 +0.15 0.15 0.247487 0.494975 +0.15 0.15 0.606218 0.202073 +0.15 0.15 0.202062 0.606248 +0.15 0.15 0.41833 0.41833 +0.25 0.15 0.35 0.35 +0.25 0.15 0.494975 0.247487 +0.25 0.15 0.247487 0.494975 +0.25 0.15 0.606218 0.202073 +0.25 0.15 0.202062 0.606248 +0.25 0.15 0.41833 0.41833 +0.35 0.15 0.35 0.35 +0.35 0.15 0.494975 0.247487 +0.35 0.15 0.247487 0.494975 +0.35 0.15 0.606218 0.202073 +0.35 0.15 0.202062 0.606248 +0.35 0.15 0.41833 0.41833 +0.45 0.15 0.35 0.35 +0.45 0.15 0.494975 0.247487 +0.45 0.15 0.247487 0.494975 +0.45 0.15 0.606218 0.202073 +0.45 0.15 0.202062 0.606248 +0.45 0.15 0.41833 0.41833 +0.55 0.15 0.35 0.35 +0.55 0.15 0.494975 0.247487 +0.55 0.15 0.247487 0.494975 +0.55 0.15 0.606218 0.202073 +0.55 0.15 0.202062 0.606248 +0.55 0.15 0.41833 0.41833 +0.65 0.15 0.35 0.35 +0.65 0.15 0.494975 0.247487 +0.65 0.15 0.247487 0.494975 +0.65 0.15 0.606218 0.202073 +0.65 0.15 0.202062 0.606248 +0.65 0.15 0.41833 0.41833 +0.75 0.15 0.35 0.35 +0.75 0.15 0.494975 0.247487 +0.75 0.15 0.247487 0.494975 +0.75 0.15 0.606218 0.202073 +0.75 0.15 0.202062 0.606248 +0.75 0.15 0.41833 0.41833 +0.85 0.15 0.35 0.35 +0.85 0.15 0.494975 0.247487 +0.85 0.15 0.247487 0.494975 +0.85 0.15 0.606218 0.202073 +0.85 0.15 0.202062 0.606248 +0.85 0.15 0.41833 0.41833 +0.95 0.15 0.35 0.35 +0.95 0.15 0.494975 0.247487 +0.95 0.15 0.247487 0.494975 +0.95 0.15 0.606218 0.202073 +0.95 0.15 0.202063 0.606248 +0.95 0.15 0.41833 0.41833 +0.05 0.25 0.35 0.35 +0.05 0.25 0.494975 0.247487 +0.05 0.25 0.247487 0.494975 +0.05 0.25 0.606218 0.202073 +0.05 0.25 0.202062 0.606248 +0.05 0.25 0.41833 0.41833 +0.15 0.25 0.35 0.35 +0.15 0.25 0.494975 0.247487 +0.15 0.25 0.247487 0.494975 +0.15 0.25 0.606218 0.202073 +0.15 0.25 0.202062 0.606248 +0.15 0.25 0.41833 0.41833 +0.25 0.25 0.35 0.35 +0.25 0.25 0.494975 0.247487 +0.25 0.25 0.247487 0.494975 +0.25 0.25 0.606218 0.202073 +0.25 0.25 0.202062 0.606248 +0.25 0.25 0.41833 0.41833 +0.35 0.25 0.35 0.35 +0.35 0.25 0.494975 0.247487 +0.35 0.25 0.247487 0.494975 +0.35 0.25 0.606218 0.202073 +0.35 0.25 0.202062 0.606248 +0.35 0.25 0.41833 0.41833 +0.45 0.25 0.35 0.35 +0.45 0.25 0.494975 0.247487 +0.45 0.25 0.247487 0.494975 +0.45 0.25 0.606218 0.202073 +0.45 0.25 0.202062 0.606248 +0.45 0.25 0.41833 0.41833 +0.55 0.25 0.35 0.35 +0.55 0.25 0.494975 0.247487 +0.55 0.25 0.247487 0.494975 +0.55 0.25 0.606218 0.202073 +0.55 0.25 0.202062 0.606248 +0.55 0.25 0.41833 0.41833 +0.65 0.25 0.35 0.35 +0.65 0.25 0.494975 0.247487 +0.65 0.25 0.247487 0.494975 +0.65 0.25 0.606218 0.202073 +0.65 0.25 0.202062 0.606248 +0.65 0.25 0.41833 0.41833 +0.75 0.25 0.35 0.35 +0.75 0.25 0.494975 0.247487 +0.75 0.25 0.247487 0.494975 +0.75 0.25 0.606218 0.202073 +0.75 0.25 0.202062 0.606248 +0.75 0.25 0.41833 0.41833 +0.85 0.25 0.35 0.35 +0.85 0.25 0.494975 0.247487 +0.85 0.25 0.247487 0.494975 +0.85 0.25 0.606218 0.202073 +0.85 0.25 0.202062 0.606248 +0.85 0.25 0.41833 0.41833 +0.95 0.25 0.35 0.35 +0.95 0.25 0.494975 0.247487 +0.95 0.25 0.247487 0.494975 +0.95 0.25 0.606218 0.202073 +0.95 0.25 0.202063 0.606248 +0.95 0.25 0.41833 0.41833 +0.05 0.35 0.35 0.35 +0.05 0.35 0.494975 0.247487 +0.05 0.35 0.247487 0.494975 +0.05 0.35 0.606218 0.202073 +0.05 0.35 0.202062 0.606248 +0.05 0.35 0.41833 0.41833 +0.15 0.35 0.35 0.35 +0.15 0.35 0.494975 0.247487 +0.15 0.35 0.247487 0.494975 +0.15 0.35 0.606218 0.202073 +0.15 0.35 0.202062 0.606248 +0.15 0.35 0.41833 0.41833 +0.25 0.35 0.35 0.35 +0.25 0.35 0.494975 0.247487 +0.25 0.35 0.247487 0.494975 +0.25 0.35 0.606218 0.202073 +0.25 0.35 0.202062 0.606248 +0.25 0.35 0.41833 0.41833 +0.35 0.35 0.35 0.35 +0.35 0.35 0.494975 0.247487 +0.35 0.35 0.247487 0.494975 +0.35 0.35 0.606218 0.202073 +0.35 0.35 0.202062 0.606248 +0.35 0.35 0.41833 0.41833 +0.45 0.35 0.35 0.35 +0.45 0.35 0.494975 0.247487 +0.45 0.35 0.247487 0.494975 +0.45 0.35 0.606218 0.202073 +0.45 0.35 0.202062 0.606248 +0.45 0.35 0.41833 0.41833 +0.55 0.35 0.35 0.35 +0.55 0.35 0.494975 0.247487 +0.55 0.35 0.247487 0.494975 +0.55 0.35 0.606218 0.202073 +0.55 0.35 0.202062 0.606248 +0.55 0.35 0.41833 0.41833 +0.65 0.35 0.35 0.35 +0.65 0.35 0.494975 0.247487 +0.65 0.35 0.247487 0.494975 +0.65 0.35 0.606218 0.202073 +0.65 0.35 0.202062 0.606248 +0.65 0.35 0.41833 0.41833 +0.75 0.35 0.35 0.35 +0.75 0.35 0.494975 0.247487 +0.75 0.35 0.247487 0.494975 +0.75 0.35 0.606218 0.202073 +0.75 0.35 0.202062 0.606248 +0.75 0.35 0.41833 0.41833 +0.85 0.35 0.35 0.35 +0.85 0.35 0.494975 0.247487 +0.85 0.35 0.247487 0.494975 +0.85 0.35 0.606218 0.202073 +0.85 0.35 0.202062 0.606248 +0.85 0.35 0.41833 0.41833 +0.95 0.35 0.35 0.35 +0.95 0.35 0.494975 0.247487 +0.95 0.35 0.247487 0.494975 +0.95 0.35 0.606218 0.202073 +0.95 0.35 0.202063 0.606248 +0.95 0.35 0.41833 0.41833 +0.05 0.45 0.35 0.35 +0.05 0.45 0.494975 0.247487 +0.05 0.45 0.247487 0.494975 +0.05 0.45 0.606218 0.202073 +0.05 0.45 0.202062 0.606248 +0.05 0.45 0.41833 0.41833 +0.15 0.45 0.35 0.35 +0.15 0.45 0.494975 0.247487 +0.15 0.45 0.247487 0.494975 +0.15 0.45 0.606218 0.202073 +0.15 0.45 0.202062 0.606248 +0.15 0.45 0.41833 0.41833 +0.25 0.45 0.35 0.35 +0.25 0.45 0.494975 0.247487 +0.25 0.45 0.247487 0.494975 +0.25 0.45 0.606218 0.202073 +0.25 0.45 0.202062 0.606248 +0.25 0.45 0.41833 0.41833 +0.35 0.45 0.35 0.35 +0.35 0.45 0.494975 0.247487 +0.35 0.45 0.247487 0.494975 +0.35 0.45 0.606218 0.202073 +0.35 0.45 0.202062 0.606248 +0.35 0.45 0.41833 0.41833 +0.45 0.45 0.35 0.35 +0.45 0.45 0.494975 0.247487 +0.45 0.45 0.247487 0.494975 +0.45 0.45 0.606218 0.202073 +0.45 0.45 0.202062 0.606248 +0.45 0.45 0.41833 0.41833 +0.55 0.45 0.35 0.35 +0.55 0.45 0.494975 0.247487 +0.55 0.45 0.247487 0.494975 +0.55 0.45 0.606218 0.202073 +0.55 0.45 0.202062 0.606248 +0.55 0.45 0.41833 0.41833 +0.65 0.45 0.35 0.35 +0.65 0.45 0.494975 0.247487 +0.65 0.45 0.247487 0.494975 +0.65 0.45 0.606218 0.202073 +0.65 0.45 0.202062 0.606248 +0.65 0.45 0.41833 0.41833 +0.75 0.45 0.35 0.35 +0.75 0.45 0.494975 0.247487 +0.75 0.45 0.247487 0.494975 +0.75 0.45 0.606218 0.202073 +0.75 0.45 0.202062 0.606248 +0.75 0.45 0.41833 0.41833 +0.85 0.45 0.35 0.35 +0.85 0.45 0.494975 0.247487 +0.85 0.45 0.247487 0.494975 +0.85 0.45 0.606218 0.202073 +0.85 0.45 0.202062 0.606248 +0.85 0.45 0.41833 0.41833 +0.95 0.45 0.35 0.35 +0.95 0.45 0.494975 0.247487 +0.95 0.45 0.247487 0.494975 +0.95 0.45 0.606218 0.202073 +0.95 0.45 0.202063 0.606248 +0.95 0.45 0.41833 0.41833 +0.05 0.55 0.35 0.35 +0.05 0.55 0.494975 0.247487 +0.05 0.55 0.247487 0.494975 +0.05 0.55 0.606218 0.202073 +0.05 0.55 0.202062 0.606248 +0.05 0.55 0.41833 0.41833 +0.15 0.55 0.35 0.35 +0.15 0.55 0.494975 0.247487 +0.15 0.55 0.247487 0.494975 +0.15 0.55 0.606218 0.202073 +0.15 0.55 0.202062 0.606248 +0.15 0.55 0.41833 0.41833 +0.25 0.55 0.35 0.35 +0.25 0.55 0.494975 0.247487 +0.25 0.55 0.247487 0.494975 +0.25 0.55 0.606218 0.202073 +0.25 0.55 0.202062 0.606248 +0.25 0.55 0.41833 0.41833 +0.35 0.55 0.35 0.35 +0.35 0.55 0.494975 0.247487 +0.35 0.55 0.247487 0.494975 +0.35 0.55 0.606218 0.202073 +0.35 0.55 0.202062 0.606248 +0.35 0.55 0.41833 0.41833 +0.45 0.55 0.35 0.35 +0.45 0.55 0.494975 0.247487 +0.45 0.55 0.247487 0.494975 +0.45 0.55 0.606218 0.202073 +0.45 0.55 0.202062 0.606248 +0.45 0.55 0.41833 0.41833 +0.55 0.55 0.35 0.35 +0.55 0.55 0.494975 0.247487 +0.55 0.55 0.247487 0.494975 +0.55 0.55 0.606218 0.202073 +0.55 0.55 0.202062 0.606248 +0.55 0.55 0.41833 0.41833 +0.65 0.55 0.35 0.35 +0.65 0.55 0.494975 0.247487 +0.65 0.55 0.247487 0.494975 +0.65 0.55 0.606218 0.202073 +0.65 0.55 0.202062 0.606248 +0.65 0.55 0.41833 0.41833 +0.75 0.55 0.35 0.35 +0.75 0.55 0.494975 0.247487 +0.75 0.55 0.247487 0.494975 +0.75 0.55 0.606218 0.202073 +0.75 0.55 0.202062 0.606248 +0.75 0.55 0.41833 0.41833 +0.85 0.55 0.35 0.35 +0.85 0.55 0.494975 0.247487 +0.85 0.55 0.247487 0.494975 +0.85 0.55 0.606218 0.202073 +0.85 0.55 0.202062 0.606248 +0.85 0.55 0.41833 0.41833 +0.95 0.55 0.35 0.35 +0.95 0.55 0.494975 0.247487 +0.95 0.55 0.247487 0.494975 +0.95 0.55 0.606218 0.202073 +0.95 0.55 0.202063 0.606248 +0.95 0.55 0.41833 0.41833 +0.05 0.65 0.35 0.35 +0.05 0.65 0.494975 0.247487 +0.05 0.65 0.247487 0.494975 +0.05 0.65 0.606218 0.202073 +0.05 0.65 0.202062 0.606248 +0.05 0.65 0.41833 0.41833 +0.15 0.65 0.35 0.35 +0.15 0.65 0.494975 0.247487 +0.15 0.65 0.247487 0.494975 +0.15 0.65 0.606218 0.202073 +0.15 0.65 0.202062 0.606248 +0.15 0.65 0.41833 0.41833 +0.25 0.65 0.35 0.35 +0.25 0.65 0.494975 0.247487 +0.25 0.65 0.247487 0.494975 +0.25 0.65 0.606218 0.202073 +0.25 0.65 0.202062 0.606248 +0.25 0.65 0.41833 0.41833 +0.35 0.65 0.35 0.35 +0.35 0.65 0.494975 0.247487 +0.35 0.65 0.247487 0.494975 +0.35 0.65 0.606218 0.202073 +0.35 0.65 0.202062 0.606248 +0.35 0.65 0.41833 0.41833 +0.45 0.65 0.35 0.35 +0.45 0.65 0.494975 0.247487 +0.45 0.65 0.247487 0.494975 +0.45 0.65 0.606218 0.202073 +0.45 0.65 0.202062 0.606248 +0.45 0.65 0.41833 0.41833 +0.55 0.65 0.35 0.35 +0.55 0.65 0.494975 0.247487 +0.55 0.65 0.247487 0.494975 +0.55 0.65 0.606218 0.202073 +0.55 0.65 0.202062 0.606248 +0.55 0.65 0.41833 0.41833 +0.65 0.65 0.35 0.35 +0.65 0.65 0.494975 0.247487 +0.65 0.65 0.247487 0.494975 +0.65 0.65 0.606218 0.202073 +0.65 0.65 0.202062 0.606248 +0.65 0.65 0.41833 0.41833 +0.75 0.65 0.35 0.35 +0.75 0.65 0.494975 0.247487 +0.75 0.65 0.247487 0.494975 +0.75 0.65 0.606218 0.202073 +0.75 0.65 0.202062 0.606248 +0.75 0.65 0.41833 0.41833 +0.85 0.65 0.35 0.35 +0.85 0.65 0.494975 0.247487 +0.85 0.65 0.247487 0.494975 +0.85 0.65 0.606218 0.202073 +0.85 0.65 0.202062 0.606248 +0.85 0.65 0.41833 0.41833 +0.95 0.65 0.35 0.35 +0.95 0.65 0.494975 0.247487 +0.95 0.65 0.247487 0.494975 +0.95 0.65 0.606218 0.202073 +0.95 0.65 0.202063 0.606248 +0.95 0.65 0.41833 0.41833 +0.05 0.75 0.35 0.35 +0.05 0.75 0.494975 0.247487 +0.05 0.75 0.247487 0.494975 +0.05 0.75 0.606218 0.202073 +0.05 0.75 0.202062 0.606248 +0.05 0.75 0.41833 0.41833 +0.15 0.75 0.35 0.35 +0.15 0.75 0.494975 0.247487 +0.15 0.75 0.247487 0.494975 +0.15 0.75 0.606218 0.202073 +0.15 0.75 0.202062 0.606248 +0.15 0.75 0.41833 0.41833 +0.25 0.75 0.35 0.35 +0.25 0.75 0.494975 0.247487 +0.25 0.75 0.247487 0.494975 +0.25 0.75 0.606218 0.202073 +0.25 0.75 0.202062 0.606248 +0.25 0.75 0.41833 0.41833 +0.35 0.75 0.35 0.35 +0.35 0.75 0.494975 0.247487 +0.35 0.75 0.247487 0.494975 +0.35 0.75 0.606218 0.202073 +0.35 0.75 0.202062 0.606248 +0.35 0.75 0.41833 0.41833 +0.45 0.75 0.35 0.35 +0.45 0.75 0.494975 0.247487 +0.45 0.75 0.247487 0.494975 +0.45 0.75 0.606218 0.202073 +0.45 0.75 0.202062 0.606248 +0.45 0.75 0.41833 0.41833 +0.55 0.75 0.35 0.35 +0.55 0.75 0.494975 0.247487 +0.55 0.75 0.247487 0.494975 +0.55 0.75 0.606218 0.202073 +0.55 0.75 0.202062 0.606248 +0.55 0.75 0.41833 0.41833 +0.65 0.75 0.35 0.35 +0.65 0.75 0.494975 0.247487 +0.65 0.75 0.247487 0.494975 +0.65 0.75 0.606218 0.202073 +0.65 0.75 0.202062 0.606248 +0.65 0.75 0.41833 0.41833 +0.75 0.75 0.35 0.35 +0.75 0.75 0.494975 0.247487 +0.75 0.75 0.247487 0.494975 +0.75 0.75 0.606218 0.202073 +0.75 0.75 0.202062 0.606248 +0.75 0.75 0.41833 0.41833 +0.85 0.75 0.35 0.35 +0.85 0.75 0.494975 0.247487 +0.85 0.75 0.247487 0.494975 +0.85 0.75 0.606218 0.202073 +0.85 0.75 0.202062 0.606248 +0.85 0.75 0.41833 0.41833 +0.95 0.75 0.35 0.35 +0.95 0.75 0.494975 0.247487 +0.95 0.75 0.247487 0.494975 +0.95 0.75 0.606218 0.202073 +0.95 0.75 0.202063 0.606248 +0.95 0.75 0.41833 0.41833 +0.05 0.85 0.35 0.35 +0.05 0.85 0.494975 0.247487 +0.05 0.85 0.247487 0.494975 +0.05 0.85 0.606218 0.202073 +0.05 0.85 0.202062 0.606248 +0.05 0.85 0.41833 0.41833 +0.15 0.85 0.35 0.35 +0.15 0.85 0.494975 0.247487 +0.15 0.85 0.247487 0.494975 +0.15 0.85 0.606218 0.202073 +0.15 0.85 0.202062 0.606248 +0.15 0.85 0.41833 0.41833 +0.25 0.85 0.35 0.35 +0.25 0.85 0.494975 0.247487 +0.25 0.85 0.247487 0.494975 +0.25 0.85 0.606218 0.202073 +0.25 0.85 0.202062 0.606248 +0.25 0.85 0.41833 0.41833 +0.35 0.85 0.35 0.35 +0.35 0.85 0.494975 0.247487 +0.35 0.85 0.247487 0.494975 +0.35 0.85 0.606218 0.202073 +0.35 0.85 0.202062 0.606248 +0.35 0.85 0.41833 0.41833 +0.45 0.85 0.35 0.35 +0.45 0.85 0.494975 0.247487 +0.45 0.85 0.247487 0.494975 +0.45 0.85 0.606218 0.202073 +0.45 0.85 0.202062 0.606248 +0.45 0.85 0.41833 0.41833 +0.55 0.85 0.35 0.35 +0.55 0.85 0.494975 0.247487 +0.55 0.85 0.247487 0.494975 +0.55 0.85 0.606218 0.202073 +0.55 0.85 0.202062 0.606248 +0.55 0.85 0.41833 0.41833 +0.65 0.85 0.35 0.35 +0.65 0.85 0.494975 0.247487 +0.65 0.85 0.247487 0.494975 +0.65 0.85 0.606218 0.202073 +0.65 0.85 0.202062 0.606248 +0.65 0.85 0.41833 0.41833 +0.75 0.85 0.35 0.35 +0.75 0.85 0.494975 0.247487 +0.75 0.85 0.247487 0.494975 +0.75 0.85 0.606218 0.202073 +0.75 0.85 0.202062 0.606248 +0.75 0.85 0.41833 0.41833 +0.85 0.85 0.35 0.35 +0.85 0.85 0.494975 0.247487 +0.85 0.85 0.247487 0.494975 +0.85 0.85 0.606218 0.202073 +0.85 0.85 0.202062 0.606248 +0.85 0.85 0.41833 0.41833 +0.95 0.85 0.35 0.35 +0.95 0.85 0.494975 0.247487 +0.95 0.85 0.247487 0.494975 +0.95 0.85 0.606218 0.202073 +0.95 0.85 0.202063 0.606248 +0.95 0.85 0.41833 0.41833 +0.05 0.95 0.35 0.35 +0.05 0.95 0.494975 0.247487 +0.05 0.95 0.247487 0.494975 +0.05 0.95 0.606218 0.202073 +0.05 0.95 0.202062 0.606248 +0.05 0.95 0.41833 0.41833 +0.15 0.95 0.35 0.35 +0.15 0.95 0.494975 0.247487 +0.15 0.95 0.247487 0.494975 +0.15 0.95 0.606218 0.202073 +0.15 0.95 0.202062 0.606248 +0.15 0.95 0.41833 0.41833 +0.25 0.95 0.35 0.35 +0.25 0.95 0.494975 0.247487 +0.25 0.95 0.247487 0.494975 +0.25 0.95 0.606218 0.202073 +0.25 0.95 0.202062 0.606248 +0.25 0.95 0.41833 0.41833 +0.35 0.95 0.35 0.35 +0.35 0.95 0.494975 0.247487 +0.35 0.95 0.247487 0.494975 +0.35 0.95 0.606218 0.202073 +0.35 0.95 0.202062 0.606248 +0.35 0.95 0.41833 0.41833 +0.45 0.95 0.35 0.35 +0.45 0.95 0.494975 0.247487 +0.45 0.95 0.247487 0.494975 +0.45 0.95 0.606218 0.202073 +0.45 0.95 0.202062 0.606248 +0.45 0.95 0.41833 0.41833 +0.55 0.95 0.35 0.35 +0.55 0.95 0.494975 0.247487 +0.55 0.95 0.247487 0.494975 +0.55 0.95 0.606218 0.202073 +0.55 0.95 0.202062 0.606248 +0.55 0.95 0.41833 0.41833 +0.65 0.95 0.35 0.35 +0.65 0.95 0.494975 0.247487 +0.65 0.95 0.247487 0.494975 +0.65 0.95 0.606218 0.202073 +0.65 0.95 0.202062 0.606248 +0.65 0.95 0.41833 0.41833 +0.75 0.95 0.35 0.35 +0.75 0.95 0.494975 0.247487 +0.75 0.95 0.247487 0.494975 +0.75 0.95 0.606218 0.202073 +0.75 0.95 0.202062 0.606248 +0.75 0.95 0.41833 0.41833 +0.85 0.95 0.35 0.35 +0.85 0.95 0.494975 0.247487 +0.85 0.95 0.247487 0.494975 +0.85 0.95 0.606218 0.202073 +0.85 0.95 0.202062 0.606248 +0.85 0.95 0.41833 0.41833 +0.95 0.95 0.35 0.35 +0.95 0.95 0.494975 0.247487 +0.95 0.95 0.247487 0.494975 +0.95 0.95 0.606218 0.202073 +0.95 0.95 0.202063 0.606248 +0.95 0.95 0.41833 0.41833 +0.1 0.1 0.5 0.5 +0.1 0.1 0.707107 0.353553 +0.1 0.1 0.353553 0.707107 +0.1 0.1 0.866025 0.288675 +0.1 0.1 0.288661 0.866069 +0.1 0.1 0.570088 0.570088 +0.3 0.1 0.5 0.5 +0.3 0.1 0.707107 0.353553 +0.3 0.1 0.353553 0.707107 +0.3 0.1 0.866025 0.288675 +0.3 0.1 0.288661 0.866069 +0.3 0.1 0.570088 0.570088 +0.5 0.1 0.5 0.5 +0.5 0.1 0.707107 0.353553 +0.5 0.1 0.353553 0.707107 +0.5 0.1 0.866025 0.288675 +0.5 0.1 0.288661 0.866069 +0.5 0.1 0.570088 0.570088 +0.7 0.1 0.5 0.5 +0.7 0.1 0.707107 0.353553 +0.7 0.1 0.353553 0.707107 +0.7 0.1 0.866025 0.288675 +0.7 0.1 0.288661 0.866069 +0.7 0.1 0.570088 0.570088 +0.9 0.1 0.5 0.5 +0.9 0.1 0.707107 0.353553 +0.9 0.1 0.353553 0.707107 +0.9 0.1 0.866025 0.288675 +0.9 0.1 0.288661 0.866069 +0.9 0.1 0.570088 0.570088 +0.1 0.3 0.5 0.5 +0.1 0.3 0.707107 0.353553 +0.1 0.3 0.353553 0.707107 +0.1 0.3 0.866025 0.288675 +0.1 0.3 0.288661 0.866069 +0.1 0.3 0.570088 0.570088 +0.3 0.3 0.5 0.5 +0.3 0.3 0.707107 0.353553 +0.3 0.3 0.353553 0.707107 +0.3 0.3 0.866025 0.288675 +0.3 0.3 0.288661 0.866069 +0.3 0.3 0.570088 0.570088 +0.5 0.3 0.5 0.5 +0.5 0.3 0.707107 0.353553 +0.5 0.3 0.353553 0.707107 +0.5 0.3 0.866025 0.288675 +0.5 0.3 0.288661 0.866069 +0.5 0.3 0.570088 0.570088 +0.7 0.3 0.5 0.5 +0.7 0.3 0.707107 0.353553 +0.7 0.3 0.353553 0.707107 +0.7 0.3 0.866025 0.288675 +0.7 0.3 0.288661 0.866069 +0.7 0.3 0.570088 0.570088 +0.9 0.3 0.5 0.5 +0.9 0.3 0.707107 0.353553 +0.9 0.3 0.353553 0.707107 +0.9 0.3 0.866025 0.288675 +0.9 0.3 0.288661 0.866069 +0.9 0.3 0.570088 0.570088 +0.1 0.5 0.5 0.5 +0.1 0.5 0.707107 0.353553 +0.1 0.5 0.353553 0.707107 +0.1 0.5 0.866025 0.288675 +0.1 0.5 0.288661 0.866069 +0.1 0.5 0.570088 0.570088 +0.3 0.5 0.5 0.5 +0.3 0.5 0.707107 0.353553 +0.3 0.5 0.353553 0.707107 +0.3 0.5 0.866025 0.288675 +0.3 0.5 0.288661 0.866069 +0.3 0.5 0.570088 0.570088 +0.5 0.5 0.5 0.5 +0.5 0.5 0.707107 0.353553 +0.5 0.5 0.353553 0.707107 +0.5 0.5 0.866025 0.288675 +0.5 0.5 0.288661 0.866069 +0.5 0.5 0.570088 0.570088 +0.7 0.5 0.5 0.5 +0.7 0.5 0.707107 0.353553 +0.7 0.5 0.353553 0.707107 +0.7 0.5 0.866025 0.288675 +0.7 0.5 0.288661 0.866069 +0.7 0.5 0.570088 0.570088 +0.9 0.5 0.5 0.5 +0.9 0.5 0.707107 0.353553 +0.9 0.5 0.353553 0.707107 +0.9 0.5 0.866025 0.288675 +0.9 0.5 0.288661 0.866069 +0.9 0.5 0.570088 0.570088 +0.1 0.7 0.5 0.5 +0.1 0.7 0.707107 0.353553 +0.1 0.7 0.353553 0.707107 +0.1 0.7 0.866025 0.288675 +0.1 0.7 0.288661 0.866069 +0.1 0.7 0.570088 0.570088 +0.3 0.7 0.5 0.5 +0.3 0.7 0.707107 0.353553 +0.3 0.7 0.353553 0.707107 +0.3 0.7 0.866025 0.288675 +0.3 0.7 0.288661 0.866069 +0.3 0.7 0.570088 0.570088 +0.5 0.7 0.5 0.5 +0.5 0.7 0.707107 0.353553 +0.5 0.7 0.353553 0.707107 +0.5 0.7 0.866025 0.288675 +0.5 0.7 0.288661 0.866069 +0.5 0.7 0.570088 0.570088 +0.7 0.7 0.5 0.5 +0.7 0.7 0.707107 0.353553 +0.7 0.7 0.353553 0.707107 +0.7 0.7 0.866025 0.288675 +0.7 0.7 0.288661 0.866069 +0.7 0.7 0.570088 0.570088 +0.9 0.7 0.5 0.5 +0.9 0.7 0.707107 0.353553 +0.9 0.7 0.353553 0.707107 +0.9 0.7 0.866025 0.288675 +0.9 0.7 0.288661 0.866069 +0.9 0.7 0.570088 0.570088 +0.1 0.9 0.5 0.5 +0.1 0.9 0.707107 0.353553 +0.1 0.9 0.353553 0.707107 +0.1 0.9 0.866025 0.288675 +0.1 0.9 0.288661 0.866069 +0.1 0.9 0.570088 0.570088 +0.3 0.9 0.5 0.5 +0.3 0.9 0.707107 0.353553 +0.3 0.9 0.353553 0.707107 +0.3 0.9 0.866025 0.288675 +0.3 0.9 0.288661 0.866069 +0.3 0.9 0.570088 0.570088 +0.5 0.9 0.5 0.5 +0.5 0.9 0.707107 0.353553 +0.5 0.9 0.353553 0.707107 +0.5 0.9 0.866025 0.288675 +0.5 0.9 0.288661 0.866069 +0.5 0.9 0.570088 0.570088 +0.7 0.9 0.5 0.5 +0.7 0.9 0.707107 0.353553 +0.7 0.9 0.353553 0.707107 +0.7 0.9 0.866025 0.288675 +0.7 0.9 0.288661 0.866069 +0.7 0.9 0.570088 0.570088 +0.9 0.9 0.5 0.5 +0.9 0.9 0.707107 0.353553 +0.9 0.9 0.353553 0.707107 +0.9 0.9 0.866025 0.288675 +0.9 0.9 0.288661 0.866069 +0.9 0.9 0.570088 0.570088 +0.166667 0.166667 0.65 0.65 +0.166667 0.166667 0.919239 0.459619 +0.166667 0.166667 0.459619 0.919239 +0.166667 0.166667 1.12583 0.375278 +0.166667 0.166667 0.375259 1.12589 +0.166667 0.166667 0.72111 0.72111 +0.5 0.166667 0.65 0.65 +0.5 0.166667 0.919239 0.459619 +0.5 0.166667 0.459619 0.919239 +0.5 0.166667 1.12583 0.375278 +0.5 0.166667 0.375259 1.12589 +0.5 0.166667 0.72111 0.72111 +0.833333 0.166667 0.65 0.65 +0.833333 0.166667 0.919239 0.459619 +0.833333 0.166667 0.459619 0.919239 +0.833333 0.166667 1.12583 0.375278 +0.833333 0.166667 0.375259 1.12589 +0.833333 0.166667 0.72111 0.72111 +0.166667 0.5 0.65 0.65 +0.166667 0.5 0.919239 0.459619 +0.166667 0.5 0.459619 0.919239 +0.166667 0.5 1.12583 0.375278 +0.166667 0.5 0.375259 1.12589 +0.166667 0.5 0.72111 0.72111 +0.5 0.5 0.65 0.65 +0.5 0.5 0.919239 0.459619 +0.5 0.5 0.459619 0.919239 +0.5 0.5 1.12583 0.375278 +0.5 0.5 0.375259 1.12589 +0.5 0.5 0.72111 0.72111 +0.833333 0.5 0.65 0.65 +0.833333 0.5 0.919239 0.459619 +0.833333 0.5 0.459619 0.919239 +0.833333 0.5 1.12583 0.375278 +0.833333 0.5 0.375259 1.12589 +0.833333 0.5 0.72111 0.72111 +0.166667 0.833333 0.65 0.65 +0.166667 0.833333 0.919239 0.459619 +0.166667 0.833333 0.459619 0.919239 +0.166667 0.833333 1.12583 0.375278 +0.166667 0.833333 0.375259 1.12589 +0.166667 0.833333 0.72111 0.72111 +0.5 0.833333 0.65 0.65 +0.5 0.833333 0.919239 0.459619 +0.5 0.833333 0.459619 0.919239 +0.5 0.833333 1.12583 0.375278 +0.5 0.833333 0.375259 1.12589 +0.5 0.833333 0.72111 0.72111 +0.833333 0.833333 0.65 0.65 +0.833333 0.833333 0.919239 0.459619 +0.833333 0.833333 0.459619 0.919239 +0.833333 0.833333 1.12583 0.375278 +0.833333 0.833333 0.375259 1.12589 +0.833333 0.833333 0.72111 0.72111 +0.25 0.25 0.8 0.8 +0.25 0.25 1.13137 0.565686 +0.25 0.25 0.565685 1.13137 +0.25 0.25 1.38564 0.46188 +0.25 0.25 0.461857 1.38571 +0.25 0.25 0.87178 0.87178 +0.75 0.25 0.8 0.8 +0.75 0.25 1.13137 0.565686 +0.75 0.25 0.565685 1.13137 +0.75 0.25 1.38564 0.46188 +0.75 0.25 0.461857 1.38571 +0.75 0.25 0.87178 0.87178 +0.25 0.75 0.8 0.8 +0.25 0.75 1.13137 0.565686 +0.25 0.75 0.565685 1.13137 +0.25 0.75 1.38564 0.46188 +0.25 0.75 0.461857 1.38571 +0.25 0.75 0.87178 0.87178 +0.75 0.75 0.8 0.8 +0.75 0.75 1.13137 0.565686 +0.75 0.75 0.565685 1.13137 +0.75 0.75 1.38564 0.46188 +0.75 0.75 0.461857 1.38571 +0.75 0.75 0.87178 0.87178 +0.5 0.5 0.95 0.95 +0.5 0.5 1.3435 0.671751 +0.5 0.5 0.671751 1.3435 +0.5 0.5 1.64545 0.548483 +0.5 0.5 0.548455 1.64553 +0.5 0.5 0.974679 0.974679 diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc new file mode 100644 index 000000000..7ef0e246b --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -0,0 +1,418 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // ANDROID + +namespace { +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +// Commonly used to compute the number of blocks to launch in a kernel. +int RoundUp(const int size, const int multiple) { + return (size + multiple - 1) / multiple; +} +} // namespace + +namespace mediapipe { + +#if defined(__ANDROID__) +using ::tflite::gpu::gl::GlBuffer; +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +struct GPUData { + int width; + int height; + int channels; + GlBuffer ssbo; + GlShader shader; + GlProgram program; +}; +#endif // ANDROID + +// Calculator for normalizing and converting an ImageFrame or GpuBuffer +// into a TfLiteTensor (float 32) or tflite::gpu::GlBuffer, respetively. +// +// This calculator is designed to be used with the TfLiteInferenceCalcualtor, +// as a pre-processing step for calculator inputs. +// +// Input data is normalized to [-1,1] (default) or [0,1], specified by options. +// +// Input: +// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data). +// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture) +// +// Output: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 +// TENSORS_GPU - vector of GlBuffer +// +// Example use: +// node { +// calculator: "TfLiteConverterCalculator" +// input_stream: "IMAGE:input_image" +// output_stream: "TENSORS:image_tensor" +// options: { +// [mediapipe.TfLiteConverterCalculatorOptions.ext] { +// zero_center: true +// } +// } +// } +// +// IMPORTANT Notes: +// No conversion between CPU/GPU is done. +// Inputs/outputs must match type: CPU->CPU or GPU->GPU. +// GPU tensors are currently only supported on Android. +// This calculator uses FixedSizeInputStreamHandler by default. +// +class TfLiteConverterCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status InitGpu(CalculatorContext* cc); + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + template + ::mediapipe::Status NormalizeImage(const ImageFrame& image_frame, + bool zero_center, bool flip_vertically, + float* tensor_buffer); + + std::unique_ptr interpreter_ = nullptr; + +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr gpu_data_out_; +#endif + + bool initialized_ = false; + bool use_gpu_ = false; + bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1] + bool flip_vertically_ = false; + int max_num_channels_ = 3; +}; +REGISTER_CALCULATOR(TfLiteConverterCalculator); + +::mediapipe::Status TfLiteConverterCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("IMAGE") || cc->Inputs().HasTag("IMAGE_GPU")); + RET_CHECK(cc->Outputs().HasTag("TENSORS") || + cc->Outputs().HasTag("TENSORS_GPU")); + + if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set(); +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("IMAGE_GPU")) + cc->Inputs().Tag("IMAGE_GPU").Set(); +#endif + + if (cc->Outputs().HasTag("TENSORS")) + cc->Outputs().Tag("TENSORS").Set>(); +#if defined(__ANDROID__) + if (cc->Outputs().HasTag("TENSORS_GPU")) + cc->Outputs().Tag("TENSORS_GPU").Set>(); +#endif + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif + + // Assign this calculator's default InputStreamHandler. + cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { + RETURN_IF_ERROR(LoadOptions(cc)); + + if (cc->Inputs().HasTag("IMAGE_GPU") || + cc->Outputs().HasTag("IMAGE_OUT_GPU")) { +#if defined(__ANDROID__) + use_gpu_ = true; +#else + RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; +#endif + } + + if (use_gpu_) { + // Cannot mix CPU/GPU streams. + RET_CHECK(cc->Inputs().HasTag("IMAGE_GPU") && + cc->Outputs().HasTag("TENSORS_GPU")); +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } else { + interpreter_ = absl::make_unique(); + interpreter_->AddTensors(1); + interpreter_->SetInputs({0}); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { + if (use_gpu_) { + // GpuBuffer to tflite::gpu::GlBuffer conversion. +#if defined(__ANDROID__) + if (!initialized_) { + RETURN_IF_ERROR(InitGpu(cc)); + initialized_ = true; + } + + const auto& input = + cc->Inputs().Tag("IMAGE_GPU").Get(); + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status { + // Convert GL texture into TfLite GlBuffer (SSBO). + auto src = gpu_helper_.CreateSourceTexture(input); + glActiveTexture(GL_TEXTURE0 + 0); + glBindTexture(GL_TEXTURE_2D, src.name()); + auto status = gpu_data_out_->ssbo.BindToIndex(1); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + const tflite::gpu::uint3 workgroups = { + RoundUp(gpu_data_out_->width, kWorkgroupSize), + RoundUp(gpu_data_out_->height, kWorkgroupSize), 1}; + status = gpu_data_out_->program.Dispatch(workgroups); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); + glBindTexture(GL_TEXTURE_2D, 0); + src.Release(); + return ::mediapipe::OkStatus(); + })); + + auto output_tensors = absl::make_unique>(); + output_tensors->resize(1); + for (int i = 0; i < 1; ++i) { + GlBuffer& tensor = output_tensors->at(i); + using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; + auto status = CreateReadWriteShaderStorageBuffer( + gpu_data_out_->width * gpu_data_out_->height * + gpu_data_out_->channels, + &tensor); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + tflite::gpu::gl::CopyBuffer(gpu_data_out_->ssbo, tensor); + } + cc->Outputs() + .Tag("TENSORS_GPU") + .Add(output_tensors.release(), cc->InputTimestamp()); +#else + RET_CHECK_FAIL() + << "GPU input on non-Android devices is not supported yet."; +#endif + } else { + // CPU ImageFrame to TfLiteTensor conversion. + + const auto& image_frame = cc->Inputs().Tag("IMAGE").Get(); + const int height = image_frame.Height(); + const int width = image_frame.Width(); + const int channels_preserved = + std::min(image_frame.NumberOfChannels(), max_num_channels_); + + if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA || + image_frame.Format() == mediapipe::ImageFormat::SRGB || + image_frame.Format() == mediapipe::ImageFormat::GRAY8 || + image_frame.Format() == mediapipe::ImageFormat::VEC32F1)) + RET_CHECK_FAIL() << "Unsupported CPU input format."; + + if (!initialized_) { + interpreter_->SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {channels_preserved}, TfLiteQuantization()); + initialized_ = true; + } + + const int tensor_idx = interpreter_->inputs()[0]; + TfLiteTensor* tensor = interpreter_->tensor(tensor_idx); + interpreter_->ResizeInputTensor(tensor_idx, + {height, width, channels_preserved}); + interpreter_->AllocateTensors(); + + float* tensor_buffer = tensor->data.f; + RET_CHECK(tensor_buffer); + + if (image_frame.ByteDepth() == 1) { + RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, + flip_vertically_, tensor_buffer)); + } else if (image_frame.ByteDepth() == 4) { + RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, + flip_vertically_, tensor_buffer)); + } else { + return ::mediapipe::InternalError( + "Only byte-based (8 bit) and float (32 bit) images supported."); + } + + auto output_tensors = absl::make_unique>(); + output_tensors->emplace_back(*tensor); + cc->Outputs().Tag("TENSORS").Add(output_tensors.release(), + cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); +#endif // __ANDROID__ + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { +#if defined(__ANDROID__) + // Get input image sizes. + const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); + + mediapipe::ImageFormat::Format format = + mediapipe::ImageFormatForGpuBufferFormat(input.format()); + + gpu_data_out_ = absl::make_unique(); + gpu_data_out_->height = input.height(); + gpu_data_out_->width = input.width(); + gpu_data_out_->channels = max_num_channels_; // desired output channels + + const bool include_alpha = (max_num_channels_ == 4); + + if (!(format == mediapipe::ImageFormat::SRGB || + format == mediapipe::ImageFormat::SRGBA)) + RET_CHECK_FAIL() << "Unsupported GPU input format."; + + if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) + RET_CHECK_FAIL() << "Num input channels is less than desired output."; + + // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), + // with normalization to either: [0,1] or [-1,1]. + auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_out_->width * gpu_data_out_->height * gpu_data_out_->channels, + &gpu_data_out_->ssbo); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + const std::string shader_source = absl::Substitute( + R"( #version 310 es + layout(local_size_x = $0, local_size_y = $0) in; + layout(binding = 0) uniform sampler2D input_texture; + layout(std430, binding = 1) buffer Output {float elements[];} output_data; + ivec2 width_height = ivec2($1, $2); + void main() { + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= width_height.x || gid.y >= width_height.y) return; + $5 // pixel fetch + $3 // normalize [-1,1] + int linear_index = $7 * ($4 * width_height.x + gid.x); + output_data.elements[linear_index + 0] = pixel.x; + output_data.elements[linear_index + 1] = pixel.y; + output_data.elements[linear_index + 2] = pixel.z; + $6 // alpha channel + })", + /*$0=*/kWorkgroupSize, /*$1=*/gpu_data_out_->width, + /*$2=*/gpu_data_out_->height, + /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", + /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", + /*$5=*/ + include_alpha ? "vec4 pixel = texelFetch(input_texture, gid, 0);" + : "vec3 pixel = texelFetch(input_texture, gid, 0).xyz;", + /*$6=*/ + include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" : "", + /*$7=*/include_alpha ? 4 : 3); + status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, + &gpu_data_out_->shader); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + status = GlProgram::CreateWithShader(gpu_data_out_->shader, + &gpu_data_out_->program); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } +#endif // ANDROID + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + const auto& options = + cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>(); + + // Get data normalization mode. + zero_center_ = options.zero_center(); + + // Get y-flip mode. + flip_vertically_ = options.flip_vertically(); + + // Get desired way to handle input channels. + max_num_channels_ = options.max_num_channels(); + // Currently only alpha channel toggling is suppored. + CHECK_GE(max_num_channels_, 3); + CHECK_LE(max_num_channels_, 4); + + return ::mediapipe::OkStatus(); +} + +template +::mediapipe::Status TfLiteConverterCalculator::NormalizeImage( + const ImageFrame& image_frame, bool zero_center, bool flip_vertically, + float* tensor_buffer) { + const int height = image_frame.Height(); + const int width = image_frame.Width(); + const int channels = image_frame.NumberOfChannels(); + const int channels_preserved = std::min(channels, max_num_channels_); + const int channels_ignored = channels - channels_preserved; + + float div, sub; + if (zero_center) { + // [-1,1] + div = 127.5f; + sub = 1.0f; + } else { + // [0,1] + div = 255.0f; + sub = 0.0f; + } + + for (int i = 0; i < height; ++i) { + const T* image_ptr = reinterpret_cast( + image_frame.PixelData() + + (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep()); + for (int j = 0; j < width; ++j) { + for (int c = 0; c < channels_preserved; ++c) { + *tensor_buffer++ = *image_ptr++ / div - sub; + } + image_ptr += channels_ignored; + } + } + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.proto b/mediapipe/calculators/tflite/tflite_converter_calculator.proto new file mode 100644 index 000000000..f4c931c11 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.proto @@ -0,0 +1,41 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Full Example: +// +// node { +// calculator: "TfLiteConverterCalculator" +// input_stream: "IMAGE_IN:input_image" +// output_stream: "TENSOR_OUT:image_tensor" +// options { +// [mediapipe.TfLiteConverterCalculatorOptions.ext] { +// zero_center: true +// } +// } +// } +// +message TfLiteConverterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TfLiteConverterCalculatorOptions ext = 245817797; + } + + // Choose normalization mode for output: + // true = [-1,1] + // false = [0,1] + optional bool zero_center = 1 [default = true]; + + // Whether the input image should be flipped vertically (along the + // y-direction). This is useful, for example, when the input image is defined + // with a coordinate system where the origin is at the bottom-left corner + // (e.g., in OpenGL) whereas the ML model expects an image with a top-left + // origin. + optional bool flip_vertically = 2 [default = false]; + + // Controls how many channels of the input image get passed through to the + // tensor. Currently this only controls whether or not to ignore alpha + // channel, so it must be 3 or 4. + optional int32 max_num_channels = 3 [default = 3]; +} diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc new file mode 100644 index 000000000..4628062e7 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/tflite/cpu_op_resolver.h" +#include "mediapipe/util/tflite/op_resolver.h" + +namespace mediapipe { + +// This calculator creates a custom op resolver as a side packet that can be +// used in TfLiteInferenceCalculator. Current custom op resolver supports the +// following custom op on CPU and GPU: +// Convolution2DTransposeBias +// MaxPoolArgmax +// MaxUnpooling +// +// Usage example: +// node { +// calculator: "TfLiteCustomOpResolverCalculator" +// output_side_packet: "op_resolver" +// node_options: { +// [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { +// use_gpu: true +// } +// } +// } +class TfLiteCustomOpResolverCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->OutputSidePackets() + .Index(0) + .Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const TfLiteCustomOpResolverCalculatorOptions& options = + cc->Options(); + + std::unique_ptr op_resolver; + if (options.use_gpu()) { + op_resolver = absl::make_unique<::mediapipe::OpResolver>(); + } else { + op_resolver = absl::make_unique<::mediapipe::CpuOpResolver>(); + } + + cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(TfLiteCustomOpResolverCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.proto b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.proto new file mode 100644 index 000000000..546165a9f --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Options to generate an op resolver for running TfLite inference. +message TfLiteCustomOpResolverCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TfLiteCustomOpResolverCalculatorOptions ext = 252087553; + } + + // Flag for using GPU inference which uses the correspondent op resolver. + optional bool use_gpu = 1 [default = false]; +} diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc new file mode 100644 index 000000000..392a6c853 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -0,0 +1,472 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // ANDROID + +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS +#if defined(__OBJC__) +#import +#import +#endif // OBJC +#import "mediapipe/framework/ios/NSError+util_status.h" +#import "mediapipe/gpu/MediaPipeMetalHelper.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate.h" +#endif // APPLE && !TARGET_OS_OSX + +// TfLiteInferenceCalculator File Layout: +// * Header +// * Core +// * Aux +namespace mediapipe { + +#if defined(__ANDROID__) +using ::tflite::gpu::gl::GlBuffer; +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +struct GPUData { + int elements = 1; + GlBuffer ssbo; + GlShader shader; + GlProgram program; +}; +#endif // ANDROID + +// Calculator Header Section + +// Runs inference on the provided input TFLite tensors and TFLite model. +// +// Creates an interpreter with given model and calls invoke(). +// Optionally run inference on CPU/GPU. +// +// This calculator is designed to be used with the TfLiteConverterCalcualtor, +// to get the appropriate inputs. +// +// When the input tensors are on CPU, gpu inference is optional and can be +// specified in the calculator options. +// When the input tensors are on GPU, inference is GPU and output can be CPU or +// GPU. +// +// Input: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 +// TENSORS_GPU - Vector of GlBuffer (assumed to be RGB image) +// +// Output: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 +// TENSORS_GPU - Vector of GlBuffer +// +// Input side packet: +// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, +// instead of the builtin one. +// +// Example use: +// node { +// calculator: "TfLiteInferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.TfLiteInferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// use_gpu: true +// } +// } +// } +// +// IMPORTANT Notes: +// Tensors are assumed to be ordered correctly (sequentially added to model). +// Input tensors are assumed to be of the correct size and already normalized. +// All output TfLiteTensors will be destroyed when the graph closes, +// (i.e. after calling graph.WaitUntilDone()). +// GPU tensors are currently only supported on Android. +// This calculator uses FixedSizeInputStreamHandler by default. +// +class TfLiteInferenceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + ::mediapipe::Status LoadModel(CalculatorContext* cc); + ::mediapipe::Status LoadDelegate(CalculatorContext* cc); + + std::unique_ptr interpreter_; + std::unique_ptr model_; + TfLiteDelegate* delegate_ = nullptr; + +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr gpu_data_in_; + std::vector> gpu_data_out_; +#endif +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS + MediaPipeMetalHelper* gpu_helper_ = nullptr; +#endif + + std::string model_path_ = ""; + bool gpu_inference_ = false; + bool gpu_input_ = false; + bool gpu_output_ = false; +}; // TfLiteInferenceCalculator + +REGISTER_CALCULATOR(TfLiteInferenceCalculator); + +// Calculator Core Section + +::mediapipe::Status TfLiteInferenceCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("TENSORS") || + cc->Inputs().HasTag("TENSORS_GPU")); + RET_CHECK(cc->Outputs().HasTag("TENSORS") || + cc->Outputs().HasTag("TENSORS_GPU")); + + if (cc->Inputs().HasTag("TENSORS")) + cc->Inputs().Tag("TENSORS").Set>(); +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("TENSORS_GPU")) + cc->Inputs().Tag("TENSORS_GPU").Set>(); +#endif + + if (cc->Outputs().HasTag("TENSORS")) + cc->Outputs().Tag("TENSORS").Set>(); +#if defined(__ANDROID__) + if (cc->Outputs().HasTag("TENSORS_GPU")) + cc->Outputs().Tag("TENSORS_GPU").Set>(); +#endif + + if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + cc->InputSidePackets() + .Tag("CUSTOM_OP_RESOLVER") + .Set(); + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + RETURN_IF_ERROR([MediaPipeMetalHelper updateContract:cc]); +#endif + + // Assign this calculator's default InputStreamHandler. + cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { + RETURN_IF_ERROR(LoadOptions(cc)); + + if (cc->Inputs().HasTag("TENSORS_GPU")) { +#if defined(__ANDROID__) + gpu_input_ = true; + gpu_inference_ = true; // Inference must be on GPU also. +#else + RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) + << "GPU input for non-Android not supported yet."; +#endif + } + + if (cc->Outputs().HasTag("TENSORS_GPU")) { +#if defined(__ANDROID__) + gpu_output_ = true; + RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU")) + << "GPU output must also have GPU Input."; +#else + RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) + << "GPU output for non-Android not supported yet."; +#endif + } + + RETURN_IF_ERROR(LoadModel(cc)); + + if (gpu_inference_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + gpu_helper_ = [[MediaPipeMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); +#endif + + RETURN_IF_ERROR(LoadDelegate(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { + // Receive pre-processed tensor inputs. + if (gpu_input_) { + // Read GPU input into SSBO. +#if defined(__ANDROID__) + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + RET_CHECK_EQ(input_tensors.size(), 1); + RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Explicit copy input. + tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->ssbo); + // Run inference. + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + return ::mediapipe::OkStatus(); + })); +#else + RET_CHECK_FAIL() + << "GPU input on non-Android devices is not supported yet."; +#endif + } else { + // Read CPU input into tensors. + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + RET_CHECK_GT(input_tensors.size(), 0); + for (int i = 0; i < input_tensors.size(); ++i) { + const TfLiteTensor* input_tensor = &input_tensors[i]; + const float* input_tensor_buffer = input_tensor->data.f; + RET_CHECK(input_tensor_buffer); + + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + RET_CHECK(local_tensor_buffer); + + memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); + } + + // Run inference. + if (gpu_inference_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + return ::mediapipe::OkStatus(); + })); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); +#endif + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } + } + + if (gpu_output_) { +#if defined(__ANDROID__) + // Output result tensors (GPU). + auto output_tensors = absl::make_unique>(); + output_tensors->resize(gpu_data_out_.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + GlBuffer& tensor = output_tensors->at(i); + using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; + auto status = CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &tensor); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->ssbo, tensor); + } + cc->Outputs() + .Tag("TENSORS_GPU") + .Add(output_tensors.release(), cc->InputTimestamp()); +#else + LOG(ERROR) << "GPU output on non-Android not supported yet."; +#endif + } else { + // Output result tensors (CPU). + const auto& tensor_indexes = interpreter_->outputs(); + auto output_tensors = absl::make_unique>(); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + output_tensors->emplace_back(*tensor); + } + cc->Outputs().Tag("TENSORS").Add(output_tensors.release(), + cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { + if (delegate_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { + TfLiteGpuDelegateDelete(delegate_); + gpu_data_in_.reset(); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + gpu_data_out_[i].reset(); + } + return ::mediapipe::OkStatus(); + })); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + DeleteGpuDelegate(delegate_); +#endif + } + return ::mediapipe::OkStatus(); +} + +// Calculator Auxiliary Section + +::mediapipe::Status TfLiteInferenceCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + const auto& options = + cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); + + // Get model name. + if (!options.model_path().empty()) { + ASSIGN_OR_RETURN(model_path_, + mediapipe::PathToResourceAsFile(options.model_path())); + } else { + LOG(ERROR) << "Must specify path to TFLite model."; + return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound, + "Must specify path to TFLite model."); + } + + // Get execution modes. + gpu_inference_ = options.use_gpu(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::LoadModel( + CalculatorContext* cc) { + model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str()); + RET_CHECK(model_); + +#if !defined(__ANDROID__) && !(defined(__APPLE__) && !TARGET_OS_OSX) + LOG(WARNING) << "GPU only supported on mobile platforms. Using CPU fallback."; + gpu_inference_ = false; +#endif + + if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + const auto& op_resolver = + cc->InputSidePackets() + .Tag("CUSTOM_OP_RESOLVER") + .Get(); + tflite::InterpreterBuilder(*model_, op_resolver)(&interpreter_); + } else { + const tflite::ops::builtin::BuiltinOpResolver op_resolver; + tflite::InterpreterBuilder(*model_, op_resolver)(&interpreter_); + } + + RET_CHECK(interpreter_); + + if (!gpu_output_) { + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( + CalculatorContext* cc) { +#if defined(__ANDROID__) + // Get input image sizes. + if (gpu_input_) { + gpu_data_in_ = absl::make_unique(); + const auto& input_indices = interpreter_->inputs(); + // TODO accept > 1. + RET_CHECK_EQ(input_indices.size(), 1); + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]); + gpu_data_in_->elements = 1; + for (int d = 0; d < tensor->dims->size; ++d) { + gpu_data_in_->elements *= tensor->dims->data[d]; + } + // Input to model can be either RGB/RGBA only. + RET_CHECK_GE(tensor->dims->data[3], 3); + RET_CHECK_LE(tensor->dims->data[3], 4); + } + // Get output image sizes. + if (gpu_output_) { + const auto& output_indices = interpreter_->outputs(); + gpu_data_out_.resize(output_indices.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + gpu_data_out_[i] = absl::make_unique(); + gpu_data_out_[i]->elements = 1; + // TODO handle *2 properly on some dialated models + for (int d = 0; d < tensor->dims->size; ++d) { + gpu_data_out_[i]->elements *= tensor->dims->data[d]; + } + } + } + // Configure and create the delegate. + TfLiteGpuDelegateOptions options; + options.metadata = nullptr; + options.compile_options.precision_loss_allowed = 1; + options.compile_options.preferred_gl_object_type = + TFLITE_GL_OBJECT_TYPE_FASTEST; + options.compile_options.dynamic_batch_enabled = 0; + if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options); + // Shader to convert GL texture to SSBO. + if (gpu_input_) { + auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_in_->elements, &gpu_data_in_->ssbo); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_, gpu_data_in_->ssbo.id(), + interpreter_->inputs()[0]), // First tensor only + kTfLiteOk); + } + // Create output SSBO buffers. + if (gpu_output_) { + interpreter_->SetAllowBufferHandleOutput(true); + const auto& output_indices = interpreter_->outputs(); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; + auto status = CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &gpu_data_out_[i]->ssbo); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + RET_CHECK_EQ( + TfLiteGpuDelegateBindBufferToTensor( + delegate_, gpu_data_out_[i]->ssbo.id(), output_indices[i]), + kTfLiteOk); + } + } + // Must call this last. + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); + return ::mediapipe::OkStatus(); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + GpuDelegateOptions options; + options.allow_precision_loss = 1; + options.wait_type = GpuDelegateOptions::WaitType::kPassive; + if (!delegate_) delegate_ = NewGpuDelegate(&options); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); + return ::mediapipe::OkStatus(); +#endif // ANDROID or iOS + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto new file mode 100644 index 000000000..a2950add3 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -0,0 +1,48 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Full Example: +// +// node { +// calculator: "TfLiteInferenceCalculator" +// input_stream: "TENSOR_IN:image_tensors" +// output_stream: "TENSOR_OUT:result_tensors" +// options { +// [mediapipe.TfLiteInferenceCalculatorOptions.ext] { +// model_path: "model.tflite" +// use_gpu: true +// } +// } +// } +// +message TfLiteInferenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TfLiteInferenceCalculatorOptions ext = 233867213; + } + + // Path to the TF Lite model (ex: /path/to/modelname.tflite). + // On mobile, this is generally just modelname.tflite. + optional string model_path = 1; + + // Whether the TF Lite GPU or CPU backend should be used. Effective only when + // input tensors are on CPU. For input tensors on GPU, GPU backend is always + // used. + optional bool use_gpu = 2 [default = false]; +} diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc new file mode 100644 index 000000000..ab2c8c87a --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc @@ -0,0 +1,123 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +#ifdef __APPLE__ +#include +#endif // defined(__APPLE__) + +namespace mediapipe { + +using ::tflite::Interpreter; + +class TfLiteInferenceCalculatorTest : public ::testing::Test { + protected: + std::unique_ptr runner_ = nullptr; +}; + +// Tests a simple add model that adds an input tensor to itself. +TEST_F(TfLiteInferenceCalculatorTest, SmokeTest) { + const int width = 8; + const int height = 8; + const int channels = 3; + + // Prepare input tensor. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(interpreter, nullptr); + + interpreter->AddTensors(1); + interpreter->SetInputs({0}); + interpreter->SetOutputs({0}); + interpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + TfLiteQuantization()); + int t = interpreter->inputs()[0]; + TfLiteTensor* tensor = interpreter->tensor(t); + interpreter->ResizeInputTensor(t, {width, height, channels}); + interpreter->AllocateTensors(); + + float* tensor_buffer = tensor->data.f; + ASSERT_NE(tensor_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + tensor_buffer[i] = 1; + } + + auto input_vec = absl::make_unique>(); + input_vec->emplace_back(*tensor); + + // Prepare single calculator graph to and wait for packets. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:tensor_in" + output_stream: "TENSORS:tensor_out" + options { + [mediapipe.TfLiteInferenceCalculatorOptions.ext] { + use_gpu: false + model_path: "mediapipe/calculators/tflite/testdata/add.bin" + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor_out", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + // Push the tensor into the graph. + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); + // Wait until the calculator done processing. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& result_vec = + output_packets[0].Get>(); + ASSERT_EQ(1, result_vec.size()); + + const TfLiteTensor* result = &result_vec[0]; + float* result_buffer = result->data.f; + ASSERT_NE(result_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + ASSERT_EQ(3, result_buffer[i]); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc new file mode 100644 index 000000000..54ad20190 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -0,0 +1,747 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/interpreter.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // ANDROID + +#if defined(__ANDROID__) +using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; +using ::tflite::gpu::gl::GlBuffer; +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +#endif // ANDROID + +namespace mediapipe { + +namespace { + +constexpr int kNumInputTensorsWithAnchors = 3; +constexpr int kNumCoordsPerBox = 4; + +void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, + std::vector* anchors) { + anchors->clear(); + for (int i = 0; i < num_boxes; ++i) { + Anchor new_anchor; + new_anchor.set_y_center(raw_anchors[i * kNumCoordsPerBox + 0]); + new_anchor.set_x_center(raw_anchors[i * kNumCoordsPerBox + 1]); + new_anchor.set_h(raw_anchors[i * kNumCoordsPerBox + 2]); + new_anchor.set_w(raw_anchors[i * kNumCoordsPerBox + 3]); + anchors->push_back(new_anchor); + } +} + +void ConvertAnchorsToRawValues(const std::vector& anchors, + int num_boxes, float* raw_anchors) { + CHECK_EQ(anchors.size(), num_boxes); + int box = 0; + for (auto anchor : anchors) { + raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); + raw_anchors[box * kNumCoordsPerBox + 1] = anchor.x_center(); + raw_anchors[box * kNumCoordsPerBox + 2] = anchor.h(); + raw_anchors[box * kNumCoordsPerBox + 3] = anchor.w(); + ++box; + } +} + +} // namespace + +// Convert result TFLite tensors from object detection models into MediaPipe +// Detections. +// +// Input: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. The vector of +// tensors can have 2 or 3 tensors. First tensor is the predicted +// raw boxes/keypoints. The size of the values must be (num_boxes +// * num_predicted_values). Second tensor is the score tensor. The +// size of the valuse must be (num_boxes * num_classes). It's +// optional to pass in a third tensor for anchors (e.g. for SSD +// models) depend on the outputs of the detection model. The size +// of anchor tensor must be (num_boxes * 4). +// TENSORS_GPU - vector of GlBuffer. +// Output: +// DETECTIONS - Result MediaPipe detections. +// +// Usage example: +// node { +// calculator: "TfLiteTensorsToDetectionsCalculator" +// input_stream: "TENSORS:tensors" +// input_side_packet: "ANCHORS:anchors" +// output_stream: "DETECTIONS:detections" +// options: { +// [mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] { +// num_classes: 91 +// num_boxes: 1917 +// num_coords: 4 +// ignore_classes: [0, 1, 2] +// x_scale: 10.0 +// y_scale: 10.0 +// h_scale: 5.0 +// w_scale: 5.0 +// } +// } +// } +class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + ::mediapipe::Status GlSetup(CalculatorContext* cc); + ::mediapipe::Status DecodeBoxes(const float* raw_boxes, + const std::vector& anchors, + std::vector* boxes); + Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, + float box_xmax, float score, int class_id, + bool flip_vertically); + + int num_classes_ = 0; + int num_boxes_ = 0; + int num_coords_ = 0; + std::set ignore_classes_; + + ::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_; + std::vector anchors_; + +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr decode_program_; + std::unique_ptr score_program_; + std::unique_ptr decoded_boxes_buffer_; + std::unique_ptr raw_boxes_buffer_; + std::unique_ptr raw_anchors_buffer_; + std::unique_ptr scored_boxes_buffer_; + std::unique_ptr raw_scores_buffer_; +#endif + + bool gpu_input_ = false; + bool anchors_init_ = false; +}; +REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + if (cc->Inputs().HasTag("TENSORS")) { + cc->Inputs().Tag("TENSORS").Set>(); + } + +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("TENSORS_GPU")) { + cc->Inputs().Tag("TENSORS_GPU").Set>(); + } +#endif + + if (cc->Outputs().HasTag("DETECTIONS")) { + cc->Outputs().Tag("DETECTIONS").Set>(); + } + + if (cc->InputSidePackets().UsesTags()) { + if (cc->InputSidePackets().HasTag("ANCHORS")) { + cc->InputSidePackets().Tag("ANCHORS").Set>(); + } + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open( + CalculatorContext* cc) { + if (cc->Inputs().HasTag("TENSORS_GPU")) { + gpu_input_ = true; +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } + + RETURN_IF_ERROR(LoadOptions(cc)); + + if (gpu_input_) { + RETURN_IF_ERROR(GlSetup(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Process( + CalculatorContext* cc) { + if ((!gpu_input_ && cc->Inputs().Tag("TENSORS").IsEmpty()) || + (gpu_input_ && cc->Inputs().Tag("TENSORS_GPU").IsEmpty())) { + return ::mediapipe::OkStatus(); + } + + const bool side_packet_anchors = + cc->InputSidePackets().HasTag("ANCHORS") && + !cc->InputSidePackets().Tag("ANCHORS").IsEmpty(); + auto output_detections = absl::make_unique>(); + + std::vector boxes(num_boxes_ * num_coords_); + std::vector score_class_id_pairs(num_boxes_ * 2); + + if (gpu_input_) { +#if defined(__ANDROID__) + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + + // Copy inputs. + tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get()); + tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get()); + if (!anchors_init_) { + if (side_packet_anchors) { + const auto& anchors = + cc->InputSidePackets().Tag("ANCHORS").Get>(); + std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); + ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); + raw_anchors_buffer_->Write(absl::MakeSpan(raw_anchors)); + } else { + CHECK_EQ(input_tensors.size(), 3); + tflite::gpu::gl::CopyBuffer(input_tensors[2], + *raw_anchors_buffer_.get()); + } + anchors_init_ = true; + } + + // Run shaders. + RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Decode boxes. + decoded_boxes_buffer_->BindToIndex(0); + raw_boxes_buffer_->BindToIndex(1); + raw_anchors_buffer_->BindToIndex(2); + const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; + decode_program_->Dispatch(decode_workgroups); + + // Score boxes. + scored_boxes_buffer_->BindToIndex(0); + raw_scores_buffer_->BindToIndex(1); + const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; + score_program_->Dispatch(score_workgroups); + + return ::mediapipe::OkStatus(); + })); + + // Copy decoded boxes from GPU to CPU. + auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes)); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs)); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } +#else + LOG(ERROR) << "GPU input on non-Android not supported yet."; +#endif // defined(__ANDROID__) + } else { + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + + const TfLiteTensor* raw_box_tensor = &input_tensors[0]; + const TfLiteTensor* raw_score_tensor = &input_tensors[1]; + + // TODO: Add flexible input tensor size handling. + CHECK_EQ(raw_box_tensor->dims->size, 3); + CHECK_EQ(raw_box_tensor->dims->data[0], 1); + CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_); + CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_); + CHECK_EQ(raw_score_tensor->dims->size, 3); + CHECK_EQ(raw_score_tensor->dims->data[0], 1); + CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_); + CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_); + const float* raw_boxes = raw_box_tensor->data.f; + const float* raw_scores = raw_score_tensor->data.f; + + // TODO: Support other options to load anchors. + if (!anchors_init_) { + if (input_tensors.size() == kNumInputTensorsWithAnchors) { + const TfLiteTensor* anchor_tensor = &input_tensors[2]; + CHECK_EQ(anchor_tensor->dims->size, 2); + CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_); + CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); + const float* raw_anchors = anchor_tensor->data.f; + ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); + } else if (side_packet_anchors) { + anchors_ = + cc->InputSidePackets().Tag("ANCHORS").Get>(); + } else { + return ::mediapipe::UnavailableError("No anchor data available."); + } + anchors_init_ = true; + } + RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes)); + + // Filter classes by scores. + for (int i = 0; i < num_boxes_; ++i) { + int class_id = -1; + float max_score = -std::numeric_limits::max(); + // Find the top score for box i. + for (int score_idx = 0; score_idx < num_classes_; ++score_idx) { + if (ignore_classes_.find(score_idx) == ignore_classes_.end()) { + auto score = raw_scores[i * num_classes_ + score_idx]; + if (options_.sigmoid_score()) { + if (options_.has_score_clipping_thresh()) { + score = score < -options_.score_clipping_thresh() + ? -options_.score_clipping_thresh() + : score; + score = score > options_.score_clipping_thresh() + ? options_.score_clipping_thresh() + : score; + } + score = 1.0f / (1.0f + std::exp(-score)); + } + if (max_score < score) { + max_score = score; + class_id = score_idx; + } + } + } + score_class_id_pairs[i * 2 + 0] = max_score; + score_class_id_pairs[i * 2 + 1] = class_id; + } + } // if gpu_input_ + + // Convert to Detection. + for (int i = 0; i < num_boxes_; ++i) { + const float score = score_class_id_pairs[i * 2 + 0]; + const int class_id = score_class_id_pairs[i * 2 + 1]; + const int box_offset = i * num_coords_; + Detection detection = ConvertToDetection( + boxes[box_offset + 0], boxes[box_offset + 1], boxes[box_offset + 2], + boxes[box_offset + 3], score, class_id, options_.flip_vertically()); + // Add keypoints. + if (options_.num_keypoints() > 0) { + auto* location_data = detection.mutable_location_data(); + for (int kp_id = 0; kp_id < options_.num_keypoints() * + options_.num_values_per_keypoint(); + kp_id += options_.num_values_per_keypoint()) { + auto keypoint = location_data->add_relative_keypoints(); + const int keypoint_index = + box_offset + options_.keypoint_coord_offset() + kp_id; + keypoint->set_x(boxes[keypoint_index + 0]); + keypoint->set_y(options_.flip_vertically() + ? 1.f - boxes[keypoint_index + 1] + : boxes[keypoint_index + 1]); + } + } + output_detections->emplace_back(detection); + } + + // Output + if (cc->Outputs().HasTag("DETECTIONS")) { + cc->Outputs() + .Tag("DETECTIONS") + .Add(output_detections.release(), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( + CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { + decode_program_.reset(); + score_program_.reset(); + decoded_boxes_buffer_.reset(); + raw_boxes_buffer_.reset(); + raw_anchors_buffer_.reset(); + scored_boxes_buffer_.reset(); + raw_scores_buffer_.reset(); + }); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = + cc->Options<::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions>(); + + num_classes_ = options_.num_classes(); + num_boxes_ = options_.num_boxes(); + num_coords_ = options_.num_coords(); + + // Currently only support 2D when num_values_per_keypoint equals to 2. + CHECK_EQ(options_.num_values_per_keypoint(), 2); + + // Check if the output size is equal to the requested boxes and keypoints. + CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + + kNumCoordsPerBox, + num_coords_); + + for (int i = 0; i < options_.ignore_classes_size(); ++i) { + ignore_classes_.insert(options_.ignore_classes(i)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( + const float* raw_boxes, const std::vector& anchors, + std::vector* boxes) { + for (int i = 0; i < num_boxes_; ++i) { + const int box_offset = i * num_coords_ + options_.box_coord_offset(); + + float y_center = raw_boxes[box_offset]; + float x_center = raw_boxes[box_offset + 1]; + float h = raw_boxes[box_offset + 2]; + float w = raw_boxes[box_offset + 3]; + if (options_.reverse_output_order()) { + x_center = raw_boxes[box_offset]; + y_center = raw_boxes[box_offset + 1]; + w = raw_boxes[box_offset + 2]; + h = raw_boxes[box_offset + 3]; + } + + x_center = + x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center(); + y_center = + y_center / options_.y_scale() * anchors[i].h() + anchors[i].y_center(); + + if (options_.apply_exponential_on_box_size()) { + h = std::exp(h / options_.h_scale()) * anchors[i].h(); + w = std::exp(w / options_.w_scale()) * anchors[i].w(); + } else { + h = h / options_.h_scale() * anchors[i].h(); + w = w / options_.w_scale() * anchors[i].w(); + } + + const float ymin = y_center - h / 2.f; + const float xmin = x_center - w / 2.f; + const float ymax = y_center + h / 2.f; + const float xmax = x_center + w / 2.f; + + (*boxes)[i * num_coords_ + 0] = ymin; + (*boxes)[i * num_coords_ + 1] = xmin; + (*boxes)[i * num_coords_ + 2] = ymax; + (*boxes)[i * num_coords_ + 3] = xmax; + + if (options_.num_keypoints()) { + for (int k = 0; k < options_.num_keypoints(); ++k) { + const int offset = i * num_coords_ + options_.keypoint_coord_offset() + + k * options_.num_values_per_keypoint(); + + float keypoint_y = raw_boxes[offset]; + float keypoint_x = raw_boxes[offset + 1]; + if (options_.reverse_output_order()) { + keypoint_x = raw_boxes[offset]; + keypoint_y = raw_boxes[offset + 1]; + } + + (*boxes)[offset] = keypoint_x / options_.x_scale() * anchors[i].w() + + anchors[i].x_center(); + (*boxes)[offset + 1] = + keypoint_y / options_.y_scale() * anchors[i].h() + + anchors[i].y_center(); + } + } + } + return ::mediapipe::OkStatus(); +} + +Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( + float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, + int class_id, bool flip_vertically) { + Detection detection; + detection.add_score(score); + detection.add_label_id(class_id); + + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + + LocationData::RelativeBoundingBox* relative_bbox = + location_data->mutable_relative_bounding_box(); + + relative_bbox->set_xmin(box_xmin); + relative_bbox->set_ymin(flip_vertically ? 1.f - box_ymax : box_ymin); + relative_bbox->set_width(box_xmax - box_xmin); + relative_bbox->set_height(box_ymax - box_ymin); + return detection; +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GlSetup( + CalculatorContext* cc) { +#if defined(__ANDROID__) + // A shader to decode detection boxes. + const std::string decode_src = absl::Substitute( + R"( +#version 310 es + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(location = 0) uniform vec4 scale; + +layout(std430, binding = 0) writeonly buffer Output { + float data[]; +} boxes; + +layout(std430, binding = 1) readonly buffer Input0 { + float data[]; +} raw_boxes; + +layout(std430, binding = 2) readonly buffer Input1 { + float data[]; +} raw_anchors; + +uint num_coords = uint($0); +int reverse_output_order = int($1); +int apply_exponential = int($2); +int box_coord_offset = int($3); +int num_keypoints = int($4); +int keypt_coord_offset = int($5); +int num_values_per_keypt = int($6); + +void main() { + uint g_idx = gl_GlobalInvocationID.x; // box index + uint box_offset = g_idx * num_coords + uint(box_coord_offset); + uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox + + float y_center, x_center, h, w; + + if (reverse_output_order == int(0)) { + y_center = raw_boxes.data[box_offset + uint(0)]; + x_center = raw_boxes.data[box_offset + uint(1)]; + h = raw_boxes.data[box_offset + uint(2)]; + w = raw_boxes.data[box_offset + uint(3)]; + } else { + x_center = raw_boxes.data[box_offset + uint(0)]; + y_center = raw_boxes.data[box_offset + uint(1)]; + w = raw_boxes.data[box_offset + uint(2)]; + h = raw_boxes.data[box_offset + uint(3)]; + } + + float anchor_yc = raw_anchors.data[anchor_offset + uint(0)]; + float anchor_xc = raw_anchors.data[anchor_offset + uint(1)]; + float anchor_h = raw_anchors.data[anchor_offset + uint(2)]; + float anchor_w = raw_anchors.data[anchor_offset + uint(3)]; + + x_center = x_center / scale.x * anchor_w + anchor_xc; + y_center = y_center / scale.y * anchor_h + anchor_yc; + + if (apply_exponential == int(1)) { + h = exp(h / scale.w) * anchor_h; + w = exp(w / scale.z) * anchor_w; + } else { + h = (h / scale.w) * anchor_h; + w = (w / scale.z) * anchor_w; + } + + float ymin = y_center - h / 2.0; + float xmin = x_center - w / 2.0; + float ymax = y_center + h / 2.0; + float xmax = x_center + w / 2.0; + + boxes.data[box_offset + uint(0)] = ymin; + boxes.data[box_offset + uint(1)] = xmin; + boxes.data[box_offset + uint(2)] = ymax; + boxes.data[box_offset + uint(3)] = xmax; + + if (num_keypoints > int(0)){ + for (int k = 0; k < num_keypoints; ++k) { + int kp_offset = + int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; + float kp_y, kp_x; + if (reverse_output_order == int(0)) { + kp_y = raw_boxes.data[kp_offset + int(0)]; + kp_x = raw_boxes.data[kp_offset + int(1)]; + } else { + kp_x = raw_boxes.data[kp_offset + int(0)]; + kp_y = raw_boxes.data[kp_offset + int(1)]; + } + boxes.data[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc; + boxes.data[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc; + } + } +})", + options_.num_coords(), // box xywh + options_.reverse_output_order() ? 1 : 0, + options_.apply_exponential_on_box_size() ? 1 : 0, + options_.box_coord_offset(), options_.num_keypoints(), + options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); + + // Shader program + GlShader decode_shader; + auto status = + GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + decode_program_ = absl::make_unique(); + status = GlProgram::CreateWithShader(decode_shader, decode_program_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + // Outputs + size_t decoded_boxes_length = num_boxes_ * num_coords_; + decoded_boxes_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer( + decoded_boxes_length, decoded_boxes_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + // Inputs + size_t raw_boxes_length = num_boxes_ * num_coords_; + raw_boxes_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer(raw_boxes_length, + raw_boxes_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + size_t raw_anchors_length = num_boxes_ * num_coords_; + raw_anchors_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer(raw_anchors_length, + raw_anchors_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + // Parameters + glUseProgram(decode_program_->id()); + glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(), + options_.h_scale()); + + // A shader to score detection boxes. + const std::string score_src = absl::Substitute( + R"( +#version 310 es + +layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in; + +#define FLT_MAX 1.0e+37 + +shared float local_scores[$0]; + +layout(std430, binding = 0) writeonly buffer Output { + float data[]; +} scored_boxes; + +layout(std430, binding = 1) readonly buffer Input0 { + float data[]; +} raw_scores; + +uint num_classes = uint($0); +int apply_sigmoid = int($1); +int apply_clipping_thresh = int($2); +float clipping_thresh = float($3); +int ignore_class_0 = int($4); + +float optional_sigmoid(float x) { + if (apply_sigmoid == int(0)) return x; + if (apply_clipping_thresh == int(1)) { + x = clamp(x, -clipping_thresh, clipping_thresh); + } + x = 1.0 / (1.0 + exp(-x)); + return x; +} + +void main() { + uint g_idx = gl_GlobalInvocationID.x; // box idx + uint s_idx = gl_LocalInvocationID.y; // score/class idx + + // load all scores into shared memory + float score = raw_scores.data[g_idx * num_classes + s_idx]; + local_scores[s_idx] = optional_sigmoid(score); + memoryBarrierShared(); + barrier(); + + // find max score in shared memory + if (s_idx == uint(0)) { + float max_score = -FLT_MAX; + float max_class = -1.0; + for (int i=ignore_class_0; i max_score) { + max_score = local_scores[i]; + max_class = float(i); + } + } + scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score; + scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class; + } +})", + num_classes_, options_.sigmoid_score() ? 1 : 0, + options_.has_score_clipping_thresh() ? 1 : 0, + options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() + : 0, + ignore_classes_.size() ? 1 : 0); + + // # filter classes supported is hardware dependent. + int max_wg_size; // typically <= 1024 + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim + CHECK_LT(num_classes_, max_wg_size) << "# classes must be < " << max_wg_size; + // TODO support better filtering. + CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + + // Shader program + GlShader score_shader; + status = GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + score_program_ = absl::make_unique(); + status = GlProgram::CreateWithShader(score_shader, score_program_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + // Outputs + size_t scored_boxes_length = num_boxes_ * 2; // score, class + scored_boxes_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer( + scored_boxes_length, scored_boxes_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + // Inputs + size_t raw_scores_length = num_boxes_ * num_classes_; + raw_scores_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer(raw_scores_length, + raw_scores_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + +#endif // defined(__ANDROID__) + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto new file mode 100644 index 000000000..ca4688086 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto @@ -0,0 +1,71 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TfLiteTensorsToDetectionsCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TfLiteTensorsToDetectionsCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TfLiteTensorsToDetectionsCalculatorOptions ext = 246514968; + } + + // The number of output classes predicted by the detection model. + required int32 num_classes = 1; + // The number of output boxes predicted by the detection model. + required int32 num_boxes = 2; + // The number of output values per boxes predicted by the detection model. The + // values contain bounding boxes, keypoints, etc. + required int32 num_coords = 3; + + // The offset of keypoint coordinates in the location tensor. + optional int32 keypoint_coord_offset = 9; + // The number of predicted keypoints. + optional int32 num_keypoints = 10 [default = 0]; + // The dimension of each keypoint, e.g. number of values predicted for each + // keypoint. + optional int32 num_values_per_keypoint = 11 [default = 2]; + // The offset of box coordinates in the location tensor. + optional int32 box_coord_offset = 12 [default = 0]; + + // Parameters for decoding SSD detection model. + optional float x_scale = 4 [default = 0.0]; + optional float y_scale = 5 [default = 0.0]; + optional float w_scale = 6 [default = 0.0]; + optional float h_scale = 7 [default = 0.0]; + + optional bool apply_exponential_on_box_size = 13 [default = false]; + + // Whether to reverse the order of predicted x, y from output. + // If false, the order is [y_center, x_center, h, w], if true the order is + // [x_center, y_center, w, h]. + optional bool reverse_output_order = 14 [default = false]; + // The ids of classes that should be ignored during decoding the score for + // each predicted box. + repeated int32 ignore_classes = 8; + + optional bool sigmoid_score = 15 [default = false]; + optional float score_clipping_thresh = 16; + + // Whether the detection coordinates from the input tensors should be flipped + // vertically (along the y-direction). This is useful, for example, when the + // input tensors represent detections defined with a coordinate system where + // the origin is at the top-left corner, whereas the desired detection + // representation has a bottom-left origin (e.g., in OpenGL). + optional bool flip_vertically = 18 [default = false]; +} diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc new file mode 100644 index 000000000..73258c2a7 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -0,0 +1,589 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/interpreter.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // __ANDROID__ + +namespace { +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +// Commonly used to compute the number of blocks to launch in a kernel. +int RoundUp(const int size, const int multiple) { + return (size + multiple - 1) / multiple; +} +} // namespace + +namespace mediapipe { + +#if defined(__ANDROID__) +using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; +using ::tflite::gpu::gl::GlBuffer; +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +#endif // __ANDROID__ + +// Converts TFLite tensors from a tflite segmentation model to an image mask. +// +// Performs optional upscale to REFERENCE_IMAGE dimensions if provided, +// otherwise the mask is the same size as input tensor. +// +// Note: This calculator is currently GPU only, so only *_GPU tags can be used. +// +// Inputs: +// One of the following TENSORS tags: +// TENSORS: Vector of TfLiteTensor of type kTfLiteFloat32. +// The tensor dimensions are specified in this calculator's options. +// TENSORS_GPU: Vector of GlBuffer. +// One of the following REFERENCE_IMAGE tags: +// REFERENCE_IMAGE (optional): An ImageFrame input image, +// used only for output dimensions. +// REFERENCE_IMAGE_GPU (optional): A GpuBuffer input image, +// used only for output dimensions. +// One of the following PREV_MASK tags: +// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA. +// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA. +// Output: +// One of the following MASK tags: +// MASK: An ImageFrame output mask, Gray, RGB or RGBA. +// MASK_GPU: A GpuBuffer output mask, RGBA. +// +// Options: +// See tflite_segmentation_calculator.proto +// +// Usage example: +// node { +// calculator: "TfLiteTensorsToSegmentationCalculator" +// input_stream: "TENSORS_GPU:tensors" +// input_stream: "IMAGE_GPU:input_video" +// output_stream: "MASK_GPU:hair_mask" +// node_options: { +// [mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { +// tensor_in_width: 512 +// tensor_in_height: 512 +// tensor_in_channels: 2 +// combine_with_previous_ratio: 1.0 +// output_layer_index: 1 +// } +// } +// } +// +class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + ::mediapipe::Status InitGpu(CalculatorContext* cc); + ::mediapipe::Status ProcessGpu(CalculatorContext* cc); + ::mediapipe::Status ProcessCpu(CalculatorContext* cc); + void GlRender(); + + ::mediapipe::TfLiteTensorsToSegmentationCalculatorOptions options_; + + int tensor_width_ = 0; + int tensor_height_ = 0; + int tensor_channels_ = 0; + + bool use_gpu_ = false; +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr mask_program_with_prev_; + std::unique_ptr mask_program_no_prev_; + std::unique_ptr tensor_buffer_; + GLuint upsample_program_; +#endif // __ANDROID__ +}; +REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); + +// static +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + // Inputs CPU. + if (cc->Inputs().HasTag("TENSORS")) { + cc->Inputs().Tag("TENSORS").Set>(); + } + if (cc->Inputs().HasTag("PREV_MASK")) { + cc->Inputs().Tag("PREV_MASK").Set(); + } + if (cc->Inputs().HasTag("REFERENCE_IMAGE")) { + cc->Inputs().Tag("REFERENCE_IMAGE").Set(); + } + + // Inputs GPU. +#if defined(__ANDROID__) + if (cc->Inputs().HasTag("TENSORS_GPU")) { + cc->Inputs().Tag("TENSORS_GPU").Set>(); + } + if (cc->Inputs().HasTag("PREV_MASK_GPU")) { + cc->Inputs().Tag("PREV_MASK_GPU").Set(); + } + if (cc->Inputs().HasTag("REFERENCE_IMAGE_GPU")) { + cc->Inputs().Tag("REFERENCE_IMAGE_GPU").Set(); + } +#endif // __ANDROID__ + + // Outputs. + if (cc->Outputs().HasTag("MASK")) { + cc->Outputs().Tag("MASK").Set(); + } +#if defined(__ANDROID__) + if (cc->Outputs().HasTag("MASK_GPU")) { + cc->Outputs().Tag("MASK_GPU").Set(); + } +#endif // __ANDROID__ + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open( + CalculatorContext* cc) { + if (cc->Inputs().HasTag("TENSORS_GPU")) { + use_gpu_ = true; +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif // __ANDROID__ + } + + RETURN_IF_ERROR(LoadOptions(cc)); + + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + RETURN_IF_ERROR(InitGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#else + RET_CHECK_FAIL() + << "GPU processing on non-Android devices is not supported yet."; +#endif // __ANDROID__ + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( + CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + RETURN_IF_ERROR(ProcessGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ + } else { + RETURN_IF_ERROR(ProcessCpu(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( + CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { + if (upsample_program_) glDeleteProgram(upsample_program_); + upsample_program_ = 0; + mask_program_with_prev_.reset(); + mask_program_no_prev_.reset(); + tensor_buffer_.reset(); + }); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( + CalculatorContext* cc) { + return ::mediapipe::UnimplementedError("CPU support is not implemented yet."); +} + +// Steps: +// 1. receive tensor and optional previous mask +// 2. process segmentation tensor into small mask +// 3. upsample small mask into output mask to be same size as input image +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( + CalculatorContext* cc) { + if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) { + return ::mediapipe::OkStatus(); + } +#if defined(__ANDROID__) + // Get input streams. + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + const bool has_prev_mask = cc->Inputs().HasTag("PREV_MASK_GPU") && + !cc->Inputs().Tag("PREV_MASK_GPU").IsEmpty(); + const auto& input_mask = + has_prev_mask + ? cc->Inputs().Tag("PREV_MASK_GPU").Get() + : mediapipe::GpuBuffer(); + int output_width = tensor_width_, output_height = tensor_height_; + if (cc->Inputs().HasTag("REFERENCE_IMAGE_GPU")) { + const auto& input_image = + cc->Inputs().Tag("REFERENCE_IMAGE_GPU").Get(); + output_width = input_image.width(); + output_height = input_image.height(); + } + + RET_CHECK_EQ(input_tensors.size(), 1); + + // Create initial output mask texture. + ::tflite::gpu::gl::GlTexture small_mask_texture; + ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture( + tflite::gpu::DataType::UINT8, // GL_RGBA8 + {tensor_width_, tensor_height_}, &small_mask_texture); + + // Get input previous mask. + auto input_mask_texture = has_prev_mask + ? gpu_helper_.CreateSourceTexture(input_mask) + : mediapipe::GlTexture(); + + // Copy input tensor. + tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_); + + // Run shader, process mask tensor. + { + const int output_index = 0; + glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, + GL_WRITE_ONLY, GL_RGBA8); + tensor_buffer_->BindToIndex(2); + + const tflite::gpu::uint3 workgroups = { + RoundUp(tensor_width_, kWorkgroupSize), + RoundUp(tensor_height_, kWorkgroupSize), 1}; + + if (!has_prev_mask) { + mask_program_no_prev_->Dispatch(workgroups); + } else { + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_mask_texture.name()); + mask_program_with_prev_->Dispatch(workgroups); + } + } + + // Upsample small mask into output. + mediapipe::GlTexture output_texture = gpu_helper_.CreateDestinationTexture( + output_width, output_height, + mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 + + // Run shader, upsample result. + { + gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); + GlRender(); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } + + // Send out image as GPU packet. + auto output_image = output_texture.GetFrame(); + cc->Outputs() + .Tag("MASK_GPU") + .Add(output_image.release(), cc->InputTimestamp()); + + // Cleanup + input_mask_texture.Release(); + output_texture.Release(); + +#endif // __ANDROID__ + return ::mediapipe::OkStatus(); +} + +void TfLiteTensorsToSegmentationCalculator::GlRender() { +#if defined(__ANDROID__) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(upsample_program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // __ANDROID__ +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = + cc->Options<::mediapipe::TfLiteTensorsToSegmentationCalculatorOptions>(); + + if (!options_.has_tensor_width() || !options_.has_tensor_height() || + !options_.has_tensor_channels()) + RET_CHECK_FAIL() << "Missing tensor dimensions in options."; + + tensor_width_ = options_.tensor_width(); + tensor_height_ = options_.tensor_height(); + tensor_channels_ = options_.tensor_channels(); + RET_CHECK_EQ(tensor_channels_, 2) + << "Only 2 channel segmentation tensor currently supported"; + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( + CalculatorContext* cc) { +#if defined(__ANDROID__) + + // A shader to process a segmentation tensor into an output mask, + // and use an optional previous mask as input. + // Currently uses 4 channels for output, + // and sets both R and A channels as mask value. + const std::string shader_src_template = + R"( +#version 310 es + +layout(local_size_x = $0, local_size_y = $0, local_size_z = 1) in; + +precision highp float; + +layout(std430, binding = 2) readonly buffer B0 { + vec2 elements[]; +} input_data; // data tensor +layout(binding = 1) uniform sampler2D input_texture; // previous mask +layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; + +uniform ivec2 out_size; + +const int output_layer_index = int($1); +const float combine_with_previous_ratio = float($2); + +// Will be replaced with either '#define READ_PREVIOUS' or empty std::string +$3 //DEFINE_READ_PREVIOUS + +void main() { + int out_width = out_size.x; + int out_height = out_size.y; + + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= out_width || gid.y >= out_height) { return; } + + int linear_index = gid.y * out_width + gid.x; + vec2 input_value = input_data.elements[linear_index]; + + // Only two channel output is supported. + vec2 input_px = input_value.rg; + float shift = max(input_px.r, input_px.g); + float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift); + float new_mask_value = + exp(input_px[output_layer_index] - shift) / softmax_denom; + + // Combine previous value with current using uncertainty^2 as mixing parameter +#ifdef READ_PREVIOUS + vec2 normalized_gid = vec2(gid) / vec2(out_width - 1, out_height - 1); + float prev_mask_value = texture(input_texture, normalized_gid).r; + + float eps = 0.001; + float uncertainty_alpha = + 1.0 + (new_mask_value * log(new_mask_value + eps) + + (1.0 - new_mask_value) * log(1.0 - new_mask_value + eps)) / + log(2.0f); + uncertainty_alpha = clamp(uncertainty_alpha, 0.0, 1.0); + // equivalent to a = 1 - (1 - a) * (1 - a); (squaring the uncertainty) + uncertainty_alpha *= 2.0 - uncertainty_alpha; + + float mixed_mask_value = new_mask_value * uncertainty_alpha + + prev_mask_value * (1.0f - uncertainty_alpha); + + // Use user provided value to mix raw value & a value mixed with previous mask + new_mask_value = mixed_mask_value * combine_with_previous_ratio + + (1.0f - combine_with_previous_ratio) * new_mask_value; +#endif // READ_PREVIOUS + + // Texture coordinates are inverted on y axis. + ivec2 output_coordinate = ivec2(gid.x, out_height - gid.y - 1); + // Set both R and A channels for convenience. + vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); + imageStore(output_texture, output_coordinate, out_value); +})"; + + const std::string shader_src_no_previous = absl::Substitute( + shader_src_template, kWorkgroupSize, options_.output_layer_index(), + options_.combine_with_previous_ratio(), ""); + const std::string shader_src_with_previous = absl::Substitute( + shader_src_template, kWorkgroupSize, options_.output_layer_index(), + options_.combine_with_previous_ratio(), "#define READ_PREVIOUS"); + + auto status = ::tflite::gpu::OkStatus(); + + // Shader programs. + GlShader shader_without_previous; + status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_no_previous, + &shader_without_previous); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + mask_program_no_prev_ = absl::make_unique(); + status = GlProgram::CreateWithShader(shader_without_previous, + mask_program_no_prev_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + GlShader shader_with_previous; + status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_with_previous, + &shader_with_previous); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + mask_program_with_prev_ = absl::make_unique(); + status = GlProgram::CreateWithShader(shader_with_previous, + mask_program_with_prev_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + + // Buffer storage for input tensor. + size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; + tensor_buffer_ = absl::make_unique(); + status = CreateReadWriteShaderStorageBuffer(tensor_length, + tensor_buffer_.get()); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + + // Parameters. + glUseProgram(mask_program_with_prev_->id()); + glUniform2i(glGetUniformLocation(mask_program_with_prev_->id(), "out_size"), + tensor_width_, tensor_height_); + glUniform1i( + glGetUniformLocation(mask_program_with_prev_->id(), "input_texture"), 1); + glUseProgram(mask_program_no_prev_->id()); + glUniform2i(glGetUniformLocation(mask_program_no_prev_->id(), "out_size"), + tensor_width_, tensor_height_); + glUniform1i( + glGetUniformLocation(mask_program_no_prev_->id(), "input_texture"), 1); + + // Vertex shader attributes. + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Simple pass-through shader, used for hardware upsampling. + std::string upsample_shader_base = R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_data; + + void main() { + vec4 pix = texture2D(input_data, sample_coordinate); + fragColor = pix; + } +)"; + + // Program + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, + upsample_shader_base.c_str(), NUM_ATTRIBUTES, + &attr_name[0], attr_location, &upsample_program_); + RET_CHECK(upsample_program_) << "Problem initializing the program."; + + // Parameters + glUseProgram(upsample_program_); + glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto new file mode 100644 index 000000000..9694d2c5f --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto @@ -0,0 +1,37 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TfLiteTensorsToSegmentationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TfLiteTensorsToSegmentationCalculatorOptions ext = 252526026; + } + + // Dimensions of input segmentation tensor to process. + required int32 tensor_width = 1; + required int32 tensor_height = 2; + required int32 tensor_channels = 3; + + // How much to use previous mask when computing current one; range [0-1]. + // This is a tradeoff between responsiveness (0.0) and accuracy (1.0). + optional float combine_with_previous_ratio = 4 [default = 1.0]; + + // Model specific: Channel to use for processing tensor. + optional int32 output_layer_index = 5 [default = 1]; +} diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD new file mode 100644 index 000000000..b9f0875c8 --- /dev/null +++ b/mediapipe/calculators/util/BUILD @@ -0,0 +1,387 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +exports_files(["LICENSE"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "annotation_overlay_calculator_proto", + srcs = ["annotation_overlay_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + ], +) + +proto_library( + name = "detection_label_id_to_text_calculator_proto", + srcs = ["detection_label_id_to_text_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "latency_proto", + srcs = ["latency.proto"], +) + +proto_library( + name = "non_max_suppression_calculator_proto", + srcs = ["non_max_suppression_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "packet_frequency_proto", + srcs = ["packet_frequency.proto"], +) + +proto_library( + name = "packet_frequency_calculator_proto", + srcs = ["packet_frequency_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "packet_latency_calculator_proto", + srcs = ["packet_latency_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "annotation_overlay_calculator_cc_proto", + srcs = ["annotation_overlay_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":annotation_overlay_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "detection_label_id_to_text_calculator_cc_proto", + srcs = ["detection_label_id_to_text_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [ + ":detection_label_id_to_text_calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "latency_cc_proto", + srcs = ["latency.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":latency_proto"], +) + +mediapipe_cc_proto_library( + name = "non_max_suppression_calculator_cc_proto", + srcs = ["non_max_suppression_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":non_max_suppression_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_frequency_cc_proto", + srcs = ["packet_frequency.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":packet_frequency_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_frequency_calculator_cc_proto", + srcs = ["packet_frequency_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [ + ":packet_frequency_calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "packet_latency_calculator_cc_proto", + srcs = ["packet_latency_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [ + ":packet_latency_calculator_proto", + ], +) + +cc_library( + name = "packet_frequency_calculator", + srcs = ["packet_frequency_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", + "//mediapipe/calculators/util:packet_frequency_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/time", + ], + alwayslink = 1, +) + +cc_test( + name = "packet_frequency_calculator_test", + size = "small", + srcs = ["packet_frequency_calculator_test.cc"], + deps = [ + ":packet_frequency_calculator", + "//mediapipe/calculators/util:packet_frequency_cc_proto", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + ], +) + +cc_library( + name = "packet_latency_calculator", + srcs = ["packet_latency_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/util:latency_cc_proto", + "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], + alwayslink = 1, +) + +cc_test( + name = "packet_latency_calculator_test", + size = "small", + srcs = ["packet_latency_calculator_test.cc"], + deps = [ + ":packet_latency_calculator", + "//mediapipe/calculators/util:latency_cc_proto", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:simulation_clock", + "//mediapipe/framework/tool:simulation_clock_executor", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "annotation_overlay_calculator", + srcs = ["annotation_overlay_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":annotation_overlay_calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + "//mediapipe/util:annotation_renderer", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "detection_label_id_to_text_calculator", + srcs = ["detection_label_id_to_text_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/util:resource_util", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/util/android/file/base", + ], + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "non_max_suppression_calculator", + srcs = ["non_max_suppression_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":non_max_suppression_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:rectangle", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +proto_library( + name = "detections_to_render_data_calculator_proto", + srcs = ["detections_to_render_data_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + +mediapipe_cc_proto_library( + name = "detections_to_render_data_calculator_cc_proto", + srcs = ["detections_to_render_data_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":detections_to_render_data_calculator_proto"], +) + +cc_library( + name = "detections_to_render_data_calculator", + srcs = ["detections_to_render_data_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":detections_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "detections_to_render_data_calculator_test", + size = "small", + srcs = ["detections_to_render_data_calculator_test.cc"], + deps = [ + ":detections_to_render_data_calculator", + ":detections_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "detection_letterbox_removal_calculator", + srcs = ["detection_letterbox_removal_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "detection_letterbox_removal_calculator_test", + srcs = ["detection_letterbox_removal_calculator_test.cc"], + deps = [ + ":detection_letterbox_removal_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + ], +) diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc new file mode 100644 index 000000000..25c8c65fe --- /dev/null +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -0,0 +1,589 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/util/annotation_overlay_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" +#include "mediapipe/util/annotation_renderer.h" +#include "mediapipe/util/color.pb.h" + +#if defined(__ANDROID__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ + +namespace mediapipe { + +namespace { + +constexpr char kInputFrameTag[] = "INPUT_FRAME"; +constexpr char kOutputFrameTag[] = "OUTPUT_FRAME"; + +constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU"; +constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU"; + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + +// A calculator for rendering data on images. +// +// Inputs: +// 1. INPUT_FRAME or INPUT_FRAME_GPU (optional): An ImageFrame (or GpuBuffer) +// containing the input image. +// If output is CPU, and input isn't provided, the renderer creates a +// blank canvas with the width, height and color provided in the options. +// 2. RenderData proto on variable number of input streams. All the RenderData +// at a particular timestamp is drawn on the image in the order of their +// input streams. No tags required. +// +// Output: +// 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer). +// +// For CPU input frames, only SRGBA, SRGB and GRAY8 format are supported. The +// output format is the same as input except for GRAY8 where the output is in +// SRGB to support annotations in color. +// +// For GPU input frames, only 4-channel images are supported. +// +// Note: When using GPU, drawing with black color is not supported. +// +// Example config (CPU): +// node { +// calculator: "AnnotationOverlayCalculator" +// input_stream: "INPUT_FRAME:image_frames" +// input_stream: "render_data_1" +// input_stream: "render_data_2" +// input_stream: "render_data_3" +// output_stream: "OUTPUT_FRAME:decorated_frames" +// options { +// [mediapipe.AnnotationOverlayCalculatorOptions.ext] { +// } +// } +// } +// +// Example config (GPU): +// node { +// calculator: "AnnotationOverlayCalculator" +// input_stream: "INPUT_FRAME_GPU:image_frames" +// input_stream: "render_data_1" +// input_stream: "render_data_2" +// input_stream: "render_data_3" +// output_stream: "OUTPUT_FRAME_GPU:decorated_frames" +// options { +// [mediapipe.AnnotationOverlayCalculatorOptions.ext] { +// } +// } +// } +// +class AnnotationOverlayCalculator : public CalculatorBase { + public: + AnnotationOverlayCalculator() = default; + ~AnnotationOverlayCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + // From Calculator. + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status CreateRenderTargetCpu(CalculatorContext* cc, + std::unique_ptr& image_mat, + ImageFormat::Format* target_format); + ::mediapipe::Status CreateRenderTargetGpu( + CalculatorContext* cc, std::unique_ptr& image_mat); + ::mediapipe::Status RenderToGpu(CalculatorContext* cc, uchar* overlay_image); + ::mediapipe::Status RenderToCpu(CalculatorContext* cc, + const ImageFormat::Format& target_format, + uchar* data_image); + + ::mediapipe::Status GlRender(CalculatorContext* cc); + ::mediapipe::Status GlSetup(CalculatorContext* cc); + + // Options for the calculator. + AnnotationOverlayCalculatorOptions options_; + + // Underlying helper renderer library. + std::unique_ptr renderer_; + + // Number of input streams with render data. + int num_render_streams_; + + // Indicates if image frame is available as input. + bool image_frame_available_ = false; + + bool use_gpu_ = false; + bool gpu_initialized_ = false; +#if defined(__ANDROID__) + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; + GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. +#endif // __ANDROID__ +}; +REGISTER_CALCULATOR(AnnotationOverlayCalculator); + +::mediapipe::Status AnnotationOverlayCalculator::GetContract( + CalculatorContract* cc) { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + if (cc->Inputs().HasTag(kInputFrameTag) && + cc->Inputs().HasTag(kInputFrameTagGpu)) { + return ::mediapipe::InternalError("Cannot have multiple input images."); + } + if (cc->Inputs().HasTag(kInputFrameTagGpu) != + cc->Outputs().HasTag(kOutputFrameTagGpu)) { + return ::mediapipe::InternalError("GPU output must have GPU input."); + } + + // Assume all inputs are render streams; adjust below. + int num_render_streams = cc->Inputs().NumEntries(); + + // Input image to render onto copy of. +#if defined(__ANDROID__) + if (cc->Inputs().HasTag(kInputFrameTagGpu)) { + cc->Inputs().Tag(kInputFrameTagGpu).Set(); + num_render_streams = cc->Inputs().NumEntries() - 1; + } +#endif // __ANDROID__ + if (cc->Inputs().HasTag(kInputFrameTag)) { + cc->Inputs().Tag(kInputFrameTag).Set(); + num_render_streams = cc->Inputs().NumEntries() - 1; + } + + // Data streams to render. + for (int i = 0; i < num_render_streams; ++i) { + cc->Inputs().Index(i).Set(); + } + + // Rendered image. +#if defined(__ANDROID__) + if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { + cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + } +#endif // __ANDROID__ + if (cc->Outputs().HasTag(kOutputFrameTag)) { + cc->Outputs().Tag(kOutputFrameTag).Set(); + } + +#if defined(__ANDROID__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + + if (cc->Inputs().HasTag(kInputFrameTagGpu) && + cc->Outputs().HasTag(kOutputFrameTagGpu)) { +#if defined(__ANDROID__) + use_gpu_ = true; +#else + RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; +#endif // __ANDROID__ + } + + if (cc->Inputs().HasTag(kInputFrameTagGpu) || + cc->Inputs().HasTag(kInputFrameTag)) { + image_frame_available_ = true; + num_render_streams_ = cc->Inputs().NumEntries() - 1; + } else { + image_frame_available_ = false; + RET_CHECK(options_.has_canvas_width_px()); + RET_CHECK(options_.has_canvas_height_px()); + num_render_streams_ = cc->Inputs().NumEntries(); + } + + // Initialize the helper renderer library. + renderer_ = absl::make_unique(); + renderer_->SetFlipTextVertically(options_.flip_text_vertically()); + + // Set the output header based on the input header (if present). + const char* input_tag = use_gpu_ ? kInputFrameTagGpu : kInputFrameTag; + const char* output_tag = use_gpu_ ? kOutputFrameTagGpu : kOutputFrameTag; + if (image_frame_available_ && + !cc->Inputs().Tag(input_tag).Header().IsEmpty()) { + const auto& input_header = + cc->Inputs().Tag(input_tag).Header().Get(); + auto* output_video_header = new VideoHeader(input_header); + cc->Outputs().Tag(output_tag).SetHeader(Adopt(output_video_header)); + } + + if (use_gpu_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::Process( + CalculatorContext* cc) { + // Initialize render target, drawn with OpenCV. + std::unique_ptr image_mat; + ImageFormat::Format target_format; + if (use_gpu_) { + RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat)); + } else { + RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); + } + + // Reset the renderer with the image_mat. No copy here. + renderer_->AdoptImage(image_mat.get()); + + // Render streams onto render target. + for (int i = 0; i < num_render_streams_; ++i) { + if (cc->Inputs().Index(i).IsEmpty()) { + continue; + } + const RenderData& render_data = cc->Inputs().Index(i).Get(); + renderer_->RenderDataOnImage(render_data); + } + + if (use_gpu_) { +#if defined(__ANDROID__) + // Overlay rendered image in OpenGL, onto a copy of input. + uchar* image_mat_ptr = image_mat->data; + RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, cc, image_mat_ptr]() -> ::mediapipe::Status { + if (!gpu_initialized_) { + RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + + RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr)); + + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ + } else { + // Copy the rendered image to output. + uchar* image_mat_ptr = image_mat->data; + RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); + image_mat_tex_ = 0; + }); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( + + CalculatorContext* cc, const ImageFormat::Format& target_format, + uchar* data_image) { + auto output_frame = absl::make_unique( + target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); + +#if defined(__ANDROID__) + output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), + renderer_->GetImageHeight(), data_image, + ImageFrame::kGlDefaultAlignmentBoundary); +#else + output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), + renderer_->GetImageHeight(), data_image, + ImageFrame::kDefaultAlignmentBoundary); +#endif // __ANDROID__ + + cc->Outputs() + .Tag(kOutputFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( + CalculatorContext* cc, uchar* overlay_image) { +#if defined(__ANDROID__) + // Source and destination textures. + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); + + const int width = input_frame.width(), height = input_frame.height(); + auto output_texture = gpu_helper_.CreateDestinationTexture( + width, height, mediapipe::GpuBufferFormat::kBGRA32); + + // Upload render target to GPU. + { + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + + glBindTexture(GL_TEXTURE_2D, image_mat_tex_); + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width, height, GL_RGB, + GL_UNSIGNED_BYTE, overlay_image); + glBindTexture(GL_TEXTURE_2D, 0); + } + + // Blend overlay image in GPU shader. + { + gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 + + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_texture.name()); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, image_mat_tex_); + + RETURN_IF_ERROR(GlRender(cc)); + + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } + + // Send out blended image as GPU packet. + auto output_frame = output_texture.GetFrame(); + cc->Outputs() + .Tag(kOutputFrameTagGpu) + .Add(output_frame.release(), cc->InputTimestamp()); + + // Cleanup + input_texture.Release(); + output_texture.Release(); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( + CalculatorContext* cc, std::unique_ptr& image_mat, + ImageFormat::Format* target_format) { + if (image_frame_available_) { + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTag).Get(); + + int target_mat_type; + switch (input_frame.Format()) { + case ImageFormat::SRGBA: + *target_format = ImageFormat::SRGBA; + target_mat_type = CV_8UC4; + break; + case ImageFormat::SRGB: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + case ImageFormat::GRAY8: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + default: + return ::mediapipe::UnknownError("Unexpected image frame format."); + break; + } + + image_mat = absl::make_unique( + input_frame.Height(), input_frame.Width(), target_mat_type); + if (input_frame.Format() == ImageFormat::GRAY8) { + const int target_num_channels = + ImageFrame::NumberOfChannelsForFormat(*target_format); + for (int i = 0; i < input_frame.PixelDataSize(); i++) { + const auto& pix = input_frame.PixelData()[i]; + for (int c = 0; c < target_num_channels; c++) { + image_mat->data[i * target_num_channels + c] = pix; + } + } + } else { + // Make of a copy since the input frame may be consumed by other nodes. + const int buffer_size = + input_frame.Height() * input_frame.Width() * + ImageFrame::NumberOfChannelsForFormat(*target_format); + input_frame.CopyToBuffer(image_mat->data, buffer_size); + } + } else { + image_mat = absl::make_unique( + options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, + cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), + options_.canvas_color().b())); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( + CalculatorContext* cc, std::unique_ptr& image_mat) { +#if defined(__ANDROID__) + if (image_frame_available_) { + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + + const mediapipe::ImageFormat::Format format = + mediapipe::ImageFormatForGpuBufferFormat(input_frame.format()); + if (format != mediapipe::ImageFormat::SRGBA) + RET_CHECK_FAIL() << "Unsupported GPU input format."; + + image_mat = + absl::make_unique(input_frame.height(), input_frame.width(), + CV_8UC3, cv::Scalar(0, 0, 0, 0)); + } else { + image_mat = absl::make_unique( + options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, + cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), + options_.canvas_color().b())); + } +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::GlRender( + CalculatorContext* cc) { +#if defined(__ANDROID__) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AnnotationOverlayCalculator::GlSetup( + CalculatorContext* cc) { +#if defined(__ANDROID__) + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Shader to overlay a texture onto another when overlay is non-zero. + const GLchar* frag_src = GLES_VERSION_COMPAT + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_frame; + uniform sampler2D overlay; + + void main() { + vec3 image_pix = texture2D(input_frame, sample_coordinate).rgb; + vec3 overlay_pix = texture2D(overlay, sample_coordinate).rgb; + vec3 out_pix = image_pix; + float mag = dot(overlay_pix.rgb, vec3(1.0)); + if (mag > 0.0) out_pix = overlay_pix; + fragColor.rgb = out_pix; + fragColor.a = 1.0; + } + )"; + + // Create shader program and set parameters + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src, + NUM_ATTRIBUTES, (const GLchar**)&attr_name[0], + attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); + glUniform1i(glGetUniformLocation(program_, "overlay"), 2); + + // Init texture for opencv rendered frame. + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + const int width = input_frame.width(), height = input_frame.height(); + { + glGenTextures(1, &image_mat_tex_); + glBindTexture(GL_TEXTURE_2D, image_mat_tex_); + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width, height, 0, GL_RGB, + GL_UNSIGNED_BYTE, nullptr); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + glBindTexture(GL_TEXTURE_2D, 0); + } +#endif // __ANDROID__ + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.proto b/mediapipe/calculators/util/annotation_overlay_calculator.proto new file mode 100644 index 000000000..93e436110 --- /dev/null +++ b/mediapipe/calculators/util/annotation_overlay_calculator.proto @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; + +// Options for the AnnotationOverlayCalculator. +message AnnotationOverlayCalculatorOptions { + extend CalculatorOptions { + optional AnnotationOverlayCalculatorOptions ext = 250607623; + } + + // The canvas width and height in pixels, and the background color. These + // options are used only if an input stream of ImageFrame isn't provided to + // the renderer calculator. If an input stream of ImageFrame is provided, then + // the calculator renders the annotations on top of the provided image, else a + // canvas is created with the dimensions and background color specified in + // these options and the annotations are rendered on top of this canvas. + optional int32 canvas_width_px = 2 [default = 1920]; + optional int32 canvas_height_px = 3 [default = 1080]; + optional Color canvas_color = 4; + + // Whether text should be rendered upside down. When it's set to false, text + // is rendered normally assuming the underlying image has its origin at the + // top-left corner. Therefore, for images with the origin at the bottom-left + // corner this should be set to true. + optional bool flip_text_vertically = 5 [default = false]; +} diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc new file mode 100644 index 000000000..107e08148 --- /dev/null +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -0,0 +1,112 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe//framework/packet.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/resource_util.h" + +#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \ + (defined(__APPLE__) && !TARGET_OS_OSX) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/helpers.h" +#else +#include "mediapipe/framework/port/file_helpers.h" +#endif + +namespace mediapipe { + +// Takes a label map (from label IDs to names), and replaces the label IDs +// in Detection protos with label names. Note that the calculator makes a copy +// of the input detections. Consider using it only when the size of input +// detections is small. +// +// Example usage: +// node { +// calculator: "DetectionLabelIdToTextCalculator" +// input_stream: "input_detections" +// output_stream: "output_detections" +// node_options: { +// [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { +// label_map_path: "labelmap.txt" +// } +// } +// } +class DetectionLabelIdToTextCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + std::unordered_map label_map_; +}; +REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); + +::mediapipe::Status DetectionLabelIdToTextCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set>(); + cc->Outputs().Index(0).Set>(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionLabelIdToTextCalculator::Open( + CalculatorContext* cc) { + const auto& options = + cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); + + std::string string_path; + ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(options.label_map_path())); + std::string label_map_string; + RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + + std::istringstream stream(label_map_string); + std::string line; + int i = 0; + while (std::getline(stream, line)) { + label_map_[i++] = line; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionLabelIdToTextCalculator::Process( + CalculatorContext* cc) { + std::vector output_detections; + for (const auto& input_detection : + cc->Inputs().Index(0).Get>()) { + output_detections.push_back(input_detection); + Detection& output_detection = output_detections.back(); + bool has_text_label = false; + for (const int32 label_id : output_detection.label_id()) { + if (label_map_.find(label_id) != label_map_.end()) { + output_detection.add_label(label_map_[label_id]); + has_text_label = true; + } + } + // Remove label_id field if text labels exist. + if (has_text_label) { + output_detection.clear_label_id(); + } + } + cc->Outputs().Index(0).AddPacket( + MakePacket>(output_detections) + .At(cc->InputTimestamp())); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto new file mode 100644 index 000000000..0486d1d0a --- /dev/null +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message DetectionLabelIdToTextCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional DetectionLabelIdToTextCalculatorOptions ext = 251889072; + } + + // Path to a label map file for getting the actual name of detected classes. + optional string label_map_path = 1; +} diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc new file mode 100644 index 000000000..800718823 --- /dev/null +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -0,0 +1,143 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; + +} // namespace + +// Adjusts detection locations on a letterboxed image to the corresponding +// locations on the same image with the letterbox removed. This is useful to map +// the detections inferred from a letterboxed image, for example, output of +// the ImageTransformationCalculator when the scale mode is FIT, back to the +// corresponding input image before letterboxing. +// +// Input: +// DETECTIONS: An std::vector representing detections on an +// letterboxed image. +// +// LETTERBOX_PADDING: An std::array representing the letterbox +// padding from the 4 sides ([left, top, right, bottom]) of the letterboxed +// image, normalized to [0.f, 1.f] by the letterboxed image dimensions. +// +// Output: +// DETECTIONS: An std::vector representing detections with their +// locations adjusted to the letterbox-removed (non-padded) image. +// +// Usage example: +// node { +// calculator: "DetectionLetterboxRemovalCalculator" +// input_stream: "DETECTIONS:detections" +// input_stream: "LETTERBOX_PADDING:letterbox_padding" +// output_stream: "DETECTIONS:adjusted_detections" +// } +class DetectionLetterboxRemovalCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetectionsTag) && + cc->Inputs().HasTag(kLetterboxPaddingTag)) + << "Missing one or more input streams."; + + cc->Inputs().Tag(kDetectionsTag).Set>(); + cc->Inputs().Tag(kLetterboxPaddingTag).Set>(); + + cc->Outputs().Tag(kDetectionsTag).Set>(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + const auto& input_detections = + cc->Inputs().Tag(kDetectionsTag).Get>(); + const auto& letterbox_padding = + cc->Inputs().Tag(kLetterboxPaddingTag).Get>(); + + const float left = letterbox_padding[0]; + const float top = letterbox_padding[1]; + const float left_and_right = letterbox_padding[0] + letterbox_padding[2]; + const float top_and_bottom = letterbox_padding[1] + letterbox_padding[3]; + + auto output_detections = absl::make_unique>(); + for (const auto& detection : input_detections) { + Detection new_detection; + new_detection.CopyFrom(detection); + LocationData::RelativeBoundingBox* relative_bbox = + new_detection.mutable_location_data() + ->mutable_relative_bounding_box(); + + relative_bbox->set_xmin( + (detection.location_data().relative_bounding_box().xmin() - left) / + (1.0f - left_and_right)); + relative_bbox->set_ymin( + (detection.location_data().relative_bounding_box().ymin() - top) / + (1.0f - top_and_bottom)); + // The size of the bounding box will change as well. + relative_bbox->set_width( + detection.location_data().relative_bounding_box().width() / + (1.0f - left_and_right)); + relative_bbox->set_height( + detection.location_data().relative_bounding_box().height() / + (1.0f - top_and_bottom)); + + // Adjust keypoints as well. + for (int i = 0; + i < new_detection.mutable_location_data()->relative_keypoints_size(); + ++i) { + auto* keypoint = + new_detection.mutable_location_data()->mutable_relative_keypoints( + i); + const float new_x = (keypoint->x() - left) / (1.0f - left_and_right); + const float new_y = (keypoint->y() - top) / (1.0f - top_and_bottom); + keypoint->set_x(new_x); + keypoint->set_y(new_y); + } + + output_detections->emplace_back(new_detection); + } + + cc->Outputs() + .Tag("DETECTIONS") + .Add(output_detections.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(DetectionLetterboxRemovalCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc new file mode 100644 index 000000000..e2a04e525 --- /dev/null +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc @@ -0,0 +1,158 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +LocationData CreateRelativeLocationData(double xmin, double ymin, double width, + double height) { + LocationData location_data; + location_data.set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data.mutable_relative_bounding_box()->set_xmin(xmin); + location_data.mutable_relative_bounding_box()->set_ymin(ymin); + location_data.mutable_relative_bounding_box()->set_width(width); + location_data.mutable_relative_bounding_box()->set_height(height); + return location_data; +} + +Detection CreateDetection(const std::vector& labels, + const std::vector& label_ids, + const std::vector& scores, + const LocationData& location_data, + const std::string& feature_tag) { + Detection detection; + for (const auto& label : labels) { + detection.add_label(label); + } + for (const auto& label_id : label_ids) { + detection.add_label_id(label_id); + } + for (const auto& score : scores) { + detection.add_score(score); + } + *(detection.mutable_location_data()) = location_data; + detection.set_feature_tag(feature_tag); + return detection; +} + +CalculatorGraphConfig::Node GetDefaultNode() { + return ParseTextProtoOrDie(R"( + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:adjusted_detections" + )"); +} + +TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) { + CalculatorRunner runner(GetDefaultNode()); + + LocationData location_data = + CreateRelativeLocationData(0.25f, 0.25f, 0.25f, 0.25f); + const std::string label = "detected_object"; + + auto detections = absl::make_unique>(); + detections->push_back( + CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + auto padding = absl::make_unique>( + std::array{0.2f, 0.f, 0.3f, 0.f}); + runner.MutableInputs() + ->Tag("LETTERBOX_PADDING") + .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("DETECTIONS").packets; + ASSERT_EQ(1, output.size()); + const auto& output_detections = output[0].Get>(); + + EXPECT_EQ(output_detections.size(), 1); + const auto& output_detection = output_detections[0]; + + EXPECT_EQ(output_detection.label_size(), 1); + EXPECT_EQ(output_detection.label(0), label); + EXPECT_EQ(output_detection.label_id_size(), 0); + EXPECT_EQ(output_detection.score_size(), 1); + EXPECT_EQ(output_detection.score(0), 0.3f); + + EXPECT_EQ(output_detection.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().xmin(), + testing::FloatNear(0.1f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().ymin(), + testing::FloatNear(0.25f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().width(), + testing::FloatNear(0.5f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().height(), + testing::FloatNear(0.25f, 1e-5)); +} + +TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) { + CalculatorRunner runner(GetDefaultNode()); + + LocationData location_data = + CreateRelativeLocationData(0.25f, 0.25f, 0.25f, 0.25f); + const std::string label = "detected_object"; + + auto detections = absl::make_unique>(); + detections->push_back( + CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + auto padding = absl::make_unique>( + std::array{0.f, 0.2f, 0.f, 0.3f}); + runner.MutableInputs() + ->Tag("LETTERBOX_PADDING") + .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("DETECTIONS").packets; + ASSERT_EQ(1, output.size()); + const auto& output_detections = output[0].Get>(); + + EXPECT_EQ(output_detections.size(), 1); + const auto& output_detection = output_detections[0]; + + EXPECT_EQ(output_detection.location_data().format(), + LocationData::RELATIVE_BOUNDING_BOX); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().xmin(), + testing::FloatNear(0.25f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().ymin(), + testing::FloatNear(0.1f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().width(), + testing::FloatNear(0.25f, 1e-5)); + EXPECT_THAT(output_detection.location_data().relative_bounding_box().height(), + testing::FloatNear(0.5f, 1e-5)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc new file mode 100644 index 000000000..aa4b35089 --- /dev/null +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -0,0 +1,359 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "mediapipe/calculators/util/detections_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +namespace mediapipe { + +namespace { + +constexpr char kDetectionListTag[] = "DETECTION_LIST"; +constexpr char kDetectionVectorTag[] = "DETECTION_VECTOR"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; + +constexpr char kSceneLabelLabel[] = "LABEL"; +constexpr char kSceneFeatureLabel[] = "FEATURE"; +constexpr char kSceneLocationLabel[] = "LOCATION"; +constexpr char kKeypointLabel[] = "KEYPOINT"; + +// The ratio of detection label font height to the height of detection bounding +// box. +constexpr double kLabelToBoundingBoxRatio = 0.1; + +} // namespace + +// A calculator that converts Detection proto to RenderData proto for +// visualization. +// +// Detection is the format for encoding one or more detections in an image. +// The input can be std::vector or DetectionList. +// +// Please note that only Location Data formats of BOUNDING_BOX and +// RELATIVE_BOUNDING_BOX are supported. Normalized coordinates for +// RELATIVE_BOUNDING_BOX must be between 0.0 and 1.0. Any incremental normalized +// coordinates calculation in this calculator is capped at 1.0. +// +// The text(s) for "label(_id),score" will be shown on top left +// corner of the bounding box. The text for "feature_tag" will be shown on +// bottom left corner of the bounding box. +// +// Example config: +// node { +// calculator: "DetectionsToRenderDataCalculator" +// input_stream: "DETECTION_LIST:detection_list" +// input_stream: "DETECTION_VECTOR:detection_vector" +// output_stream: "RENDER_DATA:render_data" +// options { +// [DetectionsToRenderDataCalculatorOptions.ext] { +// produce_empty_packet : false +// } +// } +// } +class DetectionsToRenderDataCalculator : public CalculatorBase { + public: + DetectionsToRenderDataCalculator() {} + ~DetectionsToRenderDataCalculator() override {} + DetectionsToRenderDataCalculator(const DetectionsToRenderDataCalculator&) = + delete; + DetectionsToRenderDataCalculator& operator=( + const DetectionsToRenderDataCalculator&) = delete; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // These utility methods are supposed to be used only by this class. No + // external client should depend on them. Due to C++ style guide unnamed + // namespace should not be used in header files. So, these has been defined + // as private static methods. + static void SetRenderAnnotationColorThickness( + const DetectionsToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation); + + static void SetTextCoordinate(bool normalized, double left, double baseline, + RenderAnnotation::Text* text); + + static void SetRectCoordinate(bool normalized, double xmin, double ymin, + double width, double height, + RenderAnnotation::Rectangle* rect); + + static void AddLabels(const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data); + static void AddFeatureTag( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data); + static void AddLocationData( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + RenderData* render_data); + static void AddDetectionToRenderData( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + RenderData* render_data); +}; +REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); + +::mediapipe::Status DetectionsToRenderDataCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || + cc->Inputs().HasTag(kDetectionVectorTag)) + << "None of the input streams are provided."; + + if (cc->Inputs().HasTag(kDetectionListTag)) { + cc->Inputs().Tag(kDetectionListTag).Set(); + } + if (cc->Inputs().HasTag(kDetectionVectorTag)) { + cc->Inputs().Tag(kDetectionVectorTag).Set>(); + } + cc->Outputs().Tag(kRenderDataTag).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionsToRenderDataCalculator::Process( + CalculatorContext* cc) { + const auto& options = cc->Options(); + const bool has_detection_from_list = + cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs() + .Tag(kDetectionListTag) + .Get() + .detection() + .empty(); + const bool has_detection_from_vector = + cc->Inputs().HasTag(kDetectionVectorTag) && + !cc->Inputs() + .Tag(kDetectionVectorTag) + .Get>() + .empty(); + if (!options.produce_empty_packet() && !has_detection_from_list && + !has_detection_from_vector) { + return ::mediapipe::OkStatus(); + } + + // TODO: Add score threshold to + // DetectionsToRenderDataCalculatorOptions. + auto render_data = absl::make_unique(); + render_data->set_scene_class(options.scene_class()); + if (has_detection_from_list) { + for (const auto& detection : + cc->Inputs().Tag(kDetectionListTag).Get().detection()) { + AddDetectionToRenderData(detection, options, render_data.get()); + } + } + if (has_detection_from_vector) { + for (const auto& detection : + cc->Inputs().Tag(kDetectionVectorTag).Get>()) { + AddDetectionToRenderData(detection, options, render_data.get()); + } + } + cc->Outputs() + .Tag(kRenderDataTag) + .Add(render_data.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +void DetectionsToRenderDataCalculator::SetRenderAnnotationColorThickness( + const DetectionsToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation) { + render_annotation->mutable_color()->set_r(options.color().r()); + render_annotation->mutable_color()->set_g(options.color().g()); + render_annotation->mutable_color()->set_b(options.color().b()); + render_annotation->set_thickness(options.thickness()); +} + +void DetectionsToRenderDataCalculator::SetTextCoordinate( + bool normalized, double left, double baseline, + RenderAnnotation::Text* text) { + text->set_normalized(normalized); + text->set_left(normalized ? std::max(left, 0.0) : left); + // Normalized coordinates must be between 0.0 and 1.0, if they are used. + text->set_baseline(normalized ? std::min(baseline, 1.0) : baseline); +} + +void DetectionsToRenderDataCalculator::SetRectCoordinate( + bool normalized, double xmin, double ymin, double width, double height, + RenderAnnotation::Rectangle* rect) { + if (xmin + width < 0.0 || ymin + height < 0.0) return; + if (normalized) { + if (xmin > 1.0 || ymin > 1.0) return; + } + rect->set_normalized(normalized); + rect->set_left(normalized ? std::max(xmin, 0.0) : xmin); + rect->set_top(normalized ? std::max(ymin, 0.0) : ymin); + // No "xmin + width -1" because the coordinates can be relative, i.e. [0,1], + // and we don't know what 1 pixel means in term of double [0,1]. + // For consistency decided to not decrease by 1 also when it is not relative. + // However, when the coordinate is normalized it has to be between 0.0 and + // 1.0. + rect->set_right(normalized ? std::min(xmin + width, 1.0) : xmin + width); + rect->set_bottom(normalized ? std::min(ymin + height, 1.0) : ymin + height); +} + +void DetectionsToRenderDataCalculator::AddLabels( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data) { + CHECK(detection.label().empty() || detection.label_id().empty()) + << "Either std::string or integer labels must be used for detection " + "but not both at the same time."; + const auto num_labels = + std::max(detection.label_size(), detection.label_id_size()); + CHECK_EQ(detection.score_size(), num_labels) + << "Number of scores and labels should match for detection."; + + // Extracts all "label(_id),score" for the detection. + std::vector label_and_scores = {}; + for (int i = 0; i < num_labels; ++i) { + std::string label_str = detection.label().empty() + ? absl::StrCat(detection.label_id(i)) + : detection.label(i); + std::string label_and_score = + absl::StrCat(label_str, options.text_delimiter(), detection.score(i), + options.text_delimiter()); + label_and_scores.push_back(label_and_score); + } + std::vector labels; + if (options.one_label_per_line()) { + labels.swap(label_and_scores); + } else { + labels.push_back(absl::StrJoin(label_and_scores, "")); + } + + // Add the render annotations for "label(_id),score". + for (int i = 0; i < labels.size(); ++i) { + auto label = labels.at(i); + auto* label_annotation = render_data->add_render_annotations(); + label_annotation->set_scene_tag(kSceneLabelLabel); + SetRenderAnnotationColorThickness(options, label_annotation); + auto* text = label_annotation->mutable_text(); + *text = options.text(); + text->set_display_text(label); + if (detection.location_data().format() == LocationData::BOUNDING_BOX) { + SetTextCoordinate(false, detection.location_data().bounding_box().xmin(), + detection.location_data().bounding_box().ymin() + + (i + 1) * text_line_height, + text); + } else { + text->set_font_height(text_line_height * 0.9); + SetTextCoordinate( + true, detection.location_data().relative_bounding_box().xmin(), + detection.location_data().relative_bounding_box().ymin() + + (i + 1) * text_line_height, + text); + } + } +} + +void DetectionsToRenderDataCalculator::AddFeatureTag( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data) { + auto* feature_tag_annotation = render_data->add_render_annotations(); + feature_tag_annotation->set_scene_tag(kSceneFeatureLabel); + SetRenderAnnotationColorThickness(options, feature_tag_annotation); + auto* feature_tag_text = feature_tag_annotation->mutable_text(); + feature_tag_text->set_display_text(detection.feature_tag()); + if (detection.location_data().format() == LocationData::BOUNDING_BOX) { + SetTextCoordinate(false, detection.location_data().bounding_box().xmin(), + detection.location_data().bounding_box().ymin() + + detection.location_data().bounding_box().height(), + feature_tag_text); + } else { + feature_tag_text->set_font_height(text_line_height * 0.9); + SetTextCoordinate( + true, detection.location_data().relative_bounding_box().xmin(), + detection.location_data().relative_bounding_box().ymin() + + detection.location_data().relative_bounding_box().height(), + feature_tag_text); + } +} + +void DetectionsToRenderDataCalculator::AddLocationData( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + RenderData* render_data) { + auto* location_data_annotation = render_data->add_render_annotations(); + location_data_annotation->set_scene_tag(kSceneLocationLabel); + SetRenderAnnotationColorThickness(options, location_data_annotation); + auto* location_data_rect = location_data_annotation->mutable_rectangle(); + if (detection.location_data().format() == LocationData::BOUNDING_BOX) { + SetRectCoordinate(false, detection.location_data().bounding_box().xmin(), + detection.location_data().bounding_box().ymin(), + detection.location_data().bounding_box().width(), + detection.location_data().bounding_box().height(), + location_data_rect); + } else { + SetRectCoordinate( + true, detection.location_data().relative_bounding_box().xmin(), + detection.location_data().relative_bounding_box().ymin(), + detection.location_data().relative_bounding_box().width(), + detection.location_data().relative_bounding_box().height(), + location_data_rect); + // Keypoints are only supported in normalized/relative coordinates. + if (detection.location_data().relative_keypoints_size()) { + for (int i = 0; i < detection.location_data().relative_keypoints_size(); + ++i) { + auto* keypoint_data_annotation = render_data->add_render_annotations(); + keypoint_data_annotation->set_scene_tag(kKeypointLabel); + SetRenderAnnotationColorThickness(options, keypoint_data_annotation); + auto* keypoint_data = keypoint_data_annotation->mutable_point(); + keypoint_data->set_normalized(true); + // See location_data.proto for detail. + keypoint_data->set_x( + detection.location_data().relative_keypoints(i).x()); + keypoint_data->set_y( + detection.location_data().relative_keypoints(i).y()); + } + } + } +} + +void DetectionsToRenderDataCalculator::AddDetectionToRenderData( + const Detection& detection, + const DetectionsToRenderDataCalculatorOptions& options, + RenderData* render_data) { + CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX || + detection.location_data().format() == + LocationData::RELATIVE_BOUNDING_BOX) + << "Only Detection with formats of BOUNDING_BOX or RELATIVE_BOUNDING_BOX " + "are supported."; + double text_line_height; + if (detection.location_data().format() == LocationData::BOUNDING_BOX) { + text_line_height = options.text().font_height(); + } else { + // Determine the text line height based on the default label to bounding box + // ratio and the number of labels. + text_line_height = + detection.location_data().relative_bounding_box().height() * + std::min(kLabelToBoundingBoxRatio, + 1 / (double)(std::max(detection.label_size(), + detection.label_id_size()) + + 1 /* for feature_tag */)); + } + AddLabels(detection, options, text_line_height, render_data); + AddFeatureTag(detection, options, text_line_height, render_data); + AddLocationData(detection, options, render_data); +} +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.proto b/mediapipe/calculators/util/detections_to_render_data_calculator.proto new file mode 100644 index 000000000..245f48fea --- /dev/null +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.proto @@ -0,0 +1,56 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; +import "mediapipe/util/render_data.proto"; + +message DetectionsToRenderDataCalculatorOptions { + extend CalculatorOptions { + optional DetectionsToRenderDataCalculatorOptions ext = 248360806; + } + + // If true, produces a RenderData packet with no annotation when the input + // packet has no detection. Otherwise, it won't produce any packet. + // Please note, regardless of this flag nothing will be produce if there is + // no input packet for a timestamp. + optional bool produce_empty_packet = 1 [default = true]; + + // The delimiter to separate label(_id) and score. + optional string text_delimiter = 2 [default = ","]; + + // If true, each "label(_id),score" will be on a separate line. + // Otherwise, all "label(_id),score" will be concatenated when the detection + // has more than one label. + optional bool one_label_per_line = 3 [default = false]; + + // Rendering options for the label. + optional RenderAnnotation.Text text = 4; + + // Thickness for drawing the label(s) and the location_data(box). + optional double thickness = 5 [default = 1.0]; + + // Color for drawing the label(s), feature_tag, and the location_data(box). + optional Color color = 6; + + // An optional string that identifies this class of annotations + // for the render data output this calculator produces. If multiple + // instances of this calculator are present in the graph, this value + // should be unique among them. + optional string scene_class = 7 [default = "DETECTION"]; +} diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc new file mode 100644 index 000000000..f15fec3d0 --- /dev/null +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -0,0 +1,258 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/detections_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { + +using ::testing::DoubleNear; + +// Error tolerance for pixels, distances, etc. +static constexpr double kErrorTolerance = 1e-5; + +void VerifyRenderAnnotationColorThickness( + const RenderAnnotation& annotation, + const DetectionsToRenderDataCalculatorOptions& options) { + EXPECT_THAT(annotation.color(), EqualsProto(options.color())); + EXPECT_EQ(annotation.thickness(), options.thickness()); +} + +LocationData CreateLocationData(int32 xmin, int32 ymin, int32 width, + int32 height) { + LocationData location_data; + location_data.set_format(LocationData::BOUNDING_BOX); + location_data.mutable_bounding_box()->set_xmin(xmin); + location_data.mutable_bounding_box()->set_ymin(ymin); + location_data.mutable_bounding_box()->set_width(width); + location_data.mutable_bounding_box()->set_height(height); + return location_data; +} + +LocationData CreateRelativeLocationData(double xmin, double ymin, double width, + double height) { + LocationData location_data; + location_data.set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data.mutable_relative_bounding_box()->set_xmin(xmin); + location_data.mutable_relative_bounding_box()->set_ymin(ymin); + location_data.mutable_relative_bounding_box()->set_width(width); + location_data.mutable_relative_bounding_box()->set_height(height); + return location_data; +} + +Detection CreateDetection(const std::vector& labels, + const std::vector& label_ids, + const std::vector& scores, + const LocationData& location_data, + const std::string& feature_tag) { + Detection detection; + for (const auto& label : labels) { + detection.add_label(label); + } + for (const auto& label_id : label_ids) { + detection.add_label_id(label_id); + } + for (const auto& score : scores) { + detection.add_score(score); + } + *(detection.mutable_location_data()) = location_data; + detection.set_feature_tag(feature_tag); + return detection; +} + +TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_LIST:detection_list" + output_stream: "RENDER_DATA:render_data" + )")); + + LocationData location_data = CreateLocationData(100, 200, 300, 400); + auto detections(absl::make_unique()); + *(detections->add_detection()) = + CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"); + + runner.MutableInputs() + ->Tag("DETECTION_LIST") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(1, output.size()); + const auto& actual = output[0].Get(); + EXPECT_EQ(actual.render_annotations_size(), 3); + // Labels + EXPECT_EQ(actual.render_annotations(0).text().display_text(), "label1,0.3,"); + // Feature tag + EXPECT_EQ(actual.render_annotations(1).text().display_text(), "feature_tag"); + // Location data + EXPECT_EQ(actual.render_annotations(2).rectangle().left(), 100); + EXPECT_EQ(actual.render_annotations(2).rectangle().right(), 100 + 300); + EXPECT_EQ(actual.render_annotations(2).rectangle().top(), 200); + EXPECT_EQ(actual.render_annotations(2).rectangle().bottom(), 200 + 400); +} + +TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) { + CalculatorRunner runner{ParseTextProtoOrDie(R"( + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:detection_vector" + output_stream: "RENDER_DATA:render_data" + )")}; + + LocationData location_data = CreateLocationData(100, 200, 300, 400); + auto detections(absl::make_unique>()); + detections->push_back( + CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag")); + + runner.MutableInputs() + ->Tag("DETECTION_VECTOR") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(1, output.size()); + const auto& actual = output[0].Get(); + EXPECT_EQ(actual.render_annotations_size(), 3); + // Labels + EXPECT_EQ(actual.render_annotations(0).text().display_text(), "label1,0.3,"); + // Feature tag + EXPECT_EQ(actual.render_annotations(1).text().display_text(), "feature_tag"); + // Location data + EXPECT_EQ(actual.render_annotations(2).rectangle().left(), 100); + EXPECT_EQ(actual.render_annotations(2).rectangle().right(), 100 + 300); + EXPECT_EQ(actual.render_annotations(2).rectangle().top(), 200); + EXPECT_EQ(actual.render_annotations(2).rectangle().bottom(), 200 + 400); +} + +TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { + CalculatorRunner runner{ParseTextProtoOrDie(R"( + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_LIST:detection_list" + input_stream: "DETECTION_VECTOR:detection_vector" + output_stream: "RENDER_DATA:render_data" + )")}; + + LocationData location_data1 = CreateLocationData(100, 200, 300, 400); + auto detection_list(absl::make_unique()); + *(detection_list->add_detection()) = + CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1"); + runner.MutableInputs() + ->Tag("DETECTION_LIST") + .packets.push_back( + Adopt(detection_list.release()).At(Timestamp::PostStream())); + + LocationData location_data2 = CreateLocationData(600, 700, 800, 900); + auto detection_vector(absl::make_unique>()); + detection_vector->push_back( + CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2")); + runner.MutableInputs() + ->Tag("DETECTION_VECTOR") + .packets.push_back( + Adopt(detection_vector.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& actual = + runner.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(1, actual.size()); + // Check the feature tag for item from detection list. + EXPECT_EQ( + actual[0].Get().render_annotations(1).text().display_text(), + "feature_tag1"); + // Check the feature tag for item from detection vector. + EXPECT_EQ( + actual[0].Get().render_annotations(4).text().display_text(), + "feature_tag2"); +} + +TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { + // Check when produce_empty_packet is false. + CalculatorRunner runner1{ParseTextProtoOrDie(R"( + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_LIST:detection_list" + input_stream: "DETECTION_VECTOR:detection_vector" + output_stream: "RENDER_DATA:render_data" + options { + [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { + produce_empty_packet: false + } + } + )")}; + + auto detection_list1(absl::make_unique()); + runner1.MutableInputs() + ->Tag("DETECTION_LIST") + .packets.push_back( + Adopt(detection_list1.release()).At(Timestamp::PostStream())); + + auto detection_vector1(absl::make_unique>()); + runner1.MutableInputs() + ->Tag("DETECTION_VECTOR") + .packets.push_back( + Adopt(detection_vector1.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner1.Run()) << "Calculator execution failed."; + const std::vector& exact1 = + runner1.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(0, exact1.size()); + + // Check when produce_empty_packet is true. + CalculatorRunner runner2{ParseTextProtoOrDie(R"( + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_LIST:detection_list" + input_stream: "DETECTION_VECTOR:detection_vector" + output_stream: "RENDER_DATA:render_data" + options { + [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { + produce_empty_packet: true + } + } + )")}; + + auto detection_list2(absl::make_unique()); + runner2.MutableInputs() + ->Tag("DETECTION_LIST") + .packets.push_back( + Adopt(detection_list2.release()).At(Timestamp::PostStream())); + + auto detection_vector2(absl::make_unique>()); + runner2.MutableInputs() + ->Tag("DETECTION_VECTOR") + .packets.push_back( + Adopt(detection_vector2.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner2.Run()) << "Calculator execution failed."; + const std::vector& exact2 = + runner2.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(1, exact2.size()); + EXPECT_EQ(exact2[0].Get().render_annotations_size(), 0); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/latency.proto b/mediapipe/calculators/util/latency.proto new file mode 100644 index 000000000..4b122fb19 --- /dev/null +++ b/mediapipe/calculators/util/latency.proto @@ -0,0 +1,40 @@ +// Proto messages related to latency measurement for Soapbox. +syntax = "proto2"; + +package mediapipe; + +// Contains the latency information for a packet stream in mediapipe. The +// following are provided +// 1. current latency +// 2. running average +// 3. histogram of latencies observed +// 4. cumulative sum of latencies observed +// NextId: 13 +message PacketLatency { + // Reserved tags. + reserved 1, 3 to 6; + + // Current latency (delay in microseconds wrt a reference packet). + optional int64 current_latency_usec = 8; + + // The latency histogram which stores the count recorded for each specified + // interval. + repeated int64 counts = 9; + + // Number of intervals for the latency histogram output. + optional int64 num_intervals = 10 [default = 10]; + + // Size of the histogram intervals (in microseconds). The first interval is + // [0, interval_size_usec). The last interval extends to +inf. + optional int64 interval_size_usec = 11 [default = 10000]; + + // Running average of latencies observed so far. + optional int64 avg_latency_usec = 2; + + // An identifier label for the packet. + optional string label = 7; + + // Cumulative sum of individual packet latencies of all the packets output so + // far. + optional int64 sum_latency_usec = 12; +} diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.cc b/mediapipe/calculators/util/non_max_suppression_calculator.cc new file mode 100644 index 000000000..f02aa9c21 --- /dev/null +++ b/mediapipe/calculators/util/non_max_suppression_calculator.cc @@ -0,0 +1,359 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/location.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/rectangle.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +typedef std::vector Detections; +typedef std::vector> IndexedScores; + +namespace { + +constexpr char kImageTag[] = "IMAGE"; + +bool SortBySecond(const std::pair& indexed_score_0, + const std::pair& indexed_score_1) { + return (indexed_score_0.second > indexed_score_1.second); +} + +// Removes all but the max scoring label and its score from the detection. +// Returns true if the detection has at least one label. +bool RetainMaxScoringLabelOnly(Detection* detection) { + if (detection->label_id_size() == 0 && detection->label_size() == 0) { + return false; + } + CHECK(detection->label_id_size() == detection->score_size() || + detection->label_size() == detection->score_size()) + << "Number of scores must be equal to number of detections."; + + std::vector> indexed_scores; + for (int k = 0; k < detection->score_size(); ++k) { + indexed_scores.push_back(std::make_pair(k, detection->score(k))); + } + std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond); + const int top_index = indexed_scores[0].first; + detection->clear_score(); + detection->add_score(indexed_scores[0].second); + if (detection->label_id_size() > top_index) { + const int top_label_id = detection->label_id(top_index); + detection->clear_label_id(); + detection->add_label_id(top_label_id); + } else { + const std::string top_label = detection->label(top_index); + detection->clear_label(); + detection->add_label(top_label); + } + + return true; +} + +// Computes an overlap similarity between two rectangles. Similarity measure is +// defined by overlap_type parameter. +float OverlapSimilarity( + const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type, + const Rectangle_f& rect1, const Rectangle_f& rect2) { + if (!rect1.Intersects(rect2)) return 0.0f; + const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area(); + float normalization; + switch (overlap_type) { + case NonMaxSuppressionCalculatorOptions::JACCARD: + normalization = Rectangle_f(rect1).Union(rect2).Area(); + break; + case NonMaxSuppressionCalculatorOptions::MODIFIED_JACCARD: + normalization = rect2.Area(); + break; + case NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION: + normalization = rect1.Area() + rect2.Area() - intersection_area; + break; + default: + LOG(FATAL) << "Unrecognized overlap type: " << overlap_type; + } + return normalization > 0.0f ? intersection_area / normalization : 0.0f; +} + +// Computes an overlap similarity between two locations by first extracting the +// relative box (dimension normalized by frame width/height) from the location. +float OverlapSimilarity( + const int frame_width, const int frame_height, + const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type, + const Location& location1, const Location& location2) { + const auto rect1 = location1.ConvertToRelativeBBox(frame_width, frame_height); + const auto rect2 = location2.ConvertToRelativeBBox(frame_width, frame_height); + return OverlapSimilarity(overlap_type, rect1, rect2); +} + +// Computes an overlap similarity between two locations by first extracting the +// relative box from the location. It assumes that a relative-box representation +// is already available in the location, and therefore frame width and height +// are not needed for further normalization. +float OverlapSimilarity( + const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type, + const Location& location1, const Location& location2) { + const auto rect1 = location1.GetRelativeBBox(); + const auto rect2 = location2.GetRelativeBBox(); + return OverlapSimilarity(overlap_type, rect1, rect2); +} + +} // namespace + +// A calculator performing non-maximum suppression on a set of detections. +// Inputs: +// 1. IMAGE (optional): A stream of ImageFrame used to obtain the frame size. +// No image data is used. Not needed if the detection bounding boxes are +// already represented in normalized dimensions (0.0~1.0). +// 2. A variable number of input streams of type std::vector. The +// exact number of such streams should be set via num_detection_streams +// field in the calculator options. +// +// Outputs: a single stream of type std::vector containing a subset +// of the input detections after non-maximum suppression. +// +// Example config: +// node { +// calculator: "NonMaxSuppressionCalculator" +// input_stream: "IMAGE:frames" +// input_stream: "detections1" +// input_stream: "detections2" +// output_stream: "detections" +// options { +// [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { +// num_detection_streams: 2 +// max_num_detections: 10 +// min_suppression_threshold: 0.2 +// overlap_type: JACCARD +// } +// } +// } +class NonMaxSuppressionCalculator : public CalculatorBase { + public: + NonMaxSuppressionCalculator() = default; + ~NonMaxSuppressionCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + if (cc->Inputs().HasTag(kImageTag)) { + cc->Inputs().Tag(kImageTag).Set(); + } + for (int k = 0; k < options.num_detection_streams(); ++k) { + cc->Inputs().Index(k).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + CHECK_GT(options_.num_detection_streams(), 0) + << "At least one detection stream need to be specified."; + CHECK_NE(options_.max_num_detections(), 0) + << "max_num_detections=0 is not a valid value. Please choose a " + << "positive number of you want to limit the number of output " + << "detections, or set -1 if you do not want any limit."; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + // Add all input detections to the same vector. + Detections input_detections; + for (int i = 0; i < options_.num_detection_streams(); ++i) { + const auto& detections_packet = cc->Inputs().Index(i).Value(); + // Check whether this stream has a packet for this timestamp. + if (detections_packet.IsEmpty()) { + continue; + } + const auto& detections = detections_packet.Get(); + + input_detections.insert(input_detections.end(), detections.begin(), + detections.end()); + } + + // Check if there are any detections at all. + if (input_detections.empty()) { + if (options_.return_empty_detections()) { + cc->Outputs().Index(0).Add(new Detections(), cc->InputTimestamp()); + } + return ::mediapipe::OkStatus(); + } + + // Remove all but the maximum scoring label from each input detection. This + // corresponds to non-maximum suppression among detections which have + // identical locations. + Detections pruned_detections; + pruned_detections.reserve(input_detections.size()); + for (auto& detection : input_detections) { + if (RetainMaxScoringLabelOnly(&detection)) { + pruned_detections.push_back(detection); + } + } + + // Copy all the scores (there is a single score in each detection after + // the above pruning) to an indexed vector for sorting. The first value is + // the index of the detection in the original vector from which the score + // stems, while the second is the actual score. + IndexedScores indexed_scores; + indexed_scores.reserve(pruned_detections.size()); + for (int index = 0; index < pruned_detections.size(); ++index) { + indexed_scores.push_back( + std::make_pair(index, pruned_detections[index].score(0))); + } + std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond); + + const int max_num_detections = + (options_.max_num_detections() > -1) + ? options_.max_num_detections() + : static_cast(indexed_scores.size()); + // A set of detections and locations, wrapping the location data from each + // detection, which are retained after the non-maximum suppression. + auto* retained_detections = new Detections(); + retained_detections->reserve(max_num_detections); + + if (options_.algorithm() == NonMaxSuppressionCalculatorOptions::WEIGHTED) { + WeightedNonMaxSuppression(indexed_scores, pruned_detections, + max_num_detections, cc, retained_detections); + } else { + NonMaxSuppression(indexed_scores, pruned_detections, max_num_detections, + cc, retained_detections); + } + + cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); + } + + private: + void NonMaxSuppression(const IndexedScores& indexed_scores, + const Detections& detections, int max_num_detections, + CalculatorContext* cc, Detections* output_detections) { + std::vector retained_locations; + retained_locations.reserve(max_num_detections); + // We traverse the detections by decreasing score. + for (const auto& indexed_score : indexed_scores) { + const auto& detection = detections[indexed_score.first]; + if (options_.min_score_threshold() > 0 && + detection.score(0) < options_.min_score_threshold()) { + break; + } + const Location location(detection.location_data()); + bool suppressed = false; + // The current detection is suppressed iff there exists a retained + // detection, whose location overlaps more than the specified + // threshold with the location of the current detection. + for (const auto& retained_location : retained_locations) { + float similarity; + if (cc->Inputs().HasTag(kImageTag)) { + const auto& frame = cc->Inputs().Tag(kImageTag).Get(); + similarity = OverlapSimilarity(frame.Width(), frame.Height(), + options_.overlap_type(), + retained_location, location); + } else { + similarity = OverlapSimilarity(options_.overlap_type(), + retained_location, location); + } + if (similarity > options_.min_suppression_threshold()) { + suppressed = true; + break; + } + } + if (!suppressed) { + output_detections->push_back(detection); + retained_locations.push_back(location); + } + if (output_detections->size() >= max_num_detections) { + break; + } + } + } + + void WeightedNonMaxSuppression(const IndexedScores& indexed_scores, + const Detections& detections, + int max_num_detections, CalculatorContext* cc, + Detections* output_detections) { + IndexedScores remained_indexed_scores; + remained_indexed_scores.assign(indexed_scores.begin(), + indexed_scores.end()); + + IndexedScores remained; + IndexedScores candidates; + output_detections->clear(); + while (!remained_indexed_scores.empty()) { + const auto& detection = detections[remained_indexed_scores[0].first]; + if (options_.min_score_threshold() > 0 && + detection.score(0) < options_.min_score_threshold()) { + break; + } + + remained.clear(); + candidates.clear(); + const Location location(detection.location_data()); + // This includes the first box. + for (const auto& indexed_score : remained_indexed_scores) { + Location rest_location(detections[indexed_score.first].location_data()); + float similarity = + OverlapSimilarity(options_.overlap_type(), rest_location, location); + if (similarity > options_.min_suppression_threshold()) { + candidates.push_back(indexed_score); + } else { + remained.push_back(indexed_score); + } + } + auto weighted_detection = detection; + if (!candidates.empty()) { + float w_xmin = 0.0f; + float w_ymin = 0.0f; + float w_xmax = 0.0f; + float w_ymax = 0.0f; + float total_score = 0.0f; + for (const auto& candidate : candidates) { + total_score += candidate.second; + const auto& bbox = detections[candidate.first] + .location_data() + .relative_bounding_box(); + w_xmin += bbox.xmin() * candidate.second; + w_ymin += bbox.ymin() * candidate.second; + w_xmax += (bbox.xmin() + bbox.width()) * candidate.second; + w_ymax += (bbox.ymin() + bbox.height()) * candidate.second; + } + auto* weighted_location = weighted_detection.mutable_location_data() + ->mutable_relative_bounding_box(); + weighted_location->set_xmin(w_xmin / total_score); + weighted_location->set_ymin(w_ymin / total_score); + weighted_location->set_width((w_xmax / total_score) - + weighted_location->xmin()); + weighted_location->set_height((w_ymax / total_score) - + weighted_location->ymin()); + } + remained_indexed_scores = std::move(remained); + output_detections->push_back(weighted_detection); + } + } + + NonMaxSuppressionCalculatorOptions options_; +}; +REGISTER_CALCULATOR(NonMaxSuppressionCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.proto b/mediapipe/calculators/util/non_max_suppression_calculator.proto new file mode 100644 index 000000000..5fa960497 --- /dev/null +++ b/mediapipe/calculators/util/non_max_suppression_calculator.proto @@ -0,0 +1,68 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Options to NonMaxSuppression calculator, which performs non-maximum +// suppression on a set of detections. +message NonMaxSuppressionCalculatorOptions { + extend CalculatorOptions { + optional NonMaxSuppressionCalculatorOptions ext = 55383100; + } + + // Number of input streams. Each input stream should contain a vector of + // detections. + optional int32 num_detection_streams = 1 [default = 1]; + + // Maximum number of detections to be returned. If -1, then all detections are + // returned. + optional int32 max_num_detections = 2 [default = -1]; + + // Minimum score of detections to be returned. + optional float min_score_threshold = 6 [default = -1.0]; + + // Jaccard similarity threshold for suppression -- a detection would suppress + // all other detections whose scores are lower and overlap by at least the + // specified threshold. + optional float min_suppression_threshold = 3 [default = 1.0]; + + // During the overlap computation, which is used to determine whether a + // rectangle suppresses another rectangle, one can use the Jaccard similarity, + // defined as the ration of the intersection over union of the two rectangles. + // Alternatively a modified version of Jaccard can be used, where the + // normalization is done by the area of the rectangle being checked for + // suppression. + enum OverlapType { + UNSPECIFIED_OVERLAP_TYPE = 0; + JACCARD = 1; + MODIFIED_JACCARD = 2; + INTERSECTION_OVER_UNION = 3; + } + optional OverlapType overlap_type = 4 [default = JACCARD]; + + // Whether to put empty detection vector in output stream. + optional bool return_empty_detections = 5; + + // Algorithms that can be used to apply non-maximum suppression. + enum NmsAlgorithm { + DEFAULT = 0; + // Only supports relative bounding box for weighted NMS. + WEIGHTED = 1; + } + optional NmsAlgorithm algorithm = 7 [default = DEFAULT]; +} diff --git a/mediapipe/calculators/util/packet_frequency.proto b/mediapipe/calculators/util/packet_frequency.proto new file mode 100644 index 000000000..177a73b12 --- /dev/null +++ b/mediapipe/calculators/util/packet_frequency.proto @@ -0,0 +1,13 @@ +syntax = "proto2"; + +package mediapipe; + +// Contains the packet frequency information. +message PacketFrequency { + // Packet frequency (packets per second). + optional double packet_frequency_hz = 1; + + // A label that identifies what this packet frequency is for. Eg. "Gaze", + // "Gesture", etc. + optional string label = 2; +} diff --git a/mediapipe/calculators/util/packet_frequency_calculator.cc b/mediapipe/calculators/util/packet_frequency_calculator.cc new file mode 100644 index 000000000..f63c72fdc --- /dev/null +++ b/mediapipe/calculators/util/packet_frequency_calculator.cc @@ -0,0 +1,218 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/time/time.h" +#include "mediapipe/calculators/util/packet_frequency.pb.h" +#include "mediapipe/calculators/util/packet_frequency_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" + +namespace { +constexpr int kSecondsToMicroseconds = 1000000; + +} // namespace + +namespace mediapipe { +// A MediaPipe calculator that computes the frequency (in Hertz) of incoming +// packet streams. The frequency of packets is computed over a time window +// that is configured in options. There must be one output stream corresponding +// to every input packet stream. The frequency is output as a PacketFrequency +// proto. +// +// NOTE: +// 1. For computing frequency, packet timestamps are used and not the wall +// timestamp. Hence, the calculator is best-suited for real-time applications. +// 2. When multiple input/output streams are present, the calculator must be +// used with an ImmediateInputStreamHandler. +// +// Example config: +// node { +// calculator: "PacketFrequencyCalculator" +// input_stream: "input_stream_0" +// input_stream: "input_stream_1" +// . +// . +// input_stream: "input_stream_N" +// output_stream: "packet_frequency_0" +// output_stream: "packet_frequency_1" +// . +// . +// output_stream: "packet_frequency_N" +// input_stream_handler { +// input_stream_handler: "ImmediateInputStreamHandler" +// } +// options { +// [soapbox.PacketFrequencyCalculatorOptions.ext] { +// time_window_sec: 3.0 +// label: "stream_name_0" +// label: "stream_name_1" +// . +// . +// label: "stream_name_N" +// } +// } +// } +class PacketFrequencyCalculator : public CalculatorBase { + public: + PacketFrequencyCalculator() {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Outputs the given framerate on the specified output stream as a + // PacketFrequency proto. + ::mediapipe::Status OutputPacketFrequency(CalculatorContext* cc, + int stream_id, double framerate_hz, + const std::string& label, + const Timestamp& input_timestamp); + + // Adds the input timestamp in the particular stream's timestamp buffer. + ::mediapipe::Status AddPacketTimestampForStream(int stream_id, + int64 timestamp); + + // For the specified input stream, clears timestamps from buffer that are + // older than the configured time_window_sec. + ::mediapipe::Status ClearOldpacketTimestamps(int stream_id, + int64 current_timestamp); + + // Options for the calculator. + PacketFrequencyCalculatorOptions options_; + + // Map where key is the input stream ID and value is the timestamp of the + // first packet received on that stream. + std::map first_timestamp_for_stream_id_usec_; + + // Map where key is the input stream ID and value is a vector that stores + // timestamps of recently received packets on the stream. Timestamps older + // than the time_window_sec are continuously deleted for all the streams. + std::map> previous_timestamps_for_stream_id_; +}; +REGISTER_CALCULATOR(PacketFrequencyCalculator); + +::mediapipe::Status PacketFrequencyCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_EQ(cc->Outputs().NumEntries(), cc->Inputs().NumEntries()); + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).Set(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + RET_CHECK_EQ(options_.label_size(), cc->Inputs().NumEntries()); + RET_CHECK_GT(options_.time_window_sec(), 0); + RET_CHECK_LE(options_.time_window_sec(), 100); + + // Initialize the stream-related data structures. + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + RET_CHECK(!options_.label(i).empty()); + previous_timestamps_for_stream_id_[i] = {}; + first_timestamp_for_stream_id_usec_[i] = -1; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + if (cc->Inputs().Index(i).IsEmpty()) { + continue; + } + RET_CHECK_OK(AddPacketTimestampForStream(/*stream_id=*/i, + cc->InputTimestamp().Value())); + RET_CHECK_OK(ClearOldpacketTimestamps(/*stream_id=*/i, + cc->InputTimestamp().Value())); + + if (first_timestamp_for_stream_id_usec_[i] < 0) { + first_timestamp_for_stream_id_usec_[i] = cc->InputTimestamp().Value(); + + // Since this is the very first packet on this stream, we don't have a + // window of time over which we can compute the packet frequency. So + // outputting packet frequency for this stream as 0 Hz. + return OutputPacketFrequency(cc, /*stream_id=*/i, /*framerate_hz=*/0.0, + options_.label(i), cc->InputTimestamp()); + } + + // If the time elapsed is less that the configured time window, then use + // that time duration instead, else use the configured time window. + double time_window_usec = + std::min(static_cast(cc->InputTimestamp().Value() - + first_timestamp_for_stream_id_usec_[i]), + options_.time_window_sec() * kSecondsToMicroseconds); + + double framerate_hz = (previous_timestamps_for_stream_id_[i].size() * 1.0) / + (time_window_usec / kSecondsToMicroseconds); + + return OutputPacketFrequency(cc, /*stream_id=*/i, framerate_hz, + options_.label(i), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketFrequencyCalculator::AddPacketTimestampForStream( + int stream_id, int64 timestamp_usec) { + if (previous_timestamps_for_stream_id_.find(stream_id) == + previous_timestamps_for_stream_id_.end()) { + return ::mediapipe::InvalidArgumentError("Input stream id is invalid"); + } + + previous_timestamps_for_stream_id_[stream_id].push_back(timestamp_usec); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( + int stream_id, int64 current_timestamp_usec) { + if (previous_timestamps_for_stream_id_.find(stream_id) == + previous_timestamps_for_stream_id_.end()) { + return ::mediapipe::InvalidArgumentError("Input stream id is invalid"); + } + + auto& timestamps_buffer = previous_timestamps_for_stream_id_[stream_id]; + int64 time_window_usec = options_.time_window_sec() * kSecondsToMicroseconds; + + timestamps_buffer.erase( + std::remove_if(timestamps_buffer.begin(), timestamps_buffer.end(), + [&time_window_usec, + ¤t_timestamp_usec](const int64 timestamp_usec) { + return current_timestamp_usec - timestamp_usec > + time_window_usec; + }), + timestamps_buffer.end()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketFrequencyCalculator::OutputPacketFrequency( + CalculatorContext* cc, int stream_id, double framerate_hz, + const std::string& label, const Timestamp& input_timestamp) { + auto packet_frequency = absl::make_unique(); + packet_frequency->set_packet_frequency_hz(framerate_hz); + packet_frequency->set_label(label); + + cc->Outputs().Index(stream_id).Add(packet_frequency.release(), + input_timestamp); + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/packet_frequency_calculator.proto b/mediapipe/calculators/util/packet_frequency_calculator.proto new file mode 100644 index 000000000..e7be1c420 --- /dev/null +++ b/mediapipe/calculators/util/packet_frequency_calculator.proto @@ -0,0 +1,34 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Options for PacketFrequencyCalculator. +message PacketFrequencyCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional PacketFrequencyCalculatorOptions ext = 168468918; + } + + // Time window (in seconds) over which the packet frequency is computed. Must + // be greater than 0 and less than 100 seconds (in order to limit memory + // usage). + optional double time_window_sec = 1 [default = 3.0]; + + // Text identifiers for the input streams. + repeated string label = 2; +} diff --git a/mediapipe/calculators/util/packet_frequency_calculator_test.cc b/mediapipe/calculators/util/packet_frequency_calculator_test.cc new file mode 100644 index 000000000..7e2bfa706 --- /dev/null +++ b/mediapipe/calculators/util/packet_frequency_calculator_test.cc @@ -0,0 +1,196 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/packet_frequency.pb.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace { + +CalculatorGraphConfig::Node GetDefaultNode() { + return ParseTextProtoOrDie(R"( + calculator: "PacketFrequencyCalculator" + input_stream: "packet_stream" + output_stream: "packet_frequency" + options { + [mediapipe.PacketFrequencyCalculatorOptions.ext] { + time_window_sec: 3.0 + label: "stream_description" + } + } + )"); +} + +CalculatorGraphConfig::Node GetNodeWithMultipleStreams() { + return ParseTextProtoOrDie(R"( + calculator: "PacketFrequencyCalculator" + input_stream: "packet_stream_0" + input_stream: "packet_stream_1" + input_stream: "packet_stream_2" + output_stream: "packet_frequency_0" + output_stream: "packet_frequency_1" + output_stream: "packet_frequency_2" + input_stream_handler { input_stream_handler: "ImmediateInputStreamHandler" } + options { + [mediapipe.PacketFrequencyCalculatorOptions.ext] { + time_window_sec: 3.0 + label: "stream_description_0" + label: "stream_description_1" + label: "stream_description_2" + } + } + )"); +} + +// Tests packet frequency. +TEST(PacketFrequencyCalculatorTest, MultiPacketTest) { + // Setup the calculator runner and provide integer packets as input (note that + // it doesn't have to be integer; the calculator can take any type as input). + CalculatorRunner runner(GetDefaultNode()); + + // Packet 1. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(0))); + // Packet 2. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(500000))); + // Packet 3. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(1000000))); + // Packet 4. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(1500000))); + // Packet 5. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(3000000))); + // Packet 6. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(4000000))); + // Packet 7. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int).At(Timestamp(9000000))); + + // Run the calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output_packets = runner.Outputs().Index(0).packets; + + // Very first packet. So frequency is zero. + const auto& output1 = output_packets[0].Get(); + EXPECT_FLOAT_EQ(output1.packet_frequency_hz(), 0.0); + EXPECT_EQ(output1.label(), "stream_description"); + + // 2 packets in the first 500ms. + const auto& output2 = output_packets[1].Get(); + EXPECT_FLOAT_EQ(output2.packet_frequency_hz(), 4.000000); + EXPECT_EQ(output2.label(), "stream_description"); + + // 3 packets in the first 1 sec. + const auto& output3 = output_packets[2].Get(); + EXPECT_FLOAT_EQ(output3.packet_frequency_hz(), 3.000000); + EXPECT_EQ(output3.label(), "stream_description"); + + // 4 packets in the first 1.5 sec. + const auto& output4 = output_packets[3].Get(); + EXPECT_FLOAT_EQ(output4.packet_frequency_hz(), 2.666667); + EXPECT_EQ(output4.label(), "stream_description"); + + // 5 packets in the first 3 sec. + const auto& output5 = output_packets[4].Get(); + EXPECT_FLOAT_EQ(output5.packet_frequency_hz(), 1.666667); + EXPECT_EQ(output5.label(), "stream_description"); + + // 4 packets in the past 3 sec window. + const auto& output6 = output_packets[5].Get(); + EXPECT_FLOAT_EQ(output6.packet_frequency_hz(), 1.333333); + EXPECT_EQ(output6.label(), "stream_description"); + + // 1 packet in the past 3 sec window. + const auto& output7 = output_packets[6].Get(); + EXPECT_FLOAT_EQ(output7.packet_frequency_hz(), 0.33333334); + EXPECT_EQ(output7.label(), "stream_description"); +} + +// Tests packet frequency with multiple input/output streams. +TEST(PacketFrequencyCalculatorTest, MultiStreamTest) { + // Setup the calculator runner and provide strings as input on all streams + // (note that it doesn't have to be std::string; the calculator can take any + // type as input). + CalculatorRunner runner(GetNodeWithMultipleStreams()); + + // Packet 1 on stream 1. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new std::string).At(Timestamp(250000))); + // Packet 2 on stream 1. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new std::string).At(Timestamp(500000))); + // Packet 1 on stream 2. + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(new std::string).At(Timestamp(100000))); + // Packet 2 on stream 2. + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(new std::string).At(Timestamp(5000000))); + // Packet 1 on stream 3. + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(new std::string).At(Timestamp(0))); + // Packet 2 on stream 3. + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(new std::string).At(Timestamp(3000000))); + + // Run the calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output_packets_stream_1 = + runner.Outputs().Index(0).packets; + const std::vector& output_packets_stream_2 = + runner.Outputs().Index(1).packets; + const std::vector& output_packets_stream_3 = + runner.Outputs().Index(2).packets; + + // First packet on stream 1. So frequency is zero. + const auto& output1 = output_packets_stream_1[0].Get(); + EXPECT_FLOAT_EQ(output1.packet_frequency_hz(), 0.0); + EXPECT_EQ(output1.label(), "stream_description_0"); + + // Second packet on stream 1. + const auto& output2 = output_packets_stream_1[1].Get(); + EXPECT_FLOAT_EQ(output2.packet_frequency_hz(), 8.000000); + EXPECT_EQ(output2.label(), "stream_description_0"); + + // First packet on stream 2. So frequency is zero. + const auto& output3 = output_packets_stream_2[0].Get(); + EXPECT_FLOAT_EQ(output3.packet_frequency_hz(), 0.0); + EXPECT_EQ(output3.label(), "stream_description_1"); + + // Second packet on stream 2. + const auto& output4 = output_packets_stream_2[1].Get(); + EXPECT_FLOAT_EQ(output4.packet_frequency_hz(), 0.33333334); + EXPECT_EQ(output4.label(), "stream_description_1"); + + // First packet on stream 3. So frequency is zero. + const auto& output5 = output_packets_stream_3[0].Get(); + EXPECT_FLOAT_EQ(output5.packet_frequency_hz(), 0.0); + EXPECT_EQ(output5.label(), "stream_description_2"); + + // Second packet on stream 3. + const auto& output6 = output_packets_stream_3[1].Get(); + EXPECT_FLOAT_EQ(output6.packet_frequency_hz(), 0.66666669); + EXPECT_EQ(output6.label(), "stream_description_2"); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc new file mode 100644 index 000000000..162cc9356 --- /dev/null +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -0,0 +1,299 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "mediapipe/calculators/util/latency.pb.h" +#include "mediapipe/calculators/util/packet_latency_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/monotonic_clock.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +namespace { + +// Tag name for clock side packet. +constexpr char kClockTag[] = "CLOCK"; + +// Tag name for reference signal. +constexpr char kReferenceSignalTag[] = "REFERENCE_SIGNAL"; +} // namespace + +// A MediaPipe calculator that computes latency of incoming packet streams with +// respect to a reference signal (e.g image, audio frames). +// +// The latency of a packet wrt a reference packet is defined as the difference +// between arrival times of the two. A latency of X microseconds implies that +// the packet arrived X microseconds after its corresponding reference packet. +// For each packet stream, the calculator outputs the current latency, average, +// and a histogram of observed latencies so far. +// +// NOTE: +// 1) This calculator is meant to be used ONLY with an +// ImmediateInputStreamHandler. +// 2) This calculator is meant to be used only for real-time or simulated real- +// time applications. For example, the reference signal could be audio/video +// frames coming from a calculator that reads microphone/webcam data or some +// calculator that simulates real-time input. +// 3) If the packet labels are provided through options, then the number of +// labels should be exactly same as number of output_streams. If no packet +// label is defined in the node options, the calculator uses the input stream +// names. +// +// InputSidePacket (Optional): +// CLOCK: A clock for knowing current time. +// +// Inputs: +// 0- Packet stream 0 (e.g image feature 0): +// 1- Packet stream 1 (e.g image features 1): +// ... +// N- Packet stream N (e.g image features N): +// REFERENCE_SIGNAL: The reference signal from which the above packets were +// extracted (e.g image frames). +// +// Outputs: +// 0- Latency of packet stream 0. +// 1- Latency of packet stream 1. +// ... +// N- Latency of packet stream N. +// +// Example config: +// node { +// calculator: "PacketLatencyCalculator" +// input_side_packet: "monotonic_clock" +// input_stream: "packet_stream_0" +// input_stream: "packet_stream_1" +// ... +// input_stream: "packet_stream_N" +// input_stream: "REFERENCE_SIGNAL:camera_frames" +// output_stream: "packet_latency_0" +// output_stream: "packet_latency_1" +// ... +// output_stream: "packet_latency_N" +// options { +// [soapbox.PacketLatencyCalculatorOptions.ext] { +// num_intervals: 10 +// interval_size_usec: 10000 +// } +// } +// input_stream_handler { +// input_stream_handler: 'ImmediateInputStreamHandler' +// } +// } +class PacketLatencyCalculator : public CalculatorBase { + public: + PacketLatencyCalculator() {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // Resets the histogram and running average variables by initializing them to + // zero. + void ResetStatistics(); + + // Calculator options. + PacketLatencyCalculatorOptions options_; + + // Clock object. + std::shared_ptr<::mediapipe::Clock> clock_; + + // Clock time when the first reference packet was received. + int64 first_process_time_usec_ = -1; + + // Timestamp of the first reference packet received. + int64 first_reference_timestamp_usec_ = -1; + + // Number of packet streams. + int64 num_packet_streams_ = -1; + + // Latency output for each packet stream. + std::vector packet_latencies_; + + // Running sum and count of latencies for each packet stream. This is required + // to compute the average latency. + std::vector sum_latencies_usec_; + std::vector num_latencies_; + + // Clock time when last reset was done for histogram and running average. + int64 last_reset_time_usec_ = -1; +}; +REGISTER_CALCULATOR(PacketLatencyCalculator); + +::mediapipe::Status PacketLatencyCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK_GT(cc->Inputs().NumEntries(), 1); + + // Input and output streams. + int64 num_packet_streams = cc->Inputs().NumEntries() - 1; + RET_CHECK_EQ(cc->Outputs().NumEntries(), num_packet_streams); + for (int64 i = 0; i < num_packet_streams; ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).Set(); + } + + // Reference signal. + cc->Inputs().Tag(kReferenceSignalTag).SetAny(); + + // Clock side packet. + if (cc->InputSidePackets().HasTag(kClockTag)) { + cc->InputSidePackets() + .Tag(kClockTag) + .Set>(); + } + + return ::mediapipe::OkStatus(); +} + +void PacketLatencyCalculator::ResetStatistics() { + // Initialize histogram with zero counts and set running average to zero. + for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64 interval_index = 0; interval_index < options_.num_intervals(); + ++interval_index) { + packet_latencies_[i].set_counts(interval_index, 0); + } + + // Initialize the running sum and count to 0. + sum_latencies_usec_[i] = 0; + num_latencies_[i] = 0; + } +} + +::mediapipe::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + num_packet_streams_ = cc->Inputs().NumEntries() - 1; + + // Check if provided labels are of correct size. + bool labels_provided = !options_.packet_labels().empty(); + if (labels_provided) { + RET_CHECK_EQ(options_.packet_labels_size(), num_packet_streams_) + << "Input packet stream count different from output stream count."; + } + + // Check that histogram params are valid. + RET_CHECK_GT(options_.num_intervals(), 0); + RET_CHECK_GT(options_.interval_size_usec(), 0); + + // Initialize latency outputs for all streams. + packet_latencies_.resize(num_packet_streams_); + sum_latencies_usec_.resize(num_packet_streams_); + num_latencies_.resize(num_packet_streams_); + for (int64 i = 0; i < num_packet_streams_; ++i) { + // Initialize latency histograms with zero counts. + packet_latencies_[i].set_num_intervals(options_.num_intervals()); + packet_latencies_[i].set_interval_size_usec(options_.interval_size_usec()); + packet_latencies_[i].mutable_counts()->Resize(options_.num_intervals(), 0); + + // Set the label for the stream. The packet labels are taken from options + // (if provided). If not, default labels are created using the input/output + // stream names. + if (labels_provided) { + packet_latencies_[i].set_label(options_.packet_labels(i)); + } else { + int64 input_stream_index = cc->Inputs().TagMap()->GetId("", i).value(); + packet_latencies_[i].set_label( + cc->Inputs().TagMap()->Names()[input_stream_index]); + } + } + + // Initialize the clock. + if (cc->InputSidePackets().HasTag(kClockTag)) { + clock_ = cc->InputSidePackets() + .Tag("CLOCK") + .Get>(); + } else { + clock_ = std::shared_ptr<::mediapipe::Clock>( + ::mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { + // Record first process timestamp if this is the first call. + if (first_process_time_usec_ < 0 && + !cc->Inputs().Tag(kReferenceSignalTag).IsEmpty()) { + first_process_time_usec_ = absl::ToUnixMicros(clock_->TimeNow()); + first_reference_timestamp_usec_ = cc->InputTimestamp().Value(); + last_reset_time_usec_ = first_process_time_usec_; + } + + if (first_process_time_usec_ < 0) { + LOG(WARNING) << "No reference packet received."; + return ::mediapipe::OkStatus(); + } + + if (options_.reset_duration_usec() > 0) { + const int64 time_now_usec = absl::ToUnixMicros(clock_->TimeNow()); + if (time_now_usec - last_reset_time_usec_ >= + options_.reset_duration_usec()) { + ResetStatistics(); + last_reset_time_usec_ = time_now_usec; + } + } + + // Update latency info if there is any incoming packet. + for (int64 i = 0; i < num_packet_streams_; ++i) { + if (!cc->Inputs().Index(i).IsEmpty()) { + const auto& packet_timestamp_usec = cc->InputTimestamp().Value(); + + // Update latency statistics for this stream. + int64 current_clock_time_usec = absl::ToUnixMicros(clock_->TimeNow()); + int64 current_calibrated_timestamp_usec = + (current_clock_time_usec - first_process_time_usec_) + + first_reference_timestamp_usec_; + int64 packet_latency_usec = + current_calibrated_timestamp_usec - packet_timestamp_usec; + + // Invalid timestamps in input signals could result in negative latencies. + if (packet_latency_usec < 0) { + continue; + } + + // Update the latency, running average and histogram for this stream. + packet_latencies_[i].set_current_latency_usec(packet_latency_usec); + int64 interval_index = + packet_latency_usec / packet_latencies_[i].interval_size_usec(); + if (interval_index >= packet_latencies_[i].num_intervals()) { + interval_index = packet_latencies_[i].num_intervals() - 1; + } + packet_latencies_[i].set_counts( + interval_index, packet_latencies_[i].counts(interval_index) + 1); + sum_latencies_usec_[i] += packet_latency_usec; + num_latencies_[i] += 1; + packet_latencies_[i].set_avg_latency_usec(sum_latencies_usec_[i] / + num_latencies_[i]); + + packet_latencies_[i].set_sum_latency_usec(sum_latencies_usec_[i]); + + // Push the latency packet to output. + auto packet_latency = + absl::make_unique(packet_latencies_[i]); + cc->Outputs().Index(i).Add(packet_latency.release(), + cc->InputTimestamp()); + } + } + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/packet_latency_calculator.proto b/mediapipe/calculators/util/packet_latency_calculator.proto new file mode 100644 index 000000000..63ec5f989 --- /dev/null +++ b/mediapipe/calculators/util/packet_latency_calculator.proto @@ -0,0 +1,41 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PacketLatencyCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional PacketLatencyCalculatorOptions ext = 172681421; + } + + // Number of intervals for the latency histogram output. + optional int64 num_intervals = 1 [default = 10]; + + // Interval size (in microseconds) for the histogram. + optional int64 interval_size_usec = 2 [default = 10000]; + + // Reset time (in microseconds) for histogram and average. The histogram and + // running average are initialized to zero periodically based on the specified + // duration. Negative value implies never resetting the statistics. + optional int64 reset_duration_usec = 3 [default = -1]; + + // Identifier labels for each input packet stream. The order of labels must + // correspond 1:1 with the input streams order. The labels are copied to the + // latency information output by the calculator. + repeated string packet_labels = 4; +} diff --git a/mediapipe/calculators/util/packet_latency_calculator_test.cc b/mediapipe/calculators/util/packet_latency_calculator_test.cc new file mode 100644 index 000000000..8c32d7dcc --- /dev/null +++ b/mediapipe/calculators/util/packet_latency_calculator_test.cc @@ -0,0 +1,488 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/time/time.h" +#include "mediapipe/calculators/util/latency.pb.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/clock.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/simulation_clock_executor.h" +#include "mediapipe/framework/tool/sink.h" + +namespace mediapipe { + +namespace { + +class PacketLatencyCalculatorTest : public ::testing::Test { + protected: + void SetupSimulationClock() { + auto executor = std::make_shared(4); + simulation_clock_ = executor->GetClock(); + MEDIAPIPE_ASSERT_OK(graph_.SetExecutor("", executor)); + } + + void InitializeSingleStreamGraph() { + graph_config_ = ParseTextProtoOrDie(R"( + input_stream: "delayed_packet_0" + input_stream: "camera_frames" + node { + calculator: "PacketLatencyCalculator" + input_side_packet: "CLOCK:clock" + input_stream: "delayed_packet_0" + input_stream: "REFERENCE_SIGNAL:camera_frames" + output_stream: "packet_latency_0" + options { + [mediapipe.PacketLatencyCalculatorOptions.ext] { + num_intervals: 3 + interval_size_usec: 4 + reset_duration_usec: 100 + packet_labels: "dummy input 0" + } + } + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + } + )"); + + mediapipe::tool::AddVectorSink("packet_latency_0", &graph_config_, + &out_0_packets_); + + // Create the simulation clock side packet. + SetupSimulationClock(); + std::map side_packet; + side_packet["clock"] = + ::mediapipe::MakePacket>( + simulation_clock_); + + // Start graph run. + MEDIAPIPE_ASSERT_OK(graph_.Initialize(graph_config_, {})); + MEDIAPIPE_ASSERT_OK(graph_.StartRun(side_packet)); + // Let Calculator::Open() calls finish before continuing. + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilIdle()); + } + + void InitializeMultipleStreamGraph() { + graph_config_ = ParseTextProtoOrDie(R"( + input_stream: "delayed_packet_0" + input_stream: "delayed_packet_1" + input_stream: "delayed_packet_2" + input_stream: "camera_frames" + node { + calculator: "PacketLatencyCalculator" + input_side_packet: "CLOCK:clock" + input_stream: "delayed_packet_0" + input_stream: "delayed_packet_1" + input_stream: "delayed_packet_2" + input_stream: "REFERENCE_SIGNAL:camera_frames" + output_stream: "packet_latency_0" + output_stream: "packet_latency_1" + output_stream: "packet_latency_2" + options { + [mediapipe.PacketLatencyCalculatorOptions.ext] { + num_intervals: 3 + interval_size_usec: 4 + packet_labels: "dummy input 0" + packet_labels: "dummy input 1" + packet_labels: "dummy input 2" + } + } + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + } + )"); + + mediapipe::tool::AddVectorSink("packet_latency_0", &graph_config_, + &out_0_packets_); + mediapipe::tool::AddVectorSink("packet_latency_1", &graph_config_, + &out_1_packets_); + mediapipe::tool::AddVectorSink("packet_latency_2", &graph_config_, + &out_2_packets_); + MEDIAPIPE_ASSERT_OK(graph_.Initialize(graph_config_, {})); + + // Create the simulation clock side packet. + simulation_clock_.reset(new SimulationClock()); + std::map side_packet; + side_packet["clock"] = + ::mediapipe::MakePacket>( + simulation_clock_); + + // Start graph run. + MEDIAPIPE_ASSERT_OK(graph_.StartRun(side_packet)); + // Let Calculator::Open() calls finish before continuing. + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilIdle()); + } + + void InitializeSingleStreamGraphWithoutClock() { + graph_config_ = ParseTextProtoOrDie(R"( + input_stream: "delayed_packet_0" + input_stream: "camera_frames" + node { + calculator: "PacketLatencyCalculator" + input_stream: "delayed_packet_0" + input_stream: "REFERENCE_SIGNAL:camera_frames" + output_stream: "packet_latency_0" + options { + [mediapipe.PacketLatencyCalculatorOptions.ext] { + num_intervals: 3 + interval_size_usec: 4 + packet_labels: "dummy input 0" + } + } + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + } + )"); + + mediapipe::tool::AddVectorSink("packet_latency_0", &graph_config_, + &out_0_packets_); + + // Create the simulation clock side packet. + SetupSimulationClock(); + std::map side_packet; + side_packet["clock"] = + ::mediapipe::MakePacket>( + simulation_clock_); + + // Start graph run. + MEDIAPIPE_ASSERT_OK(graph_.Initialize(graph_config_, {})); + MEDIAPIPE_ASSERT_OK(graph_.StartRun(side_packet)); + // Let Calculator::Open() calls finish before continuing. + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilIdle()); + } + + PacketLatency CreatePacketLatency(const double latency_usec, + const int64 num_intervals, + const int64 interval_size_usec, + const std::vector& counts, + const int64 avg_latency_usec, + const std::string& label) { + PacketLatency latency_info; + latency_info.set_current_latency_usec(latency_usec); + latency_info.set_num_intervals(num_intervals); + latency_info.set_interval_size_usec(interval_size_usec); + int sum_counts = 0; + for (const int& count : counts) { + latency_info.add_counts(count); + sum_counts += count; + } + latency_info.set_avg_latency_usec(avg_latency_usec); + latency_info.set_sum_latency_usec(avg_latency_usec * sum_counts); + latency_info.set_label(label); + return latency_info; + } + + std::shared_ptr<::mediapipe::Clock> simulation_clock_; + CalculatorGraphConfig graph_config_; + CalculatorGraph graph_; + std::vector out_0_packets_; + std::vector out_1_packets_; + std::vector out_2_packets_; +}; + +// Calculator must not output any latency until input packets are received. +TEST_F(PacketLatencyCalculatorTest, DoesNotOutputUntilInputPacketReceived) { + // Initialize graph_. + InitializeSingleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send reference packets with timestamps 0, 6 and 10 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(6)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(10)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect zero output packets. + ASSERT_EQ(out_0_packets_.size(), 0); +} + +// Calculator must output correct latency for single stream. +TEST_F(PacketLatencyCalculatorTest, OutputsCorrectLatencyForSingleStream) { + // Initialize graph_. + InitializeSingleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send a reference packet with timestamp 10 usec at time 12 usec. + simulation_clock_->Sleep(absl::Microseconds(12)); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(10)))); + + // Add two delayed packets with timestamp 1 and 8 resp. + simulation_clock_->Sleep(absl::Microseconds(1)); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(1)))); + simulation_clock_->Sleep(absl::Microseconds(1)); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(8)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect two latency packets with timestamp 1 and 8 resp. + ASSERT_EQ(out_0_packets_.size(), 2); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 1); + EXPECT_EQ(out_0_packets_[1].Timestamp().Value(), 8); + + EXPECT_THAT( + out_0_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/10, + /*num_intervals=*/3, /*interval_size_usec=*/4, + /*counts=*/{0, 0, 1}, /*avg_latency_usec=*/10, "dummy input 0"))); + + EXPECT_THAT( + out_0_packets_[1].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/4, + /*num_intervals=*/3, /*interval_size_usec=*/4, + /*counts=*/{0, 1, 1}, /*avg_latency_usec=*/7, "dummy input 0"))); +} + +// Calculator must not output latency until reference signal is received. +TEST_F(PacketLatencyCalculatorTest, DoesNotOutputUntilReferencePacketReceived) { + // Initialize graph_. + InitializeSingleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Add two packets with timestamp 1 and 2. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(1)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(2)))); + + // Send a reference packet with timestamp 10 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(10)))); + simulation_clock_->Sleep(absl::Microseconds(1)); + + // Add two delayed packets with timestamp 7 and 9 resp. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(7)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(9)))); + simulation_clock_->Sleep(absl::Microseconds(1)); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect two latency packets with timestamp 7 and 9 resp. The packets with + // timestamps 1 and 2 should not have any latency associated with them since + // reference signal was not sent until then. + ASSERT_EQ(out_0_packets_.size(), 2); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 7); + EXPECT_EQ(out_0_packets_[1].Timestamp().Value(), 9); + + EXPECT_THAT( + out_0_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/4, + /*num_intervals=*/3, /*interval_size_usec=*/4, + /*counts=*/{0, 1, 0}, /*avg_latency_usec=*/4, "dummy input 0"))); + + EXPECT_THAT( + out_0_packets_[1].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/2, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{1, 1, 0}, /*avg_latency_usec=*/3, "dummy input 0"))); +} + +// Calculator outputs latency even when a clock is not provided. +TEST_F(PacketLatencyCalculatorTest, OutputsCorrectLatencyWhenNoClock) { + // Initialize graph_. + InitializeSingleStreamGraphWithoutClock(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send a reference packet with timestamp 10 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(10)))); + + // Add two delayed packets with timestamp 5 and 10 resp. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(5)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(10)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect two latency packets with timestamp 5 and 10 resp. + ASSERT_EQ(out_0_packets_.size(), 2); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 5); + EXPECT_EQ(out_0_packets_[1].Timestamp().Value(), 10); +} + +// Calculator must output correct histograms counts for the corner bins. +TEST_F(PacketLatencyCalculatorTest, + OutputsCorrectLatencyStatisticsInTimeWindow) { + // Initialize graph_. + InitializeSingleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send a reference packet with timestamp 20 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(20)))); + + // Add two delayed packets with timestamp 0 and 20 resp. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(20)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect two latency packets with timestamp 0 and 20 resp. + ASSERT_EQ(out_0_packets_.size(), 2); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 0); + EXPECT_EQ(out_0_packets_[1].Timestamp().Value(), 20); + + EXPECT_THAT( + out_0_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/20, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{0, 0, 1}, /*avg_latency_usec=*/20, "dummy input 0"))); + + EXPECT_THAT( + out_0_packets_[1].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/0, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{1, 0, 1}, /*avg_latency_usec=*/10, "dummy input 0"))); +} + +// Calculator must reset histogram and average after specified duration. +TEST_F(PacketLatencyCalculatorTest, ResetsHistogramAndAverageCorrectly) { + // Initialize graph_. + InitializeSingleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send a reference packet with timestamp 0 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(0)))); + + // Add a delayed packet with timestamp 0 usec at time 20 usec. + simulation_clock_->Sleep(absl::Microseconds(20)); + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(0)))); + + // Do a long sleep so that histogram and average are reset. + simulation_clock_->Sleep(absl::Microseconds(100)); + + // Add a delayed packet with timestamp 115 usec at time 120 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(115)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect two latency packets with timestamp 0 and 115 resp. + ASSERT_EQ(out_0_packets_.size(), 2); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 0); + EXPECT_EQ(out_0_packets_[1].Timestamp().Value(), 115); + + EXPECT_THAT( + out_0_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/20, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{0, 0, 1}, /*avg_latency_usec=*/20, "dummy input 0"))); + + // The new average and histogram should ignore the previous latency because + // reset has happened. + EXPECT_THAT( + out_0_packets_[1].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/5, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{0, 1, 0}, /*avg_latency_usec=*/5, "dummy input 0"))); +} + +// Calculator must output correct latency for multiple streams. +TEST_F(PacketLatencyCalculatorTest, OutputsCorrectLatencyForMultipleStreams) { + // Initialize graph. + InitializeMultipleStreamGraph(); + dynamic_cast(&*simulation_clock_)->ThreadStart(); + + // Send a reference packet with timestamp 10 usec. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "camera_frames", Adopt(new double()).At(Timestamp(10)))); + + // Add delayed packets on each input stream. + + // Fastest stream. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_0", Adopt(new double()).At(Timestamp(10)))); + + // Slow stream. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_1", Adopt(new double()).At(Timestamp(5)))); + + // Slowest stream. + MEDIAPIPE_ASSERT_OK(graph_.AddPacketToInputStream( + "delayed_packet_2", Adopt(new double()).At(Timestamp(0)))); + + dynamic_cast(&*simulation_clock_)->ThreadFinish(); + MEDIAPIPE_ASSERT_OK(graph_.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph_.WaitUntilDone()); + + // Expect one latency packet on each output stream. + ASSERT_EQ(out_0_packets_.size(), 1); + ASSERT_EQ(out_1_packets_.size(), 1); + ASSERT_EQ(out_2_packets_.size(), 1); + EXPECT_EQ(out_0_packets_[0].Timestamp().Value(), 10); + EXPECT_EQ(out_1_packets_[0].Timestamp().Value(), 5); + EXPECT_EQ(out_2_packets_[0].Timestamp().Value(), 0); + + EXPECT_THAT( + out_0_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/0, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{1, 0, 0}, /*avg_latency_usec=*/0, "dummy input 0"))); + EXPECT_THAT( + out_1_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/5, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{0, 1, 0}, /*avg_latency_usec=*/5, "dummy input 1"))); + EXPECT_THAT( + out_2_packets_[0].Get(), + EqualsProto(CreatePacketLatency( + /*latency_usec=*/10, /*num_intervals=*/3, + /*interval_size_usec=*/4, + /*counts=*/{0, 0, 1}, /*avg_latency_usec=*/10, "dummy input 2"))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD new file mode 100644 index 000000000..7546e5443 --- /dev/null +++ b/mediapipe/calculators/video/BUILD @@ -0,0 +1,157 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "flow_to_image_calculator_proto", + srcs = ["flow_to_image_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "opencv_video_encoder_calculator_proto", + srcs = ["opencv_video_encoder_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "flow_to_image_calculator_cc_proto", + srcs = ["flow_to_image_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":flow_to_image_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "opencv_video_encoder_calculator_cc_proto", + srcs = ["opencv_video_encoder_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":opencv_video_encoder_calculator_proto"], +) + +cc_library( + name = "flow_to_image_calculator", + srcs = ["flow_to_image_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/calculators/video:flow_to_image_calculator_cc_proto", + "//mediapipe/calculators/video/tool:flow_quantizer_model", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats/motion:optical_flow_field", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_library( + name = "opencv_video_decoder_calculator", + srcs = ["opencv_video_decoder_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + ], + alwayslink = 1, +) + +cc_library( + name = "opencv_video_encoder_calculator", + srcs = ["opencv_video_encoder_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":opencv_video_encoder_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "opencv_video_decoder_calculator_test", + srcs = ["opencv_video_decoder_calculator_test.cc"], + data = ["//mediapipe/calculators/video/testdata:test_videos"], + deps = [ + ":opencv_video_decoder_calculator", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +cc_test( + name = "opencv_video_encoder_calculator_test", + srcs = ["opencv_video_encoder_calculator_test.cc"], + data = ["//mediapipe/calculators/video/testdata:test_videos"], + deps = [ + ":opencv_video_decoder_calculator", + ":opencv_video_encoder_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:deleting_file", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:parse_text_proto", + ], +) diff --git a/mediapipe/calculators/video/flow_to_image_calculator.cc b/mediapipe/calculators/video/flow_to_image_calculator.cc new file mode 100644 index 000000000..d32319c6f --- /dev/null +++ b/mediapipe/calculators/video/flow_to_image_calculator.cc @@ -0,0 +1,114 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MediaPipe calculator to take a flow field as input, and outputs a normalized +// RGB image where the B channel is forced to zero. +// TODO: Add video stream header for visualization + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/video/flow_to_image_calculator.pb.h" +#include "mediapipe/calculators/video/tool/flow_quantizer_model.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/motion/optical_flow_field.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { + +// Reads optical flow fields defined in +// mediapipe/framework/formats/motion/optical_flow_field.h, +// returns a VideoFrame with 2 channels (v_x and v_y), each channel is quantized +// to 0-255. +// +// Example config: +// node { +// calculator: "FlowToImageCalculator" +// input_stream: "flow_fields" +// output_stream: "frames" +// options: { +// [type.googleapis.com/mediapipe.FlowToImageCalculatorOptions]:{ +// min_value: -40.0 +// max_value: 40.0 +// } +// } +// } +class FlowToImageCalculator : public CalculatorBase { + public: + FlowToImageCalculator() {} + ~FlowToImageCalculator() override {} + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + FlowQuantizerModel model_; +}; + +::mediapipe::Status FlowToImageCalculator::GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + + // Model sanity check + const auto& options = cc->Options(); + if (options.min_value() >= options.max_value()) { + return ::mediapipe::InvalidArgumentError("Invalid quantizer model."); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status FlowToImageCalculator::Open(CalculatorContext* cc) { + const auto& options = cc->Options(); + // Fill the the model_data, ideally we want to train the model, but we omit + // the step for now, and takes the (min, max) range from protobuf. + const QuantizerModelData& model_data = + ParseTextProtoOrDie( + absl::StrFormat("min_value:%f min_value:%f max_value:%f max_value:%f", + options.min_value(), options.min_value(), + options.max_value(), options.max_value())); + model_.LoadFromProto(model_data); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status FlowToImageCalculator::Process(CalculatorContext* cc) { + const auto& input = cc->Inputs().Index(0).Get(); + // Input flow is 2-channel with x-dim flow and y-dim flow. + // Convert it to a ImageFrame in SRGB space, the 3rd channel is not used (0). + const cv::Mat_& flow = input.flow_data(); + std::unique_ptr output( + new ImageFrame(ImageFormat::SRGB, input.width(), input.height())); + cv::Mat image = ::mediapipe::formats::MatView(output.get()); + + for (int j = 0; j != input.height(); ++j) { + for (int i = 0; i != input.width(); ++i) { + image.at(j, i) = + cv::Vec3b(model_.Apply(flow.at(j, i).x, 0), + model_.Apply(flow.at(j, i).y, 1), 0); + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(FlowToImageCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/video/flow_to_image_calculator.proto b/mediapipe/calculators/video/flow_to_image_calculator.proto new file mode 100644 index 000000000..6d5fb8450 --- /dev/null +++ b/mediapipe/calculators/video/flow_to_image_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Specifies the maximum and minimum value to truncate when normalize optical +// flow fields. +message FlowToImageCalculatorOptions { + extend CalculatorOptions { + optional FlowToImageCalculatorOptions ext = 69508592; + } + optional float min_value = 1 [default = -40.0]; + optional float max_value = 2 [default = 40.0]; +} diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc new file mode 100644 index 000000000..2b1f205c5 --- /dev/null +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -0,0 +1,184 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/status_util.h" + +namespace mediapipe { + +namespace { +// cv::VideoCapture set data type to unsigned char by default. Therefore, the +// image format is only related to the number of channles the cv::Mat has. +ImageFormat::Format GetImageFormat(int num_channels) { + ImageFormat::Format format; + switch (num_channels) { + case 1: + format = ImageFormat::GRAY8; + break; + case 3: + format = ImageFormat::SRGB; + break; + case 4: + format = ImageFormat::SRGBA; + break; + default: + format = ImageFormat::UNKNOWN; + break; + } + return format; +} +} // namespace + +// This Calculator takes no input streams and produces video packets. +// All streams and input side packets are specified using tags and all of them +// are optional. +// +// Output Streams: +// VIDEO: Output video frames (ImageFrame). +// VIDEO_PRESTREAM: +// Optional video header information output at +// Timestamp::PreStream() for the corresponding stream. +// Input Side Packets: +// INPUT_FILE_PATH: The input file path. +// +// Example config: +// node { +// calculator: "OpenCvVideoDecoderCalculator" +// input_side_packet: "INPUT_FILE_PATH:input_file_path" +// output_stream: "VIDEO:video_frames" +// output_stream: "VIDEO_PRESTREAM:video_header" +// } +class OpenCvVideoDecoderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); + cc->Outputs().Tag("VIDEO").Set(); + if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { + cc->Outputs().Tag("VIDEO_PRESTREAM").Set(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const std::string& input_file_path = + cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); + cap_ = absl::make_unique(input_file_path); + if (!cap_->isOpened()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Fail to open video file at " << input_file_path; + } + width_ = static_cast(cap_->get(cv::CAP_PROP_FRAME_WIDTH)); + height_ = static_cast(cap_->get(cv::CAP_PROP_FRAME_HEIGHT)); + double fps = static_cast(cap_->get(cv::CAP_PROP_FPS)); + frame_count_ = static_cast(cap_->get(cv::CAP_PROP_FRAME_COUNT)); + // Unfortunately, cap_->get(cv::CAP_PROP_FORMAT) always returns CV_8UC1 + // back. To get correct image format, we read the first frame from the video + // and get the number of channels. + cv::Mat frame; + cap_->read(frame); + if (frame.empty()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Fail to read any frames from the video file at " + << input_file_path; + } + format_ = GetImageFormat(frame.channels()); + if (format_ == ImageFormat::UNKNOWN) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Unsupported video format of the video file at " + << input_file_path; + } + + if (fps <= 0 || frame_count_ <= 0 || width_ <= 0 || height_ <= 0) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Fail to make video header due to the incorrect metadata from " + "the video file at " + << input_file_path; + } + auto header = absl::make_unique(); + header->format = format_; + header->width = width_; + header->height = height_; + header->frame_rate = fps; + header->duration = frame_count_ / fps; + + if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { + cc->Outputs() + .Tag("VIDEO_PRESTREAM") + .Add(header.release(), Timestamp::PreStream()); + } + // Rewind to the very first frame. + cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + auto image_frame = absl::make_unique(format_, width_, height_, + /*alignment_boundary=*/1); + // Use microsecond as the unit of time. + Timestamp timestamp(cap_->get(cv::CAP_PROP_POS_MSEC) * 1000); + if (format_ == ImageFormat::GRAY8) { + cv::Mat frame = formats::MatView(image_frame.get()); + cap_->read(frame); + if (frame.empty()) { + return tool::StatusStop(); + } + } else { + cv::Mat tmp_frame; + cap_->read(tmp_frame); + if (tmp_frame.empty()) { + return tool::StatusStop(); + } + if (format_ == ImageFormat::SRGB) { + cv::cvtColor(tmp_frame, formats::MatView(image_frame.get()), + cv::COLOR_BGR2RGB); + } else if (format_ == ImageFormat::SRGBA) { + cv::cvtColor(tmp_frame, formats::MatView(image_frame.get()), + cv::COLOR_BGRA2RGBA); + } + } + cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp); + decoded_frames_++; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + if (cap_ && cap_->isOpened()) { + cap_->release(); + } + if (decoded_frames_ != frame_count_) { + LOG(WARNING) << "Not all the frames are decoded (total frames: " + << frame_count_ << " vs decoded frames: " << decoded_frames_ + << ")."; + } + return ::mediapipe::OkStatus(); + } + + private: + std::unique_ptr cap_; + int width_; + int height_; + int frame_count_; + int decoded_frames_ = 0; + ImageFormat::Format format_; +}; + +REGISTER_CALCULATOR(OpenCvVideoDecoderCalculator); +} // namespace mediapipe diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc new file mode 100644 index 000000000..7d78bd728 --- /dev/null +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator_test.cc @@ -0,0 +1,161 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream")"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_MP4_AVC720P_AAC.video")); + MEDIAPIPE_EXPECT_OK(runner.Run()); + + EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + MEDIAPIPE_EXPECT_OK(runner.Outputs() + .Tag("VIDEO_PRESTREAM") + .packets[0] + .ValidateAsType()); + const mediapipe::VideoHeader& header = + runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + EXPECT_EQ(ImageFormat::SRGB, header.format); + EXPECT_EQ(1280, header.width); + EXPECT_EQ(640, header.height); + EXPECT_FLOAT_EQ(6.0f, header.duration); + EXPECT_FLOAT_EQ(30.0f, header.frame_rate); + EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); + for (int i = 0; i < 180; ++i) { + Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + cv::Mat output_mat = + formats::MatView(&(image_frame_packet.Get())); + EXPECT_EQ(1280, output_mat.size().width); + EXPECT_EQ(640, output_mat.size().height); + EXPECT_EQ(3, output_mat.channels()); + cv::Scalar s = cv::mean(output_mat); + for (int i = 0; i < 3; ++i) { + EXPECT_GT(s[i], 0); + EXPECT_LT(s[i], 255); + } + } +} + +TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream")"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_FLV_H264_AAC.video")); + MEDIAPIPE_EXPECT_OK(runner.Run()); + + EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + MEDIAPIPE_EXPECT_OK(runner.Outputs() + .Tag("VIDEO_PRESTREAM") + .packets[0] + .ValidateAsType()); + const mediapipe::VideoHeader& header = + runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + EXPECT_EQ(ImageFormat::SRGB, header.format); + EXPECT_EQ(640, header.width); + EXPECT_EQ(320, header.height); + // TODO: The actual header.duration is 6.0666666f and the frame_rate + // can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4). + // EXPECT_FLOAT_EQ(6.0f, header.duration); + // EXPECT_FLOAT_EQ(30.0f, header.frame_rate); + EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); + for (int i = 0; i < 180; ++i) { + Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + cv::Mat output_mat = + formats::MatView(&(image_frame_packet.Get())); + EXPECT_EQ(640, output_mat.size().width); + EXPECT_EQ(320, output_mat.size().height); + EXPECT_EQ(3, output_mat.channels()); + cv::Scalar s = cv::mean(output_mat); + for (int i = 0; i < 3; ++i) { + EXPECT_GT(s[i], 0); + EXPECT_LT(s[i], 255); + } + } +} + +TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream")"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_MKV_VP8_VORBIS.video")); + MEDIAPIPE_EXPECT_OK(runner.Run()); + + EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); + MEDIAPIPE_EXPECT_OK(runner.Outputs() + .Tag("VIDEO_PRESTREAM") + .packets[0] + .ValidateAsType()); + const mediapipe::VideoHeader& header = + runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get(); + EXPECT_EQ(ImageFormat::SRGB, header.format); + EXPECT_EQ(640, header.width); + EXPECT_EQ(320, header.height); + EXPECT_FLOAT_EQ(6.0f, header.duration); + EXPECT_FLOAT_EQ(30.0f, header.frame_rate); + EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); + for (int i = 0; i < 180; ++i) { + Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; + cv::Mat output_mat = + formats::MatView(&(image_frame_packet.Get())); + EXPECT_EQ(640, output_mat.size().width); + EXPECT_EQ(320, output_mat.size().height); + EXPECT_EQ(3, output_mat.channels()); + cv::Scalar s = cv::mean(output_mat); + for (int i = 0; i < 3; ++i) { + EXPECT_GT(s[i], 0); + EXPECT_LT(s[i], 255); + } + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc new file mode 100644 index 000000000..c34a30ade --- /dev/null +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -0,0 +1,176 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "mediapipe/calculators/video/opencv_video_encoder_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/tool/status_util.h" + +namespace mediapipe { + +// Encodes the input video stream and produces a media file. +// The media file can be output to the output_file_path specified as a side +// packet. Currently, the calculator only supports one video stream (in +// mediapipe::ImageFrame). +// +// Example config to generate the output video file: +// +// node { +// calculator: "OpenCvVideoEncoderCalculator" +// input_stream: "VIDEO:video" +// input_stream: "VIDEO_PRESTREAM:video_header" +// input_side_packet: "OUTPUT_FILE_PATH:output_file_path" +// node_options { +// [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { +// codec: "avc1" +// video_format: "mp4" +// } +// } +// } +class OpenCvVideoEncoderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status SetUpVideoWriter(float frame_rate, int width, int height); + + std::string output_file_path_; + int four_cc_; + std::unique_ptr writer_; +}; + +::mediapipe::Status OpenCvVideoEncoderCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("VIDEO")); + cc->Inputs().Tag("VIDEO").Set(); + if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { + cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); + } + RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH")); + cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { + OpenCvVideoEncoderCalculatorOptions options = + cc->Options(); + RET_CHECK(options.has_codec() && options.codec().length() == 4) + << "A 4-character codec code must be specified in " + "OpenCvVideoEncoderCalculatorOptions"; + const char* codec_array = options.codec().c_str(); + four_cc_ = mediapipe::fourcc(codec_array[0], codec_array[1], codec_array[2], + codec_array[3]); + RET_CHECK(!options.video_format().empty()) + << "Video format must be specified in " + "OpenCvVideoEncoderCalculatorOptions"; + output_file_path_ = + cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get(); + std::vector splited_file_path = + absl::StrSplit(output_file_path_, '.'); + RET_CHECK(splited_file_path.size() >= 2 && + splited_file_path[splited_file_path.size() - 1] == + options.video_format()) + << "The output file path is invalid."; + // If the video header will be available, the video metadata will be fetched + // from the video header directly. The calculator will receive the video + // header packet at timestamp prestream. + if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { + return ::mediapipe::OkStatus(); + } + return SetUpVideoWriter(options.fps(), options.width(), options.height()); +} + +::mediapipe::Status OpenCvVideoEncoderCalculator::Process( + CalculatorContext* cc) { + if (cc->InputTimestamp() == Timestamp::PreStream()) { + const VideoHeader& video_header = + cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); + return SetUpVideoWriter(video_header.frame_rate, video_header.width, + video_header.height); + } + + const ImageFrame& image_frame = + cc->Inputs().Tag("VIDEO").Value().Get(); + ImageFormat::Format format = image_frame.Format(); + cv::Mat frame; + if (format == ImageFormat::GRAY8) { + frame = formats::MatView(&image_frame); + if (frame.empty()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Receive empty frame at timestamp " + << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << " in OpenCvVideoEncoderCalculator::Process()"; + } + } else { + cv::Mat tmp_frame = formats::MatView(&image_frame); + if (tmp_frame.empty()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Receive empty frame at timestamp " + << cc->Inputs().Tag("VIDEO").Value().Timestamp() + << " in OpenCvVideoEncoderCalculator::Process()"; + } + if (format == ImageFormat::SRGB) { + cv::cvtColor(tmp_frame, frame, cv::COLOR_BGR2RGB); + } else if (format == ImageFormat::SRGBA) { + cv::cvtColor(tmp_frame, frame, cv::COLOR_BGRA2RGBA); + } else { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Unsupported image format: " << format; + } + } + writer_->write(frame); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { + if (writer_ && writer_->isOpened()) { + writer_->release(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter( + float frame_rate, int width, int height) { + RET_CHECK(frame_rate > 0 && width > 0 && height > 0) + << "Invalid video metadata: frame_rate=" << frame_rate + << ", width=" << width << ", height=" << height; + writer_ = absl::make_unique( + output_file_path_, four_cc_, frame_rate, cv::Size(width, height)); + if (!writer_->isOpened()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Fail to open file at " << output_file_path_; + } + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(OpenCvVideoEncoderCalculator); +} // namespace mediapipe diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.proto b/mediapipe/calculators/video/opencv_video_encoder_calculator.proto new file mode 100644 index 000000000..a3a7af3ac --- /dev/null +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.proto @@ -0,0 +1,37 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message OpenCvVideoEncoderCalculatorOptions { + extend CalculatorOptions { + optional OpenCvVideoEncoderCalculatorOptions ext = 207936763; + } + // The 4-character code of the codec to encode the video. + optional string codec = 1; + + // The video format of the output video file. + optional string video_format = 2; + + // The frame rate in Hz at which the video frames are output. + optional double fps = 3; + + // Dimensions of the video in pixels. + optional int32 width = 4; + optional int32 height = 5; +} diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc new file mode 100644 index 000000000..4323ce016 --- /dev/null +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc @@ -0,0 +1,218 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/deleting_file.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { +// Temporarily disable the test. +// TODO: Investigate the “Could not open codec 'libx264'” error with +// opencv2. +TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) { + CalculatorGraphConfig config = ParseTextProtoOrDie(R"( + node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream" + } + node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:video" + input_stream: "VIDEO_PRESTREAM:video_prestream" + input_side_packet: "OUTPUT_FILE_PATH:output_file_path" + node_options { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } + } + )"); + std::map input_side_packets; + input_side_packets["input_file_path"] = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_MP4_AVC720P_AAC.video")); + const std::string output_file_path = "/tmp/tmp_video.mp4"; + DeletingFile deleting_file(output_file_path, true); + input_side_packets["output_file_path"] = + MakePacket(output_file_path); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config, input_side_packets)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("video_prestream"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + Packet packet; + while (poller.Next(&packet)) { + } + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + const VideoHeader& video_header = packet.Get(); + + // Checks the generated video file has the same width, height, fps, and + // duration as the original one. + cv::VideoCapture cap(output_file_path); + ASSERT_TRUE(cap.isOpened()); + EXPECT_EQ(video_header.width, + static_cast(cap.get(cv::CAP_PROP_FRAME_WIDTH))); + EXPECT_EQ(video_header.height, + static_cast(cap.get(cv::CAP_PROP_FRAME_HEIGHT))); + EXPECT_EQ(video_header.frame_rate, + static_cast(cap.get(cv::CAP_PROP_FPS))); + EXPECT_EQ(video_header.duration, + static_cast(cap.get(cv::CAP_PROP_FRAME_COUNT) / + cap.get(cv::CAP_PROP_FPS))); +} + +TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) { + CalculatorGraphConfig config = ParseTextProtoOrDie(R"( + node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream" + } + node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:video" + input_stream: "VIDEO_PRESTREAM:video_prestream" + input_side_packet: "OUTPUT_FILE_PATH:output_file_path" + node_options { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "MJPG" + video_format: "avi" + } + } + } + )"); + std::map input_side_packets; + input_side_packets["input_file_path"] = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_FLV_H264_AAC.video")); + const std::string output_file_path = "/tmp/tmp_video.avi"; + DeletingFile deleting_file(output_file_path, true); + input_side_packets["output_file_path"] = + MakePacket(output_file_path); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config, input_side_packets)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("video_prestream"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + Packet packet; + while (poller.Next(&packet)) { + } + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + const VideoHeader& video_header = packet.Get(); + + // Checks the generated video file has the same width, height, fps, and + // duration as the original one. + cv::VideoCapture cap(output_file_path); + ASSERT_TRUE(cap.isOpened()); + EXPECT_EQ(video_header.width, + static_cast(cap.get(cv::CAP_PROP_FRAME_WIDTH))); + EXPECT_EQ(video_header.height, + static_cast(cap.get(cv::CAP_PROP_FRAME_HEIGHT))); + // TODO: The actual header.duration is 6.0666666f and the frame_rate + // can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4). + // EXPECT_EQ(video_header.frame_rate, + // static_cast(cap.get(cv::CAP_PROP_FPS))); + // EXPECT_EQ(video_header.duration, + // static_cast(cap.get(cv::CAP_PROP_FRAME_COUNT) / + // cap.get(cv::CAP_PROP_FPS))); +} + +TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) { + CalculatorGraphConfig config = ParseTextProtoOrDie(R"( + node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "VIDEO:video" + output_stream: "VIDEO_PRESTREAM:video_prestream" + } + node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:video" + input_stream: "VIDEO_PRESTREAM:video_prestream" + input_side_packet: "OUTPUT_FILE_PATH:output_file_path" + node_options { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "PIM1" + video_format: "mkv" + } + } + } + )"); + std::map input_side_packets; + input_side_packets["input_file_path"] = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/video/" + "testdata/format_MKV_VP8_VORBIS.video")); + const std::string output_file_path = "/tmp/tmp_video.mkv"; + DeletingFile deleting_file(output_file_path, true); + input_side_packets["output_file_path"] = + MakePacket(output_file_path); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config, input_side_packets)); + StatusOrPoller status_or_poller = + graph.AddOutputStreamPoller("video_prestream"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + Packet packet; + while (poller.Next(&packet)) { + } + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + const VideoHeader& video_header = packet.Get(); + + // Checks the generated video file has the same width, height, fps, and + // duration as the original one. + cv::VideoCapture cap(output_file_path); + ASSERT_TRUE(cap.isOpened()); + EXPECT_EQ(video_header.width, + static_cast(cap.get(cv::CAP_PROP_FRAME_WIDTH))); + EXPECT_EQ(video_header.height, + static_cast(cap.get(cv::CAP_PROP_FRAME_HEIGHT))); + EXPECT_EQ(video_header.frame_rate, + static_cast(cap.get(cv::CAP_PROP_FPS))); + EXPECT_EQ(video_header.duration, + static_cast(cap.get(cv::CAP_PROP_FRAME_COUNT) / + cap.get(cv::CAP_PROP_FPS))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/video/testdata/BUILD b/mediapipe/calculators/video/testdata/BUILD new file mode 100644 index 000000000..cd7c3d57c --- /dev/null +++ b/mediapipe/calculators/video/testdata/BUILD @@ -0,0 +1,26 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "test_videos", + srcs = [ + "format_FLV_H264_AAC.video", + "format_MKV_VP8_VORBIS.video", + "format_MP4_AVC720P_AAC.video", + ], + visibility = ["//visibility:public"], +) diff --git a/mediapipe/calculators/video/testdata/format_FLV_H264_AAC.video b/mediapipe/calculators/video/testdata/format_FLV_H264_AAC.video new file mode 100644 index 000000000..980da555e Binary files /dev/null and b/mediapipe/calculators/video/testdata/format_FLV_H264_AAC.video differ diff --git a/mediapipe/calculators/video/testdata/format_MKV_VP8_VORBIS.video b/mediapipe/calculators/video/testdata/format_MKV_VP8_VORBIS.video new file mode 100644 index 000000000..ace0e8e6e Binary files /dev/null and b/mediapipe/calculators/video/testdata/format_MKV_VP8_VORBIS.video differ diff --git a/mediapipe/calculators/video/testdata/format_MP4_AVC720P_AAC.video b/mediapipe/calculators/video/testdata/format_MP4_AVC720P_AAC.video new file mode 100644 index 000000000..4a9cafb94 Binary files /dev/null and b/mediapipe/calculators/video/testdata/format_MP4_AVC720P_AAC.video differ diff --git a/mediapipe/calculators/video/tool/BUILD b/mediapipe/calculators/video/tool/BUILD new file mode 100644 index 000000000..422bb034a --- /dev/null +++ b/mediapipe/calculators/video/tool/BUILD @@ -0,0 +1,50 @@ +# +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"]) + +exports_files(["LICENSE"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "flow_quantizer_model_proto", + srcs = ["flow_quantizer_model.proto"], +) + +mediapipe_cc_proto_library( + name = "flow_quantizer_model_cc_proto", + srcs = ["flow_quantizer_model.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":flow_quantizer_model_proto"], +) + +cc_library( + name = "flow_quantizer_model", + srcs = ["flow_quantizer_model.cc"], + hdrs = ["flow_quantizer_model.h"], + deps = [ + "//mediapipe/calculators/video/tool:flow_quantizer_model_cc_proto", + "//mediapipe/framework:type_map", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats/motion:optical_flow_field", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/tool:status_util", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/video/tool/flow_quantizer_model.cc b/mediapipe/calculators/video/tool/flow_quantizer_model.cc new file mode 100644 index 000000000..0cfad8539 --- /dev/null +++ b/mediapipe/calculators/video/tool/flow_quantizer_model.cc @@ -0,0 +1,78 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/video/tool/flow_quantizer_model.h" + +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/type_map.h" + +namespace mediapipe { + +// Uniform normalization to 0-255. +uint8 FlowQuantizerModel::Apply(const float val, const int channel) const { + CHECK_LT(channel, model_.min_value_size()); + const auto& min_value = model_.min_value(channel); + const auto& max_value = model_.max_value(channel); + QCHECK_GT(max_value, min_value); + float res = (val - min_value) / (max_value - min_value); + if (res < 0.0) { + res = 0.0; + } else if (res > 1.0) { + res = 1.0; + } + return static_cast(res * 255); +} + +void FlowQuantizerModel::LoadFromProto(const QuantizerModelData& data) { + QCHECK_GT(data.max_value(0), data.min_value(0)); + QCHECK_GT(data.max_value(1), data.min_value(1)); + + model_ = data; +} + +const QuantizerModelData& FlowQuantizerModel::GetModelData() const { + return model_; +} + +// Used for training, update the (min, max) range. We want to estimate the range +// of optical flow fields (Theorectically it is (-num_pixels_along_diag, +// num_pixels_along_diag). +// TODO: Taking the min and max over all training flow fields might be +// sensitive to noise. We should use more robust statistics. +void FlowQuantizerModel::AddSampleFlowField(const OpticalFlowField& flow) { + CHECK_EQ(model_.min_value_size(), 2); + const cv::Mat_& flow_mat = flow.flow_data(); + for (int i = 0; i != flow.width(); ++i) { + for (int j = 0; j != flow.height(); ++j) { + const auto& x = flow_mat.at(i, j).x; + const auto& y = flow_mat.at(i, j).y; + // Always use the minimum and maximum value occurred in training flow + // fields. + model_.set_min_value(0, std::min(x, model_.min_value(0))); + model_.set_min_value(1, std::min(y, model_.min_value(1))); + model_.set_max_value(0, std::max(x, model_.max_value(0))); + model_.set_max_value(1, std::max(y, model_.max_value(1))); + } + } +} + +void FlowQuantizerModel::Init() { + model_.Clear(); + // Initialize the values. + for (int i = 0; i != 2; ++i) { + model_.add_min_value(std::numeric_limits::max()); + model_.add_max_value(-std::numeric_limits::max()); + } +} +} // namespace mediapipe diff --git a/mediapipe/calculators/video/tool/flow_quantizer_model.h b/mediapipe/calculators/video/tool/flow_quantizer_model.h new file mode 100644 index 000000000..16ae4b7ac --- /dev/null +++ b/mediapipe/calculators/video/tool/flow_quantizer_model.h @@ -0,0 +1,47 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Quantization model to convert a real value float number (flow field) to a +// 8-bit discrete number. +#ifndef MEDIAPIPE_CALCULATORS_VIDEO_TOOL_FLOW_QUANTIZER_MODEL_H_ +#define MEDIAPIPE_CALCULATORS_VIDEO_TOOL_FLOW_QUANTIZER_MODEL_H_ + +#include "mediapipe/calculators/video/tool/flow_quantizer_model.pb.h" +#include "mediapipe/framework/formats/motion/optical_flow_field.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/tool/status_util.h" + +namespace mediapipe { + +class FlowQuantizerModel { + public: + // Initializes the model proto. + void Init(); + // Quantizes flow field with the model. + uint8 Apply(const float val, const int channel) const; + // Loads model from proto. + void LoadFromProto(const QuantizerModelData& data); + // Gets proto from model. + const QuantizerModelData& GetModelData() const; + // Used in training. Updates the model proto by reading the flow fields. + // TODO: This model is currently manually set. Need to find a way to + // learn from flow fields directly. + void AddSampleFlowField(const OpticalFlowField& flow); + + private: + QuantizerModelData model_; +}; +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_VIDEO_TOOL_FLOW_QUANTIZER_MODEL_H_ diff --git a/mediapipe/calculators/video/tool/flow_quantizer_model.proto b/mediapipe/calculators/video/tool/flow_quantizer_model.proto new file mode 100644 index 000000000..02f3fc11a --- /dev/null +++ b/mediapipe/calculators/video/tool/flow_quantizer_model.proto @@ -0,0 +1,24 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +// Message storing min value and max value for normalization in all channels. +message QuantizerModelData { + // For all channels. + repeated float min_value = 1; + repeated float max_value = 2; +} diff --git a/mediapipe/docs/Makefile b/mediapipe/docs/Makefile new file mode 100644 index 000000000..13a5e64fc --- /dev/null +++ b/mediapipe/docs/Makefile @@ -0,0 +1,21 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + rm -rf ./_build + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/mediapipe/docs/README.md b/mediapipe/docs/README.md new file mode 100644 index 000000000..199643cbe --- /dev/null +++ b/mediapipe/docs/README.md @@ -0,0 +1,2 @@ +This directory contains the source markdown files presented on +the [MediaPipe Read-the-Docs](https://mediapipe.readthedocs.io) documentation site. diff --git a/mediapipe/docs/calculator.md b/mediapipe/docs/calculator.md new file mode 100644 index 000000000..7e1f6e94c --- /dev/null +++ b/mediapipe/docs/calculator.md @@ -0,0 +1,165 @@ +## Building MediaPipe Calculators + +- [Example calculator](#example-calculator) + + +### Example calculator + +This section discusses the implementation of `PacketClonerCalculator`, which +does a relatively simple job, and is used in many calculator graphs. +`PacketClonerCalculator` simply produces a copy of its most recent input +packets on demand. + +`PacketClonerCalculator` is useful when the timestamps of arriving data packets +are not aligned perfectly. Suppose we have a room with a microphone, light +sensor and a video camera that is collecting sensory data. Each of the sensors +operates independently and collects data intermittently. Suppose that the output +of each sensor is: + +* microphone = loudness in decibels of sound in the room (Integer) +* light sensor = brightness of room (Integer) +* video camera = RGB image frame of room (ImageFrame) + +Our simple perception pipeline is designed to process sensory data from these 3 +sensors such that at any time when we have image frame data from the camera that +is synchronized with the last collected microphone loudness data and light +sensor brightness data. To do this with MediaPipe, our perception pipeline has 3 +input streams: + +* room_mic_signal - Each packet of data in this input stream is integer data + representing how loud audio is in a room with timestamp. +* room_lightening_sensor - Each packet of data in this input stream is integer + data representing how bright is the room illuminated with timestamp. +* room_video_tick_signal - Each packet of data in this input stream is + imageframe of video data representing video collected from camera in the + room with timestamp. + +Below is the implementation of the `PacketClonerCalculator`. You can see +the `GetContract()`, `Open()`, and `Process()` methods as well as the instance +variable `current_` which holds the most recent input packets. + +```c++ +// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B. +// For every packet that appears in B, outputs the most recent packet from each +// of the A_i on a separate stream. + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { + +// For every packet received on the last stream, output the latest packet +// obtained on all other streams. Therefore, if the last stream outputs at a +// higher rate than the others, this effectively clones the packets from the +// other streams to match the last. +// +// Example config: +// node { +// calculator: "PacketClonerCalculator" +// input_stream: "first_base_signal" +// input_stream: "second_base_signal" +// input_stream: "tick_signal" +// output_stream: "cloned_first_base_signal" +// output_stream: "cloned_second_base_signal" +// } +// +class PacketClonerCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + const int tick_signal_index = cc->Inputs().NumEntries() - 1; + // cc->Inputs().NumEntries() returns the number of input streams + // for the PacketClonerCalculator + for (int i = 0; i < tick_signal_index; ++i) { + cc->Inputs().Index(i).SetAny(); + // cc->Inputs().Index(i) returns the input stream pointer by index + cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); + } + cc->Inputs().Index(tick_signal_index).SetAny(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + tick_signal_index_ = cc->Inputs().NumEntries() - 1; + current_.resize(tick_signal_index_); + // Pass along the header for each stream if present. + for (int i = 0; i < tick_signal_index_; ++i) { + if (!cc->Inputs().Index(i).Header().IsEmpty()) { + cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); + // Sets the output stream of index i header to be the same as + // the header for the input stream of index i + } + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + // Store input signals. + for (int i = 0; i < tick_signal_index_; ++i) { + if (!cc->Inputs().Index(i).Value().IsEmpty()) { + current_[i] = cc->Inputs().Index(i).Value(); + } + } + + // Output if the tick signal is non-empty. + if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) { + for (int i = 0; i < tick_signal_index_; ++i) { + if (!current_[i].IsEmpty()) { + cc->Outputs().Index(i).AddPacket( + current_[i].At(cc->InputTimestamp())); + // Add a packet to output stream of index i a packet from inputstream i + // with timestamp common to all present inputs + // + } else { + cc->Outputs().Index(i).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + // if current_[i], 1 packet buffer for input stream i is empty, we will set + // next allowed timestamp for input stream i to be current timestamp + 1 + } + } + } + return ::mediapipe::OkStatus(); + } + + private: + std::vector current_; + int tick_signal_index_; +}; + +REGISTER_CALCULATOR(PacketClonerCalculator); +} // namespace mediapipe +``` + +Typically, a calculator has only a .cc file. No .h is required, because +mediapipe uses registration to make calculators known to it. After you have +defined your calculator class, register it with a macro invocation +REGISTER_CALCULATOR(calculator_class_name). + +Below is a trivial MediaPipe graph that has 3 input streams, 1 node +(PacketClonerCalculator) and 3 output streams. + +```proto +input_stream: "room_mic_signal" +input_stream: "room_lighting_sensor" +input_stream: "room_video_tick_signal" + +node { + calculator: "PacketClonerCalculator" + input_stream: "room_mic_signal" + input_stream: "room_lighting_sensor" + input_stream: "room_video_tick_signal" + output_stream: "cloned_room_mic_signal" + output_stream: "cloned_lighting_sensor" + output_stream: "cloned_video_tick_signal" + } +``` + +The diagram below shows how the `PacketClonerCalculator` defines its output +packets based on its series of input packets. + +| ![Graph using PacketClonerCalculator](images/packet_cloner_calculator.png) | +|:--:| +| *Each time it receives a packet on its TICK input stream, the PacketClonerCalculator outputs the most recent packet from each of its input streams. The sequence of output packets is determined by the sequene of input packets and their timestamps. The timestamps are shows along the right side of the diagram.* | + + diff --git a/mediapipe/docs/concepts.md b/mediapipe/docs/concepts.md new file mode 100644 index 000000000..37d988e6d --- /dev/null +++ b/mediapipe/docs/concepts.md @@ -0,0 +1,92 @@ +# MediaPipe Concepts + +## The basics + +### Packet + +The basic data flow unit. A packet consists of a numeric timestamp and a shared pointer to an **immutable** payload. The payload can be of any C++ type, and the payload's type is also referred to as the type of the packet. Packets are value classes and can be copied cheaply. Each copy shares ownership of the payload, with reference-counting semantics. Each copy has its own timestamp. [Details](packets.md). + +### Graph + +MediaPipe processing takes place inside a graph, which defines packet flow paths +between **nodes**. A graph can have any number of inputs and outputs, and data +flow can branch and merge. Generally data flows forward, but +[backward loops](cycles.md) are possible. + +### Nodes + +Nodes produce and/or consume packets, and they are where the bulk of the graph’s +work takes place. They are also known as “calculators”, for historical reasons. +Each node’s interface defines a number of input and output **ports**, identified by +a tag and/or an index. + +### Streams + +A stream is a connection between two nodes that carries a sequence of packets, +whose timestamps must be monotonically increasing. + +### Side packets + +A side packet connection between nodes carries a single packet (with unspecified +timestamp). It can be used to provide some data that will remain constant, +whereas a stream represents a flow of data that changes over time. + +### Packet Ports + +A port has an associated type; packets transiting through the port must be of +that type. An output stream port can be connected to any number of +input stream ports of the same type; each consumer receives a separate copy of +the output packets, and has its own queue, so it can consume them at its own +pace. Similarly, a side packet output port can be connected to as many side +packet input ports as desired. + +A port can be required, meaning that a connection must be made for the graph to +be valid, or optional, meaning it may remain unconnected. + +Note: even if a stream connection is required, the stream may not carry a packet for all timestamps. + +## Input and output + +Data flow can originate from **source nodes**, which have no input streams and +produce packets spontaneously (e.g. by reading from a file); or from **graph input streams**, which let an application feed packets into a graph. + +Similarly, there are **sink nodes** that receive data and write it to various +destinations (e.g. a file, a memory buffer, etc.), and an application can also +receive output from the graph using **callbacks**. + +## Runtime behavior + +### Graph lifetime + +Once a graph has been initialized, it can be **started** to begin processing +data, and can process a stream of packets until each stream is closed or the +graph is **canceled**. Then the graph can be destroyed or **started** again. + +### Node lifetime + +There are three main lifetime methods the framework will call on a node: + +- Open: called once, before the other methods. When it is called, all input + side packets required by the node will be available. +- Process: called multiple times, when a new set of inputs is available, + according to the node’s input policy. +- Close: called once, at the end. + +In addition, each calculator can define constructor and destructor, which are +useful for creating and deallocating resources that are independent of the +processed data. + +### Input policies + +The default input policy is deterministic collation of packets by timestamp. A node receives +all inputs for the same timestamp at the same time, in an invocation of its +Process method; and successive input sets are received in their timestamp order. This can +require delaying the processing of some packets until a packet with the same +timestamp is received on all input streams, or until it can be guaranteed that a +packet with that timestamp will not be arriving on the streams that have not +received it. + +Other policies are also available, implemented using a separate kind of +component known as an InputStreamHandler. + +See [scheduling](scheduling_sync.md) for more details. diff --git a/mediapipe/docs/conf.py b/mediapipe/docs/conf.py new file mode 100644 index 000000000..fe8352196 --- /dev/null +++ b/mediapipe/docs/conf.py @@ -0,0 +1,57 @@ +"""Configuration file for the Sphinx documentation builder. + +This file only contains a selection of the most common options. +For a full list see the documentation: +http://www.sphinx-doc.org/en/master/config +-- Path setup -------------------------------------------------------------- +If extensions (or modules to document with autodoc) are in another directory, +add these directories to sys.path here. +If the directory is relative to the documentation root, +use os.path.abspath to make it absolute, like shown here. + +""" +import sphinx_rtd_theme + + +# -- Project information ----------------------------------------------------- + +project = 'MediaPipe' +author = 'Google LLC' + +# The full version, including alpha/beta/rc tags +release = 'v0.5' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'recommonmark' +] + +master_doc = 'index' + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/mediapipe/docs/cycles.md b/mediapipe/docs/cycles.md new file mode 100644 index 000000000..ab5cfc184 --- /dev/null +++ b/mediapipe/docs/cycles.md @@ -0,0 +1,128 @@ +# Cycles in MediaPipe Graphs + + + +[TOC] + +By default, MediaPipe requires calculator graphs to be acyclic and treats cycles +in a graph as errors. If a graph is intended to have cycles, the cycles need to +be annotated in the graph config. This page describes how to do that. + +NOTE: The current approach is experimental and subject to change. We welcome +your feedback. + +Please use the `CalculatorGraphTest.Cycle` unit test in +`mediapipe/framework/calculator_graph_test.cc` as sample code. Shown +below is the cyclic graph in the test. The `sum` output of the adder is the sum +of the integers generated by the integer source calculator. + +![a cyclic graph that adds a stream of integers](images/cyclic_integer_sum_graph.svg "A cyclic graph") + +This simple graph illustrates all the issues in supporting cyclic graphs. + +## Back Edge Annotation + +We require that an edge in each cycle be annotated as a back edge. This allows +MediaPipe’s topological sort to work, after removing all the back edges. + +There are usually multiple ways to select the back edges. Which edges are marked +as back edges affects which nodes are considered as upstream and which nodes are +considered as downstream, which in turn affects the priorities MediaPipe assigns +to the nodes. + +For example, the `CalculatorGraphTest.Cycle` test marks the `old_sum` edge as a +back edge, so the Delay node is considered as a downstream node of the adder +node and is given a higher priority. Alternatively, we could mark the `sum` +input to the delay node as the back edge, in which case the delay node would be +considered as an upstream node of the adder node and is given a lower priority. + +## Initial Packet + +For the adder calculator to be runnable when the first integer from the integer +source arrives, we need an initial packet, with value 0 and with the same +timestamp, on the `old_sum` input stream to the adder. This initial packet +should be output by the delay calculator in the `Open()` method. + +## Delay in a Loop + +Each loop should incur a delay to align the previous `sum` output with the next +integer input. This is also done by the delay node. So the delay node needs to +know the following about the timestamps of the integer source calculator: + +* The timestamp of the first output. + +* The timestamp delta between successive outputs. + +We plan to add an alternative scheduling policy that only cares about packet +ordering and ignores packet timestamps, which will eliminate this inconvenience. + +## Early Termination of a Calculator When One Input Stream is Done + +By default, MediaPipe calls the `Close()` method of a non-source calculator when +all of its input streams are done. In the example graph, we want to stop the +adder node as soon as the integer source is done. This is accomplished by +configuring the adder node with an alternative input stream hander, +`EarlyCloseInputStreamHandler`. + +## Relevant Source Code + +### Delay Calculator + +Note the code in `Open()` that outputs the initial packet and the code in +`Process()` that adds a (unit) delay to input packets. As noted above, this +delay node assumes that its output stream is used alongside an input stream with +packet timestamps 0, 1, 2, 3, ... + +```c++ +class UnitDelayCalculator : public Calculator { + public: +  static ::util::Status FillExpectations( +      const CalculatorOptions& extendable_options, PacketTypeSet* inputs, +      PacketTypeSet* outputs, PacketTypeSet* input_side_packets) { +    inputs->Index(0)->Set("An integer."); +    outputs->Index(0)->Set("The input delayed by one time unit."); +    return ::mediapipe::OkStatus(); +  } + +  ::util::Status Open() final { +    Output()->Add(new int(0), Timestamp(0)); +    return ::mediapipe::OkStatus(); +  } + +  ::util::Status Process() final { +    const Packet& packet = Input()->Value(); +    Output()->AddPacket(packet.At(packet.Timestamp().NextAllowedInStream())); +    return ::mediapipe::OkStatus(); +  } +}; +``` + +### Graph Config + +Note the `back_edge` annotation and the alternative `input_stream_handler`. + +```proto +node { +  calculator: 'GlobalCountSourceCalculator' +  input_side_packet: 'global_counter' +  output_stream: 'integers' +} +node { +  calculator: 'IntAdderCalculator' +  input_stream: 'integers' +  input_stream: 'old_sum' +  input_stream_info: { +    tag_index: ':1' # 'old_sum' +    back_edge: true +  } +  output_stream: 'sum' +  input_stream_handler { +    input_stream_handler: 'EarlyCloseInputStreamHandler' +  } +} +node { +  calculator: 'UnitDelayCalculator' +  input_stream: 'sum' +  output_stream: 'old_sum' +} +``` diff --git a/mediapipe/docs/examples.md b/mediapipe/docs/examples.md new file mode 100644 index 000000000..e8d773e26 --- /dev/null +++ b/mediapipe/docs/examples.md @@ -0,0 +1,73 @@ +# Examples + +Below are code samples on how to run MediaPipe on both mobile and desktop. We +currently support MediaPipe APIs on mobile for Android only but will add support +for Objective-C shortly. + +## Mobile + +### Hello World! on Android + +[Hello World! on Android](./hello_world_android.md) should be the first mobile +example users go through in detail. It teaches the following: + +* Introduction of a simple MediaPipe graph running on mobile GPUs for + [Sobel edge detection]. +* Building a simple baseline Android application that displays "Hello World!". +* Adding camera preview support into the baseline application using the + Android [CameraX] API. +* Incorporating the Sobel edge detection graph to process the live camera + preview and display the processed video in real-time. + +### Object Detection with GPU on Android + +[Object Detection on GPU on Android](./object_detection_android_gpu.md) +illustrates how to use MediaPipe with a TFLite model for object detection in a +GPU-accelerated pipeline. + +### Object Detection with CPU on Android + +[Object Detection on CPU on Android](./object_detection_android_cpu.md) +illustrates using the same TFLite model in a CPU-based pipeline. This example +highlights how graphs can be easily adapted to run on CPU v.s. GPU. + +### Face Detection on Android + +[Face Detection on Android](./face_detection_android_gpu.md) illustrates how to +use MediaPipe with a TFLite model for face detection in a GPU-accelerated +pipeline. + +* The selfie face detection TFLite model is based on + ["BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/blazeface). +* [Model card](https://sites.google.com/corp/view/perception-cv4arvr/blazeface#h.p_21ojPZDx3cqq). + +### Hair Segmentation on Android + +[Hair Segmentation on Android](./hair_segmentation_android_gpu.md) illustrates +how to use MediaPipe with a TFLite model for hair segmentation in a +GPU-accelerated pipeline. + +* The selfie hair segmentation TFLite model is based on + ["Real-time Hair segmentation and recoloring on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/hair-segmentation). +* [Model card](https://sites.google.com/corp/view/perception-cv4arvr/hair-segmentation#h.p_NimuO7PgHxlY). + +## Desktop + +### Hello World for C++ + +[Hello World for C++](./hello_world_desktop.md) shows how to run a simple graph +using the MediaPipe C++ APIs. + +### Preparing Data Sets with MediaSequence + +[Preparing Data Sets with MediaSequence](./media_sequence.md) shows how to use +MediaPipe for media processing to prepare video data sets for training a +TensorFlow model. + +### Object Detection on Desktop + +[Object Detection on Desktop](./object_detection_desktop.md) shows how to run +object detection models (TensorFlow and TFLite) using the MediaPipe C++ APIs. + +[Sobel edge detection]:https://en.wikipedia.org/wiki/Sobel_operator +[CameraX]:https://developer.android.com/training/camerax diff --git a/mediapipe/docs/face_detection_android_gpu.md b/mediapipe/docs/face_detection_android_gpu.md new file mode 100644 index 000000000..7878041d3 --- /dev/null +++ b/mediapipe/docs/face_detection_android_gpu.md @@ -0,0 +1,231 @@ +# Face Detection on Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. This +doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt) +that performs face detection with TensorFlow Lite on GPU. + +![face_detection_android_gpu_gif](images/mobile/face_detection_android_gpu.gif){width="300"} + +## App + +The graph is used in the +[Face Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.apk +``` + +## Graph + +![face_detection_android_gpu_graph](images/mobile/face_detection_android_gpu.png){width="400"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). + +```bash +# MediaPipe graph that performs object detection with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on GPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored in +# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically +# option is set to true to account for the descrepancy between the +# representation of the input image (origin at the bottom-left corner, the +# OpenGL convention) and what the model used in this graph is expecting (origin +# at the top-left corner). +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "TENSORS_GPU:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: true + flip_vertically: true + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS_GPU:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "ssdlite_object_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS_GPU:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + sigmoid_score: true + apply_exponential_on_box_size: true + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + flip_vertically: true + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 3 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. Annotation drawing is performed on CPU, and the result is +# transferred to GPU and overlaid on the input image. The calculator assumes +# that image origin is always at the top-left corner and renders text +# accordingly. However, the input image has its origin at the bottom-left corner +# (OpenGL convention) and the flip_text_vertically option is set to true to +# compensate that. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:throttled_input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME_GPU:output_video" + node_options: { + [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { + flip_text_vertically: true + } + } +} +``` diff --git a/mediapipe/docs/framework_concepts.md b/mediapipe/docs/framework_concepts.md new file mode 100644 index 000000000..34aa429e8 --- /dev/null +++ b/mediapipe/docs/framework_concepts.md @@ -0,0 +1,436 @@ +## Framework Concepts + +- [CalculatorBase](#calculatorbase) +- [Life of a Calculator](#life-of-a-calculator) +- [Identifying inputs and outputs](#identifying-inputs-and-outputs) +- [Processing](#processing) +- [GraphConfig](#graphconfig) +- [Subgraph](#subgraph) + +Each calculator is a node of a graph. We describe how to create a new +calculator, how to initialize a calculator, how to perform its calculations, +input and output streams, timestamps, and options. Each node in the graph is +implemented as a `Calculator`. The bulk of graph execution happens inside its +calculators. A calculator may receive zero or more input streams and/or side +packets and produces zero or more output streams and/or side packets. + +### CalculatorBase + +A calculator is created by defining a new sub-class of the +[`CalculatorBase`](http://github.com/google/mediapipe/mediapipe/framework/calculator_base.cc) +class, implementing a number of methods, and registering the new sub-class with +Mediapipe. At a minimum, a new calculator must implement the below four methods + +* `GetContract()` + * Calculator authors can specify the expected types of inputs and outputs of a calculator in GetContract(). When a graph is initialized, the framework calls a static method to verify if the packet types of the connected inputs and outputs match the information in this specification. +* `Open()` + * After a graph starts, the framework calls `Open()`. The input side packets are available to the calculator at this point. `Open()` interprets the node configuration (see Section \ref{graph_config}) operations and prepares the calculator's per-graph-run state. This function may also write packets to calculator outputs. An error during `Open()` can terminate the graph run. +* `Process()` + * For a calculator with inputs, the framework calls `Process()` repeatedly whenever at least one input stream has a packet available. The framework by default guarantees that all inputs have the same timestamp (see Section \ref{scheduling} for more information). Multiple `Process()` calls can be invoked simultaneously when parallel execution is enabled. If an error occurs during `Process()`, the framework calls `Close()` and the graph run terminates. +* `Close()` + * After all calls to `Process()` finish or when all input streams close, the framework calls `Close()`. This function is always called if `Open()` was called and succeeded and even if the graph run terminated because of an error. No inputs are available via any input streams during `Close()`, but it still has access to input side packets and therefore may write outputs. After `Close()` returns, the calculator should be considered a dead node. The calculator object is destroyed as soon as the graph finishes running. + +The following are code snippets from +[CalculatorBase.h](http://github.com/google/mediapipe/mediapipe/framework/calculator_base.h). + +```c++ +class CalculatorBase { + public: + ... + + // The subclasses of CalculatorBase must implement GetContract. + // ... + static ::MediaPipe::Status GetContract(CalculatorContract* cc); + + // Open is called before any Process() calls, on a freshly constructed + // calculator. Subclasses may override this method to perform necessary + // setup, and possibly output Packets and/or set output streams' headers. + // ... + virtual ::MediaPipe::Status Open(CalculatorContext* cc) { + return ::MediaPipe::OkStatus(); + } + + // Processes the incoming inputs. May call the methods on cc to access + // inputs and produce outputs. + // ... + virtual ::MediaPipe::Status Process(CalculatorContext* cc) = 0; + + // Is called if Open() was called and succeeded. Is called either + // immediately after processing is complete or after a graph run has ended + // (if an error occurred in the graph). ... + virtual ::MediaPipe::Status Close(CalculatorContext* cc) { + return ::MediaPipe::OkStatus(); + } + + ... +}; +``` +### Life of a calculator + +During initialization of a MediaPipe graph, the framework calls a +`GetContract()` static method to determine what kinds of packets are expected. + +The framework constructs and destroys the entire calculator for each graph run (e.g. once per video or once per image). Expensive or large objects that remain constant across graph runs should be supplied as input side packets so the calculations are not repeated on subsequent runs. + +After initialization, for each run of the graph, the following sequence occurs: + +* `Open()` +* `Process()` (repeatedly) +* `Close()` + +The framework calls `Open()` to initialize the calculator. `Open()` should interpret any options and set up the calculator's per-graph-run state. `Open()` may obtain input side packets and write packets to calculator outputs. If appropriate, it should call `SetOffset()` to reduce potential packet buffering of input streams. + +If an error occurs during `Open()` or `Process()` (as indicated by one of them returning a non-`Ok ` status), the graph run is terminated with no further calls to the calculator's methods, and the calculator is destroyed. + +For a calculator with inputs, the framework calls `Process()` whenever at least one input has a packet available. The framework guarantees that inputs all have the same timestamp, that timestamps increase with each call to `Process()` and that all packets are delivered. As a consequence, some inputs may not have any packets when `Process()` is called. An input whose packet is missing appears to produce an empty packet (with no timestamp). + +The framework calls `Close()` after all calls to `Process()`. All inputs will have been exhausted, but `Close()` has access to input side packets and may write outputs. After Close returns, the calculator is destroyed. + +Calculators with no inputs are referred to as sources. A source calculator continues to have `Process()` called as long as it returns an `Ok` status. A source calculator indicates that it is exhausted by returning a stop status (i.e. MediaPipe::tool::StatusStop). + +### Identifying inputs and outputs + +The public interface to a calculator consists of a set of input streams and +output streams. In a CalculatorGraphConfiguration, the outputs from some +calculators are connected to the inputs of other calculators using named +streams. Stream names are normally lowercase, while input and output tags are +normally UPPERCASE. In the example below, the output with tag name `VIDEO` is +connected to the input with tag name `VIDEO_IN` using the stream named +`video_stream`. + +```proto +# Graph describing calculator SomeAudioVideoCalculator +node { + calculator: "SomeAudioVideoCalculator" + input_stream: "INPUT:combined_input" + output_stream: "VIDEO:video_stream" +} +node { + calculator: "SomeVideoCalculator" + input_stream: "VIDEO_IN:video_stream" + output_stream: "VIDEO_OUT:processed_video" +} +``` + +Input and output streams can be identified by index number, by tag name, or by a +combination of tag name and index number. You can see some examples of input and +output identifiers in the example below. `SomeAudioVideoCalculator` identifies +its video output by tag and its audio outputs by the combination of tag and +index. The input with tag `VIDEO` is connected to the stream named +`video_stream`. The inputs with tag `AUDIO` and indices `0` and `1` are +connected to the streams named `audio_left` and `audio_right`. +`SomeAudioCalculator` identifies its audio inputs by index only (no tag needed). + +```proto +# Graph describing calculator SomeAudioVideoCalculator +node { + calculator: "SomeAudioVideoCalculator" + input_stream: "combined_input" + output_stream: "VIDEO:video_stream" + output_stream: "AUDIO:0:audio_left" + output_stream: "AUDIO:1:audio_right" +} + +node { + calculator: "SomeAudioCalculator" + input_stream: "audio_left" + input_stream: "audio_right" + output_stream: "audio_energy" +} +``` + +In the calculator implementation, inputs and outputs are also identified by tag +name and index number. In the function below input are output are identified: + +* By index number: The combined input stream is identified simply by index + `0`. +* By tag name: The video output stream is identified by tag name "VIDEO". +* By tag name and index number: The output audio streams are identified by the + combination of the tag name `AUDIO` and the index numbers `0` and `1`. + +```c++ +// c++ Code snippet describing the SomeAudioVideoCalculator GetContract() method +class SomeAudioVideoCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + // SetAny() is used to specify that whatever the type of the + // stream is, it's acceptable. This does not mean that any + // packet is acceptable. Packets in the stream still have a + // particular type. SetAny() has the same effect as explicitly + // setting the type to be the stream's type. + cc->Outputs().Tag("VIDEO").Set(); + cc->Outputs().Get("AUDIO", 0).Set; + cc->Outputs().Get("AUDIO", 1).Set; + return ::mediapipe::OkStatus(); + } +``` + +### Processing + +`Process()` called on a non-source node must return `::mediapipe::OkStatus()` to +indicate that all went well, or any other status code to signal an error + +If a non-source calculator returns `tool::StatusStop()`, then this signals the +graph is being cancelled early. In this case, all source calculators and graph +input streams will be closed (and remaining Packets will propagate through the +graph). + +A source node in a graph will continue to have `Process()` called on it as long +as it returns `::mediapipe::OkStatus(`). To indicate that there is no more data +to be generated return `tool::StatusStop()`. Any other status indicates an error +has occurred. + +`Close()` returns `::mediapipe::OkStatus()` to indicate success. Any other +status indicates a failure. + +Here is the basic `Process()` function. It uses the `Input()` method (which can +be used only if the calculator has a single input) to request its input data. It +then uses `std::unique_ptr` to allocate the memory needed for the output packet, +and does the calculations. When done it releases the pointer when adding it to +the output stream. + +```c++ +::util::Status MyCalculator::Process() { + const Matrix& input = Input()->Get(); + std::unique_ptr output(new Matrix(input.rows(), input.cols())); + // do your magic here.... + // output->row(n) = ... + Output()->Add(output.release(), InputTimestamp()); + return ::mediapipe::OkStatus(); +} +``` + +### GraphConfig + +A `GraphConfig` is a specification that describes the topology and functionality +of a MediaPipe graph. In the specification, a node in the graph represents an +instance of a particular calculator. All the necessary configurations of the +node, such its type, inputs and outputs must be described in the specification. +Description of the node can also include several optional fields, such as +node-specific options, input policy and executor, discussed in Section +[Framework Concepts > Scheduling mechanics](scheduling_sync.md#scheduling-mechanics). + +`GraphConfig` has several other fields to configure the global graph-level +settings, eg, graph executor configs, number of threads, and maximum queue size +of input streams. Several graph-level settings are useful for tuning the +performance of the graph on different platforms (eg, desktop v.s. mobile). For +instance, on mobile, attaching a heavy model-inference calculator to a separate +executor can improve the performance of a real-time application since this +enables thread locality. + +Below is a trivial `GraphConfig` example where we have series of passthrough +calculators : + +```proto +# This graph named main_pass_throughcals_nosubgraph.pbtxt contains 4 +# passthrough calculators. +input_stream: "in" +node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out1" +} +node { + calculator: "PassThroughCalculator" + input_stream: "out1" + output_stream: "out2" +} +node { + calculator: "PassThroughCalculator" + input_stream: "out2" + output_stream: "out3" +} +node { + calculator: "PassThroughCalculator" + input_stream: "out3" + output_stream: "out4" +} +``` + +### Subgraph + +To modularize a `CalculatorGraphConfig` into sub-modules and assist with re-use +of perception solutions, a MediaPipe graph can be defined as a `Subgraph`. The +public interface to a subgraph consists of a set of input and output streams +similar to the public interface of a calculator. The subgraph can then be +included in an `CalculatorGraphConfig` as if it were a calculator. When a +MediaPipe graph is loaded from a `CalculatorGraphConfig`, each subgraph node is +replaced by the corresponding graph of calculators. As a result, the semantics +and performance of the subgraph is identical to the corresponding graph of +calculators. + +Below is an example of how to create a subgraph named `TwoPassThroughSubgraph` + +1. Defining the subgraph. + + ```proto + # This subgraph is defined in two_pass_through_subgraph.pbtxt + # that is registered in the BUILD file as "TwoPassThroughSubgraph" + input_stream: "out1" + output_stream: "out3" + + node { + calculator: "PassThroughculator" + input_stream: "out1" + output_stream: "out2" + } + node { + calculator: "PassThroughculator" + input_stream: "out2" + output_stream: "out3" + } + ``` + +The public interface to the graph that consist of: + * Graph input streams + * Graph output streams + * Graph input side packets + * Graph output side packets + +2. Register the subgraph using BUILD rule `mediapipe_simple_subgraph` + * The parameter `register_as` defines the component name for the new subgraph + + ```proto + # Small section of BUILD file for registering the "TwoPassThroughSubgraph" + # subgraph for use by main graph main_pass_throughcals.pbtxt + # + mediapipe_simple_subgraph( + name = "twopassthrough_subgraph", + graph = "twopassthrough_subgraph.pbtxt", + register_as = "TwoPassThroughSubgraph", + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_graph", + ], + ) + ``` + +3. Use the subgraph in the main graph + + ```proto + # This main graph is defined in main_pass_throughcals.pbtxt + # using subgraph called "TwoPassThroughSubgraph" + # + input_stream: "in" + node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out1" + } + node { + calculator: "TwoPassThroughSubgraph" + input_stream: "out1" + output_stream: "out3" + } + node { + calculator: "PassThroughCalculator" + input_stream: "out3" + output_stream: "out4" + } + ``` + + diff --git a/mediapipe/docs/gpu.md b/mediapipe/docs/gpu.md new file mode 100644 index 000000000..6c2b07cc1 --- /dev/null +++ b/mediapipe/docs/gpu.md @@ -0,0 +1,130 @@ +## Running on GPUs + +- [Overview](#overview) +- [OpenGL Support](#graphconfig) +- [Life of a GPU calculator](#life-of-a-gpu-calculator) +- [GpuBuffer to ImageFrame converters](#gpubuffer-to-imageframe-converters) + + +### Overview +MediaPipe supports calculator nodes for GPU compute and rendering, and allows combining multiple GPU nodes, as well as mixing them with CPU based calculator nodes. There exist several GPU APIs on mobile platforms (eg, OpenGL ES, Metal and Vulkan). MediaPipe does not attempt to offer a single cross-API GPU abstraction. Individual nodes can be written using different APIs, allowing them to take advantage of platform specific features when needed. + +GPU support is essential for good performance on mobile platforms, especially for real-time video. MediaPipe enables developers to write GPU compatible calculators that support the use of GPU for: + + * On-device real-time processing, not just batch processing + * Video rendering and effects, not just analysis + +Below are the design principles for GPU support in MediaPipe + + * GPU-based calculators should be able to occur anywhere in the graph, and not necessarily be used for on-screen rendering. + * Transfer of frame data from one GPU-based calculator to another should be fast, and not incur expensive copy operations. + * Transfer of frame data between CPU and GPU should be as efficient as the platform allows. + * Because different platforms may require different techniques for best performance, the API should allow flexibility in the way things are implemented behind the scenes. + * A calculator should be allowed maximum flexibility in using the GPU for all or part of its operation, combining it with the CPU if necessary. + +### OpenGL support +MediaPipe supports OpenGL ES up to version 3.2 on Android and up to ES 3.0 on iOS. In addition, MediaPipe also supports Metal on iOS. + + * MediaPipe allows graphs to run OpenGL in multiple GL contexts. For example, this can be very useful in graphs that combine a slower GPU inference path (eg, at 10 FPS) with a faster GPU rendering path (eg, at 30 FPS): since one GL context corresponds to one sequential command queue, using the same context for both tasks would reduce the rendering frame rate. One challenge MediaPipe's use of multiple contexts solves is the ability to communicate across them. An example scenario is one with an input video that is sent to both the rendering and inferences paths, and rendering needs to have access to the latest output from inference. + + * An OpenGL context cannot be accessed by multiple threads at the same time. Furthermore, switching the active GL context on the same thread can be slow on some Android devices. Therefore, our approach is to have one dedicated thread per context. Each thread issues GL commands, building up a serial command queue on its context, which is then executed by the GPU asynchronously. + +### Life of a GPU calculator + +This section presents the basic structure of the Process method of a GPU +calculator derived from base class GlSimpleCalculator. The GPU calculator +`LuminanceCalculator` is shown as an example. The method +`LuminanceCalculator::GlRender` is called from `GlSimpleCalculator::Process`. + +``` +// Converts RGB images into luminance images, still stored in RGB format. +// See GlSimpleCalculator for inputs, outputs and input side packets. +class LuminanceCalculator : public GlSimpleCalculator { + public: + ::mediapipe::Status GlSetup() override; + ::mediapipe::Status GlRender(const GlTexture& src, + const GlTexture& dst) override; + ::mediapipe::Status GlTeardown() override; + + private: + GLuint program_ = 0; + GLint frame_; +}; +REGISTER_CALCULATOR(LuminanceCalculator); + +::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // program + glUseProgram(program_); + glUniform1i(frame_, 1); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + + return ::mediapipe::OkStatus(); +} +``` + +The design principles mentioned above have resulted in the following design +choices for MediaPipe GPU support: + + * We have a GPU data type, called `GpuBuffer`, for representing image data, optimized for GPU usage. The exact contents of this data type are opaque and platform-specific. + * A low-level API based on composition, where any calculator that wants to make use of the GPU creates and owns an instance of the `GlCalculatorHelper` class. This class offers a platform-agnostic API for managing the OpenGL context, setting up textures for inputs and outputs, etc. + * A high-level API based on subclassing, where simple calculators implementing image filters subclass from `GlSimpleCalculator` and only need to override a couple of virtual methods with their specific OpenGL code, while the superclass takes care of all the plumbing. + * Data that needs to be shared between all GPU-based calculators is provided as a external input that is implemented as a graph service and is managed by the `GlCalculatorHelper` class. + * The combination of calculator-specific helpers and a shared graph service allows us great flexibility in managing the GPU resource: we can have a separate context per calculator, share a single context, share a lock or other synchronization primitives, etc. -- and all of this is managed by the helper and hidden from the individual calculators. + +### GpuBuffer to ImageFrame converters + +We provide two calculators called `GpuBufferToImageFrameCalculator` and `ImageFrameToGpuBufferCalculator`. These calculators convert between `ImageFrame` and `GpuBuffer`, allowing the construction of graphs that combine GPU and CPU calculators. They are supported on both iOS and Android + +When possible, these calculators use platform-specific functionality to share data between the CPU and the GPU without copying. + +The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU. + +| ![How GPU calculators interact](images/gpu_example_graph.png) | +|:--:| +| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. The input stream is accessed by two calculators in parallel. `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, which is then sent through a grayscale converter and a canny filter (both based on OpenCV and running on the CPU), whose output is then converted into a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as input both the original `GpuBuffer` and the one coming out of the edge detector, and overlays them using a shader. The output is then sent back to the application using a callback calculator, and the application renders the image to the screen using OpenGL.* | + diff --git a/mediapipe/docs/hair_segmentation_android_gpu.md b/mediapipe/docs/hair_segmentation_android_gpu.md new file mode 100644 index 000000000..a34145f50 --- /dev/null +++ b/mediapipe/docs/hair_segmentation_android_gpu.md @@ -0,0 +1,194 @@ +# Hair Segmentation on Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. This +doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt) +that performs hair segmentation with TensorFlow Lite on GPU. + +![hair_segmentation_android_gpu_gif](images/mobile/hair_segmentation_android_gpu.gif){width="300"} + +## App + +The graph is used in the +[Hair Segmentation GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/hairsegmentationgpu.apk +``` + +## Graph + +![hair_segmentation_android_gpu_graph](images/mobile/hair_segmentation_android_gpu.png){width="600"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). + +```bash +# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToSegmentationCalculator downstream in the graph to finish +# generating the corresponding hair mask before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToSegmentationCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hair_mask" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on GPU to a 512x512 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the hair +# segmentation model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 512 + output_height: 512 + } + } +} + +# Waits for a mask from the previous round of hair segmentation to be fed back +# as an input, and caches it. Upon the arrival of an input image, it checks if +# there is a mask cached, and sends out the mask with the timestamp replaced by +# that of the input image. This is needed so that the "current image" and the +# "previous mask" share the same timestamp, and as a result can be synchronized +# and combined in the subsequent calculator. Note that upon the arrival of the +# very first input frame, an empty packet is sent out to jump start the feedback +# loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hair_mask" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:previous_hair_mask" +} + +# Embeds the hair mask generated from the previous round of hair segmentation +# as the alpha channel of the current input image. +node { + calculator: "SetAlphaCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + input_stream: "ALPHA_GPU:previous_hair_mask" + output_stream: "IMAGE_GPU:mask_embedded_input_video" +} + +# Converts the transformed input image on GPU into an image tensor stored in +# tflite::gpu::GlBuffer. The zero_center option is set to false to normalize the +# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. The flip_vertically +# option is set to true to account for the descrepancy between the +# representation of the input image (origin at the bottom-left corner, the +# OpenGL convention) and what the model used in this graph is expecting (origin +# at the top-left corner). With the max_num_channels option set to 4, all 4 RGBA +# channels are contained in the image tensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:mask_embedded_input_video" + output_stream: "TENSORS_GPU:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: false + flip_vertically: true + max_num_channels: 4 + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "op_resolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { + use_gpu: true + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# tensor representing the hair segmentation, which has the same width and height +# as the input image tensor. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS_GPU:segmentation_tensor" + input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "hair_segmentation.tflite" + use_gpu: true + } + } +} + +# Decodes the segmentation tensor generated by the TensorFlow Lite model into a +# mask of values in [0.f, 1.f], stored in the R channel of a GPU buffer. It also +# takes the mask generated previously as another input to improve the temporal +# consistency. +node { + calculator: "TfLiteTensorsToSegmentationCalculator" + input_stream: "TENSORS_GPU:segmentation_tensor" + input_stream: "PREV_MASK_GPU:previous_hair_mask" + output_stream: "MASK_GPU:hair_mask" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { + tensor_width: 512 + tensor_height: 512 + tensor_channels: 2 + combine_with_previous_ratio: 0.9 + output_layer_index: 1 + } + } +} + +# Colors the hair segmentation with the color specified in the option. +node { + calculator: "RecolorCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + input_stream: "MASK_GPU:hair_mask" + output_stream: "IMAGE_GPU:output_video" + node_options: { + [type.googleapis.com/mediapipe.RecolorCalculatorOptions] { + color { r: 0 g: 0 b: 255 } + mask_channel: RED + } + } +} +``` diff --git a/mediapipe/docs/hello_world_android.md b/mediapipe/docs/hello_world_android.md new file mode 100644 index 000000000..6b0f7c69a --- /dev/null +++ b/mediapipe/docs/hello_world_android.md @@ -0,0 +1,728 @@ +# Hello World! in MediaPipe on Android + +## Introduction + +This codelab uses MediaPipe on an Android device. + +### What you will learn + +How to develop an Android application that uses MediaPipe and run a MediaPipe +graph on Android. + +### What you will build + +A simple camera app for real-time Sobel edge detection applied to a live video +stream on an Android device. + +![edge_detection_android_gpu_gif](images/mobile/edge_detection_android_gpu.gif){width="300"} + +## Setup + +1. Install MediaPipe on your system, see [MediaPipe installation guide] for + details. +2. Install Android Development SDK and Android NDK. See how to do so in + [Setting up Android SDK and NDK]. +3. Enable [developer options] on your Android device. +4. Setup [Bazel] on your system to build and deploy the Android app. + +## Graph for edge detection + +We will be using the following graph, [`edge_detection_android_gpu.pbtxt`]: + +``` +input_stream: "input_video" +output_stream: "output_video" + +node: { + calculator: "LuminanceCalculator" + input_stream: "input_video" + output_stream: "luma_video" +} + +node: { + calculator: "SobelEdgesCalculator" + input_stream: "luma_video" + output_stream: "output_video" +} +``` + +A visualization of the graph is shown below: + +![edge_detection_android_gpu_graph](images/mobile/edge_detection_android_graph_gpu.png){width="200"} + +This graph has a single input stream named `input_video` for all incoming frames +that will be provided by your device's camera. + +The first node in the graph, `LuminanceCalculator`, takes a single packet (image +frame) and applies a change in luminance using an OpenGL shader. The resulting +image frame is sent to the `luma_video` output stream. + +The second node, `SobelEdgesCalculator` applies edge detection to incoming +packets in the `luma_video` stream and outputs results in `output_video` output +stream. + +Our Android application will display the output image frames of the +`sobel_video` stream. + +## Initial minimal application setup + +We first start with an simple Android application that displays "Hello World!" +on the screen. You may skip this step if you are familiar with building Android +applications using `bazel`. + +Create a new directory where you will create your Android application. For +example, the complete code of this tutorial can be found at +`mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu`. +We will refer to this path as `$APPLICATION_PATH` throughout the codelab. + +Note that in the path to the application: + +* The application is named `edgedetectiongpu`. +* The `$PACKAGE_PATH` of the application is + `com.google.mediapipe.apps.edgdetectiongpu`. This is used in code snippets in + this tutorial, so please remember to use your own `$PACKAGE_PATH` when you + copy/use the code snippets. + +Add a file `activity_main.xml` to `$APPLICATION_PATH/res/layout`. This displays +a [`TextView`] on the full screen of the application with the string `Hello +World!`: + +``` + + + + + + +``` + +Add a simple `MainActivity.java` to `$APPLICATION_PATH` which loads the content +of the `activity_main.xml` layout as shown below: + +``` +package com.google.mediapipe.apps.edgedetectiongpu; + +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; + +/** Bare-bones main activity. */ +public class MainActivity extends AppCompatActivity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + } +} +``` + +Add a manifest file, `AndroidManifest.xml` to `$APPLICATION_PATH`, which +launches `MainActivity` on application start: + +``` + + + + + + + + + + + + + + + +``` + +To get `@string/app_name`, we need to add a file `strings.xml` to +`$APPLICATION_PATH/res/values/`: + +``` + + Edge Detection GPU + +``` + +Also, in our application we are using a `Theme.AppCompat` theme in the app, so +we need appropriate theme references. Add `colors.xml` to +`$APPLICATION_PATH/res/values/`: + +``` + + + #008577 + #00574B + #D81B60 + +``` + +Add `styles.xml` to `$APPLICATION_PATH/res/values/`: + +``` + + + + + + +``` + +To build the application, add a `BUILD` file to `$APPLICATION_PATH`: + +``` +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + ], +) + +android_binary( + name = "edgedetectiongpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.edgedetectiongpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) + +``` + +The `android_library` rule adds dependencies for `MainActivity`, resource files +and `AndroidManifest.xml`. + +The `android_binary` rule, uses the `mediapipe_lib` Android library generated to +build a binary APK for installation on your Android device. + +To build the app, use the following command: + +``` +bazel build -c opt --config=android_arm64 $APPLICATION_PATH +``` + +Install the generated APK file using `adb install`. For example: + +``` +adb install bazel-bin/$APPLICATION_PATH/edgedetectiongpu.apk +``` + +Open the application on your device. It should display a screen with the text +`Hello World!`. + +![bazel_hello_world_android](images/mobile/bazel_hello_world_android.png){width="300"} + +## Using the camera via `CameraX` + +### Camera Permissions + +To use the camera in our application, we need to request the user to provide +access to the camera. To request camera permissions, add the following to +`AndroidManifest.xml`: + +``` + + + +``` + +Change the minimum SDK version to `21` and target SDK version to `27` in the +same file: + +``` + +``` + +This ensures that the user is prompted to request camera permission and enables +us to use the [CameraX] library for camera access. + +To request camera permissions, we can use a utility provided by MediaPipe +components, namely [`PermissionHelper`]. To use it, add a dependency +`"//mediapipe/java/com/google/mediapipe/components:android_components"` in the +`mediapipe_lib` rule in `BUILD`. + +To use the `PermissionHelper` in `MainActivity`, add the following line to the +`onCreate` function: + +``` +PermissionHelper.checkAndRequestCameraPermissions(this); +``` + +This prompts the user with a dialog on the screen to request for permissions to +use the camera in this application. + +Add the following code to handle the user response: + +``` +@Override +public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); +} + +@Override +protected void onResume() { + super.onResume(); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } +} + +public void startCamera() {} +``` + +We will leave the `startCamera()` method empty for now. When the user responds +to the prompt, the `MainActivity` will resume and `onResume()` will be called. +The code will confirm that permissions for using the camera have been granted, +and then will start the camera. + +Rebuild and install the application. You should now see a prompt requesting +access to the camera for the application. + +Note: If the there is no dialog prompt, uninstall and reinstall the application. +This may also happen if you haven't changed the `minSdkVersion` and +`targetSdkVersion` in the `AndroidManifest.xml` file. + +### Camera Access + +With camera permissions available, we can start and fetch frames from the +camera. + +To view the frames from the camera we will use a [`SurfaceView`]. Each frame +from the camera will be stored in a [`SurfaceTexture`] object. To use these, we +first need to change the layout of our application. + +Remove the entire [`TextView`] code block from +`$APPLICATION_PATH/res/layout/activity_main.xml` and add the following code +instead: + +``` + + + +``` + +This code block has a new [`FrameLayout`] named `preview_display_layout` and a +[`TextView`] nested inside it, named `no_camera_access_preview`. When camera +access permissions are not granted, our application will display the +[`TextView`] with a string message, stored in the variable `no_camera_access`. +Add the following line in the `$APPLICATION_PATH/res/values/strings.xml` file: + +``` +Please grant camera permissions. +``` + +When the user doesn't grant camera permission, the screen will now look like +this: + +![missing_camera_permission_android](images/mobile/missing_camera_permission_android.png){width="300"} + +Now, we will add the [`SurfaceTexture`] and [`SurfaceView`] objects to +`MainActivity`: + +``` +private SurfaceTexture previewFrameTexture; +private SurfaceView previewDisplayView; +``` + +In the `onCreate(Bundle)` function, add the following two lines _before_ +requesting camera permissions: + +``` +previewDisplayView = new SurfaceView(this); +setupPreviewDisplayView(); +``` + +And now add the code defining `setupPreviewDisplayView()`: + +``` +private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); +} +``` + +We define a new [`SurfaceView`] object and add it to the +`preview_display_layout` [`FrameLayout`] object so that we can use it to display +the camera frames using a [`SurfaceTexture`] object named `previewFrameTexture`. + +To use `previewFrameTexture` for getting camera frames, we will use [CameraX]. +MediaPipe provides a utility named [`CameraXPreviewHelper`] to use [CameraX]. +This class updates a listener when camera is started via +`onCameraStarted(@Nullable SurfaceTexture)`. + +To use this utility, modify the `BUILD` file to add a dependency on +`"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper"`. + +Now import [`CameraXPreviewHelper`] and add the following line to +`MainActivity`: + +``` +private CameraXPreviewHelper cameraHelper; +``` + +Now, we can add our implementation to `startCamera()`: + +``` +public void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. + previewDisplayView.setVisibility(View.VISIBLE); + }); +} +``` + +This creates a new [`CameraXPreviewHelper`] object and adds an anonymous +listener on the object. When `cameraHelper` signals that the camera has started +and a `surfaceTexture` to grab frames is available, we save that +`surfaceTexture` as `previewFrameTexture`, and make the `previewDisplayView` +visible so that we can start seeing frames from the `previewFrameTexture`. + +However, before starting the camera, we need to decide which camera we want to +use. [`CameraXPreviewHelper`] inherits from [`CameraHelper`] which provides two +options, `FRONT` and `BACK`. We will use `BACK` camera for this application to +perform edge detection on a live scene that we view from the camera. + +Add the following line to define `CAMERA_FACING` for our application, + +``` +private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; +``` + +`CAMERA_FACING` is a static variable as we will use the same camera throughout +the application from start to finish. + +Now add the following line at the end of the `startCamera()` function: + +``` +cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); +``` + +At this point, the application should build successfully. However, when you run +the application on your device, you will see a black screen (even though camera +permissions have been granted). This is because even though we save the +`surfaceTexture` variable provided by the [`CameraXPreviewHelper`], the +`previewSurfaceView` doesn't use its output and display it on screen yet. + +Since we want to use the frames in a MediaPipe graph, we will not add code to +view the camera output directly in this tutorial. Instead, we skip ahead to how +we can send camera frames for processing to a MediaPipe graph and display the +output of the graph on the screen. + +## `ExternalTextureConverter` setup + +A [`SurfaceTexture`] captures image frames from a stream as an OpenGL ES +texture. To use a MediaPipe graph, frames captured from the camera should be +stored in a regular Open GL texture object. MediaPipe provides a class, +[`ExternalTextureConverter`] to convert the image stored in a [`SurfaceTexture`] +object to a regular OpenGL texture object. + +To use [`ExternalTextureConverter`], we also need an `EGLContext`, which is +created and managed by an [`EglManager`] object. Add a dependency to the `BUILD` +file to use [`EglManager`], `"//mediapipe/java/com/google/mediapipe/glutil"`. + +In `MainActivity`, add the following declarations: + +``` +private EglManager eglManager; +private ExternalTextureConverter converter; +``` + +In the `onCreate(Bundle)` function, add a statement to initialize the +`eglManager` object before requesting camera permissions: + +``` +eglManager = new EglManager(null); +``` + +Recall that we defined the `onResume()` function in `MainActivity` to confirm +camera permissions have been granted and call `startCamera()`. Before this +check, add the following line in `onResume()` to initialize the `converter` +object: + +``` +converter = new ExternalTextureConverter(eglManager.getContext()); +``` + +This `converter` now uses the `GLContext` managed by `eglManager`. + +We also need to override the `onPause()` function in the `MainActivity` so that +if the application goes into a paused state, we close the `converter` properly: + +``` +@Override +protected void onPause() { + super.onPause(); + converter.close(); +} +``` + +To pipe the output of `previewFrameTexture` to the `converter`, add the +following block of code to `setupPreviewDisplayView()`: + +``` +previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) {} + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) {} + }); +``` + +In this code block, we add a custom [`SurfaceHolder.Callback`] to +`previewDisplayView` and implement the `surfaceChanged(SurfaceHolder holder, int +format, int width, int height)` function to compute an appropriate display size +of the camera frames on the device screen and to tie the `previewFrameTexture` +object and send frames of the computed `displaySize` to the `converter`. + +We are now ready to use camera frames in a MediaPipe graph. + +## Using a MediaPipe graph in Android + +### Add relevant dependencies + +To use a MediaPipe graph, we need to add dependencies to the MediaPipe framework +on Android. We will first add a build rule to build a `cc_binary` using JNI code +of the MediaPipe framework and then build a `cc_library` rule to use this binary +in our application. Add the following code block to your `BUILD` file: + +``` +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) +``` + +Add the dependency `":mediapipe_jni_lib"` to the `mediapipe_lib` build rule in +the `BUILD` file. + +Next, we need to add dependencies specific to the MediaPipe graph we want to use +in the application. + +First, add dependencies to all calculator code in the `libmediapipe_jni.so` +build rule: + +``` +"//mediapipe/graphs/edge_detection:android_calculators", +``` + +MediaPipe graphs are `.pbtxt` files, but to use them in the application, we need +to use the `mediapipe_binary_graph` build rule to generate a `.binarypb` file. +We can then use an application specific alias for the graph via the `genrule` +build rule. Add the following `genrule` to use an alias for the edge detection +graph: + +``` +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/edge_detection:android_gpu_binary_graph"], + outs = ["edgedetectiongpu.binarypb"], + cmd = "cp $< $@", +) +``` + +Then in the `mediapipe_lib` build rule, add assets: + +``` +assets = [ + ":binary_graph", +], +assets_dir = "", +``` + +In the `assets` build rule, you can also add other assets such as TensorFlowLite +models used in your graph. + +Now, the `MainActivity` needs to load the MediaPipe framework. Also, the +framework uses OpenCV, so `MainActvity` should also load `OpenCV`. Use the +following code in `MainActivity` (inside the class, but not inside any function) +to load both dependencies: + +``` +static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); +} +``` + +### Use the graph in `MainActivity` + +First, we need to load the asset which contains the `.binarypb` compiled from +the `.pbtxt` file of the graph. To do this, we can use a MediaPipe utility, +[`AndroidAssetUtil`]. + +Initialize the asset manager in `onCreate(Bundle)` before initializing +`eglManager`: + +``` +// Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., +// binary graphs. +AndroidAssetUtil.initializeNativeAssetManager(this); +``` + +Declare a static variable with the graph name, the name of the input stream and +the name of the output stream: + +``` +private static final String BINARY_GRAPH_NAME = "edgedetectiongpu.binarypb"; +private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; +private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; +``` + +Now, we need to setup a [`FrameProcessor`] object that sends camera frames +prepared by the `converter` to the MediaPipe graph and runs the graph, prepares +the output and then updates the `previewDisplayView` to display the output. Add +the following code to declare the `FrameProcessor`: + +``` +private FrameProcessor processor; +``` + +and initialize it in `onCreate(Bundle)` after initializing `eglManager`: + +``` +processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); +``` + +The `processor` needs to consume the converted frames from the `converter` for +processing. Add the following line to `onResume()` after initializing the +`converter`: + +``` +converter.setConsumer(processor); +``` + +The `processor` should send its output to `previewDisplayView` To do this, add +the following function definitions to our custom [`SurfaceHolder.Callback`]: + +``` +@Override +public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); +} + +@Override +public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); +} +``` + +When the `SurfaceHolder` is created, we had the `Surface` to the +`VideoSurfaceOutput` of the `processor`. When it is destroyed, we remove it from +the `VideoSurfaceOutput` of the `processor`. + +And that's it! You should now be able to successfully build and run the +application on the device and see Sobel edge detection running on a live camera +feed! Congrats! + +![edge_detection_android_gpu_gif](images/mobile/edge_detection_android_gpu.gif){width="300"} + +If you ran into any issues, please see the full code of the tutorial +[here](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu). + +[`AndroidAssetUtil`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/framework/AndroidAssetUtil.java +[Bazel]:https://bazel.build/ +[`CameraHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraHelper.java +[CameraX]:https://developer.android.com/training/camerax +[`CameraXPreviewHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java +[developer options]:https://developer.android.com/studio/debug/dev-options +[`edge_detection_android_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt +[`EdgeDetectionGPU` example]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/ +[`EglManager`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/glutil/EglManager.java +[`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +[`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout +[`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java +[MediaPipe installation guide]:./install.md +[`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java +[Setting up Android SDK and NDK]:./install.md#setting-up-android-sdk-and-ndk +[`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html +[`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView +[`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView +[`SurfaceTexture`]:https://developer.android.com/reference/android/graphics/SurfaceTexture +[`TextView`]:https://developer.android.com/reference/android/widget/TextView diff --git a/mediapipe/docs/hello_world_desktop.md b/mediapipe/docs/hello_world_desktop.md new file mode 100644 index 000000000..366fa3b17 --- /dev/null +++ b/mediapipe/docs/hello_world_desktop.md @@ -0,0 +1,115 @@ +## Hello World for C++ + +1. Ensure you have a working version of MediaPipe. See + [installation instructions](./install.md). + +2. To run the [`hello world`] example: + + ```bash + $ git clone https://github.com/google/mediapipe/mediapipe.git + $ cd mediapipe + + # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is not supported currently. + $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + mediapipe/examples/desktop/hello_world:hello_world + + # It should print 10 rows of Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + ``` + +3. The [`hello world`] example uses a simple MediaPipe graph in the + `PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto. + + ```C++ + ::mediapipe::Status PrintHelloWorld() { + // Configures a simple graph, which concatenates 2 PassThroughCalculators. + CalculatorGraphConfig config = ParseTextProtoOrDie(R"( + input_stream: "in" + output_stream: "out" + node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out1" + } + node { + calculator: "PassThroughCalculator" + input_stream: "out1" + output_stream: "out" + } + )"); + ``` + + You can visualize this graph using + [MediaPipe Visualizer](https://mediapipe-viz.appspot.com) by pasting the + CalculatorGraphConfig content below into the visualizer. See + [here](./visualizer.md) for help on the visualizer. + + ```bash + input_stream: "in" + output_stream: "out" + node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out1" + } + node { + calculator: "PassThroughCalculator" + input_stream: "out1" + output_stream: "out" + } + ``` + + This graph consists of 1 graph input stream (`in`) and 1 graph output stream + (`out`), and 2 [`PassThroughCalculator`]s connected serially. + + ![hello_world.cc graph](./images/hello_world_graph.png){width="200"} + +4. Before running the graph, an `OutputStreamPoller` object is connected to the + output stream in order to later retrieve the graph output, and a graph run + is started with [`StartRun`]. + + ```c++ + CalculatorGraph graph; + RETURN_IF_ERROR(graph.Initialize(config)); + ASSIGN_OR_RETURN(OutputStreamPoller poller, + graph.AddOutputStreamPoller("out")); + RETURN_IF_ERROR(graph.StartRun({})); + ``` + +5. The example then creates 10 packets (each packet contains a string "Hello + World!" with Timestamp values ranging from 0, 1, ... 9) using the + [`MakePacket`] function, adds each packet into the graph through the `in` + input stream, and finally closes the input stream to finish the graph run. + + ```c++ + for (int i = 0; i < 10; ++i) { + RETURN_IF_ERROR(graph.AddPacketToInputStream("in", MakePacket("Hello World!").At(Timestamp(i)))); + } + RETURN_IF_ERROR(graph.CloseInputStream("in")); + ``` + +6. Through the `OutputStreamPoller` object the example then retrieves all 10 + packets from the output stream, gets the string content out of each packet + and prints it to the output log. + + ```c++ + mediapipe::Packet packet; + while (poller.Next(&packet)) { + LOG(INFO) << packet.Get(); + } + ``` + +[`hello world`]: https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/hello_world/hello_world.cc +[`CalculatorGraphConfig`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`PassThroughCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/pass_through_calculator.cc +[`MakePacket`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/packet.h +[`StartRun`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_graph.h diff --git a/mediapipe/docs/help.md b/mediapipe/docs/help.md new file mode 100644 index 000000000..3e104b25d --- /dev/null +++ b/mediapipe/docs/help.md @@ -0,0 +1,19 @@ +## Getting help + +- [Technical questions](#technical-questions) +- [Bugs and Feature requests](#bugs-and-feature-requests) + +Below are the various ways to get help + +### Technical questions + +For help with technical or algorithmic questions, visit +[Stack Overflow](https://stackoverflow.com/questions/tagged/mediapipe) to find +answers and support from the MediaPipe community. + +### Bugs and Feature requests + +To report bugs or make feature requests, +[file an issue on Github](https://github.com/google/mediapipe/mediapipe/issues). +Please choose the appropriate repository for the project from the +[MediaPipe repo](https://github.com/google/mediapipe/mediapipe) diff --git a/mediapipe/docs/how_to_questions.md b/mediapipe/docs/how_to_questions.md new file mode 100644 index 000000000..d05fc1f68 --- /dev/null +++ b/mediapipe/docs/how_to_questions.md @@ -0,0 +1,143 @@ +## Questions and Answers + +- [How to convert ImageFrames and GpuBuffers](#how-to-convert-imageframes-and-gpubuffers) +- [How to visualize perceived results](#how-to-visualize-perception-results) +- [How to run calculators in parallel](#how-to-run-calculators-in-parallel) +- [Output timestamps when using ImmediateInputStreamHandler](#output-timestamps-when-using-immediateinputstreamhandler) +- [How to change settings at runtime](#how-to-change-settings-at-runtime) +- [How to process real-time input streams](#how-to-process-real-time-input-streams) +- [Can I run MediaPipe on MS Windows?](#can-i-run-mediapipe-on-ms-windows) + +### How to convert ImageFrames and GpuBuffers + +The Calculators [`ImageFrameToGpuBufferCalculator`] and +[`GpuBufferToImageFrameCalculator`] convert back and forth between packets of +type [`ImageFrame`] and [`GpuBuffer`]. [`ImageFrame`] refers to image data in +CPU memory in any of a number of bitmap image formats. [`GpuBuffer`] refers to +image data in GPU memory. You can find more detail in the Framework Concepts +section +[GpuBuffer to ImageFrame converters](./gpu.md). +You can see an example in: + + * [`object_detection_android_cpu.pbtxt`] + +### How to visualize perception results + +The [`AnnotationOverlayCalculator`] allows perception results, such as boudning +boxes, arrows, and ovals, to be superimposed on the video frames aligned with +the recognized objects. The results can be displayed in a diagnostic window when +running on a workstation, or in a texture frame when running on device. You can +see an example use of [`AnnotationOverlayCalculator`] in: + + * [`face_detection_android_gpu.pbtxt`]. + +### How to run calculators in parallel + +Within a calculator graph, MediaPipe routinely runs separate calculator nodes +in parallel. MediaPipe maintains a pool of threads, and runs each calculator +as soon as a thread is available and all of it's inputs are ready. Each +calculator instance is only run for one set of inputs at a time, so most +calculators need only to be *thread-compatible* and not *thread-safe*. + +In order to enable one calculator to process multiple inputs in parallel, there +are two possible approaches: + +1. Define multiple calulator nodes and dispatch input packets to all nodes. +2. Make the calculator thread-safe and configure its [`max_in_flight`] setting. + +The first approach can be followed using the calculators designed to distribute +packets across other calculators, such as [`RoundRobinDemuxCalculator`]. A +single [`RoundRobinDemuxCalculator`] can distribute successive packets across +several identically configured [`ScaleImageCalculator`] nodes. + +The second approach allows up to [`max_in_flight`] invocations of the +[`CalculatorBase::Process`] method on the same calculator node. The output +packets from [`CalculatorBase::Process`] are automatically ordered by timestamp +before they are passed along to downstream calculators. + +With either aproach, you must be aware that the calculator running in parallel +cannot maintain internal state in the same way as a normal sequential +calculator. + +### Output timestamps when using ImmediateInputStreamHandler + +The [`ImmediateInputStreamHandler`] delivers each packet as soon as it arrives +at an input stream. As a result, it can deliver a packet +with a higher timestamp from one input stream before delivering a packet with a +lower timestamp from a different input stream. If these input timestamps are +both used for packets sent to one output stream, that output stream will +complain that the timestamps are not monotonically increasing. In order to +remedy this, the calculator must take care to output a packet only after +processing is complete for its timestamp. This could be accomplished by waiting +until input packets have been received from all inputstreams for that timestamp, +or by ignoring a packet that arrives with a timestamp that has already been +processed. + +### How to change settings at runtime + +There are two main approaches to changing the settings of a calculator graph +while the application is running: + +1. Restart the calculator graph with modified [`CalculatorGraphConfig`]. +2. Send new calculator options through packets on graph input-streams. + +The first approach has the advantage of leveraging [`CalculatorGraphConfig`] +processing tools such as "subgraphs". The second approach has the advantage of +allowing active calculators and packets to remain in-flight while settings +change. Mediapipe contributors are currently investigating alternative approaches +to achieve both of these adantages. + +### How to process realtime input streams + +The mediapipe framework can be used to process data streams either online or +offline. For offline processing, packets are pushed into the graph as soon as +calculators are ready to process those packets. For online processing, one +packet for each frame is pushed into the graph as that frame is recorded. + +The MediaPipe framework requires only that successive packets be assigned +monotonically increasing timestamps. By convention, realtime calculators and +graphs use the recording time or the presentation time as the timestamp for each +packet, with each timestamp representing microseconds since +`Jan/1/1970:00:00:00`. This allows packets from various sources to be processed +in a gloablly consistent order. + +Normally for offline processing, every input packet is processed and processing +continues as long as necessary. For online processing, it is often necessary to +drop input packets in order to keep pace with the arrival of input data frames. +When inputs arrive too frequently, the recommended technique for dropping +packets is to use the MediaPipe calculators designed specifically for this +purpose such as [`RealTimeFlowLimiterCalculator`] and [`PacketClonerCalculator`]. + +For online processing, it is also necessary to promptly determine when processing +can proceed. MediaPipe supports this by propagating timestamp bounds between +calculators. Timestamp bounds indicate timestamp intervals that will contain no +input packets, and they allow calculators to begin processing for those +timestamps immediately. Calculators designed for realtime processing should +carefully calculate timestamp bounds in order to begin processing as promptly as +possible. For example, the [`MakePairCalculator`] uses the `SetOffset` API to +propagate timestamp bounds from input streams to output streams. + +### Can I run MediaPipe on MS Windows? + +Currently MediaPipe portability supports Debian Linux, Ubuntu Linux, +MacOS, Android, and iOS. The core of MediaPipe framework is a C++ library +conforming to the C++11 standard, so it is relatively easy to port to +additional platforms. + +[`object_detection_android_cpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt + +[`ImageFrame`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/formats/image_frame.h +[`GpuBuffer`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/gpu_buffer.h +[`GpuBufferToImageFrameCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc +[`ImageFrameToGpuBufferCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +[`AnnotationOverlayCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/util/annotation_overlay_calculator.cc +[`face_detection_android_gpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt +[`CalculatorBase::Process`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h +[`max_in_flight`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`RoundRobinDemuxCalculator`]: https://github.com/google/mediapipe/tree/master//mediapipe/calculators/core/round_robin_demux_calculator.cc +[`ScaleImageCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/image/scale_image_calculator.cc +[`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc +[`CalculatorGraphConfig`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +[`PacketClonerCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/packet_cloner_calculator.cc +[`MakePairCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/make_pair_calculator.cc diff --git a/mediapipe/docs/images/console_error.png b/mediapipe/docs/images/console_error.png new file mode 100644 index 000000000..749fdf7a9 Binary files /dev/null and b/mediapipe/docs/images/console_error.png differ diff --git a/mediapipe/docs/images/cyclic_integer_sum_graph.svg b/mediapipe/docs/images/cyclic_integer_sum_graph.svg new file mode 100644 index 000000000..ac7d42a77 --- /dev/null +++ b/mediapipe/docs/images/cyclic_integer_sum_graph.svg @@ -0,0 +1,4 @@ + + + + diff --git a/mediapipe/docs/images/editor_view.png b/mediapipe/docs/images/editor_view.png new file mode 100644 index 000000000..bfeac701f Binary files /dev/null and b/mediapipe/docs/images/editor_view.png differ diff --git a/mediapipe/docs/images/gpu_example_graph.png b/mediapipe/docs/images/gpu_example_graph.png new file mode 100644 index 000000000..e6f995e5a Binary files /dev/null and b/mediapipe/docs/images/gpu_example_graph.png differ diff --git a/mediapipe/docs/images/graph_visual.png b/mediapipe/docs/images/graph_visual.png new file mode 100644 index 000000000..691d8df20 Binary files /dev/null and b/mediapipe/docs/images/graph_visual.png differ diff --git a/mediapipe/docs/images/hello_world_graph.png b/mediapipe/docs/images/hello_world_graph.png new file mode 100644 index 000000000..c36aafc08 Binary files /dev/null and b/mediapipe/docs/images/hello_world_graph.png differ diff --git a/mediapipe/docs/images/mediapipe_small.png b/mediapipe/docs/images/mediapipe_small.png new file mode 100644 index 000000000..85c284129 Binary files /dev/null and b/mediapipe/docs/images/mediapipe_small.png differ diff --git a/mediapipe/docs/images/mobile/bazel_hello_world_android.png b/mediapipe/docs/images/mobile/bazel_hello_world_android.png new file mode 100644 index 000000000..dd50be3e7 Binary files /dev/null and b/mediapipe/docs/images/mobile/bazel_hello_world_android.png differ diff --git a/mediapipe/docs/images/mobile/edge_detection_android_gpu.gif b/mediapipe/docs/images/mobile/edge_detection_android_gpu.gif new file mode 100644 index 000000000..4192eb224 Binary files /dev/null and b/mediapipe/docs/images/mobile/edge_detection_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/edge_detection_android_graph_gpu.png b/mediapipe/docs/images/mobile/edge_detection_android_graph_gpu.png new file mode 100644 index 000000000..0555c2d13 Binary files /dev/null and b/mediapipe/docs/images/mobile/edge_detection_android_graph_gpu.png differ diff --git a/mediapipe/docs/images/mobile/face_detection_android_gpu.gif b/mediapipe/docs/images/mobile/face_detection_android_gpu.gif new file mode 100644 index 000000000..28ae7d51c Binary files /dev/null and b/mediapipe/docs/images/mobile/face_detection_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/face_detection_android_gpu.png b/mediapipe/docs/images/mobile/face_detection_android_gpu.png new file mode 100644 index 000000000..e73f1b52a Binary files /dev/null and b/mediapipe/docs/images/mobile/face_detection_android_gpu.png differ diff --git a/mediapipe/docs/images/mobile/face_detection_android_gpu_small.gif b/mediapipe/docs/images/mobile/face_detection_android_gpu_small.gif new file mode 100644 index 000000000..4af8ff180 Binary files /dev/null and b/mediapipe/docs/images/mobile/face_detection_android_gpu_small.gif differ diff --git a/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.gif b/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.gif new file mode 100644 index 000000000..fa727e429 Binary files /dev/null and b/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png b/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png new file mode 100644 index 000000000..461b0e3a9 Binary files /dev/null and b/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png differ diff --git a/mediapipe/docs/images/mobile/missing_camera_permission_android.png b/mediapipe/docs/images/mobile/missing_camera_permission_android.png new file mode 100644 index 000000000..9e35aebaa Binary files /dev/null and b/mediapipe/docs/images/mobile/missing_camera_permission_android.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_cpu.gif b/mediapipe/docs/images/mobile/object_detection_android_cpu.gif new file mode 100644 index 000000000..fc5eb7fd6 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_android_cpu.gif differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_cpu.png b/mediapipe/docs/images/mobile/object_detection_android_cpu.png new file mode 100644 index 000000000..2efcdd9b1 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_android_cpu.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_gpu.gif b/mediapipe/docs/images/mobile/object_detection_android_gpu.gif new file mode 100644 index 000000000..76be3a5c2 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_gpu.png b/mediapipe/docs/images/mobile/object_detection_android_gpu.png new file mode 100644 index 000000000..603d82dba Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_android_gpu.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png b/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png new file mode 100644 index 000000000..f987f1db3 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png differ diff --git a/mediapipe/docs/images/object_detection_desktop_tensorflow.png b/mediapipe/docs/images/object_detection_desktop_tensorflow.png new file mode 100644 index 000000000..e1a363f16 Binary files /dev/null and b/mediapipe/docs/images/object_detection_desktop_tensorflow.png differ diff --git a/mediapipe/docs/images/object_detection_desktop_tflite.png b/mediapipe/docs/images/object_detection_desktop_tflite.png new file mode 100644 index 000000000..2dfa5084a Binary files /dev/null and b/mediapipe/docs/images/object_detection_desktop_tflite.png differ diff --git a/mediapipe/docs/images/packet_cloner_calculator.png b/mediapipe/docs/images/packet_cloner_calculator.png new file mode 100644 index 000000000..f2c2102ff Binary files /dev/null and b/mediapipe/docs/images/packet_cloner_calculator.png differ diff --git a/mediapipe/docs/images/side_packet.png b/mediapipe/docs/images/side_packet.png new file mode 100644 index 000000000..5155835c0 Binary files /dev/null and b/mediapipe/docs/images/side_packet.png differ diff --git a/mediapipe/docs/images/side_packet_code.png b/mediapipe/docs/images/side_packet_code.png new file mode 100644 index 000000000..88a610305 Binary files /dev/null and b/mediapipe/docs/images/side_packet_code.png differ diff --git a/mediapipe/docs/images/special_nodes.png b/mediapipe/docs/images/special_nodes.png new file mode 100644 index 000000000..bcb7763c0 Binary files /dev/null and b/mediapipe/docs/images/special_nodes.png differ diff --git a/mediapipe/docs/images/special_nodes_code.png b/mediapipe/docs/images/special_nodes_code.png new file mode 100644 index 000000000..148c54a3b Binary files /dev/null and b/mediapipe/docs/images/special_nodes_code.png differ diff --git a/mediapipe/docs/images/startup_screen.png b/mediapipe/docs/images/startup_screen.png new file mode 100644 index 000000000..a841ee759 Binary files /dev/null and b/mediapipe/docs/images/startup_screen.png differ diff --git a/mediapipe/docs/images/stream_code.png b/mediapipe/docs/images/stream_code.png new file mode 100644 index 000000000..eabcbfe3f Binary files /dev/null and b/mediapipe/docs/images/stream_code.png differ diff --git a/mediapipe/docs/images/stream_ui.png b/mediapipe/docs/images/stream_ui.png new file mode 100644 index 000000000..553e75143 Binary files /dev/null and b/mediapipe/docs/images/stream_ui.png differ diff --git a/mediapipe/docs/images/upload_button.png b/mediapipe/docs/images/upload_button.png new file mode 100644 index 000000000..086f8379b Binary files /dev/null and b/mediapipe/docs/images/upload_button.png differ diff --git a/mediapipe/docs/index.rst b/mediapipe/docs/index.rst new file mode 100644 index 000000000..163f13276 --- /dev/null +++ b/mediapipe/docs/index.rst @@ -0,0 +1,60 @@ +MediaPipe +===================================== +`MediaPipe `_ is a graph-based framework for +building multimodal (video, audio, and sensor) applied machine learning pipelines. +MediaPipe is cross-platform running on mobile devices, workstations and servers, +and supports mobile GPU acceleration. With MediaPipe, an applied +machine learning pipeline can be built as a graph of modular components, +including, for instance, inference models and media processing functions. Sensory +data such as audio and video streams enter the graph, and perceived descriptions +such as object-localization and face-landmark streams exit the graph. An example +graph that performs real-time face detection on mobile GPU is shown below. + +.. image:: images/mobile/face_detection_android_gpu.png + :width: 400 + :alt: Example MediaPipe graph + +MediaPipe is designed for machine learning (ML) practitioners, including +researchers, students, and software developers, who implement production-ready +ML applications, publish code accompanying research work, and build technology +prototypes. The main use case for MediaPipe is rapid prototyping of applied +machine learning pipelines with inference models and other reusable components. +MediaPipe also facilitates the deployment of machine learning technology into +demos and applications on a wide variety of different hardware platforms +(e.g., Android, iOS, workstations). + +APIs for MediaPipe + * Calculator API in C++ + * Graph Construction API in ProtoBuf + * (Coming Soon) Graph Construction API in C++ + * Graph Execution API in C++ + * Graph Execution API in Java (Android) + * (Coming Soon) Graph Execution API in Objective-C (iOS) + +User Documentation +================== + +.. toctree:: + :maxdepth: 3 + + install + concepts + calculator + Examples + visualizer + measure_performance + how_to_questions + troubleshooting + help + framework_concepts + gpu + scheduling_sync + license + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + diff --git a/mediapipe/docs/install.md b/mediapipe/docs/install.md new file mode 100644 index 000000000..63177b20a --- /dev/null +++ b/mediapipe/docs/install.md @@ -0,0 +1,384 @@ +## Installing MediaPipe + +Choose your operating system: + +- [Dependences](#dependences) +- [Installing on Debian and Ubuntu](#installing-on-debian-and-ubuntu) +- [Installing on CentOS](#installing-on-centos) +- [Installing on macOS](#installing-on-macos) +- [Installing using Docker](#installing-using-docker) +- [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk) + +### Dependences + +Required libraries + +* Prefer OpenCV 3.x and above but can work with OpenCV 2.x (deprecation in the + future) + +* Bazel 0.23 and above + +* gcc and g++ version other than 6.3 and 7.3 (if you need TensorFlow + calculators/demos) + +* Android SDK release 28.0.3 and above + +* Android NDK r18b and above + +### Installing on Debian and Ubuntu + +1. Checkout mediapipe repository + + ```bash + $ git clone https://github.com/google/mediapipe/mediapipe.git + + # Change directory into mediapipe root directory + $ cd mediapipe + ``` + +2. Install Bazel + + Option 1. Use package manager tool to install the latest version of Bazel. + + ```bash + $ sudo apt-get install bazel + + # Run 'bazel version' to check version of bazel installed + ``` + + Option 2. Follow Bazel's + [documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) + to install any version of Bazel manually. + +3. Install OpenCV + + Option 1. Use package manager tool to install the pre-compiled OpenCV + libraries. + + Note that Debian 9 and Ubuntu 16.04 provide OpenCV 2.4.9. You may want to + take option 2 or 3 to install OpenCV 3 or above. + + ```bash + $ sudo apt-get install libopencv-core-dev libopencv-highgui-dev \ + libopencv-imgproc-dev libopencv-video-dev + ``` + + Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source + and modify MediaPipe's OpenCV config. + + Option 3. Follow OpenCV's + [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) + to manually build OpenCV from source code. + + You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to point + MediaPipe to your own OpenCV libraries. For example, if OpenCV 4 is + installed in "/usr/local/", you need to update the "linux_opencv" + new_local_repository rule in [`WORKSAPCE`] and "opencv" cc_library rule in + [`opencv_linux.BUILD`] to be: + + ```bash + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", + ) + + cc_library( + name = "opencv", + srcs = glob( + [ + "lib/libopencv_core.so*", + "lib/libopencv_highgui.so*", + "lib/libopencv_imgcodecs.so*", + "lib/libopencv_imgproc.so*", + "lib/libopencv_video.so*", + "lib/libopencv_videoio.so*", + + ], + ), + hdrs = glob(["include/opencv4/**/*.h*"]), + includes = ["include/opencv4/"], + linkstatic = 1, + visibility = ["//visibility:public"], + ) + + ``` + +4. Run the hello world desktop example + + ```bash + # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported + $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + mediapipe/examples/desktop/hello_world:hello_world + + # Should print: + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + ``` + +### Installing on CentOS + +1. Checkout mediapipe repository + + ```bash + $ git clone https://github.com/google/mediapipe/mediapipe.git + + # Change directory into mediapipe root directory + $ cd mediapipe + ``` + +2. Install Bazel + + Follow Bazel's + [documentation](https://docs.bazel.build/versions/master/install-redhat.html) + to install Bazel manually. + +3. Install OpenCV + + Option 1. Use package manager tool to install the pre-compiled version. + + Note that yum installs OpenCV 2.4.5, which may have an opencv/gstreamer + [issue](https://github.com/opencv/opencv/issues/4592). + + ```bash + $ sudo yum install opencv-devel + ``` + + Option 2. Build OpenCV from source code. + + You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to point + MediaPipe to your own OpenCV libraries. For example, if OpenCV 4 is + installed in "/usr/local/", you need to update the "linux_opencv" + new_local_repository rule in [`WORKSAPCE`] and "opencv" cc_library rule in + [`opencv_linux.BUILD`] to be: + + ```bash + new_local_repository( + name = "linux_opencv", + build_file = "@//third_party:opencv_linux.BUILD", + path = "/usr/local", + ) + + cc_library( + name = "opencv", + srcs = glob( + [ + "lib/libopencv_core.so*", + "lib/libopencv_highgui.so*", + "lib/libopencv_imgcodecs.so*", + "lib/libopencv_imgproc.so*", + "lib/libopencv_video.so*", + "lib/libopencv_videoio.so*", + + ], + ), + hdrs = glob(["include/opencv4/**/*.h*"]), + includes = ["include/opencv4/"], + linkstatic = 1, + visibility = ["//visibility:public"], + ) + + ``` + +4. Run the hello world desktop example + + ```bash + # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported + $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + mediapipe/examples/desktop/hello_world:hello_world + + # Should print: + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + ``` + +### Installing on macOS + +1. Checkout mediapipe repository + + ```bash + $ git clone https://github.com/google/mediapipe/mediapipe.git + + $ cd mediapipe + ``` + +2. Install Bazel + + Option 1. Use package manager tool to install the latest version of Bazel. + + ```bash + $ brew install bazel + + # Run 'bazel version' to check version of bazel installed + ``` + + Option 2. Follow Bazel's + [documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) + to install any version of Bazel manually. + +3. Install OpenCV + + Use package manager tool to install the pre-compiled OpenCV libraries. + + ```bash + $ brew install opencv + ``` + +4. Run the hello world desktop example + + ```bash + # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported + $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + mediapipe/examples/desktop/hello_world:hello_world + + # Should print: + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + ``` + +### Installing using Docker + +This will use a Docker image that will isolate mediapipe's installation from the rest of the system. + +1. [Install Docker](https://docs.docker.com/install/#supported-platforms) on + your host sytem + +2. Build a docker image with tag "mediapipe" + + ```bash + $ git clone https://github.com/google/mediapipe/mediapipe.git + $ cd mediapipe + $ docker build --tag=mediapipe . + + # Should print: + # Sending build context to Docker daemon 147.8MB + # Step 1/9 : FROM ubuntu:latest + # latest: Pulling from library/ubuntu + # 6abc03819f3e: Pull complete + # 05731e63f211: Pull complete + # ........ + # See http://bazel.build/docs/getting-started.html to start a new project! + # Removing intermediate container 82901b5e79fa + # ---> f5d5f402071b + # Step 9/9 : COPY . /mediapipe/ + # ---> a95c212089c5 + # Successfully built a95c212089c5 + # Successfully tagged mediapipe:latest + ``` + +3. Run the hello world desktop example in docker + + ```bash + $ docker run -it --name mediapipe mediapipe:latest + + root@bca08b91ff63:/mediapipe# bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/hello_world:hello_world + + # Should print: + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + # Hello World! + ``` + + + + +### Setting up Android Studio with MediaPipe + +The steps below use Android Studio to build and install a MediaPipe demo app. + +1. Install and launch android studio. + +2. Select `Configure` | `SDK Manager` | `SDK Platforms` + + * verify that an Android SDK is installed + * note the Android SDK Location such as `/usr/local/home/Android/Sdk` + +3. Select `Configure` | `SDK Manager` | `SDK Tools` + + * verify that an Android NDK is installed + * note the Android NDK Location such as `/usr/local/home/Android/Sdk/ndk-bundle` + +4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point to + the installed SDK and NDK. + + ```bash + export ANDROID_HOME=/usr/local/home/Android/Sdk + export ANDROID_NDK_HOME=/usr/local/home/Android/Sdk/ndk-bundle + ``` + +5. Select `Configure` | `Plugins` install `Bazel`. + +6. Select `Import Bazel Project` + + * select `Workspace`: `/path/to/mediapipe` + * select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` + * select `Finish` + +7. Connect an android device to the workstation. + +8. Select `Run...` | `Edit Configurations...` + + * enter Target Expression: + `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu` + * enter Bazel command: `mobile-install` + * enter Bazel flags: `-c opt --config=android_arm64` select `Run` + +### Setting up Android SDK and NDK + +If Android SDK and NDK are installed (likely by Android Studio), please set +$ANDROID_HOME and $ANDROID_NDK_HOME to point to the installed SDK and NDK. + +```bash +export ANDROID_HOME= +export ANDROID_NDK_HOME= +``` + +Otherwise, please run [`setup_android_sdk_and_ndk.sh`] to download and setup +Android SDK and NDK for MediaPipe before building any Android demos. + +[`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE +[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD +[`setup_opencv.sh`]: https://github.com/google/mediapipe/tree/master/setup_opencv.sh +[`setup_android_sdk_and_ndk.sh`]: https://github.com/google/mediapipe/tree/master/setup_android_sdk_and_ndk.sh diff --git a/mediapipe/docs/license.md b/mediapipe/docs/license.md new file mode 100644 index 000000000..9a98f8910 --- /dev/null +++ b/mediapipe/docs/license.md @@ -0,0 +1,205 @@ +License +=============== +Copyright 2019 The MediaPipe Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017, The MediaPipe Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.\n diff --git a/mediapipe/docs/measure_performance.md b/mediapipe/docs/measure_performance.md new file mode 100644 index 000000000..e4b2bde58 --- /dev/null +++ b/mediapipe/docs/measure_performance.md @@ -0,0 +1,18 @@ +# Measuring Performance + +*Coming soon.* + +MediaPipe includes APIs for gathering aggregate performance data and +event timing data for CPU and GPU operations. These API's can be found at: + + + + * [`GraphProfiler`](https://github.com/google/mediapipe/tree/master/mediapipe/framework/profiler/graph_profiler.h): + Accumulates for each running calculator a histogram of latencies for + Process calls. + * [`GraphTracer`](https://github.com/google/mediapipe/tree/master/mediapipe/framework/profiler/graph_tracer.h): + Records for each running calculator and each processed packet a series + of timed events including the start and finish of each Process call. + +Future mediapipe releases will include tools for visualizing and analysing +the latency histograms and timed events captured by these API's. diff --git a/mediapipe/docs/media_sequence.md b/mediapipe/docs/media_sequence.md new file mode 100644 index 000000000..2b737d684 --- /dev/null +++ b/mediapipe/docs/media_sequence.md @@ -0,0 +1,198 @@ +## Preparing Data Sets with MediaSequence + +MediaPipe is useful and general framework for media processing that can +assist with research, development, and deployment of ML models. This example +focuses on development by demonstrating how to prepare video data for training +a TensorFlow model. + +The MediaSequence library provides an extensive set of tools for storing data in +TensorFlow.SequenceExamples. SequenceExamples provide matched semantics to most +video tasks and are efficient to use with TensorFlow. The sequence semantics +allow for a variable number of annotations per frame, which is necessary for +tasks like video object detection, but very difficult to encode in +TensorFlow.Examples. The goal of MediaSequence is to simplify working with +SequenceExamples and to automate common preparation tasks. Much more information +is available about the MediaSequence pipeline, including how to use it to +process new data sets, in the [documentation](https://github.com/google/mediapipe/tree/master/mediapipe/util/sequence/README.md). + +### Preparing an example data set + +1. Checkout mediapipe repository + + ```bash + git clone https://github.com/google/mediapipe/mediapipe + cd mediapipe + ``` + +1. Compile the MediaSequence demo C++ binary + + ```bash + bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define 'MEDIAPIPE_DISABLE_GPU=1' + ``` + + MediaSequence uses C++ binaries to improve multimedia processing speed and + encourage a strong separation between annotations and the image data or + other features. The binary code is very general in that it reads from files + into input side packets and writes output side packets to files when + completed, but it also links in all of the calculators for necessary for + the MediaPipe graphs preparing the Charades data set. + +1. Download and prepare the data set through Python + + To run this step, you must have Python 2.7 or 3.5+ installed with the + TensorFlow 1.19+ package installed. + + ```bash + python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ + --path_to_demo_data=/tmp/demo_data/ \ + --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/media_sequence/media_sequence_demo \ + --path_to_graph_directory=mediapipe/graphs/media_sequence/ + ``` + + The arguments define where data is stored. `--path_to_demo_data` defines + where the data will be downloaded to and where prepared data will be + generated. `--path_to_mediapipe_binary` is the path to the binary built in + the previous step. `--path_to_graph_directory` defines where to look for + MediaPipe graphs during processing. + + Running this module + 1. Downloads videos from the internet. + 1. For each annotation in a CSV, creates a structured metadata file. + 1. Runs MediaPipe to extract images as defined by the metadata. + 1. Stores the results in numbered set of TFRecords files. + + MediaSequence uses SequenceExamples as the format of both inputs and + outputs. Annotations are encoded as inputs in a SequenceExample of metadata + that defines the labels and the path to the cooresponding video file. This + metadata is passed as input to the C++ `media_sequence_demo` binary, and the + output is a SequenceExample filled with images and annotations ready for + model training. + +1. Reading the data in TensorFlow + + To read the data in tensorflow, first add the repo to your PYTHONPATH + + ```bash + PYTHONPATH="${PYTHONPATH};"+`pwd` + ``` + + and then you can import the data set in Python. + + ```python + import tensorflow as tf + from mediapipe.examples.desktop.media_sequence.demo_dataset import DemoDataset + demo_data_path = '/tmp/demo_data/' + with tf.Graph().as_default(): + d = DemoDataset(demo_data_path) + dataset = d.as_dataset("test") + # implement additional processing and batching here + output = dataset.make_one_shot_iterator().get_next() + + with tf.Session() as sess: + output_ = sess.run(output) + ``` + +### Preparing a practical data set +As an example of processing a practical data set, a similar set of commands will +prepare the [Charades data set](https://allenai.org/plato/charades/). The +Charades data set is a data set of human action recognition collected with and +maintained by the Allen Institute for Artificial Intelligence. To follow this +code lab, you must abide by the [license](https://allenai.org/plato/charades/license.txt) +for the Charades data set provided by the Allen Institute. + +The Charades data set is large (~150 GB), and will take considerable time to +download and process (4-8 hours). + +```bash +bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define 'MEDIAPIPE_DISABLE_GPU=1' + +python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ + --alsologtostderr \ + --path_to_charades_data=/tmp/demo_data/ \ + --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/media_sequence/media_sequence_demo \ + --path_to_graph_directory=mediapipe/graphs/media_sequence/ +``` + +### Preparing your own data set +The process for preparing your own data set is described in the [MediaSequence +documentation](https://github.com/google/mediapipe/tree/master/mediapipe/util/sequence/README.md). +The Python code for Charades can easily be modified to process most annotations, +but the MediaPipe processing warrants further discussion. MediaSequence uses +MediaPipe graphs to extract features related to the metadata or previously +extracted data. Each graph can focus on extracting a single type of feature, and +graphs can be chained together to extract derived features in a composable way. +For example, one graph may extract images from a video at 10 fps and another +graph extract images at 24 fps. A subsequent graph can extract ResNet-50 +features from the output of either preceding graph. MediaPipe enables a +composable interface of data process for machine learning at multiple levels. + +The MediaPipe graph with brief annotations for adding images to a data set is as +follows. Common changes would be to change the frame_rate or encoding quality of +frames. + +``` +# Convert the string input into a decoded SequenceExample. +node { + calculator: "StringToSequenceExampleCalculator" + input_side_packet: "STRING:input_sequence_example" + output_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" +} + +# Unpack the data path and clip timing from the SequenceExample. +node { + calculator: "UnpackMediaSequenceCalculator" + input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" + output_side_packet: "DATA_PATH:input_video_path" + output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" + options { + [mediapipe.UnpackMediaSequenceCalculatorOptions.ext]: { + base_packet_resampler_options { + frame_rate: 24.0 + base_timestamp: 0 + } + } + } +} + +# Decode the entire video. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:decoded_frames" +} + +# Extract the subset of frames we want to keep. +node { + calculator: "PacketResamplerCalculator" + input_stream: "decoded_frames" + output_stream: "sampled_frames" + input_side_packet: "OPTIONS:packet_resampler_options" +} + +# Encode the images to store in the SequenceExample. +node { + calculator: "OpenCvImageEncoderCalculator" + input_stream: "sampled_frames" + output_stream: "encoded_frames" + node_options { + [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { + quality: 80 + } + } +} + +# Store the images in the SequenceExample. +node { + calculator: "PackMediaSequenceCalculator" + input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" + output_side_packet: "SEQUENCE_EXAMPLE:sequence_example_to_serialize" + input_stream: "IMAGE:encoded_frames" +} + +# Serialize the SequenceExample to a string for storage. +node { + calculator: "StringToSequenceExampleCalculator" + input_side_packet: "SEQUENCE_EXAMPLE:sequence_example_to_serialize" + output_side_packet: "STRING:output_sequence_example" +} +``` diff --git a/mediapipe/docs/object_detection_android_cpu.md b/mediapipe/docs/object_detection_android_cpu.md new file mode 100644 index 000000000..7dfe67a60 --- /dev/null +++ b/mediapipe/docs/object_detection_android_cpu.md @@ -0,0 +1,254 @@ +# Object Detection on CPU on Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. This +doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt) +that performs object detection with TensorFlow Lite on CPU. + +This is very similar to the +[Object Detection on GPU on Android](object_detection_android_gpu.md) example +except that at the beginning and the end of the graph it performs GPU-to-CPU and +CPU-to-GPU image transfer respectively. As a result, the rest of graph, which +shares the same configuration as the +[GPU graph](images/mobile/object_detection_android_gpu.png), runs entirely on +CPU. + +![object_detection_android_cpu_gif](images/mobile/object_detection_android_cpu.gif){width="300"} + +## App + +The graph is used in the +[Object Detection CPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/objectdetectioncpu.apk +``` + +## Graph + +![object_detection_android_cpu_graph](images/mobile/object_detection_android_cpu.png){width="400"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). + +```bash +# MediaPipe graph that performs object detection with TensorFlow Lite on CPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectioncpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Transfers the input image from GPU to CPU memory for the purpose of +# demonstrating a CPU-based pipeline. Note that the input image on GPU has the +# origin defined at the bottom-left corner (OpenGL convention). As a result, +# the transferred image on CPU also shares the same representation. +node: { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "input_video" + output_stream: "input_video_cpu" +} + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "input_video_cpu" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video_cpu" +} + +# Transforms the input image on CPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video_cpu" + output_stream: "IMAGE:transformed_input_video_cpu" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on CPU into an image tensor as a +# TfLiteTensor. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically +# option is set to true to account for the descrepancy between the +# representation of the input image (origin at the bottom-left corner) and what +# the model used in this graph is expecting (origin at the top-left corner). +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "TENSORS:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: true + flip_vertically: true + } + } +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "ssdlite_object_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + sigmoid_score: true + apply_exponential_on_box_size: true + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + flip_vertically: true + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 3 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the CPU copy of the original +# image coming into the graph. The calculator assumes that image origin is +# always at the top-left corner and renders text accordingly. However, the input +# image has its origin at the bottom-left corner (OpenGL convention) and the +# flip_text_vertically option is set to true to compensate that. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:throttled_input_video_cpu" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video_cpu" + node_options: { + [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { + flip_text_vertically: true + } + } +} + +# Transfers the annotated image from CPU back to GPU memory, to be sent out of +# the graph. +node: { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "output_video_cpu" + output_stream: "output_video" +} +``` diff --git a/mediapipe/docs/object_detection_android_gpu.md b/mediapipe/docs/object_detection_android_gpu.md new file mode 100644 index 000000000..3c5de429d --- /dev/null +++ b/mediapipe/docs/object_detection_android_gpu.md @@ -0,0 +1,231 @@ +# Object Detection on GPU on Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. This +doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt) +that performs object detection with TensorFlow Lite on GPU. + +![object_detection_android_gpu_gif](images/mobile/object_detection_android_gpu.gif){width="300"} + +## App + +The graph is used in the +[Object Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/objectdetectiongpu.apk +``` + +## Graph + +![object_detection_android_gpu_graph](images/mobile/object_detection_android_gpu.png){width="400"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). + +```bash +# MediaPipe graph that performs object detection with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "RealTimeFlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on GPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored in +# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically +# option is set to true to account for the descrepancy between the +# representation of the input image (origin at the bottom-left corner, the +# OpenGL convention) and what the model used in this graph is expecting (origin +# at the top-left corner). +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "TENSORS_GPU:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: true + flip_vertically: true + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS_GPU:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "ssdlite_object_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS_GPU:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + sigmoid_score: true + apply_exponential_on_box_size: true + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + flip_vertically: true + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 3 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. Annotation drawing is performed on CPU, and the result is +# transferred to GPU and overlaid on the input image. The calculator assumes +# that image origin is always at the top-left corner and renders text +# accordingly. However, the input image has its origin at the bottom-left corner +# (OpenGL convention) and the flip_text_vertically option is set to true to +# compensate that. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:throttled_input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME_GPU:output_video" + node_options: { + [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { + flip_text_vertically: true + } + } +} +``` diff --git a/mediapipe/docs/object_detection_desktop.md b/mediapipe/docs/object_detection_desktop.md new file mode 100644 index 000000000..2590578bf --- /dev/null +++ b/mediapipe/docs/object_detection_desktop.md @@ -0,0 +1,428 @@ +## Object Detection on Desktop + +This is an example of using MediaPipe to run object detection models (TensorFlow +and TensorFlow Lite) and render bounding boxes on the detected objects. To know +more about the object detection models and TensorFlow-to-TFLite model +conversion, please refer to the model [`README file`]. Moreover, if you are +interested in running the same TensorfFlow Lite model on Android, please see the +[Object Detection on GPU on Android](object_detection_android_gpu.md) and +[Object Detection on CPU on Android](object_detection_android_cpu.md) examples. + +### TensorFlow Model + +To build and run the TensorFlow example on desktop, run: + +```bash +# Note that this command also builds TensorFlow targets from scratch, it may +# take a long time (e.g., up to 30 mins) to build for the first time. +$ bazel build -c opt \ + --define 'MEDIAPIPE_DISABLE_GPU=1' \ + --define 'no_aws_support=true' \ + mediapipe/examples/desktop/object_detection:object_detection_tensorflow + +# It should print: +# Target //mediapipe/examples/desktop/object_detection:object_detection_tensorflow up-to-date: +# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow +# INFO: Elapsed time: 172.262s, Critical Path: 125.68s +# INFO: 2675 processes: 2673 linux-sandbox, 2 local. +# INFO: Build completed successfully, 2807 total actions + +# Replace and . +# You can find a test video in mediapipe/examples/desktop/object_detection. +$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \ + --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \ + --input_side_packets=input_video_path=,output_video_path= +``` + +#### Graph + +![graph visualization](images/object_detection_desktop_tensorflow.png){width="800"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into +[MediaPipe Visualizer](https://mediapipe-viz.appspot.com). + +```bash +# MediaPipe graph that performs object detection on desktop with TensorFlow +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/object_detection:object_detection_tensorflow. + +# Decodes an input video file into images and a video header. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:input_video" + output_stream: "VIDEO_PRESTREAM:input_video_header" +} + +# Converts the input image into an image tensor as a tensorflow::Tensor. +node { + calculator: "ImageFrameToTensorCalculator" + input_stream: "input_video" + output_stream: "image_tensor" +} + +# Generates a single side packet containing a TensorFlow session from a saved +# model. The directory path that contains the saved model is specified in the +# saved_model_path option, and the name of the saved model file has to be +# "saved_model.pb". +node { + calculator: "TensorFlowSessionFromSavedModelCalculator" + output_side_packet: "SESSION:object_detection_session" + node_options: { + [type.googleapis.com/mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions]: { + saved_model_path: "mediapipe/models/object_detection_saved_model" + } + } +} + +# Runs a TensorFlow session (specified as an input side packet) that takes an +# image tensor and outputs multiple tensors that describe the objects detected +# in the image. The batch_size option is set to 1 to disable batching entirely. +# Note that the particular TensorFlow model used in this session handles image +# scaling internally before the object-detection inference, and therefore no +# additional calculator for image transformation is needed in this MediaPipe +# graph. +node: { + calculator: "TensorFlowInferenceCalculator" + input_side_packet: "SESSION:object_detection_session" + input_stream: "INPUTS:image_tensor" + output_stream: "DETECTION_BOXES:detection_boxes_tensor" + output_stream: "DETECTION_CLASSES:detection_classes_tensor" + output_stream: "DETECTION_SCORES:detection_scores_tensor" + output_stream: "NUM_DETECTIONS:num_detections_tensor" + node_options: { + [type.googleapis.com/mediapipe.TensorFlowInferenceCalculatorOptions]: { + batch_size: 1 + } + } +} + +# Decodes the detection tensors from the TensorFlow model into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "ObjectDetectionTensorsToDetectionsCalculator" + input_stream: "BOXES:detection_boxes_tensor" + input_stream: "SCORES:detection_scores_tensor" + input_stream: "CLASSES:detection_classes_tensor" + input_stream: "NUM_DETECTIONS:num_detections_tensor" + output_stream: "DETECTIONS:detections" +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 10 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} + +# Encodes the annotated images into a video file, adopting properties specified +# in the input video header, e.g., video framerate. +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options: { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } +} +``` + +### TensorFlow Lite Model + +To build and run the TensorFlow Lite example on desktop, run: + +```bash +$ bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' \ + mediapipe/examples/desktop/object_detection:object_detection_tflite + +# It should print: +# Target //mediapipe/examples/desktop/object_detection:object_detection_tflite up-to-date: +# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite +# INFO: Elapsed time: 36.417s, Critical Path: 23.22s +# INFO: 711 processes: 710 linux-sandbox, 1 local. +# INFO: Build completed successfully, 734 total actions + +# Replace and . +# You can find a test video in mediapipe/examples/desktop/object_detection. +$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \ + --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \ + --input_side_packets=input_video_path=,output_video_path= +``` + +#### Graph + +![graph visualization](images/object_detection_desktop_tflite.png){width="400"} + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into +[MediaPipe Visualizer](https://mediapipe-viz.appspot.com). + +```bash +# MediaPipe graph that performs object detection on desktop with TensorFlow Lite +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/object_detection:object_detection_tflite. + +# max_queue_size limits the number of packets enqueued on any input stream +# by throttling inputs to the graph. This makes the graph only process one +# frame per time. +max_queue_size: 1 + +# Decodes an input video file into images and a video header. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:input_video" + output_stream: "VIDEO_PRESTREAM:input_video_header" +} + +# Transforms the input image on CPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_video" + output_stream: "IMAGE:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on CPU into an image tensor as a +# TfLiteTensor. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: true + } + } +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/ssdlite_object_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + apply_exponential_on_box_size: true + + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 5 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTION_VECTOR:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} + +# Encodes the annotated images into a video file, adopting properties specified +# in the input video header, e.g., video framerate. +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options: { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } +} +``` + +### Known issues with OpenCV 2 + +Note that OpenCV 2 may not be able to render an mp4 file and returns the +following error message: + +``` +[libx264 @ 0x7fe6eadf49a0] broken ffmpeg default settings detected +[libx264 @ 0x7fe6eadf49a0] use an encoding preset (e.g. -vpre medium) +[libx264 @ 0x7fe6eadf49a0] preset usage: -vpre -vpre +[libx264 @ 0x7fe6eadf49a0] speed presets are listed in x264 --help +[libx264 @ 0x7fe6eadf49a0] profile is optional; x264 defaults to high +Could not open codec 'libx264': Unspecified errorE0612 19:40:09.067003 2089 simple_run_graph_main.cc:64] Fail to run the graph: CalculatorGraph::Run() failed in Run: +Calculator::Process() for node "[OpenCvVideoEncoderCalculator, OpenCvVideoEncoderCalculator with node ID: 7 and input streams: ]" failed: ; Fail to open file at ... +``` + +In that case, please change the OpenCvVideoEncoderCalculator option in either +the [`TensorFlow graph`] or the [`TensorFlow Lite graph`] to the following and +in the command line specify the output video to be a .mkv file. + +```bash +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "MPEG" + video_format: "mkv" + } +} +``` + +[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md +[`TensorFlow graph`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt +[`TensorFlow Lite graph`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt diff --git a/mediapipe/docs/packets.md b/mediapipe/docs/packets.md new file mode 100644 index 000000000..2e52ae956 --- /dev/null +++ b/mediapipe/docs/packets.md @@ -0,0 +1,19 @@ +### Packets + +- [Creating a packet](#creating-a-packet) + +Each calculator is a node of of a graph. We describe how to create a new calculator, how to initialize a calculator, how to perform its calculations, input and output streams, timestamps, and options + +#### Creating a packet +Packets are generally created with `MediaPipe::Adopt()` (from packet.h). + +```c++ +// Create some data. +auto data = gtl::MakeUnique("constructor_argument"); +// Create a packet to own the data. +Packet p = Adopt(data.release()); +// Make a new packet with the same data and a different timestamp. +Packet p2 = p.At(Timestamp::PostStream()); +``` + +Data within a packet is accessed with `Packet::Get()` diff --git a/mediapipe/docs/scheduling_sync.md b/mediapipe/docs/scheduling_sync.md new file mode 100644 index 000000000..382a9a326 --- /dev/null +++ b/mediapipe/docs/scheduling_sync.md @@ -0,0 +1,164 @@ +# Framework Architecture + +## Scheduling mechanics + +Data processing in a MediaPipe graph occurs inside processing nodes defined as +[`CalculatorBase`] subclasses. The scheduling system decides when each +calculator should run. + +Each graph has at least one **scheduler queue**. Each scheduler queue has +exactly one **executor**. Nodes are statically assigned to a queue (and +therefore to an executor). By default there is one queue, whose executor is a +thread pool with a number of threads based on the system’s capabilities. + +Each node has a scheduling state, which can be *not ready*, *ready*, or +*running*. A readiness function determines whether a node is ready to run. This +function is invoked at graph initialization, whenever a node finishes running, +and whenever the state of a node’s inputs changes. + +The readiness function used depends on the type of node. A node with no stream +inputs is known as a **source node**; source nodes are always ready to run, +until they tell the framework they have no more data to output, at which point +they are closed. + +Non-source nodes are ready if they have inputs to process, and if those inputs +form a valid input set according to the conditions set by the node’s **input +policy** (discussed below). Most nodes use the default input policy, but some +nodes specify a different one. + +Note: Because changing the input policy changes the guarantees the calculator’s +code can expect from its inputs, it is not generally possible to mix and match +calculators with arbitrary input policies. Thus a calculator that uses a special +input policy should be written for it, and declare it in its contract. + +When a node becomes ready, a task is added to the corresponding scheduler queue, +which is a priority queue. The priority function is currently fixed, and takes +into account static properties of the nodes and their topological sorting within +the graph. For example, nodes closer to the output side of the graph have higher +priority, while source nodes have the lowest priority. + +Each queue is served by an executor, which is responsible for actually running +the task by invoking the calculator’s code. Different executors can be provided +and configured; this can be used to customize the use of execution resources, +e.g. by running certain nodes on lower-priority threads. + +## Timestamp Synchronization + +MediaPipe graph execution is decentralized: there is no global clock, and +different nodes can process data from different timestamps at the same time. +This allows higher throughput via pipelining. + +However, time information is very important for many perception workflows. Nodes +that receive multiple input streams generally need to coordinate them in some +way. For example, an object detector may output a list of boundary rectangles +from a frame, and this information may be fed into a rendering node, which +should process it together with the original frame. + +Therefore, one of the key responsibilities of the MediaPipe framework is to +provide input synchronization for nodes. In terms of framework mechanics, the +primary role of a timestamp is to serve as a **synchronization key**. + +Furthermore, MediaPipe is designed to support deterministic operations, which is +important in many scenarios (testing, simulation, batch processing, etc.), while +allowing graph authors to relax determinism where needed to meet real-time +constraints. + +The two objectives of synchronization and determinism underlie several design +choices. Notably, the packets pushed into a given stream must have monotonically +increasing timestamps: this is not just a useful assumption for many nodes, but +it is also relied upon by the synchronization logic. Each stream has a +**timestamp bound**, which is the lowest possible timestamp allowed for a new +packet on the stream. When a packet with timestamp `T` arrives, the bound +automatically advances to `T+1`, reflecting the monotonic requirement. This +allows the framework to know for certain that no more packets with timestamp +lower than `T` will arrive. + +## Input policies + +Synchronization is handled locally on each node, using the input policy +specified by the node. + +The default input policy, defined by [`DefaultInputStreamHandler`], provides +deterministic synchronization of inputs, with the following guarantees: + +* If packets with the same timestamp are provided on multiple input streams, + they will always be processed together regardless of their arrival order in + real time. + +* Input sets are processed in strictly ascending timestamp order. + +* No packets are dropped, and the processing is fully deterministic. + +* The node becomes ready to process data as soon as possible given the + guarantees above. + +Note: An important consequence of this is that if the calculator always uses the +current input timestamp when outputting packets, the output will inherently obey +the monotonically increasing timestamp requirement. + +Warning: On the other hand, it is not guaranteed that an input packet will +always be available for all streams. + +To explain how it works, we need to introduce the definition of a settled +timestamp. We say that a timestamp in a stream is *settled* if it lower than the +timestamp bound. In other words, a timestamp is settled for a stream once the +state of the input at that timestamp is irrevocably known: either there is a +packet, or there is the certainty that a packet with that timestamp will not +arrive. + +Note: For this reason, MediaPipe also allows a stream producer to explicitly +advance the timestamp bound farther that what the last packet implies, i.e. to +provide a tighter bound. This can allow the downstream nodes to settle their +inputs sooner. + +A timestamp is settled across multiple streams if it is settled on each of those +streams. Furthermore, if a timestamp is settled it implies that all previous +timestamps are also settled. Thus settled timestamps can be processed +deterministically in ascending order. + +Given this definition, a calculator with the default input policy is ready if +there is a timestamp which is settled across all input streams and contains a +packet on at least one input stream. The input policy provides all available +packets for a settled timestamp as a single *input set* to the calculator. + +One consequence of this deterministic behavior is that, for nodes with multiple +input streams, there can be a theoretically unbounded wait for a timestamp to be +settled, and an unbounded number of packets can be buffered in the meantime. +(Consider a node with two input streams, one of which keeps sending packets +while the other sends nothing and does not advance the bound.) + +Therefore, we also provide for custom input policies: for example, splitting the +inputs in different synchronization sets defined by +[`SyncSetInputStreamHandler`], or avoiding synchronization altogether and +processing inputs immediately as they arrive defined by +[`ImmediateInputStreamHandler`]. + +## Flow control + +There are two main flow control mechanisms. A backpressure mechanism throttles +the execution of upstream nodes when the packets buffered on a stream reach a +(configurable) limit defined by [`CalculatorGraphConfig::max_queue_size`]. This +mechanism maintains deterministic behavior, and includes a deadlock avoidance +system that relaxes configured limits when needed. + +The second system consists of inserting special nodes which can drop packets +according to real-time constraints (typically using custom input policies) +defined by [`RealTimeFlowLimiterCalculator`]. For example, a common pattern +places a flow-control node at the input of a subgraph, with a loopback +connection from the final output to the flow-control node. The flow-control node +is thus able to keep track of how many timestamps are being processed in the +downstream graph, and drop packets if this count hits a (configurable) limit; +and since packets are dropped upstream, we avoid the wasted work that would +result from partially processing a timestamp and then dropping packets between +intermediate stages. + +This calculator-based approach gives the graph author control of where packets +can be dropped, and allows flexibility in adapting and customizing the graph’s +behavior depending on resource constraints. + +[`CalculatorBase`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h +[`DefaultInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/default_input_stream_handler.h +[`SyncSetInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h +[`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.h +[`CalculatorGraphConfig::max_queue_size`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc diff --git a/mediapipe/docs/troubleshooting.md b/mediapipe/docs/troubleshooting.md new file mode 100644 index 000000000..a3ad61e72 --- /dev/null +++ b/mediapipe/docs/troubleshooting.md @@ -0,0 +1,144 @@ +# Troubleshooting + +- [Native method not found](#native-method-not-found) +- [No registered calculator found](#no-registered-calculator-found) +- [Out Of Memory error](#out-of-memory-error) +- [Graph hangs](#graph-hangs) +- [Calculator is scheduled infrequently](#calculator-is-scheduled-infrequently) +- [Output timing is uneven](#output-timing-is-uneven) +- [CalculatorGraph lags behind inputs](#calculatorgraph-lags-behind-inputs) + +## Native method not found + +The error message: + +``` +java.lang.UnsatisfiedLinkError: No implementation found for void com.google.wick.Wick.nativeWick +``` + +usually indicates that a needed native library, such as `/libwickjni.so` has not +been loaded or has not been included in the dependencies of the app or cannot be +found for some reason. Note that Java requires every native library to be +explicitly loaded using the function `System.loadLibrary`. + +## No registered calculator found + +The error message: + +``` +No registered object with name: OurNewCalculator; Unable to find Calculator "OurNewCalculator" +``` + +usually indicates that `OurNewCalculator` is referenced by name in a +[`CalculatorGraphConfig`] but that the library target for OurNewCalculator has +not been linked to the application binary. When a new calculator is added to a +calculator graph, that calculator must also be added as a build dependency of +the applications using the calculator graph. + +This error is caught at runtime because calculator graphs reference their +calculators by name through the field `CalculatorGraphConfig::Node:calculator`. +When the library for a calculator is linked into an application binary, the +calculator is automatically registered by name through the +[`REGISTER_CALCULATOR`] macro using the [`registration.h`] library. Note that +[`REGISTER_CALCULATOR`] can register a calculator with a namespace prefix, +identical to its C++ namespace. In this case, the calcultor graph must also use +the same namespace prefix. + +## Out Of Memory error + +Exhausting memory can be a symptom of too many packets accumulating inside a +running MediaPipe graph. This can occur for a number of reasons, such as: + +1. Some calculators in the graph simply can't keep pace with the arrival of + packets from a realtime input stream such as a video camera. +2. Some calculators are waiting for packets that will never arrive. + +For problem (1), it may be necessary to drop some old packets in older to +process the more recent packets. For some hints, see: +[How to process realtime input streams](how_to_questions.md#how-to-process-realtime-input-streams) + +For problem (2), it could be that one input stream is lacking packets for some +reason. A device or a calculator may be misconfigured or may produce packets +only sporadically. This can cause downstream calculators to wait for many +packets that will never arrive, which in turn causes packets to accumulate on +some of their input streams. MediaPipe addresses this sort of problem using +"timestamp bounds". For some hints see: +[How to process realtime input streams](how_to_questions.md#how-to-process-realtime-input-streams) + +The MediaPipe setting [`CalculatorGraphConfig::max_queue_size`] limits the +number of packets enqueued on any input stream by throttling inputs to the +graph. For realtime input streams, the number of packets queued at an input +stream should almost always be zero or one. If this is not the case, you may see +the following warning message: + +``` +Resolved a deadlock by increasing max_queue_size of input stream +``` + +Also, the setting [`CalculatorGraphConfig::report_deadlock`] can be set to cause +graph run to fail and surface the deadlock as an error, such that max_queue_size +to acts as a memory usage limit. + +## Graph hangs + +Many applications will call [`CalculatorGraph::CloseAllPacketSources`] and +[`CalculatorGraph::WaitUntilDone`] to finish or suspend execution of a MediaPipe +graph. The objective here is to allow any pending calculators or packets to +complete processing, and then to shutdown the graph. If all goes well, every +stream in the graph will reach [`Timestamp::Done`], and every calculator will +reach [`CalculatorBase::Close`], and then [`CalculatorGraph::WaitUntilDone`] +will complete successfully. + +If some calculators or streams cannot reach state [`Timestamp::Done`] or +[`CalculatorBase::Close`], then the method [`CalculatorGraph::Cancel`] can be +called to terminate the graph run without waiting for all pending calculators +and packets to complete. + +## Output timing is uneven + +Some realtime MediaPipe graphs produce a series of video frames for viewing as a +video effect or as a video diagnostic. Sometimes, a MediaPipe graph will produce +these frames in clusters, for example when several output frames are +extrapolated from the same cluster of input frames. If the outputs are presented +as they are produced, some output frames are immediately replaced by later +frames in the same cluster, which makes the results hard to see and evaluate +visually. In cases like this, the output visualization can be improved by +presenting the frames at even intervals in real time. + +MediaPipe addresses this use case by mapping timestamps to points in real time. +Each timestamp indicates a time in microseconds, and a calculator such as +`LiveClockSyncCalculator` can delay the output of packets to match their +timestamps. This sort of calculator adjusts the timing of outputs such that: + +1. The time between outputs corresponds to the time between timestamps as + closely as possible. +2. Outputs are produced with the smallest delay possible. + +## CalculatorGraph lags behind inputs + +For many realtime MediaPipe graphs, low latency is an objective. MediaPipe +supports "pipelined" style parallel processing in order to begin processing of +each packet as early as possible. Normally the lowest possible latency is the +total time required by each calculator along a "critical path" of successive +calculators. The latency of the a MediaPipe graph could be worse than the ideal +due to delays introduced to display frames a even intervals as described in +[Output timing is uneven](troubleshooting.md?cl=252235797#output-timing-is-uneven). + +If some of the calculators in the graph cannot keep pace with the realtime input +streams, then latency will continue to increase, and it becomes necessary to +drop some input packets. The recommended technique is to use the MediaPipe +calculators designed specifically for this purpose such as +[`RealTimeFlowLimiterCalculator`] as described in +[How to process realtime input streams](how_to_questions.md#how-to-process-realtime-input-streams). + +[`CalculatorGraphConfig`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`CalculatorGraphConfig::max_queue_size`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`CalculatorGraphConfig::report_deadlock`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto +[`REGISTER_CALCULATOR`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_registry.h +[`registration.h`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/deps/registration.h +[`CalculatorGraph::CloseAllPacketSources`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_graph.h +[`CalculatorGraph::Cancel`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_graph.h +[`CalculatorGraph::WaitUntilDone`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_graph.h +[`Timestamp::Done`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/timestamp.h +[`CalculatorBase::Close`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h +[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc diff --git a/mediapipe/docs/visualizer.md b/mediapipe/docs/visualizer.md new file mode 100644 index 000000000..0ba3f9403 --- /dev/null +++ b/mediapipe/docs/visualizer.md @@ -0,0 +1,61 @@ +## Visualizing MediaPipe Graphs + +- [Working within the editor](#working-within-the-editor) +- [Understanding the Graph](#understanding-the-graph) + +To help users understand the structure of their calculator graphs and to +understand the overall behavior of their machine learning inference pipelines, +we have built the [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/) that is available online. + +* A graph view allows users to see a connected calculator graph as expressed + through a graph configuration that is pasted into the graph editor or + uploaded. The user can visualize and troubleshoot a graph they have created. + + ![Startup screen](./images/startup_screen.png){width="800"} + +### Working within the editor + +Getting Started: + +The graph can be modified by adding and editing code in the Editor view. + +![Editor UI](./images/editor_view.png){width="600"} + +* Pressing the "New" button in the upper right corner will clear any existing + code in the Editor window. + + ![New Button](./images/upload_button.png){width="300"} + +* Pressing the "Upload" button will prompt the user to select a local PBTXT + file, which will everwrite the current code within the editor. + +* Alternatively, code can be pasted directly into the editor window. + +* Errors and informational messages will appear in the Feedback window. + + ![Error Msg](./images/console_error.png){width="400"} + +### Understanding the Graph + +The visualizer graph shows the connections between calculator nodes. + +* Streams exit from the bottom of the calculator producing the stream and + enter the top of any calculator receiving the stream. (Notice the use of the + key, "input_stream" and "output_stream"). + + ![Stream UI](./images/stream_ui.png){width="350"} + ![Stream_code](./images/stream_code.png){width="350"} + +* Sidepackets work the same, except that they exit a node on the right and + enter on the left. (Notice the use of the key, "input_side_packet" and + "output_side_packet"). + + ![Sidepacket UI](./images/side_packet.png){width="350"} + ![Sidepacket_code](./images/side_packet_code.png){width="350"} + +* There are special nodes that represent inputs and outputs to the graph and + can supply either side packets or streams. + + ![Special nodes](./images/special_nodes.png){width="350"} + ![Special nodes](./images/special_nodes_code.png){width="350"} + diff --git a/mediapipe/examples/__init__.py b/mediapipe/examples/__init__.py new file mode 100644 index 000000000..6db73bc52 --- /dev/null +++ b/mediapipe/examples/__init__.py @@ -0,0 +1,14 @@ +"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/examples/android/README.md b/mediapipe/examples/android/README.md new file mode 100644 index 000000000..136d37a3f --- /dev/null +++ b/mediapipe/examples/android/README.md @@ -0,0 +1,4 @@ +MediaPipe Examples +================== + +This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details. diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md new file mode 100644 index 000000000..4a5a1cd33 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md @@ -0,0 +1,26 @@ +MediaPipe Examples +================== + +This directory contains MediaPipe Android example applications for different use cases. The applications use CameraX API to access the camera. + +## Use Cases + +| Use Case | Directory | +|---------------------------------------|:-----------------------------------:| +| Edge Detection on GPU | edgedetectiongpu | +| Face Detection on CPU | facedetectioncpu | +| Face Detection on GPU | facedetectiongpu | +| Object Detection on CPU | objectdetectioncpu | +| Object Detection on GPU | objectdetectiongpu | + +For instance, to build an example app for face detection on CPU, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu +``` + +To further install the app on an Android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/facedetectioncpu.apk +``` diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/AndroidManifest.xml new file mode 100644 index 000000000..2b19bb84c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/AndroidManifest.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD new file mode 100644 index 000000000..21ee273a0 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD @@ -0,0 +1,77 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/edge_detection:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/edge_detection:android_gpu_binary_graph"], + outs = ["edgedetectiongpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidx_concurrent_futures//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "edgedetectiongpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.edgedetectiongpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java new file mode 100644 index 000000000..7ee302137 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java @@ -0,0 +1,158 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.edgedetectiongpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Bare-bones main activity. */ +public class MainActivity extends AppCompatActivity { + + private static final String BINARY_GRAPH_NAME = "edgedetectiongpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + public void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/strings.xml new file mode 100644 index 000000000..ac720f510 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Edge Detection GPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/AndroidManifest.xml new file mode 100644 index 000000000..fb99539da --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD new file mode 100644 index 000000000..0c1a5be65 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD @@ -0,0 +1,83 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/face_detection:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/face_detection:android_cpu_binary_graph"], + outs = ["facedetectioncpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + "//mediapipe/models:facedetector_front.tflite", + "//mediapipe/models:facedetector_front_labelmap.txt", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidsdk//com.android.support:recyclerview-v7-25.0.0", + "@androidsdk//com.android.support:support-v4-25.0.0", + "@androidx_concurrent_futures//jar", + "@androidx_lifecycle//jar", + "@com_google_code_findbugs//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "facedetectioncpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.facedetectioncpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java new file mode 100644 index 000000000..a0dc964f3 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.facedetectioncpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Main activity of MediaPipe example apps. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private static final String BINARY_GRAPH_NAME = "facedetectioncpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } + + private void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/strings.xml new file mode 100644 index 000000000..0c678f5cb --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Face Detection CPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/AndroidManifest.xml new file mode 100644 index 000000000..27afedd3d --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD new file mode 100644 index 000000000..7728b3bd7 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD @@ -0,0 +1,83 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/face_detection:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/face_detection:android_gpu_binary_graph"], + outs = ["facedetectiongpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + "//mediapipe/models:facedetector_front.tflite", + "//mediapipe/models:facedetector_front_labelmap.txt", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidsdk//com.android.support:recyclerview-v7-25.0.0", + "@androidsdk//com.android.support:support-v4-25.0.0", + "@androidx_concurrent_futures//jar", + "@androidx_lifecycle//jar", + "@com_google_code_findbugs//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "facedetectiongpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.facedetectiongpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java new file mode 100644 index 000000000..d232992fb --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.facedetectiongpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Main activity of MediaPipe example apps. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private static final String BINARY_GRAPH_NAME = "facedetectiongpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } + + private void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/strings.xml new file mode 100644 index 000000000..25f08adfc --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Face Detection GPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/AndroidManifest.xml new file mode 100644 index 000000000..27b432ed4 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD new file mode 100644 index 000000000..071ccf986 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD @@ -0,0 +1,82 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/hair_segmentation:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/hair_segmentation:android_gpu_binary_graph"], + outs = ["hairsegmentationgpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + "//mediapipe/models:hair_segmentation.tflite", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidsdk//com.android.support:recyclerview-v7-25.0.0", + "@androidsdk//com.android.support:support-v4-25.0.0", + "@androidx_concurrent_futures//jar", + "@androidx_lifecycle//jar", + "@com_google_code_findbugs//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "hairsegmentationgpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.hairsegmentationgpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java new file mode 100644 index 000000000..c33311ffb --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.hairsegmentationgpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Main activity of MediaPipe example apps. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private static final String BINARY_GRAPH_NAME = "hairsegmentationgpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } + + private void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/strings.xml new file mode 100644 index 000000000..41cce9899 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Hair Segmentation GPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/AndroidManifest.xml new file mode 100644 index 000000000..5b40791e4 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD new file mode 100644 index 000000000..3cab05ff2 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD @@ -0,0 +1,83 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/object_detection:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/object_detection:android_cpu_binary_graph"], + outs = ["objectdetectioncpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidsdk//com.android.support:recyclerview-v7-25.0.0", + "@androidsdk//com.android.support:support-v4-25.0.0", + "@androidx_concurrent_futures//jar", + "@androidx_lifecycle//jar", + "@com_google_code_findbugs//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "objectdetectioncpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.objectdetectioncpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java new file mode 100644 index 000000000..2cbbe7cd5 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.objectdetectioncpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Main activity of MediaPipe example apps. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private static final String BINARY_GRAPH_NAME = "objectdetectioncpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } + + private void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/strings.xml new file mode 100644 index 000000000..86b06b44b --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Object Detection CPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/AndroidManifest.xml new file mode 100644 index 000000000..decc87de9 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD new file mode 100644 index 000000000..39a3d1523 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD @@ -0,0 +1,83 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/object_detection:android_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be +# easily incorporated into the app via, for example, +# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". +genrule( + name = "binary_graph", + srcs = ["//mediapipe/graphs/object_detection:android_gpu_binary_graph"], + outs = ["objectdetectiongpu.binarypb"], + cmd = "cp $< $@", +) + +android_library( + name = "mediapipe_lib", + srcs = glob(["*.java"]), + assets = [ + ":binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + assets_dir = "", + manifest = "AndroidManifest.xml", + resource_files = glob(["res/**"]), + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", + "//mediapipe/java/com/google/mediapipe/components:android_components", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/glutil", + "//third_party:android_constraint_layout", + "//third_party:androidx_appcompat", + "//third_party:opencv", + "@androidsdk//com.android.support:recyclerview-v7-25.0.0", + "@androidsdk//com.android.support:support-v4-25.0.0", + "@androidx_concurrent_futures//jar", + "@androidx_lifecycle//jar", + "@com_google_code_findbugs//jar", + "@com_google_guava_android//jar", + ], +) + +android_binary( + name = "objectdetectiongpu", + aapt_version = "aapt2", + manifest = "AndroidManifest.xml", + manifest_values = {"applicationId": "com.google.mediapipe.apps.objectdetectiongpu"}, + multidex = "native", + deps = [ + ":mediapipe_lib", + ], +) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java new file mode 100644 index 000000000..9d4324fde --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java @@ -0,0 +1,159 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.apps.objectdetectiongpu; + +import android.graphics.SurfaceTexture; +import android.os.Bundle; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.SurfaceHolder; +import android.view.SurfaceView; +import android.view.View; +import android.view.ViewGroup; +import com.google.mediapipe.components.CameraHelper; +import com.google.mediapipe.components.CameraXPreviewHelper; +import com.google.mediapipe.components.ExternalTextureConverter; +import com.google.mediapipe.components.FrameProcessor; +import com.google.mediapipe.components.PermissionHelper; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.glutil.EglManager; + +/** Main activity of MediaPipe example apps. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private static final String BINARY_GRAPH_NAME = "objectdetectiongpu.binarypb"; + private static final String INPUT_VIDEO_STREAM_NAME = "input_video"; + private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; + private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + + static { + // Load all native libraries needed by the app. + System.loadLibrary("mediapipe_jni"); + System.loadLibrary("opencv_java4"); + } + + // {@link SurfaceTexture} where the camera-preview frames can be accessed. + private SurfaceTexture previewFrameTexture; + // {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph. + private SurfaceView previewDisplayView; + + // Creates and manages an {@link EGLContext}. + private EglManager eglManager; + // Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed + // frames onto a {@link Surface}. + private FrameProcessor processor; + // Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be + // consumed by {@link FrameProcessor} and the underlying MediaPipe graph. + private ExternalTextureConverter converter; + + // Handles camera access via the {@link CameraX} Jetpack support library. + private CameraXPreviewHelper cameraHelper; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + previewDisplayView = new SurfaceView(this); + setupPreviewDisplayView(); + + // Initilize asset manager so that MediaPipe native libraries can access the app assets, e.g., + // binary graphs. + AndroidAssetUtil.initializeNativeAssetManager(this); + + eglManager = new EglManager(null); + processor = + new FrameProcessor( + this, + eglManager.getNativeContext(), + BINARY_GRAPH_NAME, + INPUT_VIDEO_STREAM_NAME, + OUTPUT_VIDEO_STREAM_NAME); + + PermissionHelper.checkAndRequestCameraPermissions(this); + } + + @Override + protected void onResume() { + super.onResume(); + converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setConsumer(processor); + if (PermissionHelper.cameraPermissionsGranted(this)) { + startCamera(); + } + } + + @Override + protected void onPause() { + super.onPause(); + converter.close(); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + private void setupPreviewDisplayView() { + previewDisplayView.setVisibility(View.GONE); + ViewGroup viewGroup = findViewById(R.id.preview_display_layout); + viewGroup.addView(previewDisplayView); + + previewDisplayView + .getHolder() + .addCallback( + new SurfaceHolder.Callback() { + @Override + public void surfaceCreated(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(holder.getSurface()); + } + + @Override + public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) { + // (Re-)Compute the ideal size of the camera-preview display (the area that the + // camera-preview frames get rendered onto, potentially with scaling and rotation) + // based on the size of the SurfaceView that contains the display. + Size viewSize = new Size(width, height); + Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize); + + // Connect the converter to the camera-preview frames as its input (via + // previewFrameTexture), and configure the output width and height as the computed + // display size. + converter.setSurfaceTextureAndAttachToGLContext( + previewFrameTexture, displaySize.getWidth(), displaySize.getHeight()); + } + + @Override + public void surfaceDestroyed(SurfaceHolder holder) { + processor.getVideoSurfaceOutput().setSurface(null); + } + }); + } + + private void startCamera() { + cameraHelper = new CameraXPreviewHelper(); + cameraHelper.setOnCameraStartedListener( + surfaceTexture -> { + previewFrameTexture = surfaceTexture; + // Make the display view visible to start showing the preview. This triggers the + // SurfaceHolder.Callback added to (the holder of) previewDisplayView. + previewDisplayView.setVisibility(View.VISIBLE); + }); + cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null); + } +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml new file mode 100644 index 000000000..22240a2d6 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml @@ -0,0 +1,20 @@ + + + + + + + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/colors.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/colors.xml new file mode 100644 index 000000000..69b22338c --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/strings.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/strings.xml new file mode 100644 index 000000000..a6c688177 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/strings.xml @@ -0,0 +1,4 @@ + + Object Detection GPU + Please grant camera permissions. + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/styles.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/styles.xml new file mode 100644 index 000000000..5885930df --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/mediapipe/examples/desktop/BUILD b/mediapipe/examples/desktop/BUILD new file mode 100644 index 000000000..829601df7 --- /dev/null +++ b/mediapipe/examples/desktop/BUILD @@ -0,0 +1,31 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_library( + name = "simple_run_graph_main", + srcs = ["simple_run_graph_main.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/examples/desktop/README.md b/mediapipe/examples/desktop/README.md new file mode 100644 index 000000000..21cb9b2a7 --- /dev/null +++ b/mediapipe/examples/desktop/README.md @@ -0,0 +1,48 @@ +**Hello World** + +To build the "Hello World" example, use: + +``` +bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world +``` + +and then run it using: + +``` +bazel-bin/mediapipe/examples/desktop/hello_world/hello_world --logtostderr +``` + +**TFlite Object Detection** + +To build the objet detection demo using a TFLite model on desktop, use: + +``` +bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define 'MEDIAPIPE_DISABLE_GPU=1' +``` + +and run it using: + +``` +bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \ + --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \ + --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file \ + --alsologtostderr +``` + +**TensorFlow Object Detection** + +To build the object detection demo using a TensorFlow model on desktop, use: + +``` +bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \ + --define 'MEDIAPIPE_DISABLE_GPU=1' +``` + +and run it using: + +``` +bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \ + --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \ + --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file + --alsologtostderr +``` diff --git a/mediapipe/examples/desktop/__init__.py b/mediapipe/examples/desktop/__init__.py new file mode 100644 index 000000000..6db73bc52 --- /dev/null +++ b/mediapipe/examples/desktop/__init__.py @@ -0,0 +1,14 @@ +"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD new file mode 100644 index 000000000..ff36a24f0 --- /dev/null +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -0,0 +1,30 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_binary( + name = "hello_world", + srcs = ["hello_world.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_graph", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) diff --git a/mediapipe/examples/desktop/hello_world/hello_world.cc b/mediapipe/examples/desktop/hello_world/hello_world.cc new file mode 100644 index 000000000..0fe378d4b --- /dev/null +++ b/mediapipe/examples/desktop/hello_world/hello_world.cc @@ -0,0 +1,65 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// A simple example to print out "Hello World!" from a MediaPipe graph. + +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +::mediapipe::Status PrintHelloWorld() { + // Configures a simple graph, which concatenates 2 PassThroughCalculators. + CalculatorGraphConfig config = ParseTextProtoOrDie(R"( + input_stream: "in" + output_stream: "out" + node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out1" + } + node { + calculator: "PassThroughCalculator" + input_stream: "out1" + output_stream: "out" + } + )"); + + CalculatorGraph graph; + RETURN_IF_ERROR(graph.Initialize(config)); + ASSIGN_OR_RETURN(OutputStreamPoller poller, + graph.AddOutputStreamPoller("out")); + RETURN_IF_ERROR(graph.StartRun({})); + // Give 10 input packets that contains the same std::string "Hello World!". + for (int i = 0; i < 10; ++i) { + RETURN_IF_ERROR(graph.AddPacketToInputStream( + "in", MakePacket("Hello World!").At(Timestamp(i)))); + } + // Close the input stream "in". + RETURN_IF_ERROR(graph.CloseInputStream("in")); + mediapipe::Packet packet; + // Get the output packets std::string. + while (poller.Next(&packet)) { + LOG(INFO) << packet.Get(); + } + return graph.WaitUntilDone(); +} +} // namespace mediapipe + +int main(int argc, char** argv) { + CHECK(mediapipe::PrintHelloWorld().ok()); + return 0; +} diff --git a/mediapipe/examples/desktop/media_sequence/BUILD b/mediapipe/examples/desktop/media_sequence/BUILD new file mode 100644 index 000000000..c2a39f758 --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/BUILD @@ -0,0 +1,39 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_library( + name = "run_graph_file_io_main", + srcs = ["run_graph_file_io_main.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "media_sequence_demo", + deps = [ + ":run_graph_file_io_main", + "//mediapipe/graphs/media_sequence:clipped_images_from_file_at_24fps_calculators", + ], +) diff --git a/mediapipe/examples/desktop/media_sequence/README.md b/mediapipe/examples/desktop/media_sequence/README.md new file mode 100644 index 000000000..6be9014db --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/README.md @@ -0,0 +1,51 @@ +# Preparing data sets for machine learning with MediaPipe +We include two pipelines to prepare data sets for training TensorFlow models. + +Using these data sets is split into two parts. First, the data set is +constructed in with a Python script and MediaPipe C++ binary. The C++ binary +should be compiled by the end user because the preparation for different data +sets requires different MediaPipe calculator dependencies. The result of running +the script is a data set of TFRecord files on disk. The second stage is reading +the data from TensorFlow into a tf.data.Dataset. Both pipelines can be imported +and support a simple call to as_dataset() to make the data available. + +### Demo data set +To generate the demo dataset you must have Tensorflow [version >= 1.19] +installed. Then the media_sequence_demo binary must be built from the top +directory in the mediapipe repo and the command to build the data set must be +run from the same directory. +``` +bazel -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ + --define=MEDIAPIPE_DISABLE_GPU=1 + +python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ + --alsologtostderr \ + --path_to_demo_data=/tmp/demo_data/ \ + --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/\ +media_sequence/media_sequence_demo \ + --path_to_graph_directory=mediapipe/graphs/media_sequence/ +``` + +### Charades data set + +The Charades data set is ready for training and/or evaluating action recognition +models in TensorFlow. You may only use this script in ways that comply with the +Allen Institute for Artificial Intelligence's [license for the Charades data +set.](https://allenai.org/plato/charades/license.txt) + +To generate the Charades dataset you must have Tensorflow [version >= 1.19] +installed. Then the media_sequence_demo binary must be built from the top +directory in the mediapipe repo and the command to build the data set must be +run from the same directory. + +``` +bazel -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ + --define=MEDIAPIPE_DISABLE_GPU=1 + +python -m mediapipe.examples.desktop.media_sequence.charades_dataset \ + --alsologtostderr \ + --path_to_charades_data=/tmp/charades_data/ \ + --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/\ +media_sequence/media_sequence_demo \ + --path_to_graph_directory=mediapipe/graphs/media_sequence/ +``` diff --git a/mediapipe/examples/desktop/media_sequence/__init__.py b/mediapipe/examples/desktop/media_sequence/__init__.py new file mode 100644 index 000000000..6db73bc52 --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/__init__.py @@ -0,0 +1,14 @@ +"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/examples/desktop/media_sequence/charades_dataset.py b/mediapipe/examples/desktop/media_sequence/charades_dataset.py new file mode 100644 index 000000000..cb94e07eb --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/charades_dataset.py @@ -0,0 +1,517 @@ +r"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Code to download and parse the Charades dataset for TensorFlow models. + +The [Charades data set](https://allenai.org/plato/charades/) is a data set of +human action recognition collected with and maintained by the Allen Institute +for Artificial Intelligence. This script downloads and prepares the data set for +training a TensorFlow model. To use this script, you must abide by the +[lincense](https://allenai.org/plato/charades/license.txt) for the Charades data +set provided by the Allen Institute. The license for this script only covers +this code and not the data set. + +Running this code as a module generates the data set on disk. First, the +required files are downloaded (_download_data). Then, for each split in the +data set (generate_examples), the metadata is generated from the annotations for +each example (_generate_metadata), and MediaPipe is used to fill in the video +frames (_run_mediapipe). The data set is written to disk as a set of numbered +TFRecord files. If the download is disrupted, the incomplete files will need to +be removed before running the script again. This pattern can be reproduced and +modified to generate most video data sets. + +Generating the data on disk will probably take 4-8 hours and requires 150 GB of +disk space. (Image compression quality is the primary determiner of disk usage.) +After generating the data, the 30 GB of compressed video data can be deleted. + +Once the data is on disk, reading the data as a tf.data.Dataset is accomplished +with the following lines: + + charades = CharadesDataset("charades_data_path") + dataset = charades.as_dataset("test") + # implement additional processing and batching here + images_and_labels = dataset.make_one_shot_iterator().get_next() + images = images_and_labels["images"] + labels = image_and_labels["classification_target"] + label_weights = image_and_labels["indicator_matrix"] + +This data is structured for per-frame action classification where images is +the sequence of images, labels are the sequence of classification targets and, +label_weights is 1 for valid frames and 0 for padded frames (if any). See +as_dataset() for more details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import csv +import os +import random +import subprocess +import sys +import tempfile +import urllib +import zipfile +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf +from mediapipe.util.sequence import media_sequence as ms + + +DATA_URL_ANNOTATIONS = "http://ai2-website.s3.amazonaws.com/data/Charades.zip" +DATA_URL_VIDEOS = "http://ai2-website.s3.amazonaws.com/data/Charades_v1_480.zip" +DATA_URL_LICENSE = "https://allenai.org/plato/charades/license.txt" +CITATION = r"""@article{sigurdsson2016hollywood, +author = {Gunnar A. Sigurdsson and G{\"u}l Varol and Xiaolong Wang and Ivan Laptev and Ali Farhadi and Abhinav Gupta}, +title = {Hollywood in Homes: Crowdsourcing Data Collection for Activity Understanding}, +journal = {ArXiv e-prints}, +eprint = {1604.01753}, +year = {2016}, +url = {http://arxiv.org/abs/1604.01753}, +}""" +SECONDS_TO_MICROSECONDS = 1000000 +GRAPHS = ["clipped_images_from_file_at_24fps.pbtxt"] +SPLITS = { + "train": ("charades_v1_train_records", # base name for sharded files + "Charades_v1_train.csv", # path to csv of annotations + 1000, # number of shards + 7986), # number of examples + "test": ("charades_v1_test_records", + "Charades_v1_test.csv", + 100, + 1864), +} +NUM_CLASSES = 157 +CLASS_LABEL_OFFSET = 1 + + +class Charades(object): + """Generates and loads the Charades data set.""" + + def __init__(self, path_to_data): + if not path_to_data: + raise ValueError("You must supply the path to the data directory.") + self.path_to_data = path_to_data + + def as_dataset(self, split, shuffle=False, repeat=False, + serialized_prefetch_size=32, decoded_prefetch_size=32): + """Returns Charades as a tf.data.Dataset. + + After running this function, calling padded_batch() on the Dataset object + will produce batches of data, but additional preprocessing may be desired. + If using padded_batch, the indicator_matrix output distinguishes valid + from padded frames. + + Args: + split: either "train" or "test" + shuffle: if true, shuffles both files and examples. + repeat: if true, repeats the data set forever. + serialized_prefetch_size: the buffer size for reading from disk. + decoded_prefetch_size: the buffer size after decoding. + Returns: + A tf.data.Dataset object with the following structure: { + "images": uint8 tensor, shape [time, height, width, channels] + "segment_matrix": binary tensor of segments, shape [time, num_segments]. + See one_hot_segments() for details. + "indicator_matrix": binary tensor indicating valid frames, + shape [time, 1]. If padded with zeros to align sizes, the indicator + marks where segments is valid. + "classification_target": binary tensor of classification targets, + shape [time, 158 classes]. More than one value in a row can be 1.0 if + segments overlap. + "example_id": a unique string id for each example, shape []. + "sampling_rate": the frame rate for each sequence, shape []. + "gt_segment_seconds": the start and end time of each segment, + shape [num_segments, 2]. + "gt_segment_classes": the class labels for each segment, + shape [num_segments]. + "num_segments": the number of segments in the example, shape []. + "num_timesteps": the number of timesteps in the example, shape []. + "images": the [time, height, width, channels] tensor of images. + """ + def parse_fn(sequence_example): + """Parses a Charades example.""" + context_features = { + ms.get_example_id_key(): ms.get_example_id_default_parser(), + ms.get_segment_start_index_key(): ( + ms.get_segment_start_index_default_parser()), + ms.get_segment_end_index_key(): ( + ms.get_segment_end_index_default_parser()), + ms.get_segment_label_index_key(): ( + ms.get_segment_label_index_default_parser()), + ms.get_segment_label_string_key(): ( + ms.get_segment_label_string_default_parser()), + ms.get_segment_start_timestamp_key(): ( + ms.get_segment_start_timestamp_default_parser()), + ms.get_segment_end_timestamp_key(): ( + ms.get_segment_end_timestamp_default_parser()), + ms.get_image_frame_rate_key(): ( + ms.get_image_frame_rate_default_parser()), + } + + sequence_features = { + ms.get_image_encoded_key(): ms.get_image_encoded_default_parser() + } + parsed_context, parsed_sequence = tf.io.parse_single_sequence_example( + sequence_example, context_features, sequence_features) + + sequence_length = tf.shape(parsed_sequence[ms.get_image_encoded_key()])[0] + num_segments = tf.shape( + parsed_context[ms.get_segment_label_index_key()])[0] + # segments matrix and targets for training. + segments_matrix, indicator = one_hot_segments( + tf.sparse_tensor_to_dense( + parsed_context[ms.get_segment_start_index_key()]), + tf.sparse_tensor_to_dense( + parsed_context[ms.get_segment_end_index_key()]), + sequence_length) + + classification_target = timepoint_classification_target( + segments_matrix, + tf.sparse_tensor_to_dense( + parsed_context[ms.get_segment_label_index_key()] + ) + CLASS_LABEL_OFFSET, + NUM_CLASSES + CLASS_LABEL_OFFSET) + + # [segments, 2] start and end time in seconds. + gt_segment_seconds = tf.to_float(tf.concat( + [tf.expand_dims(tf.sparse_tensor_to_dense(parsed_context[ + ms.get_segment_start_timestamp_key()]), 1), + tf.expand_dims(tf.sparse_tensor_to_dense(parsed_context[ + ms.get_segment_end_timestamp_key()]), 1)], + 1)) / float(SECONDS_TO_MICROSECONDS) + gt_segment_classes = tf.sparse_tensor_to_dense(parsed_context[ + ms.get_segment_label_index_key()]) + CLASS_LABEL_OFFSET + example_id = parsed_context[ms.get_example_id_key()] + sampling_rate = parsed_context[ms.get_image_frame_rate_key()] + + images = tf.map_fn(tf.image.decode_jpeg, + parsed_sequence[ms.get_image_encoded_key()], + back_prop=False, + dtype=tf.uint8) + + output_dict = { + "segment_matrix": segments_matrix, + "indicator_matrix": indicator, + "classification_target": classification_target, + "example_id": example_id, + "sampling_rate": sampling_rate, + "gt_segment_seconds": gt_segment_seconds, + "gt_segment_classes": gt_segment_classes, + "num_segments": num_segments, + "num_timesteps": sequence_length, + "images": images, + } + return output_dict + + if split not in SPLITS: + raise ValueError("Split %s not in %s" % split, str(SPLITS.keys())) + all_shards = tf.io.gfile.glob( + os.path.join(self.path_to_data, SPLITS[split][0] + "-*-of-*")) + random.shuffle(all_shards) + all_shards_dataset = tf.data.Dataset.from_tensor_slices(all_shards) + cycle_length = min(16, len(all_shards)) + dataset = all_shards_dataset.apply( + tf.contrib.data.parallel_interleave( + tf.data.TFRecordDataset, + cycle_length=cycle_length, + block_length=1, sloppy=True, + buffer_output_elements=serialized_prefetch_size)) + dataset = dataset.prefetch(serialized_prefetch_size) + if shuffle: + dataset = dataset.shuffle(serialized_prefetch_size) + if repeat: + dataset = dataset.repeat() + dataset = dataset.map(parse_fn) + dataset = dataset.prefetch(decoded_prefetch_size) + return dataset + + def generate_examples(self, + path_to_mediapipe_binary, path_to_graph_directory): + """Downloads data and generates sharded TFRecords. + + Downloads the data files, generates metadata, and processes the metadata + with MediaPipe to produce tf.SequenceExamples for training. The resulting + files can be read with as_dataset(). After running this function the + original data files can be deleted. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + path_to_graph_directory: Path to the directory with MediaPipe graphs in + mediapipe/graphs/media_sequence/. + """ + if not path_to_mediapipe_binary: + raise ValueError( + "You must supply the path to the MediaPipe binary for " + "mediapipe/examples/desktop/demo:media_sequence_demo.") + if not path_to_graph_directory: + raise ValueError( + "You must supply the path to the directory with MediaPipe graphs in " + "mediapipe/graphs/media_sequence/.") + logging.info("Downloading data.") + annotation_dir, video_dir = self._download_data() + for name, annotations, shards, _ in SPLITS.values(): + annotation_file = os.path.join( + annotation_dir, annotations) + logging.info("Generating metadata for split: %s", name) + all_metadata = list(self._generate_metadata(annotation_file, video_dir)) + random.seed(47) + random.shuffle(all_metadata) + shard_names = [os.path.join(self.path_to_data, name + "-%05d-of-%05d" % ( + i, shards)) for i in range(shards)] + writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] + with _close_on_exit(writers) as writers: + for i, seq_ex in enumerate(all_metadata): + print("Processing example %d of %d (%d%%) \r" % ( + i, len(all_metadata), i * 100 / len(all_metadata)), end="") + for graph in GRAPHS: + graph_path = os.path.join(path_to_graph_directory, graph) + seq_ex = self._run_mediapipe( + path_to_mediapipe_binary, seq_ex, graph_path) + writers[i % len(writers)].write(seq_ex.SerializeToString()) + logging.info("Data extraction complete.") + + def _generate_metadata(self, annotations_file, video_dir): + """For each row in the annotation CSV, generates the corresponding metadata. + + Args: + annotations_file: path to the file of Charades CSV annotations. + video_dir: path to the directory of video files referenced by the + annotations. + Yields: + Each tf.SequenceExample of metadata, ready to pass to MediaPipe. + """ + with open(annotations_file, "r") as annotations: + reader = csv.DictReader(annotations) + for row in reader: + metadata = tf.train.SequenceExample() + filepath = os.path.join(video_dir, "%s.mp4" % row["id"]) + actions = row["actions"].split(";") + action_indices = [] + action_strings = [] + action_start_times = [] + action_end_times = [] + for action in actions: + if not action: + continue + string, start, end = action.split(" ") + action_indices.append(int(string[1:])) + action_strings.append(bytes23(string)) + action_start_times.append(int(float(start) * SECONDS_TO_MICROSECONDS)) + action_end_times.append(int(float(end) * SECONDS_TO_MICROSECONDS)) + ms.set_example_id(bytes23(row["id"]), metadata) + ms.set_clip_data_path(bytes23(filepath), metadata) + ms.set_clip_start_timestamp(0, metadata) + ms.set_clip_end_timestamp( + int(float(row["length"]) * SECONDS_TO_MICROSECONDS), metadata) + ms.set_segment_start_timestamp(action_start_times, metadata) + ms.set_segment_end_timestamp(action_end_times, metadata) + ms.set_segment_label_string(action_strings, metadata) + ms.set_segment_label_index(action_indices, metadata) + yield metadata + + def _download_data(self): + """Downloads and extracts data if not already available.""" + if sys.version_info >= (3, 0): + urlretrieve = urllib.request.urlretrieve + else: + urlretrieve = urllib.urlretrieve + logging.info("Creating data directory.") + tf.io.gfile.makedirs(self.path_to_data) + logging.info("Downloading license.") + local_license_path = os.path.join( + self.path_to_data, DATA_URL_LICENSE.split("/")[-1]) + if not tf.io.gfile.exists(local_license_path): + urlretrieve(DATA_URL_LICENSE, local_license_path) + logging.info("Downloading annotations.") + local_annotations_path = os.path.join( + self.path_to_data, DATA_URL_ANNOTATIONS.split("/")[-1]) + if not tf.io.gfile.exists(local_annotations_path): + urlretrieve(DATA_URL_ANNOTATIONS, local_annotations_path) + logging.info("Downloading videos.") + local_videos_path = os.path.join( + self.path_to_data, DATA_URL_VIDEOS.split("/")[-1]) + if not tf.io.gfile.exists(local_videos_path): + urlretrieve(DATA_URL_VIDEOS, local_videos_path, progress_hook) + logging.info("Extracting annotations.") + # return video dir and annotation_dir by removing .zip from the path. + annotations_dir = local_annotations_path[:-4] + if not tf.io.gfile.exists(annotations_dir): + with zipfile.ZipFile(local_annotations_path) as annotations_zip: + annotations_zip.extractall(self.path_to_data) + logging.info("Extracting videos.") + video_dir = local_videos_path[:-4] + if not tf.io.gfile.exists(video_dir): + with zipfile.ZipFile(local_videos_path) as videos_zip: + videos_zip.extractall(self.path_to_data) + return annotations_dir, video_dir + + def _run_mediapipe(self, path_to_mediapipe_binary, sequence_example, graph): + """Runs MediaPipe over MediaSequence tf.train.SequenceExamples. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + sequence_example: The SequenceExample with metadata or partial data file. + graph: The path to the graph that extracts data to add to the + SequenceExample. + Returns: + A copy of the input SequenceExample with additional data fields added + by the MediaPipe graph. + Raises: + RuntimeError: if MediaPipe returns an error or fails to run the graph. + """ + if not path_to_mediapipe_binary: + raise ValueError("--path_to_mediapipe_binary must be specified.") + input_fd, input_filename = tempfile.mkstemp() + output_fd, output_filename = tempfile.mkstemp() + cmd = [path_to_mediapipe_binary, + "--calculator_graph_config_file=%s" % graph, + "--input_side_packets=input_sequence_example=%s" % input_filename, + "--output_side_packets=output_sequence_example=%s" % output_filename] + with open(input_filename, "wb") as input_file: + input_file.write(sequence_example.SerializeToString()) + mediapipe_output = subprocess.check_output(cmd) + if b"Failed to run the graph" in mediapipe_output: + raise RuntimeError(mediapipe_output) + with open(output_filename, "rb") as output_file: + output_example = tf.train.SequenceExample() + output_example.ParseFromString(output_file.read()) + os.close(input_fd) + os.remove(input_filename) + os.close(output_fd) + os.remove(output_filename) + return output_example + + +def one_hot_segments(start_indices, end_indices, num_samples): + """Returns a one-hot, float matrix of segments at each timestep. + + All integers in the inclusive range of start_indices and end_indices are used. + This allows start and end timestamps to be mapped to the same index and the + segment will not be omitted. + + Args: + start_indices: a 1d tensor of integer indices for the start of each + segement. + end_indices: a tensor of integer indices for the end of each segment. + Must be the same shape as start_indices. Values should be >= start_indices + but not strictly enforced. + num_samples: the number of rows in the output. Indices should be < + num_samples, but this is not strictly enforced. + Returns: + (segments, indicator) + segments: A [num_samples, num_elements(start_indices)] tensor where in each + column the rows with indices >= start_indices[column] and + <= end_indices[column] are 1.0 and all other values are 0.0. + indicator: a tensor of 1.0 values with shape [num_samples, 1]. If padded + with zeros to align sizes, the indicator marks where segments is valid. + """ + start_indices = tf.convert_to_tensor(start_indices) + end_indices = tf.convert_to_tensor(end_indices) + start_indices.shape.assert_is_compatible_with(end_indices.shape) + start_indices.shape.assert_has_rank(1) + end_indices.shape.assert_has_rank(1) + # create a matrix of the index at each row with a column per segment. + indices = tf.to_int64( + tf.tile( + tf.transpose(tf.expand_dims(tf.range(num_samples), 0)), + [1, tf.shape(start_indices)[0]])) + # switch to one hot encoding of segments (includes start and end indices) + segments = tf.to_float( + tf.logical_and( + tf.greater_equal(indices, start_indices), + tf.less_equal(indices, end_indices))) + # create a tensors of ones everywhere there's an annotation. If padded with + # zeros later, element-wise multiplication of the loss will mask out the + # padding. + indicator = tf.ones(shape=[num_samples, 1], dtype=tf.float32) + return segments, indicator + + +def timepoint_classification_target(segments, segment_classes, num_classes): + """Produces a classification target at each timepoint. + + If no segments are present at a time point, the first class is set to 1.0. + This should be used as a background class unless segments are always present. + + Args: + segments: a [time, num_segments] tensor that is 1.0 at indices within + each segment and 0.0 elsewhere. + segment_classes: a [num_segments] tensor with the class index of each + segment. + num_classes: the number of classes (must be >= max(segment_classes) + 1) + Returns: + a [time, num_classes] tensor. In the final output, more than one + value in a row can be 1.0 if segments overlap. + """ + num_segments = tf.shape(segments)[1] + matrix_of_class_indices = tf.to_int32( + segments * tf.to_float(tf.expand_dims(segment_classes, 0))) + # First column will have one count per zero segment. Correct this to be 0 + # unless no segments are present. + one_hot = tf.reduce_sum(tf.one_hot(matrix_of_class_indices, num_classes), 1) + normalizer = tf.concat([ + tf.ones(shape=[1, 1], dtype=tf.float32) / tf.to_float(num_segments), + tf.ones(shape=[1, num_classes - 1], dtype=tf.float32) + ], 1) + corrected_one_hot = tf.floor(one_hot * normalizer) + return corrected_one_hot + + +def progress_hook(blocks, block_size, total_size): + print("Downloaded %d%% of %d bytes (%d blocks)\r" % ( + blocks * block_size / total_size * 100, total_size, blocks), end="") + + +def bytes23(string): + """Creates a bytes string in either Python 2 or 3.""" + if sys.version_info >= (3, 0): + return bytes(string, "utf8") + else: + return bytes(string) + + +@contextlib.contextmanager +def _close_on_exit(writers): + """Call close on all writers on exit.""" + try: + yield writers + finally: + for writer in writers: + writer.close() + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + Charades(flags.FLAGS.path_to_charades_data).generate_examples( + flags.FLAGS.path_to_mediapipe_binary, + flags.FLAGS.path_to_graph_directory) + +if __name__ == "__main__": + flags.DEFINE_string("path_to_charades_data", + "", + "Path to directory to write data to.") + flags.DEFINE_string("path_to_mediapipe_binary", + "", + "Path to the MediaPipe run_graph_file_io_main binary.") + flags.DEFINE_string("path_to_graph_directory", + "", + "Path to directory containing the graph files.") + app.run(main) diff --git a/mediapipe/examples/desktop/media_sequence/demo_dataset.py b/mediapipe/examples/desktop/media_sequence/demo_dataset.py new file mode 100644 index 000000000..627149be3 --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/demo_dataset.py @@ -0,0 +1,317 @@ +r"""Copyright 2019 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +A demo data set constructed with MediaSequence and MediaPipe. + +This code demonstrates the steps for constructing a data set with MediaSequence. +This code has two functions. First, it can be run as a module to download and +prepare a toy dataset. Second, it can be imported and used to provide a +tf.data.Dataset reading that data from disk via as_dataset(). + +Running as a module prepares the data in three stages via generate_examples(). +First, the actual data files are downloaded. If the download is disrupted, the +incomplete files will need to be removed before running the script again. +Second, the annotations are parsed and reformated into metadata as described in +the MediaSequence documentation. Third, MediaPipe is run to extract subsequences +of frames for subsequent training via _run_mediapipe(). + +The toy data set is classifying a clip as a panning shot of galaxy or nebula +from videos releasued under the [Creative Commons Attribution 4.0 International +license](http://creativecommons.org/licenses/by/4.0/) on the ESA/Hubble site. +(The use of these ESA/Hubble materials does not imply the endorsement by +ESA/Hubble or any ESA/Hubble employee of a commercial product or service.) Each +video is split into 5 or 6 ten-second clips with a label of "galaxy" or "nebula" +and downsampled to 10 frames per second. (The last clip for each test example is +only 6 seconds.) There is one video of each class in each of the training and +testing splits. + +Reading the data as a tf.data.Dataset is accomplished with the following lines: + + demo = DemoDataset("demo_data_path") + dataset = demo.as_dataset("test") + # implement additional processing and batching here + images_and_labels = dataset.make_one_shot_iterator().get_next() + images = images_and_labels["images"] + labels = image_and_labels["labels"] +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import csv +import os +import random +import subprocess +import sys +import tempfile +import urllib + +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf + +from mediapipe.util.sequence import media_sequence as ms + +SPLITS = { + "train": + """url,label index,label string,duration,credits +https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4,0,nebula,50,"ESA/Hubble; Music: Johan B. Monell" +https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1712b.mp4,1,galaxy,50,"ESA/Hubble, Digitized Sky Survey, Nick Risinger (skysurvey.org) Music: Johan B Monell" +""", + "test": + """url,label index,label string,duration,credits +https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1301b.m4v,0,nebula,56,"NASA, ESA. Acknowledgement: Josh Lake" +https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1305b.m4v,1,galaxy,56,"NASA, ESA, Digitized Sky Survey 2. Acknowledgement: A. van der Hoeven" +""" +} +NUM_CLASSES = 2 +NUM_SHARDS = 2 +SECONDS_PER_EXAMPLE = 10 +MICROSECONDS_PER_SECOND = 1000000 +TF_RECORD_PATTERN = "demo_space_dataset_%s_tfrecord" +GRAPHS = ["clipped_images_from_file_at_24fps.pbtxt"] + + +class DemoDataset(object): + """Generates and loads a demo data set.""" + + def __init__(self, path_to_data): + if not path_to_data: + raise ValueError("You must supply the path to the data directory.") + self.path_to_data = path_to_data + + def as_dataset(self, + split, + shuffle=False, + repeat=False, + serialized_prefetch_size=32, + decoded_prefetch_size=32): + """Returns the dataset as a tf.data.Dataset. + + Args: + split: either "train" or "test" + shuffle: if true, shuffles both files and examples. + repeat: if true, repeats the data set forever. + serialized_prefetch_size: the buffer size for reading from disk. + decoded_prefetch_size: the buffer size after decoding. + + Returns: + A tf.data.Dataset object with the following structure: { + "images": uint8 tensor, shape [time, height, width, channels] + "labels": one hot encoded label tensor, shape [2] + "id": a unique string id for each example, shape [] + } + """ + + def parse_fn(sequence_example): + """Parses a clip classification example.""" + context_features = { + ms.get_example_id_key(): + ms.get_example_id_default_parser(), + ms.get_clip_label_index_key(): + ms.get_clip_label_index_default_parser(), + ms.get_clip_label_string_key(): + ms.get_clip_label_string_default_parser() + } + sequence_features = { + ms.get_image_encoded_key(): ms.get_image_encoded_default_parser(), + } + parsed_context, parsed_sequence = tf.io.parse_single_sequence_example( + sequence_example, context_features, sequence_features) + example_id = parsed_context[ms.get_example_id_key()] + classification_target = tf.one_hot( + tf.sparse_tensor_to_dense( + parsed_context[ms.get_clip_label_index_key()]), NUM_CLASSES) + images = tf.map_fn( + tf.image.decode_jpeg, + parsed_sequence[ms.get_image_encoded_key()], + back_prop=False, + dtype=tf.uint8) + return { + "id": example_id, + "labels": classification_target, + "images": images, + } + + if split not in SPLITS: + raise ValueError("split '%s' is unknown." % split) + all_shards = tf.io.gfile.glob( + os.path.join(self.path_to_data, TF_RECORD_PATTERN % split + "-*-of-*")) + if shuffle: + random.shuffle(all_shards) + all_shards_dataset = tf.data.Dataset.from_tensor_slices(all_shards) + cycle_length = min(16, len(all_shards)) + dataset = all_shards_dataset.apply( + tf.contrib.data.parallel_interleave( + tf.data.TFRecordDataset, + cycle_length=cycle_length, + block_length=1, + sloppy=True, + buffer_output_elements=serialized_prefetch_size)) + dataset = dataset.prefetch(serialized_prefetch_size) + if shuffle: + dataset = dataset.shuffle(serialized_prefetch_size) + if repeat: + dataset = dataset.repeat() + dataset = dataset.map(parse_fn) + dataset = dataset.prefetch(decoded_prefetch_size) + return dataset + + def generate_examples(self, path_to_mediapipe_binary, + path_to_graph_directory): + """Downloads data and generates sharded TFRecords. + + Downloads the data files, generates metadata, and processes the metadata + with MediaPipe to produce tf.SequenceExamples for training. The resulting + files can be read with as_dataset(). After running this function the + original data files can be deleted. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + path_to_graph_directory: Path to the directory with MediaPipe graphs in + mediapipe/graphs/media_sequence/. + """ + if not path_to_mediapipe_binary: + raise ValueError("You must supply the path to the MediaPipe binary for " + "mediapipe/examples/desktop/demo:media_sequence_demo.") + if not path_to_graph_directory: + raise ValueError( + "You must supply the path to the directory with MediaPipe graphs in " + "mediapipe/graphs/media_sequence/.") + logging.info("Downloading data.") + tf.io.gfile.makedirs(self.path_to_data) + if sys.version_info >= (3, 0): + urlretrieve = urllib.request.urlretrieve + else: + urlretrieve = urllib.urlretrieve + for split in SPLITS: + reader = csv.DictReader(SPLITS[split].split("\n")) + all_metadata = [] + for row in reader: + url = row["url"] + basename = url.split("/")[-1] + local_path = os.path.join(self.path_to_data, basename) + if not tf.io.gfile.exists(local_path): + urlretrieve(url, local_path) + + for start_time in range(0, int(row["duration"]), SECONDS_PER_EXAMPLE): + metadata = tf.train.SequenceExample() + ms.set_example_id(bytes23(basename + "_" + str(start_time)), + metadata) + ms.set_clip_data_path(bytes23(local_path), metadata) + ms.set_clip_start_timestamp(start_time * MICROSECONDS_PER_SECOND, + metadata) + ms.set_clip_end_timestamp( + (start_time + SECONDS_PER_EXAMPLE) * MICROSECONDS_PER_SECOND, + metadata) + ms.set_clip_label_index((int(row["label index"]),), metadata) + ms.set_clip_label_string((bytes23(row["label string"]),), + metadata) + all_metadata.append(metadata) + random.seed(47) + random.shuffle(all_metadata) + shard_names = [self._indexed_shard(split, i) for i in range(NUM_SHARDS)] + writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] + with _close_on_exit(writers) as writers: + for i, seq_ex in enumerate(all_metadata): + for graph in GRAPHS: + graph_path = os.path.join(path_to_graph_directory, graph) + seq_ex = self._run_mediapipe(path_to_mediapipe_binary, seq_ex, + graph_path) + writers[i % len(writers)].write(seq_ex.SerializeToString()) + + def _indexed_shard(self, split, index): + """Constructs a sharded filename.""" + return os.path.join( + self.path_to_data, + TF_RECORD_PATTERN % split + "-%05d-of-%05d" % (index, NUM_SHARDS)) + + def _run_mediapipe(self, path_to_mediapipe_binary, sequence_example, graph): + """Runs MediaPipe over MediaSequence tf.train.SequenceExamples. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + sequence_example: The SequenceExample with metadata or partial data file. + graph: The path to the graph that extracts data to add to the + SequenceExample. + + Returns: + A copy of the input SequenceExample with additional data fields added + by the MediaPipe graph. + Raises: + RuntimeError: if MediaPipe returns an error or fails to run the graph. + """ + if not path_to_mediapipe_binary: + raise ValueError("--path_to_mediapipe_binary must be specified.") + input_fd, input_filename = tempfile.mkstemp() + output_fd, output_filename = tempfile.mkstemp() + cmd = [ + path_to_mediapipe_binary, + "--calculator_graph_config_file=%s" % graph, + "--input_side_packets=input_sequence_example=%s" % input_filename, + "--output_side_packets=output_sequence_example=%s" % output_filename + ] + with open(input_filename, "wb") as input_file: + input_file.write(sequence_example.SerializeToString()) + mediapipe_output = subprocess.check_output(cmd) + if b"Failed to run the graph" in mediapipe_output: + raise RuntimeError(mediapipe_output) + with open(output_filename, "rb") as output_file: + output_example = tf.train.SequenceExample() + output_example.ParseFromString(output_file.read()) + os.close(input_fd) + os.remove(input_filename) + os.close(output_fd) + os.remove(output_filename) + return output_example + + +def bytes23(string): + """Creates a bytes string in either Python 2 or 3.""" + if sys.version_info >= (3, 0): + return bytes(string, "utf8") + else: + return bytes(string) + + +@contextlib.contextmanager +def _close_on_exit(writers): + """Call close on all writers on exit.""" + try: + yield writers + finally: + for writer in writers: + writer.close() + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + DemoDataset(flags.FLAGS.path_to_demo_data).generate_examples( + flags.FLAGS.path_to_mediapipe_binary, flags.FLAGS.path_to_graph_directory) + + +if __name__ == "__main__": + flags.DEFINE_string("path_to_demo_data", "", + "Path to directory to write data to.") + flags.DEFINE_string("path_to_mediapipe_binary", "", + "Path to the MediaPipe run_graph_file_io_main binary.") + flags.DEFINE_string("path_to_graph_directory", "", + "Path to directory containing the graph files.") + app.run(main) diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc new file mode 100644 index 000000000..9a8e44dab --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -0,0 +1,93 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// A simple main function to run a MediaPipe graph. Input side packets are read +// from files provided via the command line and output side packets are written +// to disk. + +#include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/map_util.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" + +DEFINE_string( + calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +DEFINE_string(input_side_packets, "", + "Comma-separated list of key=value pairs specifying side packets " + "and corresponding file paths for the CalculatorGraph. The side " + "packets are read from the files and fed to the graph as strings " + "even if they represent doubles, floats, etc."); +DEFINE_string(output_side_packets, "", + "Comma-separated list of key=value pairs specifying the output " + "side packets and paths to write to disk for the " + "CalculatorGraph."); + +::mediapipe::Status RunMediaPipeGraph() { + std::string calculator_graph_config_contents; + RETURN_IF_ERROR(mediapipe::file::GetContents( + FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + calculator_graph_config_contents); + std::map input_side_packets; + std::vector kv_pairs = + absl::StrSplit(FLAGS_input_side_packets, ','); + for (const std::string& kv_pair : kv_pairs) { + std::vector name_and_value = absl::StrSplit(kv_pair, '='); + RET_CHECK(name_and_value.size() == 2); + RET_CHECK(!::mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + std::string input_side_packet_contents; + RETURN_IF_ERROR(mediapipe::file::GetContents(name_and_value[1], + &input_side_packet_contents)); + input_side_packets[name_and_value[0]] = + ::mediapipe::MakePacket(input_side_packet_contents); + } + LOG(INFO) << "Initialize the calculator graph."; + mediapipe::CalculatorGraph graph; + RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); + LOG(INFO) << "Start running the calculator graph."; + RETURN_IF_ERROR(graph.Run()); + LOG(INFO) << "Gathering output side packets."; + kv_pairs = absl::StrSplit(FLAGS_output_side_packets, ','); + for (const std::string& kv_pair : kv_pairs) { + std::vector name_and_value = absl::StrSplit(kv_pair, '='); + RET_CHECK(name_and_value.size() == 2); + ::mediapipe::StatusOr<::mediapipe::Packet> output_packet = + graph.GetOutputSidePacket(name_and_value[0]); + RET_CHECK(output_packet.ok()) + << "Packet " << name_and_value[0] << " was not available."; + const std::string& serialized_string = + output_packet.ValueOrDie().Get(); + RETURN_IF_ERROR( + mediapipe::file::SetContents(name_and_value[1], serialized_string)); + } + return ::mediapipe::OkStatus(); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + ::mediapipe::Status run_status = RunMediaPipeGraph(); + if (!run_status.ok()) { + LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + } else { + LOG(INFO) << "Success!"; + } + return 0; +} diff --git a/mediapipe/examples/desktop/object_detection/BUILD b/mediapipe/examples/desktop/object_detection/BUILD new file mode 100644 index 000000000..eda9e7023 --- /dev/null +++ b/mediapipe/examples/desktop/object_detection/BUILD @@ -0,0 +1,70 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_library( + name = "object_detection_tensorflow_deps", + deps = [ + "@org_tensorflow//tensorflow/c/kernels:bitcast_op", + "@org_tensorflow//tensorflow/core:direct_session", + "@org_tensorflow//tensorflow/core/kernels:argmax_op", + "@org_tensorflow//tensorflow/core/kernels:bias_op", + "@org_tensorflow//tensorflow/core/kernels:cast_op", + "@org_tensorflow//tensorflow/core/kernels:concat_op", + "@org_tensorflow//tensorflow/core/kernels:constant_op", + "@org_tensorflow//tensorflow/core/kernels:control_flow_ops", + "@org_tensorflow//tensorflow/core/kernels:conv_ops", + "@org_tensorflow//tensorflow/core/kernels:cwise_op", + "@org_tensorflow//tensorflow/core/kernels:depthwise_conv_op", + "@org_tensorflow//tensorflow/core/kernels:fused_batch_norm_op", + "@org_tensorflow//tensorflow/core/kernels:gather_op", + "@org_tensorflow//tensorflow/core/kernels:identity_op", + "@org_tensorflow//tensorflow/core/kernels:matmul_op", + "@org_tensorflow//tensorflow/core/kernels:non_max_suppression_op", + "@org_tensorflow//tensorflow/core/kernels:pack_op", + "@org_tensorflow//tensorflow/core/kernels:reduction_ops", + "@org_tensorflow//tensorflow/core/kernels:relu_op", + "@org_tensorflow//tensorflow/core/kernels:reshape_op", + "@org_tensorflow//tensorflow/core/kernels:resize_bilinear_op", + "@org_tensorflow//tensorflow/core/kernels:sequence_ops", + "@org_tensorflow//tensorflow/core/kernels:shape_ops", + "@org_tensorflow//tensorflow/core/kernels:slice_op", + "@org_tensorflow//tensorflow/core/kernels:split_op", + "@org_tensorflow//tensorflow/core/kernels:tensor_array_ops", + "@org_tensorflow//tensorflow/core/kernels:tile_ops", + "@org_tensorflow//tensorflow/core/kernels:topk_op", + "@org_tensorflow//tensorflow/core/kernels:transpose_op", + "@org_tensorflow//tensorflow/core/kernels:unpack_op", + ], +) + +cc_binary( + name = "object_detection_tensorflow", + deps = [ + ":object_detection_tensorflow_deps", + "//mediapipe/examples/desktop:simple_run_graph_main", + "//mediapipe/graphs/object_detection:desktop_tensorflow_calculators", + ], +) + +cc_binary( + name = "object_detection_tflite", + deps = [ + "//mediapipe/examples/desktop:simple_run_graph_main", + "//mediapipe/graphs/object_detection:desktop_tflite_calculators", + ], +) diff --git a/mediapipe/examples/desktop/object_detection/test_video.mp4 b/mediapipe/examples/desktop/object_detection/test_video.mp4 new file mode 100644 index 000000000..c706232d3 Binary files /dev/null and b/mediapipe/examples/desktop/object_detection/test_video.mp4 differ diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc new file mode 100644 index 000000000..5b6eb2876 --- /dev/null +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -0,0 +1,69 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// A simple main function to run a MediaPipe graph. + +#include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/map_util.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" + +DEFINE_string( + calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); + +DEFINE_string(input_side_packets, "", + "Comma-separated list of key=value pairs specifying side packets " + "for the CalculatorGraph. All values will be treated as the " + "string type even if they represent doubles, floats, etc."); + +::mediapipe::Status RunMediaPipeGraph() { + std::string calculator_graph_config_contents; + RETURN_IF_ERROR(mediapipe::file::GetContents( + FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + calculator_graph_config_contents); + std::map input_side_packets; + std::vector kv_pairs = + absl::StrSplit(FLAGS_input_side_packets, ','); + for (const std::string& kv_pair : kv_pairs) { + std::vector name_and_value = absl::StrSplit(kv_pair, '='); + RET_CHECK(name_and_value.size() == 2); + RET_CHECK(!::mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + input_side_packets[name_and_value[0]] = + ::mediapipe::MakePacket(name_and_value[1]); + } + LOG(INFO) << "Initialize the calculator graph."; + mediapipe::CalculatorGraph graph; + RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); + LOG(INFO) << "Start running the calculator graph."; + return graph.Run(); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + ::mediapipe::Status run_status = RunMediaPipeGraph(); + if (!run_status.ok()) { + LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + } else { + LOG(INFO) << "Success!"; + } + return 0; +} diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD new file mode 100644 index 000000000..184687e61 --- /dev/null +++ b/mediapipe/framework/BUILD @@ -0,0 +1,1638 @@ +# +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_py_proto_library") + +package_group( + name = "mediapipe_internal", + packages = [ + "//java/com/google/mediapipe/framework/...", + "//mediapipe/...", + ], +) + +exports_files([ + "transitive_protos.bzl", + "encode_binary_proto.bzl", +]) + +proto_library( + name = "calculator_proto", + srcs = ["calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:mediapipe_options_proto", + "//mediapipe/framework:packet_factory_proto", + "//mediapipe/framework:packet_generator_proto", + "//mediapipe/framework:status_handler_proto", + "//mediapipe/framework:stream_handler_proto", + "@protobuf_archive//:any_proto", + ], +) + +proto_library( + name = "calculator_options_proto", + srcs = ["calculator_options.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +proto_library( + name = "calculator_contract_test_proto", + testonly = 1, + srcs = ["calculator_contract_test.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "calculator_profile_proto", + srcs = ["calculator_profile.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "mediapipe_options_proto", + srcs = ["mediapipe_options.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +proto_library( + name = "packet_factory_proto", + srcs = ["packet_factory.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +proto_library( + name = "packet_generator_proto", + srcs = ["packet_generator.proto"], + visibility = [ + "//mediapipe:__subpackages__", + "//mediapipe/packet_generator:__pkg__", + ], +) + +proto_library( + name = "packet_test_proto", + testonly = 1, + srcs = ["packet_test.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +proto_library( + name = "status_handler_proto", + srcs = ["status_handler.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = ["//mediapipe/framework:mediapipe_options_proto"], +) + +proto_library( + name = "stream_handler_proto", + srcs = ["stream_handler.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = ["//mediapipe/framework:mediapipe_options_proto"], +) + +proto_library( + name = "test_calculators_proto", + testonly = 1, + srcs = ["test_calculators.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "thread_pool_executor_proto", + srcs = ["thread_pool_executor.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = ["//mediapipe/framework:mediapipe_options_proto"], +) + +mediapipe_cc_proto_library( + name = "calculator_cc_proto", + srcs = ["calculator.proto"], + cc_deps = [ + ":calculator_options_cc_proto", + ":mediapipe_options_cc_proto", + ":packet_factory_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", + "@protobuf_archive//:cc_wkt_protos", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "calculator_contract_test_cc_proto", + testonly = 1, + srcs = ["calculator_contract_test.proto"], + cc_deps = [":calculator_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":calculator_contract_test_proto"], +) + +mediapipe_cc_proto_library( + name = "calculator_options_cc_proto", + srcs = ["calculator_options.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":calculator_options_proto"], +) + +mediapipe_cc_proto_library( + name = "calculator_profile_cc_proto", + srcs = ["calculator_profile.proto"], + cc_deps = [":calculator_cc_proto"], + visibility = [ + "//mediapipe/framework:__subpackages__", + "//mediapipe/java/com/google/mediapipe/framework:__subpackages__", + ], + deps = [":calculator_profile_proto"], +) + +mediapipe_cc_proto_library( + name = "mediapipe_options_cc_proto", + srcs = [":mediapipe_options.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":mediapipe_options_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_factory_cc_proto", + srcs = ["packet_factory.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":packet_factory_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_generator_cc_proto", + srcs = ["packet_generator.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":packet_generator_proto"], +) + +mediapipe_cc_proto_library( + name = "packet_test_cc_proto", + testonly = 1, + srcs = ["packet_test.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":packet_test_proto"], +) + +mediapipe_cc_proto_library( + name = "status_handler_cc_proto", + srcs = ["status_handler.proto"], + cc_deps = [":mediapipe_options_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":status_handler_proto"], +) + +mediapipe_cc_proto_library( + name = "stream_handler_cc_proto", + srcs = ["stream_handler.proto"], + cc_deps = [":mediapipe_options_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":stream_handler_proto"], +) + +mediapipe_cc_proto_library( + name = "test_calculators_cc_proto", + testonly = 1, + srcs = ["test_calculators.proto"], + cc_deps = [":calculator_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":test_calculators_proto"], +) + +mediapipe_cc_proto_library( + name = "thread_pool_executor_cc_proto", + srcs = ["thread_pool_executor.proto"], + cc_deps = [":mediapipe_options_cc_proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [":thread_pool_executor_proto"], +) + +mediapipe_py_proto_library( + name = "mediapipe_options_py_pb2", + srcs = ["mediapipe_options.proto"], + api_version = 2, + proto_deps = [":mediapipe_options_proto"], + visibility = ["//mediapipe:__subpackages__"], +) + +mediapipe_py_proto_library( + name = "stream_handler_py_pb2", + srcs = ["stream_handler.proto"], + api_version = 2, + proto_deps = [":stream_handler_proto"], + py_proto_deps = [":mediapipe_options_py_pb2"], + visibility = ["//mediapipe:__subpackages__"], +) + +mediapipe_py_proto_library( + name = "packet_generator_py_pb2", + srcs = ["packet_generator.proto"], + api_version = 2, + proto_deps = [":packet_generator_proto"], + py_proto_deps = [":mediapipe_options_py_pb2"], + visibility = ["//mediapipe:__subpackages__"], +) + +mediapipe_py_proto_library( + name = "calculator_py_pb2", + srcs = [ + "calculator.proto", + "calculator_options.proto", + "packet_factory.proto", + "status_handler.proto", + ], + api_version = 2, + proto_deps = [":calculator_proto"], + py_proto_deps = [ + ":mediapipe_options_py_pb2", + ":packet_generator_py_pb2", + ":stream_handler_py_pb2", + ], + visibility = ["//mediapipe:__subpackages__"], +) + +java_lite_proto_library( + name = "calculator_java_proto_lite", + strict_deps = 0, + visibility = [":mediapipe_internal"], + deps = [":calculator_proto"], +) + +java_lite_proto_library( + name = "calculator_profile_java_proto_lite", + visibility = [":mediapipe_internal"], + deps = [":calculator_profile_proto"], +) + +cc_library( + name = "calculator_base", + srcs = ["calculator_base.cc"], + hdrs = ["calculator_base.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_context", + ":calculator_contract", + ":port", + ":timestamp", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "calculator_context", + srcs = ["calculator_context.cc"], + hdrs = ["calculator_context.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":calculator_state", + ":counter", + ":graph_service", + ":input_stream_shard", + ":output_stream_shard", + ":packet", + ":packet_set", + ":port", + ":timestamp", + "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "calculator_context_manager", + srcs = ["calculator_context_manager.cc"], + hdrs = ["calculator_context_manager.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":calculator_context", + ":calculator_state", + ":timestamp", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "calculator_contract", + srcs = ["calculator_contract.cc"], + hdrs = ["calculator_contract.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":graph_service", + ":packet_type", + ":port", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:options_util", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "legacy_calculator_support", + srcs = ["legacy_calculator_support.cc"], + hdrs = ["legacy_calculator_support.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_context", + ":calculator_contract", + ], +) + +cc_library( + name = "calculator_framework", + hdrs = ["calculator_framework.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":calculator_base", + ":calculator_graph", + ":calculator_registry", + ":counter_factory", + ":input_stream", + ":output_side_packet", + ":output_stream", + ":packet", + ":packet_generator", + ":packet_generator_graph", + ":packet_set", + ":packet_type", + ":port", + ":status_handler", + ":subgraph", + ":timestamp", + ":validated_graph_config", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:validate", + "//mediapipe/framework/tool:validate_name", + ], +) + +cc_library( + name = "calculator_graph", + srcs = [ + "calculator_graph.cc", + "scheduler.cc", + ], + hdrs = [ + "calculator_graph.h", + "scheduler.h", + ], + defines = select({ + "//conditions:default": [], + "//mediapipe/gpu:disable_gpu": [ + "MEDIAPIPE_DISABLE_GPU", + ], + }), + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_base", + ":counter_factory", + ":delegating_executor", + ":mediapipe_profiling", + ":executor", + ":graph_output_stream", + ":input_stream_manager", + ":input_stream_shard", + ":graph_service", + ":output_stream", + ":output_stream_manager", + ":output_stream_poller", + ":output_stream_shard", + ":packet", + ":packet_generator", + ":packet_generator_graph", + ":packet_set", + ":packet_type", + ":port", + ":scheduler_queue", + ":status_handler", + ":thread_pool_executor", + ":timestamp", + ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:packet_factory_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", + "//mediapipe/gpu:graph_support", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "//mediapipe/framework:calculator_node", + "//mediapipe/framework:output_side_packet_impl", + "//mediapipe/framework/profiler:graph_profiler", + "//mediapipe/framework/tool:fill_packet_set", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:validate", + "//mediapipe/framework/tool:validate_name", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/util:cpu_util", + ] + select({ + "//conditions:default": [ + "//mediapipe/gpu:gpu_shared_data_internal", + "//mediapipe/gpu:gpu_service", + ], + "//mediapipe/gpu:disable_gpu": [], + }), +) + +cc_library( + name = "calculator_node", + srcs = ["calculator_node.cc"], + hdrs = ["calculator_node.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":calculator_base", + ":calculator_context", + ":calculator_context_manager", + ":calculator_registry_util", + ":calculator_state", + ":counter_factory", + ":input_side_packet_handler", + ":input_stream_handler", + ":input_stream_manager", + ":input_stream_shard", + ":legacy_calculator_support", + ":mediapipe_profiling", + ":output_side_packet_impl", + ":output_stream_handler", + ":output_stream_manager", + ":output_stream_shard", + ":packet", + ":packet_set", + ":packet_type", + ":port", + ":timestamp", + ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/profiler:graph_profiler", + "//mediapipe/framework/stream_handler:default_input_stream_handler", + "//mediapipe/framework/stream_handler:in_order_output_stream_handler", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:validate_name", + "//mediapipe/gpu:graph_support", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "calculator_registry", + hdrs = ["calculator_registry.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_base", + ], +) + +cc_library( + name = "calculator_registry_util", + srcs = ["calculator_registry_util.cc"], + hdrs = ["calculator_registry_util.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_base", + ":calculator_context", + ":calculator_state", + ":collection", + ":collection_item_id", + ":packet_set", + ":timestamp", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/tool:tag_map", + ], +) + +cc_library( + name = "calculator_runner", + testonly = 1, + srcs = ["calculator_runner.cc"], + hdrs = ["calculator_runner.h"], + visibility = ["//visibility:public"], + deps = [ + ":calculator_framework", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "calculator_state", + srcs = ["calculator_state.cc"], + hdrs = ["calculator_state.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":counter", + ":counter_factory", + ":graph_service", + ":input_stream", + ":output_stream", + ":packet", + ":packet_set", + ":port", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:options_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "camera_intrinsics", + hdrs = ["camera_intrinsics.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "collection", + hdrs = ["collection.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":type_map", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:validate_name", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "collection_item_id", + srcs = ["collection_item_id.cc"], + hdrs = ["collection_item_id.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/deps:intops", + ], +) + +cc_library( + name = "counter", + hdrs = ["counter.h"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework/port:integral_types"], +) + +cc_library( + name = "counter_factory", + srcs = ["counter_factory.cc"], + hdrs = ["counter_factory.h"], + visibility = ["//visibility:public"], + deps = [ + ":counter", + ":port", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:map_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "delegating_executor", + srcs = ["delegating_executor.cc"], + hdrs = ["delegating_executor.h"], + visibility = ["//visibility:public"], + deps = [ + ":executor", + ], +) + +cc_library( + name = "demangle", + hdrs = ["demangle.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "mediapipe_profiling", + hdrs = [ + "mediapipe_profiling.h", + "platform_specific_profiling.h", + ], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + "//mediapipe/framework/profiler:graph_profiler", + ], +) + +cc_library( + name = "executor", + srcs = ["executor.cc"], + hdrs = ["executor.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + ], +) + +cc_library( + name = "graph_output_stream", + srcs = ["graph_output_stream.cc"], + hdrs = ["graph_output_stream.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":input_stream_handler", + ":input_stream_manager", + ":output_stream_manager", + ":packet", + ":packet_set", + ":packet_type", + ":timestamp", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "graph_service", + hdrs = ["graph_service.h"], + visibility = [":mediapipe_internal"], + deps = [ + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "input_side_packet_handler", + srcs = ["input_side_packet_handler.cc"], + hdrs = ["input_side_packet_handler.h"], + visibility = ["//visibility:public"], + deps = [ + ":collection_item_id", + ":packet", + ":packet_set", + ":packet_type", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:fill_packet_set", + ], +) + +cc_library( + name = "input_stream", + hdrs = ["input_stream.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":packet", + ":port", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "input_stream_handler", + srcs = ["input_stream_handler.cc"], + hdrs = ["input_stream_handler.h"], + visibility = [ + ":mediapipe_internal", + "//research/interaction/mediapipe/calculators:__pkg__", + ], + deps = [ + ":calculator_context", + ":calculator_context_manager", + ":collection", + ":collection_item_id", + ":input_stream_manager", + ":input_stream_shard", + ":packet", + ":packet_set", + ":packet_type", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:mediapipe_profiling", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "input_stream_manager", + srcs = ["input_stream_manager.cc"], + hdrs = ["input_stream_manager.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":packet", + ":packet_type", + ":port", + ":timestamp", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "input_stream_shard", + srcs = ["input_stream_shard.cc"], + hdrs = ["input_stream_shard.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":input_stream", + ":packet", + ":packet_type", + ":port", + ":timestamp", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "lifetime_tracker", + testonly = 1, + hdrs = ["lifetime_tracker.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "output_side_packet", + hdrs = ["output_side_packet.h"], + visibility = ["//visibility:public"], + deps = [ + ":packet", + ], +) + +cc_library( + name = "output_side_packet_impl", + srcs = ["output_side_packet_impl.cc"], + hdrs = ["output_side_packet_impl.h"], + visibility = ["//visibility:public"], + deps = [ + ":collection_item_id", + ":input_side_packet_handler", + ":output_side_packet", + ":packet", + ":packet_type", + ":port", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "output_stream", + hdrs = ["output_stream.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":packet", + ":port", + ":timestamp", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "output_stream_handler", + srcs = ["output_stream_handler.cc"], + hdrs = ["output_stream_handler.h"], + visibility = [ + ":mediapipe_internal", + ], + deps = [ + ":calculator_context_manager", + ":collection", + ":collection_item_id", + ":output_stream_manager", + ":output_stream_shard", + ":packet_set", + ":packet_type", + ":timestamp", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "output_stream_manager", + srcs = ["output_stream_manager.cc"], + hdrs = ["output_stream_manager.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":input_stream_handler", + ":output_stream_shard", + ":packet", + ":packet_type", + ":port", + ":timestamp", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "output_stream_poller", + hdrs = ["output_stream_poller.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_output_stream", + ], +) + +cc_library( + name = "output_stream_shard", + srcs = ["output_stream_shard.cc"], + hdrs = ["output_stream_shard.h"], + visibility = [":mediapipe_internal"], + deps = [ + ":output_stream", + ":packet", + ":packet_type", + ":port", + ":timestamp", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +# Defines Packet, a data carrier used throughout the framework. +cc_library( + name = "packet", + srcs = ["packet.cc"], + hdrs = ["packet.h"], + visibility = ["//visibility:public"], + deps = [ + ":port", + ":timestamp", + ":type_map", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/tool:type_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "packet_generator", + hdrs = ["packet_generator.h"], + visibility = ["//visibility:public"], + deps = [ + ":packet", + ":packet_set", + ":packet_type", + ":port", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "packet_generator_graph", + srcs = ["packet_generator_graph.cc"], + hdrs = ["packet_generator_graph.h"], + visibility = ["//visibility:public"], + deps = [ + ":delegating_executor", + ":executor", + ":packet", + ":packet_generator", + ":packet_type", + ":port", + ":thread_pool_executor", + ":validated_graph_config", + "//mediapipe/framework:packet_factory_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "packet_set", + hdrs = ["packet_set.h"], + visibility = ["//visibility:public"], + deps = [ + ":collection", + ":packet", + ], +) + +cc_library( + name = "packet_type", + srcs = ["packet_type.cc"], + hdrs = ["packet_type.h"], + visibility = ["//visibility:public"], + deps = [ + ":packet", + ":packet_set", + ":type_map", + "//mediapipe/framework:collection", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:validate_name", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "port", + hdrs = ["port.h"], + visibility = [ + "//mediapipe/framework:__subpackages__", + "//mediapipe/framework/port:__pkg__", + "//mediapipe/util:__pkg__", + ], +) + +cc_library( + name = "scheduler_queue", + srcs = ["scheduler_queue.cc"], + hdrs = [ + "scheduler_queue.h", + "scheduler_shared.h", + ], + visibility = [":mediapipe_internal"], + deps = [ + ":calculator_context", + ":calculator_node", + ":executor", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "status_handler", + hdrs = ["status_handler.h"], + visibility = ["//visibility:public"], + deps = [ + ":packet_set", + ":packet_type", + ":port", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "subgraph", + srcs = ["subgraph.cc"], + hdrs = ["subgraph.h"], + visibility = ["//visibility:public"], + deps = [ + ":port", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/deps:registration", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", + "//mediapipe/framework/tool:template_expander", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "test_calculators", + testonly = 1, + srcs = ["test_calculators.cc"], + visibility = ["//visibility:public"], + deps = [ + ":calculator_framework", + "//mediapipe/framework:test_calculators_cc_proto", + "//mediapipe/framework/deps:mathutil", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "test_service", + testonly = 1, + srcs = ["test_service.cc"], + hdrs = ["test_service.h"], + visibility = ["//visibility:public"], + deps = [ + ":calculator_contract", + ":calculator_framework", + ":graph_service", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "thread_pool_executor", + srcs = ["thread_pool_executor.cc"], + hdrs = ["thread_pool_executor.h"], + visibility = ["//visibility:public"], + deps = [ + ":executor", + "//mediapipe/framework:thread_pool_executor_cc_proto", + "//mediapipe/framework/deps:thread_options", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/port:threadpool", + "//mediapipe/util:cpu_util", + ], +) + +cc_library( + name = "timestamp", + srcs = ["timestamp.cc"], + hdrs = ["timestamp.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/deps:intops", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "throttler", + hdrs = ["throttler.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "type_map", + hdrs = ["type_map.h"], + visibility = ["//visibility:public"], + deps = [ + ":demangle", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:type_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + +cc_library( + name = "validated_graph_config", + srcs = ["validated_graph_config.cc"], + hdrs = ["validated_graph_config.h"], + visibility = ["//visibility:public"], + deps = [ + ":calculator_base", + ":calculator_contract", + ":calculator_registry_util", + ":legacy_calculator_support", + ":packet", + ":packet_generator", + ":packet_set", + ":packet_type", + ":port", + ":status_handler", + ":subgraph", + ":timestamp", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:topologicalsorter", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:subgraph_expansion", + "//mediapipe/framework/tool:validate", + "//mediapipe/framework/tool:validate_name", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "graph_validation", + hdrs = ["graph_validation.h"], + visibility = ["//visibility:public"], + deps = [ + ":calculator_framework", + "//mediapipe/framework/port:status", + ], +) + +# cc tests +cc_test( + name = "calculator_base_test", + size = "medium", + srcs = ["calculator_base_test.cc"], + linkstatic = 1, + deps = [ + ":calculator_base", + ":calculator_context", + ":calculator_context_manager", + ":calculator_registry", + ":calculator_state", + ":output_stream", + ":output_stream_manager", + ":output_stream_shard", + ":packet_set", + ":packet_type", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:tag_map_helper", + ], +) + +cc_test( + name = "calculator_contract_test", + srcs = ["calculator_contract_test.cc"], + linkstatic = 1, + deps = [ + ":calculator_contract", + ":calculator_contract_test_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +cc_test( + name = "calculator_node_test", + size = "small", + srcs = ["calculator_node_test.cc"], + linkstatic = 1, + deps = [ + ":calculator_framework", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_node", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:source", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "calculator_graph_event_loop_test", + size = "small", + srcs = ["calculator_graph_event_loop_test.cc"], + deps = [ + ":calculator_framework", + ":calculator_graph", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "calculator_graph_stopping_test", + size = "small", + srcs = ["calculator_graph_stopping_test.cc"], + deps = [ + ":calculator_framework", + ":calculator_graph", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "calculator_parallel_execution_test", + srcs = ["calculator_parallel_execution_test.cc"], + deps = [ + ":calculator_framework", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "calculator_runner_test", + size = "medium", + srcs = ["calculator_runner_test.cc"], + deps = [ + ":calculator_base", + ":calculator_registry", + ":calculator_runner", + ":input_stream", + ":output_stream", + ":packet_type", + ":timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "calculator_context_test", + size = "medium", + srcs = ["calculator_context_test.cc"], + linkstatic = 1, + deps = [ + ":calculator_context", + ":calculator_context_manager", + ":calculator_state", + ":output_stream", + ":output_stream_manager", + ":output_stream_shard", + ":packet_set", + ":packet_type", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", + "//mediapipe/framework/testdata:sky_light_calculator_cc_proto", + "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:tag_map_helper", + ], +) + +cc_test( + name = "calculator_graph_test", + size = "small", + srcs = [ + "calculator_graph_test.cc", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":calculator_framework", + ":calculator_graph", + ":collection_item_id", + ":counter_factory", + ":executor", + ":input_stream_handler", + ":lifetime_tracker", + ":output_stream_poller", + ":packet_set", + ":packet_type", + ":status_handler", + ":subgraph", + ":test_calculators", + ":thread_pool_executor", + ":timestamp", + ":type_map", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:mux_calculator", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:barrier_input_stream_handler", + "//mediapipe/framework/stream_handler:early_close_input_stream_handler", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/stream_handler:mux_input_stream_handler", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "collection_test", + size = "small", + srcs = ["collection_test.cc"], + linkstatic = 1, + deps = [ + ":collection", + ":packet_set", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:tag_map_helper", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "graph_service_test", + size = "small", + srcs = [ + "graph_service_test.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":calculator_contract", + ":calculator_framework", + ":graph_service", + ":test_service", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + ], +) + +cc_test( + name = "input_stream_manager_test", + size = "small", + srcs = ["input_stream_manager_test.cc"], + linkstatic = 1, + deps = [ + ":input_stream_manager", + ":input_stream_shard", + ":lifetime_tracker", + ":packet", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "output_stream_manager_test", + size = "small", + srcs = ["output_stream_manager_test.cc"], + linkstatic = 1, + deps = [ + ":input_stream_handler", + ":input_stream_manager", + ":output_stream_manager", + ":output_stream_shard", + ":packet", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/stream_handler:default_input_stream_handler", + "//mediapipe/framework/tool:tag_map_helper", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "packet_delete_test", + size = "small", + srcs = ["packet_delete_test.cc"], + copts = [ + "-Werror", + ], + linkstatic = 1, + deps = [ + ":packet", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "executor_external_build_test", + size = "small", + srcs = ["executor_external_build_test.cc"], + linkstatic = 1, + deps = [ + ":executor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:threadpool", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "packet_test", + size = "medium", + srcs = ["packet_test.cc"], + linkstatic = 1, + deps = [ + ":packet", + ":packet_test_cc_proto", + ":type_map", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "packet_generator_test", + size = "small", + srcs = ["packet_generator_test.cc"], + deps = [ + ":packet_generator", + ":packet_type", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "timestamp_test", + size = "small", + srcs = ["timestamp_test.cc"], + linkstatic = 1, + deps = [ + ":timestamp", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "graph_validation_test", + srcs = ["graph_validation_test.cc"], + deps = [ + ":calculator_contract_test_cc_proto", + ":calculator_framework", + ":graph_validation", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:template_parser", + ], +) diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto new file mode 100644 index 000000000..01fe859f9 --- /dev/null +++ b/mediapipe/framework/calculator.proto @@ -0,0 +1,421 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/calculator.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. +syntax = "proto3"; + +package mediapipe; + +import public "mediapipe/framework/calculator_options.proto"; + +import "google/protobuf/any.proto"; +import "mediapipe/framework/mediapipe_options.proto"; +import "mediapipe/framework/packet_factory.proto"; +import "mediapipe/framework/packet_generator.proto"; +import "mediapipe/framework/status_handler.proto"; +import "mediapipe/framework/stream_handler.proto"; + +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "CalculatorProto"; + +// Describes a MediaPipe Executor. +message ExecutorConfig { + // The name of the executor (used by a CalculatorGraphConfig::Node or + // PacketGeneratorConfig to specify which executor it will execute on). + // This field must be unique within a CalculatorGraphConfig. If this field + // is omitted or is an empty string, the ExecutorConfig describes the + // default executor. + // + // NOTE: The names "default" and "gpu" are reserved and must not be used. + string name = 1; + // The registered type of the executor. For example: "ThreadPoolExecutor". + // The framework will create an executor of this type (with the options in + // the options field) for the CalculatorGraph. + // + // The ExecutorConfig for the default executor may omit this field and let + // the framework choose an appropriate executor type. Note: If the options + // field is used in this case, it should contain the + // ThreadPoolExecutorOptions. + // + // If the ExecutorConfig for an additional (non-default) executor omits this + // field, the executor must be created outside the CalculatorGraph and + // passed to the CalculatorGraph for use. + string type = 2; + // The options passed to the Executor. The extension in the options field + // must match the type field. For example, if the type field is + // "ThreadPoolExecutor", then the options field should contain the + // ThreadPoolExecutorOptions. + MediaPipeOptions options = 3; +} + +// A collection of input data to a CalculatorGraph. +message InputCollection { + // The name of the input collection. Name must match [a-z_][a-z0-9_]* + string name = 1; + // The names of each side packet. The number of side_packet_name + // must match the number of packets generated by the input file. + repeated string side_packet_name = 2; + // DEPRECATED: old way of referring to side_packet_name. + repeated string external_input_name = 1002; + + // The input can be specified in several ways. + enum InputType { + // An invalid default value. This value is guaranteed to be the + // lowest enum value (i.e. don't add negative enum values). + UNKNOWN = 0; + // A recordio where each record is a serialized PacketManagerConfig. + // Each PacketManagerConfig must have the same number of packet + // factories in it as the number of side packet names. Furthermore, + // the output side packet name field in each PacketFactoryConfig + // must not be set. This is the most general input, and allows + // multiple side packet values to be set in arbitrarily complicated + // ways before each run. + RECORDIO = 1; + // A recordio where each record is a serialized packet payload. + // For example a recordio of serialized OmniaFeature protos dumped + // from Omnia. + FOREIGN_RECORDIO = 2; + // A text file where each line is a comma separated list. The number + // of elements for each csv string must be the same as the number + // of side_packet_name (and the order must match). Each line must + // be less than 1MiB in size. Lines comprising of only whitespace + // or only whitespace and a pound comment will be skipped. + FOREIGN_CSV_TEXT = 3; + // This and all higher values are invalid. Update this value to + // always be larger than any other enum values you add. + INVALID_UPPER_BOUND = 4; + } + // Sets the source of the input collection data. + // The default value is UNKNOWN. + InputType input_type = 3; + // A file name pointing to the data. The format of the data is + // specified by the "input_type" field. Multiple shards may be + // specified using @N or glob expressions. + string file_name = 4; +} + +// A convenient way to specify a number of InputCollections. +message InputCollectionSet { + repeated InputCollection input_collection = 1; +} + +// Additional information about an input stream. +message InputStreamInfo { + // A description of the input stream. + // This description uses the Calculator visible specification of + // a stream. The format is a tag, then an index with both being + // optional. If the tag is missing it is assumed to be "" and if + // the index is missing then it is assumed to be 0. If the index + // is provided then a colon (':') must be used. + // Examples: + // "TAG" -> tag "TAG", index 0 + // "" -> tag "", index 0 + // ":0" -> tag "", index 0 + // ":3" -> tag "", index 3 + // "VIDEO:0" -> tag "VIDEO", index 0 + // "VIDEO:2" -> tag "VIDEO", index 2 + string tag_index = 1; + // Whether the input stream is a back edge. + // By default, MediaPipe requires graphs to be acyclic and treats cycles in a + // graph as errors. To allow MediaPipe to accept a cyclic graph, set the + // back_edge fields of the input streams that are back edges to true. A + // cyclic graph usually has an obvious forward direction, and a back edge + // goes in the opposite direction. For a formal definition of a back edge, + // please see https://en.wikipedia.org/wiki/Depth-first_search. + bool back_edge = 2; +} + +// Configs for the profiler for a calculator. Not applicable to subgraphs. +message ProfilerConfig { + // Size of the runtimes histogram intervals (in microseconds) to generate the + // histogram of the Process() time. The last interval extends to +inf. + // If not specified, the interval is 1000000 usec = 1 sec. + int64 histogram_interval_size_usec = 1; + + // Number of intervals to generate the histogram of the Process() runtime. + // If not specified, one interval is used. + int64 num_histogram_intervals = 2; + + // TODO: clean up after migration to MediaPipeProfiler. + // DEPRECATED: If true, the profiler also profiles the input output latency. + // Should be true only if the packet timestamps corresponds to the + // microseconds wall time from epoch. + bool enable_input_output_latency = 3 [deprecated = true]; + + // If true, the profiler starts profiling when graph is initialized. + bool enable_profiler = 4; + + // If true, the profiler also profiles the stream latency and input-output + // latency. + // No-op if enable_profiler is false. + bool enable_stream_latency = 5; + + // If true, the profiler uses packet timestamp (as production time and source + // production time) for packets added by calling + // CalculatorGraph::AddPacketToInputStream(). + // If false, uses profiler's clock. + bool use_packet_timestamp_for_added_packet = 6; + + // The maximum number of trace events buffered in memory. + int64 trace_log_capacity = 7; + + // Trace event types that are not logged. + repeated int32 trace_event_types_disabled = 8; + + // The output directory and base-name prefix for trace log files. + // Log files are written to: StrCat(trace_log_path, index, ".binarypb") + string trace_log_path = 9; + + // The number of trace log files retained. + // The trace log files are named "trace_0.log" through "trace_k.log". + // The default value specifies 2 output files retained. + int32 trace_log_count = 10; + + // The interval in microseconds between trace log output. + // The value -1 specifies output only when the graph is closed. + // The default value specifies trace log output once every 1 sec. + int64 trace_log_interval_usec = 11; + + // The interval in microseconds between TimeNow and the highest times + // included in trace log output. This margin allows time for events + // to be appended to the TraceBuffer. + int64 trace_log_margin_usec = 12; + + // True specifies an event for each calculator invocation. + // False specifies a separate event for each start and finish time. + bool trace_log_duration_events = 13; + + // The number of trace log intervals per file. The total log duration is: + // trace_log_interval_usec * trace_log_file_count * trace_log_interval_count. + // The default value specifies 10 intervals per file. + int32 trace_log_interval_count = 14; + + // An option to turn ON/OFF writing trace files to disk. Saving trace files to + // disk is enabled by default. + bool trace_log_disabled = 15; +} + +// Describes the topology and function of a MediaPipe Graph. The graph of +// Nodes must be a Directed Acyclic Graph (DAG) except as annotated by +// "back_edge" in InputStreamInfo. Use a mediapipe::CalculatorGraph object to +// run the graph. +message CalculatorGraphConfig { + // A single node in the DAG. + message Node { + // The name of the node. This field is optional and doesn't generally + // need to be specified, but does improve error messaging. + string name = 1; + // The registered type of a calculator (provided via REGISTER_CALCULATOR), + // or of a subgraph (via REGISTER_MEDIAPIPE_GRAPH). + string calculator = 2; + // A Calculator can choose to access its input streams, output + // streams, and input side packets either by tag or by index. If the + // calculator chooses indexes then it will receive the streams or side + // packets in the same order as they are specified in this proto. + // If the calculator chooses to use tags then it must specify a + // tag along with each name. The field is given as "TAG:name". + // Meaning a tag name followed by a colon followed by the name. + // Tags use only upper case letters, numbers, and underscores, whereas + // names use only lower case letters, numbers, and underscores. + // Example: + // Node { + // calculator: "SomeAudioVideoCalculator" + // # This calculator accesses its inputs by index (no tag needed). + // input_stream: "combined_input" + // # This calculator accesses its outputs by tags, so all + // # output_streams must specify a tag. + // output_stream: "AUDIO:audio_stream" + // output_stream: "VIDEO:video_stream" + // # This calculator accesses its input side packets by tag. + // input_side_packet: "MODEL:model_01" + // } + + // String(s) representing "TAG:name" of the stream(s) from which the current + // node will get its inputs. "TAG:" part is optional, see above. + // A calculator with no input stream is a source. + repeated string input_stream = 3; + // String(s) representing "TAG:name" of the stream(s) produced by this node. + // "TAG:" part is optional, see above. These must be different from any + // other output_streams specified for other nodes in the graph. + repeated string output_stream = 4; + // String(s) representing "TAG:name" of the input side packet(s). + // "TAG:" part is optional, see above. + repeated string input_side_packet = 5; + // String(s) representing "TAG:name" of the output side packet(s). Only + // used by subgraphs. + // "TAG:" part is optional, see above. + repeated string output_side_packet = 6; + // The options passed to the Calculator, in proto2 syntax. + CalculatorOptions options = 7; + // The options passed to the Calculator, in proto3 syntax. + // Each node_options message must have a different message type. + // If the same message type is specified in |options| and |node_options|, + // only the message in |options| is used. + repeated google.protobuf.Any node_options = 8; + + // Note: the following fields are only applicable to calculators, not + // subgraphs. + + // For a Source Calculator (i.e. a calculator with no inputs), + // this is the "layer" on which the calculator is executed. For a + // non-source calculator (i.e. a calculator with one or more input + // streams) this field has no effect. The sources on each layer + // are completely exhausted before Process() is called on any source + // calculator on a higher numbered layer. + // Example: + // Decoder -> Median Frame (requires all frames) -> Image Subtraction + // ---------------------------------------> + // The entire video will be buffered on the edge from the decoder + // to the Image subtraction. To fix this problem, layers can be used. + // Decoder (layer 0) -> Median Frame -> Image Subtraction + // Decoder (layer 1) -----------------> + // The frames from layer 0 will no longer be buffered, but the video + // will be decoded again instead. Note, that different options can + // be used in the second decoder. + int32 source_layer = 9; + // Optional parameter that allows the user to indicate to the scheduler that + // this node has a buffering behavior (i.e. waits for a bunch of packets + // before emitting any) and specify the size of the buffer that is built up. + // The scheduler will then try to keep the maximum size of any input queues + // in the graph to remain below the maximum of all buffer_size_hints and + // max_queue_size (if specified). The ideal value is typically something + // larger than the actual number of buffered packets to maintain pipelining. + // The default value 0 indicates that the node has no buffering behavior. + int32 buffer_size_hint = 10; + // Config for this node's InputStreamHandler. + // If unspecified, the graph-level input stream handler will be used. + InputStreamHandlerConfig input_stream_handler = 11; + // Config for this node's OutputStreamHandler. + // If unspecified, the graph-level output stream handler will be used. + OutputStreamHandlerConfig output_stream_handler = 12; + // Additional information about an input stream. The |name| field of the + // InputStreamInfo must match an input_stream. + repeated InputStreamInfo input_stream_info = 13; + // Set the executor which the calculator will execute on. + string executor = 14; + // TODO: Remove from Node when switched to Profiler. + // DEPRECATED: Configs for the profiler. + ProfilerConfig profiler_config = 15 [deprecated = true]; + // The maximum number of invocations that can be executed in parallel. + // If not specified, the limit is one invocation. + int32 max_in_flight = 16; + // DEPRECATED: For backwards compatibility we allow users to + // specify the old name for "input_side_packet" in proto configs. + // These are automatically converted to input_side_packets during + // config canonicalization. + repeated string external_input = 1005; + } + + // The nodes. + repeated Node node = 1; + // Create a side packet using a PacketFactory. This side packet is + // created as close to the worker that does the work as possible. A + // PacketFactory is basically a PacketGenerator that takes no input side + // packets and produces a single output side packet. + repeated PacketFactoryConfig packet_factory = 6; + // Configs for PacketGenerators. Generators take zero or more + // input side packets and produce any number of output side + // packets. For example, MediaDecoderCalculator takes an input + // side packet with type DeletingFile. However, most users want + // to specify videos by ContentIdHex (i.e. video id). By using + // the VideoIdToLocalFileGenerator, a user can specify a video id + // (as a string) and obtain a DeletingFile to use with the decoder. + // PacketGenerators can take as a input side packet the output side + // packet of another PacketGenerator. The graph of PacketGenerators + // must be a directed acyclic graph. + repeated PacketGeneratorConfig packet_generator = 7; + // Number of threads for running calculators in multithreaded mode. + // If not specified, the scheduler will pick an appropriate number + // of threads depending on the number of available processors. + // To run on the calling thread, specify "ApplicationThreadExecutor" + // see: http://g3doc/mediapipe/g3doc/running.md. + int32 num_threads = 8; + // Configs for StatusHandlers that will be called after each call to + // Run() on the graph. StatusHandlers take zero or more input side + // packets and the ::util::Status returned by a graph run. For example, + // a StatusHandler could store information about graph failures and + // their causes for later monitoring. Note that graph failures during + // initialization may cause required input side packets (created by a + // PacketFactory or PacketGenerator) to be missing. In these cases, + // the handler with missing input side packets will be skipped. + repeated StatusHandlerConfig status_handler = 9; + // Specify input streams to the entire graph. Streams specified here may have + // packets added to them using CalculatorGraph::AddPacketToInputStream. This + // works much like a source calculator, except that the source is outside of + // the mediapipe graph. + repeated string input_stream = 10; + // Output streams for the graph when used as a subgraph. + repeated string output_stream = 15; + // Input side packets for the graph when used as a subgraph. + repeated string input_side_packet = 16; + // Output side packets for the graph when used as a subgraph. + repeated string output_side_packet = 17; + // Maximum queue size of any input stream in the graph. This can be used to + // control the memory usage of a MediaPipe graph by preventing fast sources + // from flooding the graph with packets. Any source that is connected to an + // input stream that has hit its maximum capacity will not be scheduled until + // the queue size falls under the specified limits, or if the scheduler queue + // is empty and no other nodes are running (to prevent possible deadlocks due + // to a incorrectly specified value). This global parameter is set to 100 + // packets by default to enable pipelining. If any node indicates that it + // buffers packets before emitting them, then the max(node_buffer_size, + // max_queue_size) is used. Set this parameter to -1 to disable throttling + // (i.e. the graph will use as much memory as it requires). If not specified, + // the limit is 100 packets. + int32 max_queue_size = 11; + // If true, the graph run fails with an error when throttling prevents all + // calculators from running. If false, max_queue_size for an input stream + // is adjusted when throttling prevents all calculators from running. + bool report_deadlock = 21; + // Config for this graph's InputStreamHandler. + // If unspecified, the framework will automatically install the default + // handler, which works as follows. + // The calculator's Process() method is called for timestamp t when: + // - at least one stream has a packet available at t; and, + // - all other streams either have packets at t, or it is known that they will + // not have packets at t (i.e. their next timestamp bound is greater than t). + // The handler then provides all available packets with timestamp t, with no + // preprocessing. + InputStreamHandlerConfig input_stream_handler = 12; + // Config for this graph's OutputStreamHandler. + // If unspecified, the default output stream handler will be automatically + // installed by the framework which does not modify any outgoing packets. + OutputStreamHandlerConfig output_stream_handler = 13; + // Configs for Executors. + // The names of the executors must be distinct. The default executor, whose + // name is the empty string, is predefined. The num_threads field of the + // CalculatorGraphConfig specifies the number of threads in the default + // executor. If the config for the default executor is specified, the + // CalculatorGraphConfig must not have the num_threads field. + repeated ExecutorConfig executor = 14; + // The default profiler-config for all calculators. If set, this defines the + // profiling settings such as num_histogram_intervals for every calculator in + // the graph. Each of these settings can be overridden by the + // |profiler_config| specified for a node. + ProfilerConfig profiler_config = 18; + + // The namespace used for class name lookup within this graph. + // An unqualified or partially qualified class name is looked up in + // this namespace first and then in enclosing namespaces. + string package = 19; + + // The type name for the graph config, used for registering and referencing + // the graph config. + string type = 20; + + // Can be used for annotating a graph. + MediaPipeOptions options = 1001; +} diff --git a/mediapipe/framework/calculator_base.cc b/mediapipe/framework/calculator_base.cc new file mode 100644 index 000000000..ecc0685ac --- /dev/null +++ b/mediapipe/framework/calculator_base.cc @@ -0,0 +1,36 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions for CalculatorBase. + +#include "mediapipe/framework/calculator_base.h" + +#include + +namespace mediapipe { + +CalculatorBase::CalculatorBase() {} + +CalculatorBase::~CalculatorBase() {} + +Timestamp CalculatorBase::SourceProcessOrder( + const CalculatorContext* cc) const { + Timestamp result = Timestamp::Max(); + for (const OutputStreamShard& output : cc->Outputs()) { + result = std::min(result, output.NextTimestampBound()); + } + return result; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h new file mode 100644 index 000000000..e0ca42170 --- /dev/null +++ b/mediapipe/framework/calculator_base.h @@ -0,0 +1,224 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines CalculatorBase, the base class for feature computation. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_ + +#include + +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_contract.h" +#include "mediapipe/framework/deps/registration.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +// Experimental: CalculatorBase will eventually replace Calculator as the +// base class of leaf (non-subgraph) nodes in a CalculatorGraph. +// +// The base calculator class. A subclass must, at a minimum, provide the +// implementation of GetContract(), Process(), and register the calculator +// using REGISTER_CALCULATOR(MyClass). +// +// The framework calls four primary functions on a calculator. +// On initialization of the graph, a static function is called. +// GetContract() +// Then, for each run of the graph on a set of input side packets, the +// following sequence will occur. +// Open() +// Process() (repeatedly) +// Close() +// +// The entire calculator is constructed and destroyed for each graph run +// (set of input side packets, which could mean once per video, or once +// per image). Any expensive operations and large objects should be +// input side packets. +// +// The framework calls Open() to initialize the calculator. +// If appropriate, Open() should call cc->SetOffset() or +// cc->Outputs().Get(id)->SetNextTimestampBound() to allow the framework to +// better optimize packet queueing. +// +// The framework calls Process() for every packet received on the input +// streams. The framework guarantees that cc->InputTimestamp() will +// increase with every call to Process(). An empty packet will be on the +// input stream if there is no packet on a particular input stream (but +// some other input stream has a packet). +// +// The framework calls Close() after all calls to Process(). +// +// Calculators with no inputs are referred to as "sources" and are handled +// slightly differently than non-sources (see the function comments for +// Process() for more details). +// +// Calculators must be thread-compatible. +// The framework does not call the non-const methods of a calculator from +// multiple threads at the same time. However, the thread that calls the +// methods of a calculator is not fixed. Therefore, calculators should not +// use ThreadLocal objects. +class CalculatorBase { + public: + CalculatorBase(); + virtual ~CalculatorBase(); + + // The subclasses of CalculatorBase must implement GetContract. + // The calculator cannot be registered without it. Notice that although + // this function is static the registration macro provides access to + // each subclass' GetContract function. + // + // static ::mediapipe::Status GetContract(CalculatorContract* cc); + // + // GetContract fills in the calculator's contract with the framework, such + // as its expectations of what packets it will receive. When this function + // is called, the numbers of inputs, outputs, and input side packets will + // have already been determined by the calculator graph. You can use + // indexes, tags, or tag:index to access input streams, output streams, + // or input side packets. + // + // Example (uses tags for inputs and indexes for outputs and input side + // packets): + // cc->Inputs().Tag("VIDEO").Set("Input Image Frames."); + // cc->Inputs().Tag("AUDIO").Set("Input Audio Frames."); + // cc->Outputs().Index(0).Set("Output FooBar feature."); + // cc->InputSidePackets().Index(0).Set( + // "Model used for FooBar feature extraction."); + // + // Example (same number and type of outputs as inputs): + // for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + // // SetAny() is used to specify that whatever the type of the + // // stream is, it's acceptable. This does not mean that any + // // packet is acceptable. Packets in the stream still have a + // // particular type. SetAny() has the same effect as explicitly + // // setting the type to be the stream's type. + // cc->Inputs().Index(i).SetAny(StrCat("Generic Input Stream ", i)); + // // Set each output to accept the same specific type as the + // // corresponding input. + // cc->Outputs().Index(i).SetSameAs( + // &cc->Inputs().Index(i), StrCat("Generic Output Stream ", i)); + // } + + // Open is called before any Process() calls, on a freshly constructed + // calculator. Subclasses may override this method to perform necessary + // setup, and possibly output Packets and/or set output streams' headers. + // Must return ::mediapipe::OkStatus() to indicate success. On failure any + // other status code can be returned. If failure is returned then the + // framework will call neither Process() nor Close() on the calculator (so any + // necessary cleanup should be done before returning failure or in the + // destructor). + virtual ::mediapipe::Status Open(CalculatorContext* cc) { + return ::mediapipe::OkStatus(); + } + + // Processes the incoming inputs. May call the methods on cc to access + // inputs and produce outputs. + // + // Process() called on a non-source node must return + // ::mediapipe::OkStatus() to indicate that all went well, or any other + // status code to signal an error. + // For example: + // ::mediapipe::UnknownError("Failure Message"); + // Notice the convenience functions in util/task/canonical_errors.h . + // If a non-source Calculator returns tool::StatusStop(), then this + // signals the graph is being cancelled early. In this case, all + // source Calculators and graph input streams will be closed (and + // remaining Packets will propagate through the graph). + // + // A source node will continue to have Process() called on it as long + // as it returns ::mediapipe::OkStatus(). To indicate that there is + // no more data to be generated return tool::StatusStop(). Any other + // status indicates an error has occurred. + virtual ::mediapipe::Status Process(CalculatorContext* cc) = 0; + + // Is called if Open() was called and succeeded. Is called either + // immediately after processing is complete or after a graph run has ended + // (if an error occurred in the graph). Must return ::mediapipe::OkStatus() + // to indicate success. On failure any other status code can be returned. + // Packets may be output during a call to Close(). However, output packets + // are silently discarded if Close() is called after a graph run has ended. + // + // NOTE: If Close() needs to perform an action only when processing is + // complete, Close() must check if cc->GraphStatus() is OK. + virtual ::mediapipe::Status Close(CalculatorContext* cc) { + return ::mediapipe::OkStatus(); + } + + // Returns a value according to which the framework selects + // the next source calculator to Process(); smaller value means + // Process() first. The default implementation returns the smallest + // NextTimestampBound value over all the output streams, but subclasses + // may override this. If a calculator is not a source, this method is + // not called. + // TODO: Does this method need to be virtual? No Calculator + // subclasses override the SourceProcessOrder method. + virtual Timestamp SourceProcessOrder(const CalculatorContext* cc) const; +}; + +using CalculatorBaseRegistry = + GlobalFactoryRegistry>; + +namespace internal { + +// Gives access to the static functions within subclasses of CalculatorBase. +// This adds functionality akin to virtual static functions. +class StaticAccessToCalculatorBase { + public: + virtual ~StaticAccessToCalculatorBase() {} + virtual ::mediapipe::Status GetContract(CalculatorContract* cc) = 0; +}; + +using StaticAccessToCalculatorBaseRegistry = + GlobalFactoryRegistry>; + +// Functions for checking that the calculator has the required GetContract. +template +constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { + typedef ::mediapipe::Status (*GetContractType)(CalculatorContract * cc); + return std::is_same::value; +} +template +constexpr bool CalculatorHasGetContract(...) { + return false; +} + +// Provides access to the static functions within a specific subclass +// of CalculatorBase. +template +class StaticAccessToCalculatorBaseTyped : public StaticAccessToCalculatorBase { + public: + static_assert(std::is_base_of<::mediapipe::CalculatorBase, + CalculatorBaseSubclass>::value, + "Classes registered with REGISTER_CALCULATOR must be " + "subclasses of ::mediapipe::CalculatorBase."); + static_assert(CalculatorHasGetContract(nullptr), + "GetContract() must be defined with the correct signature in " + "every calculator."); + + // Provides access to the static function GetContract within a specific + // subclass of CalculatorBase. + ::mediapipe::Status GetContract(CalculatorContract* cc) final { + // CalculatorBaseSubclass must implement this function, since it is not + // implemented in the parent class. + return CalculatorBaseSubclass::GetContract(cc); + } +}; + +} // namespace internal + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_ diff --git a/mediapipe/framework/calculator_base_test.cc b/mediapipe/framework/calculator_base_test.cc new file mode 100644 index 000000000..474c010f4 --- /dev/null +++ b/mediapipe/framework/calculator_base_test.cc @@ -0,0 +1,228 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_base.h" + +// TODO: Move protos in another CL after the C++ code migration. +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_registry.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/output_stream.h" +#include "mediapipe/framework/output_stream_manager.h" +#include "mediapipe/framework/output_stream_shard.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/tool/tag_map_helper.h" + +namespace mediapipe { + +namespace test_ns { + +// A calculator which does nothing but accepts any number of input/output +// streams and input side packets. +class DeadEndCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); + } + for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { + cc->Outputs().Index(i).SetAny(); + } + for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) { + cc->InputSidePackets().Index(i).SetAny(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (cc->Inputs().NumEntries() > 0) { + return ::mediapipe::OkStatus(); + } else { + // This is a source calculator, but we don't produce any outputs. + return tool::StatusStop(); + } + } +}; +REGISTER_CALCULATOR(::mediapipe::test_ns::DeadEndCalculator); + +namespace whitelisted_ns { + +class DeadCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; + +} // namespace whitelisted_ns +} // namespace test_ns + +class EndCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(::mediapipe::EndCalculator); + +namespace { + +TEST(CalculatorTest, SourceProcessOrder) { + internal::Collection output_stream_managers( + tool::CreateTagMap(2).ValueOrDie()); + + PacketType output0_type; + PacketType output1_type; + output0_type.SetAny(); + output1_type.SetAny(); + + MEDIAPIPE_ASSERT_OK( + output_stream_managers.Index(0).Initialize("output0", &output0_type)); + MEDIAPIPE_ASSERT_OK( + output_stream_managers.Index(1).Initialize("output1", &output1_type)); + + PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + + CalculatorState calculator_state("Node", /*node_id=*/0, "Calculator", + CalculatorGraphConfig::Node(), nullptr); + + calculator_state.SetInputSidePackets(&input_side_packets); + + CalculatorContextManager calculator_context_manager; + CalculatorContext calculator_context(&calculator_state, + tool::CreateTagMap({}).ValueOrDie(), + output_stream_managers.TagMap()); + InputStreamShardSet& input_set = calculator_context.Inputs(); + OutputStreamShardSet& output_set = calculator_context.Outputs(); + output_set.Index(0).SetSpec(output_stream_managers.Index(0).Spec()); + output_set.Index(0).SetNextTimestampBound(Timestamp(10)); + output_set.Index(1).SetSpec(output_stream_managers.Index(1).Spec()); + output_set.Index(1).SetNextTimestampBound(Timestamp(11)); + CalculatorContextManager().PushInputTimestampToContext( + &calculator_context, Timestamp::Unstarted()); + + InputStreamSet input_streams(input_set.TagMap()); + OutputStreamSet output_streams(output_set.TagMap()); + for (CollectionItemId id = input_streams.BeginId(); + id < input_streams.EndId(); ++id) { + input_streams.Get(id) = &input_set.Get(id); + } + for (CollectionItemId id = output_streams.BeginId(); + id < output_streams.EndId(); ++id) { + output_streams.Get(id) = &output_set.Get(id); + } + calculator_state.SetInputStreamSet(&input_streams); + calculator_state.SetOutputStreamSet(&output_streams); + + test_ns::DeadEndCalculator calculator; + EXPECT_EQ(Timestamp(10), calculator.SourceProcessOrder(&calculator_context)); + output_set.Index(0).SetNextTimestampBound(Timestamp(100)); + EXPECT_EQ(Timestamp(11), calculator.SourceProcessOrder(&calculator_context)); +} + +// Tests registration of a calculator within a namespace. +// DeadEndCalculator is registered in namespace "mediapipe::test_ns". +TEST(CalculatorTest, CreateByName) { + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByName( // + "mediapipe.test_ns.DeadEndCalculator")); + + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByName( // + ".mediapipe.test_ns.DeadEndCalculator")); + + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "alpha", ".mediapipe.test_ns.DeadEndCalculator")); + + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "alpha", "mediapipe.test_ns.DeadEndCalculator")); + + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe", "mediapipe.test_ns.DeadEndCalculator")); + + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe.test_ns.sub_ns", "DeadEndCalculator")); + + EXPECT_EQ(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe", "DeadEndCalculator") + .status() + .code(), + ::mediapipe::StatusCode::kNotFound); + + EXPECT_EQ(CalculatorBaseRegistry::CreateByName( // + "DeadEndCalculator") + .status() + .code(), + ::mediapipe::StatusCode::kNotFound); +} + +// Tests registration of a calculator within a whitelisted namespace. +TEST(CalculatorTest, CreateByNameWhitelisted) { + // Reset the registration namespace whitelist. + *const_cast*>( + &NamespaceWhitelist::TopNamespaces()) = std::unordered_set{ + "mediapipe::test_ns::whitelisted_ns", + "mediapipe", + }; + + // Register a whitelisted calculator. + CalculatorBaseRegistry::Register( + "::mediapipe::test_ns::whitelisted_ns::DeadCalculator", + absl::make_unique< ::mediapipe::test_ns::whitelisted_ns::DeadCalculator>); + + // A whitelisted calculator can be found in its own namespace. + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "", "mediapipe.test_ns.whitelisted_ns.DeadCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe.sub_ns", "test_ns.whitelisted_ns.DeadCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe.sub_ns", "mediapipe.EndCalculator")); + + // A whitelisted calculator can be found in the top-level namespace. + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "", "DeadCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe", "DeadCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe.test_ns.sub_ns", "DeadCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "", "EndCalculator")); + MEDIAPIPE_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // + "mediapipe.test_ns.sub_ns", "EndCalculator")); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_context.cc b/mediapipe/framework/calculator_context.cc new file mode 100644 index 000000000..deaad0357 --- /dev/null +++ b/mediapipe/framework/calculator_context.cc @@ -0,0 +1,76 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_context.h" + +namespace mediapipe { + +const std::string& CalculatorContext::CalculatorType() const { + CHECK(calculator_state_); + return calculator_state_->CalculatorType(); +} + +const CalculatorOptions& CalculatorContext::Options() const { + CHECK(calculator_state_); + return calculator_state_->Options(); +} + +const std::string& CalculatorContext::NodeName() const { + CHECK(calculator_state_); + return calculator_state_->NodeName(); +} + +int CalculatorContext::NodeId() const { + CHECK(calculator_state_); + return calculator_state_->NodeId(); +} + +Counter* CalculatorContext::GetCounter(const std::string& name) { + CHECK(calculator_state_); + return calculator_state_->GetCounter(name); +} + +const PacketSet& CalculatorContext::InputSidePackets() const { + return calculator_state_->InputSidePackets(); +} + +OutputSidePacketSet& CalculatorContext::OutputSidePackets() { + return calculator_state_->OutputSidePackets(); +} + +InputStreamShardSet& CalculatorContext::Inputs() { return inputs_; } + +const InputStreamShardSet& CalculatorContext::Inputs() const { return inputs_; } + +OutputStreamShardSet& CalculatorContext::Outputs() { return outputs_; } + +const OutputStreamShardSet& CalculatorContext::Outputs() const { + return outputs_; +} + +void CalculatorContext::SetOffset(TimestampDiff offset) { + for (auto& stream : outputs_) { + stream.SetOffset(offset); + } +} + +const InputStreamSet& CalculatorContext::InputStreams() const { + return calculator_state_->InputStreams(); +} + +const OutputStreamSet& CalculatorContext::OutputStreams() const { + return calculator_state_->OutputStreams(); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h new file mode 100644 index 000000000..9f12c0133 --- /dev/null +++ b/mediapipe/framework/calculator_context.h @@ -0,0 +1,178 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_H_ + +#include +#include +#include +#include + +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/counter.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/input_stream_shard.h" +#include "mediapipe/framework/output_stream_shard.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +// A CalculatorContext provides information about the graph it is running +// inside of through a number of accessor functions: Inputs(), Outputs(), +// InputSidePackets(), Options(), etc. +// +// CalculatorBase APIs, such as CalculatorBase::Open(CalculatorContext* cc), +// CalculatorBase::Process(CalculatorContext* cc), and +// CalculatorBase::Close(CalculatorContext* cc), will only interact with +// its own CalculatorContext object for exchanging data with the framework. +class CalculatorContext { + public: + CalculatorContext(CalculatorState* calculator_state, + std::shared_ptr input_tag_map, + std::shared_ptr output_tag_map) + : calculator_state_(calculator_state), + inputs_(std::move(input_tag_map)), + outputs_(std::move(output_tag_map)) {} + + CalculatorContext(const CalculatorContext&) = delete; + CalculatorContext& operator=(const CalculatorContext&) = delete; + + const std::string& NodeName() const; + int NodeId() const; + const std::string& CalculatorType() const; + // Returns the options given to this calculator. The Calculator or + // CalculatorBase implementation may get its options by calling + // GetExtension() on the result. + const CalculatorOptions& Options() const; + + // Returns the options given to this calculator. Template argument T must + // be the type of the protobuf extension message or the protobuf::Any + // message containing the options. + template + const T& Options() const { + return calculator_state_->Options(); + } + + // Returns a counter using the graph's counter factory. The counter's name is + // the passed-in name, prefixed by the calculator node's name (if present) or + // the calculator's type (if not). + Counter* GetCounter(const std::string& name); + + // Returns the current input timestamp, or Timestamp::Unset if there are + // no input packets. + Timestamp InputTimestamp() const { + return input_timestamps_.empty() ? Timestamp::Unset() + : input_timestamps_.front(); + } + + // Returns a reference to the input side packet set. + const PacketSet& InputSidePackets() const; + // Returns a reference to the output side packet collection. + OutputSidePacketSet& OutputSidePackets(); + // Returns a reference to the input stream collection. + // You may consume or move the value packets from the Inputs. + InputStreamShardSet& Inputs(); + // Returns a const reference to the input stream collection. + const InputStreamShardSet& Inputs() const; + // Returns a reference to the output stream collection. + OutputStreamShardSet& Outputs(); + // Returns a const reference to the output stream collection. + const OutputStreamShardSet& Outputs() const; + + // Sets this packet timestamp offset for Packets going to all outputs. + // If you only want to set the offset for a single output stream then + // use OutputStream::SetOffset() directly. + void SetOffset(TimestampDiff offset); + + // Returns the status of the graph run. + // + // NOTE: This method should only be called during CalculatorBase::Close(). + ::mediapipe::Status GraphStatus() const { return graph_status_; } + + ProfilingContext* GetProfilingContext() const { + return calculator_state_->GetSharedProfilingContext().get(); + } + + template + class ServiceBinding { + public: + bool IsAvailable() { + return calculator_state_->IsServiceAvailable(service_); + } + T& GetObject() { return calculator_state_->GetServiceObject(service_); } + + ServiceBinding(CalculatorState* calculator_state, + const GraphService& service) + : calculator_state_(calculator_state), service_(service) {} + + private: + CalculatorState* calculator_state_; + const GraphService& service_; + }; + + template + ServiceBinding Service(const GraphService& service) { + return ServiceBinding(calculator_state_, service); + } + + private: + int NumberOfTimestamps() const { + return static_cast(input_timestamps_.size()); + } + + bool HasInputTimestamp() const { return !input_timestamps_.empty(); } + + // Adds a new input timestamp by the friend class CalculatorContextManager. + void PushInputTimestamp(Timestamp input_timestamp) { + input_timestamps_.push(input_timestamp); + } + + void PopInputTimestamp() { + CHECK(!input_timestamps_.empty()); + input_timestamps_.pop(); + } + + void SetGraphStatus(const ::mediapipe::Status& status) { + graph_status_ = status; + } + + // Interface for the friend class Calculator. + const InputStreamSet& InputStreams() const; + const OutputStreamSet& OutputStreams() const; + + // Stores the shared data across all CalculatorContext objects, including + // input side packets, calculator options, node name, etc. + // TODO: Removes unnecessary fields from CalculatorState after + // migrating all clients to CalculatorContext. + CalculatorState* calculator_state_; + InputStreamShardSet inputs_; + OutputStreamShardSet outputs_; + // The queue of timestamp values to Process() in this calculator context. + std::queue input_timestamps_; + + // The status of the graph run. Only used when Close() is called. + ::mediapipe::Status graph_status_; + + // Accesses CalculatorContext for setting input timestamp. + friend class CalculatorContextManager; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_H_ diff --git a/mediapipe/framework/calculator_context_manager.cc b/mediapipe/framework/calculator_context_manager.cc new file mode 100644 index 000000000..93ee9855e --- /dev/null +++ b/mediapipe/framework/calculator_context_manager.cc @@ -0,0 +1,111 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_context_manager.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +void CalculatorContextManager::Initialize( + CalculatorState* calculator_state, + std::shared_ptr input_tag_map, + std::shared_ptr output_tag_map, + bool calculator_run_in_parallel) { + CHECK(calculator_state); + calculator_state_ = calculator_state; + input_tag_map_ = std::move(input_tag_map); + output_tag_map_ = std::move(output_tag_map); + calculator_run_in_parallel_ = calculator_run_in_parallel; +} + +::mediapipe::Status CalculatorContextManager::PrepareForRun( + std::function<::mediapipe::Status(CalculatorContext*)> + setup_shards_callback) { + setup_shards_callback_ = std::move(setup_shards_callback); + default_context_ = absl::make_unique( + calculator_state_, input_tag_map_, output_tag_map_); + return setup_shards_callback_(default_context_.get()); +} + +void CalculatorContextManager::CleanupAfterRun() { + default_context_ = nullptr; + absl::MutexLock lock(&contexts_mutex_); + active_contexts_.clear(); + idle_contexts_.clear(); +} + +CalculatorContext* CalculatorContextManager::GetDefaultCalculatorContext() + const { + CHECK(default_context_.get()); + return default_context_.get(); +} + +CalculatorContext* CalculatorContextManager::GetFrontCalculatorContext( + Timestamp* context_input_timestamp) { + CHECK(calculator_run_in_parallel_); + absl::MutexLock lock(&contexts_mutex_); + CHECK(!active_contexts_.empty()); + *context_input_timestamp = active_contexts_.begin()->first; + return active_contexts_.begin()->second.get(); +} + +CalculatorContext* CalculatorContextManager::PrepareCalculatorContext( + Timestamp input_timestamp) { + if (!calculator_run_in_parallel_) { + return GetDefaultCalculatorContext(); + } + absl::MutexLock lock(&contexts_mutex_); + CHECK(!::mediapipe::ContainsKey(active_contexts_, input_timestamp)) + << "Multiple invocations with the same timestamps are not allowed with " + "parallel execution, input_timestamp = " + << input_timestamp; + CalculatorContext* calculator_context = nullptr; + if (idle_contexts_.empty()) { + auto new_context = absl::make_unique( + calculator_state_, input_tag_map_, output_tag_map_); + MEDIAPIPE_CHECK_OK(setup_shards_callback_(new_context.get())); + calculator_context = new_context.get(); + active_contexts_.emplace(input_timestamp, std::move(new_context)); + } else { + // Retrieves an inactive calculator context from idle_contexts_. + calculator_context = idle_contexts_.front().get(); + active_contexts_.emplace(input_timestamp, + std::move(idle_contexts_.front())); + idle_contexts_.pop_front(); + } + return calculator_context; +} + +void CalculatorContextManager::RecycleCalculatorContext() { + absl::MutexLock lock(&contexts_mutex_); + // The first element in active_contexts_ will be recycled. + auto iter = active_contexts_.begin(); + idle_contexts_.push_back(std::move(iter->second)); + active_contexts_.erase(iter); +} + +bool CalculatorContextManager::HasActiveContexts() { + if (!calculator_run_in_parallel_) { + return false; + } + absl::MutexLock lock(&contexts_mutex_); + return !active_contexts_.empty(); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_context_manager.h b/mediapipe/framework/calculator_context_manager.h new file mode 100644 index 000000000..1d93e797f --- /dev/null +++ b/mediapipe/framework/calculator_context_manager.h @@ -0,0 +1,146 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_MANAGER_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_MANAGER_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +// Calculator context manager owns and manages all calculator context objects of +// a calculator node. +class CalculatorContextManager { + public: + CalculatorContextManager() {} + + void Initialize(CalculatorState* calculator_state, + std::shared_ptr input_tag_map, + std::shared_ptr output_tag_map, + bool calculator_run_in_parallel); + + // Sets the callback that can setup the input and output stream shards in a + // newly constructed calculator context. Then, initializes the default + // calculator context. + ::mediapipe::Status PrepareForRun( + std::function<::mediapipe::Status(CalculatorContext*)> + setup_shards_callback); + + // Invoked by CalculatorNode::CleanupAfterRun(). + void CleanupAfterRun() LOCKS_EXCLUDED(contexts_mutex_); + + // Returns true if the default calculator context has been initialized. + bool HasDefaultCalculatorContext() const { + return default_context_ != nullptr; + } + + // Returns a pointer to the default calculator context that is used for + // sequential execution. A source node should always reuse its default + // calculator context. + CalculatorContext* GetDefaultCalculatorContext() const; + + // Returns the context with the smallest input timestamp in active_contexts_. + // The input timestamp of the calculator context is returned in + // *context_input_timestamp. + CalculatorContext* GetFrontCalculatorContext( + Timestamp* context_input_timestamp) LOCKS_EXCLUDED(contexts_mutex_); + + // For sequential execution, returns a pointer to the default calculator + // context. For parallel execution, creates or reuses a calculator context, + // and inserts the calculator context with the given input timestamp into + // active_contexts_. Returns a pointer to the prepared calculator context. + // The ownership of the calculator context object isn't tranferred to the + // caller. + CalculatorContext* PrepareCalculatorContext(Timestamp input_timestamp) + LOCKS_EXCLUDED(contexts_mutex_); + + // Removes the context with the smallest input timestamp from active_contexts_ + // and moves the calculator context to idle_contexts_. The caller must + // guarantee that the output shards in the calculator context have been + // propagated before calling this function. + void RecycleCalculatorContext() LOCKS_EXCLUDED(contexts_mutex_); + + // Returns true if active_contexts_ is non-empty. + bool HasActiveContexts() LOCKS_EXCLUDED(contexts_mutex_); + + int NumberOfContextTimestamps( + const CalculatorContext& calculator_context) const { + return calculator_context.NumberOfTimestamps(); + } + + bool ContextHasInputTimestamp( + const CalculatorContext& calculator_context) const { + return calculator_context.HasInputTimestamp(); + } + + void PushInputTimestampToContext(CalculatorContext* calculator_context, + Timestamp input_timestamp) { + CHECK(calculator_context); + calculator_context->PushInputTimestamp(input_timestamp); + } + + void PopInputTimestampFromContext(CalculatorContext* calculator_context) { + CHECK(calculator_context); + calculator_context->PopInputTimestamp(); + } + + void SetGraphStatusInContext(CalculatorContext* calculator_context, + const ::mediapipe::Status& status) { + CHECK(calculator_context); + calculator_context->SetGraphStatus(status); + } + + private: + CalculatorState* calculator_state_; + std::shared_ptr input_tag_map_; + std::shared_ptr output_tag_map_; + bool calculator_run_in_parallel_; + + // The callback to setup the input and output stream shards in a newly + // constructed calculator context. + // NOTE: This callback invokes input/output stream handler methods. + // The callback is used to break the circular dependency between + // calculator context manager and input/output stream handlers. + std::function<::mediapipe::Status(CalculatorContext*)> setup_shards_callback_; + + // The default calculator context that is always reused for sequential + // execution. It is also used by Open() and Close() method of a parallel + // calculator. + std::unique_ptr default_context_; + // The mutex for synchronizing the operations on active_contexts_ and + // idle_contexts_ during parallel execution. + absl::Mutex contexts_mutex_; + // A map from input timestamps to calculator contexts. + std::map> active_contexts_ + GUARDED_BY(contexts_mutex_); + // Idle calculator contexts that are ready for reuse. + std::deque> idle_contexts_ + GUARDED_BY(contexts_mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTEXT_MANAGER_H_ diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc new file mode 100644 index 000000000..044e10310 --- /dev/null +++ b/mediapipe/framework/calculator_context_test.cc @@ -0,0 +1,146 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_context.h" + +// TODO: Move protos in another CL after the C++ code migration. +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/output_stream_manager.h" +#include "mediapipe/framework/output_stream_shard.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/testdata/night_light_calculator.pb.h" +#include "mediapipe/framework/testdata/sky_light_calculator.pb.h" +#include "mediapipe/framework/tool/tag_map_helper.h" + +namespace mediapipe { + +namespace test_ns { + +std::string Proto3GraphStr() { + static std::string kProto3GraphStr = R"( + node { + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + options { + [mediapipe.NightLightCalculatorOptions.ext] { + base_timestamp: 123 + output_header: PASS_HEADER + jitter: 0.123 + } + } + } + node { + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values_also" + node_options: { + [type.googleapis.com/mediapipe.NightLightCalculatorOptions] { + base_timestamp: 123 + output_header: PASS_HEADER + jitter: 0.123 + } + } + } + node { + calculator: "SkyLightCalculator" + node_options: { + [type.googleapis.com/mediapipe.SkyLightCalculatorOptions] { + sky_color: "sky_blue" + } + } + } + node { + calculator: "SkyLightCalculator" + input_side_packet: "label" + input_stream: "values" + output_stream: "labelled_timestamps" + node_options: { + [type.googleapis.com/mediapipe.SkyLightCalculatorOptions] { + sky_color: "light_blue" + sky_grid: 2 + sky_grid: 4 + sky_grid: 8 + } + } + } + node { + calculator: "MakeVectorCalculator" + input_stream: "labelled_timestamps" + output_stream: "timestamp_vectors" + } + )"; + return kProto3GraphStr; +} + +std::unique_ptr MakeCalculatorState( + const CalculatorGraphConfig::Node& node_config, int node_id) { + auto result = absl::make_unique( + "Node", node_id, "Calculator", node_config, nullptr); + return result; +} + +std::unique_ptr MakeCalculatorContext( + CalculatorState* calculator_state) { + return absl::make_unique( + calculator_state, tool::CreateTagMap({}).ValueOrDie(), + tool::CreateTagMap({}).ValueOrDie()); +} + +TEST(CalculatorTest, NodeId) { + mediapipe::CalculatorGraphConfig config = + ParseTextProtoOrDie(Proto3GraphStr()); + + auto calculator_state_0 = MakeCalculatorState(config.node(0), 0); + auto cc_0 = MakeCalculatorContext(&*calculator_state_0); + auto calculator_state_1 = MakeCalculatorState(config.node(1), 1); + auto cc_1 = MakeCalculatorContext(&*calculator_state_1); + auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); + auto cc_3 = MakeCalculatorContext(&*calculator_state_3); + + EXPECT_EQ(cc_0->NodeId(), calculator_state_0->NodeId()); + EXPECT_EQ(cc_1->NodeId(), calculator_state_1->NodeId()); + EXPECT_EQ(cc_3->NodeId(), calculator_state_3->NodeId()); +} + +TEST(CalculatorTest, GetOptions) { + mediapipe::CalculatorGraphConfig config = + ParseTextProtoOrDie(Proto3GraphStr()); + + auto calculator_state_0 = MakeCalculatorState(config.node(0), 0); + auto cc_0 = MakeCalculatorContext(&*calculator_state_0); + auto calculator_state_1 = MakeCalculatorState(config.node(1), 1); + auto cc_1 = MakeCalculatorContext(&*calculator_state_1); + auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); + auto cc_3 = MakeCalculatorContext(&*calculator_state_3); + + // Get a proto2 options extension from Node::options. + EXPECT_EQ(cc_0->Options().jitter(), 0.123); + + // Get a proto2 options extension from Node::node_options. + EXPECT_EQ(cc_1->Options().jitter(), 0.123); + + // Get a proto3 options protobuf::Any from Node::node_options. + EXPECT_EQ(cc_3->Options().sky_color(), + "light_blue"); +} + +} // namespace test_ns +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_contract.cc b/mediapipe/framework/calculator_contract.cc new file mode 100644 index 000000000..6e0e44749 --- /dev/null +++ b/mediapipe/framework/calculator_contract.cc @@ -0,0 +1,140 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_contract.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +::mediapipe::Status CalculatorContract::Initialize( + const CalculatorGraphConfig::Node& node) { + std::vector<::mediapipe::Status> statuses; + + auto input_stream_statusor = tool::TagMap::Create(node.input_stream()); + if (!input_stream_statusor.ok()) { + statuses.push_back(std::move(input_stream_statusor).status()); + } + auto output_stream_statusor = tool::TagMap::Create(node.output_stream()); + if (!output_stream_statusor.ok()) { + statuses.push_back(std::move(output_stream_statusor).status()); + } + auto input_side_packet_statusor = + tool::TagMap::Create(node.input_side_packet()); + if (!input_side_packet_statusor.ok()) { + statuses.push_back(std::move(input_side_packet_statusor).status()); + } + auto output_side_packet_statusor = + tool::TagMap::Create(node.output_side_packet()); + if (!output_side_packet_statusor.ok()) { + statuses.push_back(std::move(output_side_packet_statusor).status()); + } + + if (!statuses.empty()) { + auto builder = ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + << "Unable to initialize TagMaps for node."; + for (const auto& status : statuses) { + builder << "\n" << status.message(); + } +#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE)) + builder << "\nFor calculator:\n"; + builder << node.DebugString(); +#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE) + return std::move(builder); + } + + node_config_ = &node; + options_.Initialize(*node_config_); + // Create the PacketTypeSets. + inputs_ = absl::make_unique( + std::move(input_stream_statusor).ValueOrDie()); + outputs_ = absl::make_unique( + std::move(output_stream_statusor).ValueOrDie()); + input_side_packets_ = absl::make_unique( + std::move(input_side_packet_statusor).ValueOrDie()); + output_side_packets_ = absl::make_unique( + std::move(output_side_packet_statusor).ValueOrDie()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorContract::Initialize( + const PacketGeneratorConfig& node) { + std::vector<::mediapipe::Status> statuses; + + auto input_side_packet_statusor = + tool::TagMap::Create(node.input_side_packet()); + if (!input_side_packet_statusor.ok()) { + statuses.push_back(std::move(input_side_packet_statusor).status()); + } + auto output_side_packet_statusor = + tool::TagMap::Create(node.output_side_packet()); + if (!output_side_packet_statusor.ok()) { + statuses.push_back(std::move(output_side_packet_statusor).status()); + } + + if (!statuses.empty()) { + auto builder = UnknownErrorBuilder(MEDIAPIPE_LOC) + << "NodeTypeInfo Initialization failed."; + for (const auto& status : statuses) { + builder << "\n" << status.message(); + } +#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE)) + builder << "\nFor packet_generator:\n"; + builder << node.DebugString(); +#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE) + return std::move(builder); + } + + input_side_packets_ = absl::make_unique( + std::move(input_side_packet_statusor).ValueOrDie()); + output_side_packets_ = absl::make_unique( + std::move(output_side_packet_statusor).ValueOrDie()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorContract::Initialize( + const StatusHandlerConfig& node) { + std::vector<::mediapipe::Status> statuses; + + auto input_side_packet_statusor = + tool::TagMap::Create(node.input_side_packet()); + if (!input_side_packet_statusor.ok()) { + statuses.push_back(std::move(input_side_packet_statusor).status()); + } + + if (!statuses.empty()) { + auto builder = ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + << "NodeTypeInfo Initialization failed."; + for (const auto& status : statuses) { + builder << "\n" << status.message(); + } +#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE)) + builder << "\nFor status_handler:\n"; + builder << node.DebugString(); +#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE) + return std::move(builder); + } + + input_side_packets_ = absl::make_unique( + std::move(input_side_packet_statusor).ValueOrDie()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h new file mode 100644 index 000000000..692cf601f --- /dev/null +++ b/mediapipe/framework/calculator_contract.h @@ -0,0 +1,149 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTRACT_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTRACT_H_ + +#include +#include +#include +#include + +// TODO: Move protos in another CL after the C++ code migration. +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/status_handler.pb.h" +#include "mediapipe/framework/tool/options_util.h" + +namespace mediapipe { + +// CalculatorContract contains the expectations and properties of a Node +// object, such as the expected packet types of input and output streams and +// input and output side packets. +// +// Setters and getters are available for specifying an InputStreamHandler and +// it's options from inside a calculator's GetContract() method. Ex: +// cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); +// MediaPipeOptions options; +// options.MutableExtension(FixedSizeInputStreamHandlerOptions::ext) +// ->set_fixed_min_size(2); +// cc->SetInputStreamHandlerOptions(options); +// +class CalculatorContract { + public: + ::mediapipe::Status Initialize(const CalculatorGraphConfig::Node& node); + ::mediapipe::Status Initialize(const PacketGeneratorConfig& node); + ::mediapipe::Status Initialize(const StatusHandlerConfig& node); + + // Returns the options given to this node. + const CalculatorOptions& Options() const { return node_config_->options(); } + + // Returns the options given to this calculator. Template argument T must + // be the type of the protobuf extension message or the protobuf::Any + // message containing the options. + template + const T& Options() const { + return options_.Get(); + } + + // Returns the PacketTypeSet for the input streams. + PacketTypeSet& Inputs() { return *inputs_; } + const PacketTypeSet& Inputs() const { return *inputs_; } + + // Returns the PacketTypeSet for the output streams. + PacketTypeSet& Outputs() { return *outputs_; } + const PacketTypeSet& Outputs() const { return *outputs_; } + + // Returns the PacketTypeSet for the input side packets. + PacketTypeSet& InputSidePackets() { return *input_side_packets_; } + const PacketTypeSet& InputSidePackets() const { return *input_side_packets_; } + + // Returns the PacketTypeSet for the output side packets. + PacketTypeSet& OutputSidePackets() { return *output_side_packets_; } + const PacketTypeSet& OutputSidePackets() const { + return *output_side_packets_; + } + + // Set this Node's default InputStreamHandler. + // If there is an InputStreamHandler specified in the graph (.pbtxt) for this + // Node, then the graph's InputStreamHandler will take priority. + void SetInputStreamHandler(const std::string& name) { + input_stream_handler_ = name; + } + void SetInputStreamHandlerOptions(const MediaPipeOptions& options) { + input_stream_handler_options_ = options; + } + + // Returns the name of this Nodes's InputStreamHandler, or empty std::string + // if none is set. + std::string GetInputStreamHandler() const { return input_stream_handler_; } + + // Returns the MediaPipeOptions of this Node's InputStreamHandler, or empty + // options if none is set. + MediaPipeOptions GetInputStreamHandlerOptions() const { + return input_stream_handler_options_; + } + + class GraphServiceRequest { + public: + // APIs that should be used by calculators. + GraphServiceRequest& Optional() { + optional_ = true; + return *this; + } + + // Internal use. + GraphServiceRequest(const GraphServiceBase& service) : service_(service) {} + + const GraphServiceBase& Service() const { return service_; } + + bool IsOptional() const { return optional_; } + + private: + GraphServiceBase service_; + bool optional_ = false; + }; + + GraphServiceRequest& UseService(const GraphServiceBase& service) { + auto it = service_requests_.emplace(service.key, service).first; + return it->second; + } + + const std::map& ServiceRequests() const { + return service_requests_; + } + + private: + template + void GetNodeOptions(T* result) const; + + const CalculatorGraphConfig::Node* node_config_ = nullptr; + tool::OptionsMap options_; + std::unique_ptr inputs_; + std::unique_ptr outputs_; + std::unique_ptr input_side_packets_; + std::unique_ptr output_side_packets_; + std::string input_stream_handler_; + MediaPipeOptions input_stream_handler_options_; + std::map service_requests_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_CONTRACT_H_ diff --git a/mediapipe/framework/calculator_contract_test.cc b/mediapipe/framework/calculator_contract_test.cc new file mode 100644 index 000000000..c3eb5d6c7 --- /dev/null +++ b/mediapipe/framework/calculator_contract_test.cc @@ -0,0 +1,101 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_contract.h" + +// TODO: Move protos in another CL after the C++ code migration. +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_contract_test.pb.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/status_handler.pb.h" + +namespace mediapipe { + +namespace { + +TEST(CalculatorContractTest, Calculator) { + const CalculatorGraphConfig::Node node = + ::mediapipe::ParseTextProtoOrDie(R"( + calculator: "MixtureOfExpertsFusionCalculator" + input_stream: "FRAME:fdense_pca_moe_aggregated_detection" + input_stream: "FNET:fnet_logreg_aggregated_detection" + input_stream: "EGRAPH:egraph_segment_aggregated_detection" + input_stream: "VIDEO:fdense_averaged_pca_moe_v2_detection" + input_side_packet: "FUSION_MODEL:egraph_topical_packet_factory" + output_stream: "egraph_topical_detection" + )"); + CalculatorContract contract; + MEDIAPIPE_EXPECT_OK(contract.Initialize(node)); + EXPECT_EQ(contract.Inputs().NumEntries(), 4); + EXPECT_EQ(contract.Outputs().NumEntries(), 1); + EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1); + EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 0); +} + +TEST(CalculatorContractTest, CalculatorOptions) { + const CalculatorGraphConfig::Node node = + ::mediapipe::ParseTextProtoOrDie(R"( + calculator: "CalculatorTestCalculator" + input_stream: "DATA:ycbcr_frames" + input_stream: "VIDEO_HEADER:ycbcr_frames_prestream" + output_stream: "DATA:ycbcr_downsampled" + output_stream: "VIDEO_HEADER:ycbcr_downsampled_prestream" + options { + [mediapipe.CalculatorContractTestOptions.ext] { test_field: 1.0 } + })"); + CalculatorContract contract; + MEDIAPIPE_EXPECT_OK(contract.Initialize(node)); + const auto& test_options = + contract.Options().GetExtension(CalculatorContractTestOptions::ext); + EXPECT_EQ(test_options.test_field(), 1.0); + EXPECT_EQ(contract.Inputs().NumEntries(), 2); + EXPECT_EQ(contract.Outputs().NumEntries(), 2); + EXPECT_EQ(contract.InputSidePackets().NumEntries(), 0); + EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 0); +} + +TEST(CalculatorContractTest, PacketGenerator) { + const PacketGeneratorConfig node = + ::mediapipe::ParseTextProtoOrDie(R"( + packet_generator: "DaredevilLabeledTimeSeriesGenerator" + input_side_packet: "labeled_time_series" + output_side_packet: "time_series_header" + output_side_packet: "input_matrix" + output_side_packet: "label_set" + output_side_packet: "content_fingerprint" + )"); + CalculatorContract contract; + MEDIAPIPE_EXPECT_OK(contract.Initialize(node)); + EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1); + EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 4); +} + +TEST(CalculatorContractTest, StatusHandler) { + const StatusHandlerConfig node = + ::mediapipe::ParseTextProtoOrDie(R"( + status_handler: "TaskInjectorStatusHandler" + input_side_packet: "ROW:cid" + input_side_packet: "SPEC:task_specification" + )"); + CalculatorContract contract; + MEDIAPIPE_EXPECT_OK(contract.Initialize(node)); + EXPECT_EQ(contract.InputSidePackets().NumEntries(), 2); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_contract_test.proto b/mediapipe/framework/calculator_contract_test.proto new file mode 100644 index 000000000..d86d99fca --- /dev/null +++ b/mediapipe/framework/calculator_contract_test.proto @@ -0,0 +1,30 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/mediapipe_options.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CalculatorContractTestOptions { + extend CalculatorOptions { + optional CalculatorContractTestOptions ext = 188754615; + } + optional double test_field = 1 [default = -1.0]; +} diff --git a/mediapipe/framework/calculator_framework.h b/mediapipe/framework/calculator_framework.h new file mode 100644 index 000000000..afb73fb30 --- /dev/null +++ b/mediapipe/framework/calculator_framework.h @@ -0,0 +1,76 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This header is used to include the core portions of the calculator +// framework. The comments that follow describe the main classes within +// the framework and how they interact. +// +// Calculator: A class which clients subclass to do actual work. +// It receives input and produces output which may go to many other +// Calculators connected in a directed acyclic graph. +// +// CalculatorGraph: A class which sets up a CalculatorGraphConfig and +// runs it. This is the controller class which governs the top level +// behavior of the framework and how things are run. +// +// CalculatorNode: A class which keeps track of a single Calculator and +// framework level details that the client does not need to worry about +// (such as how to advertise that the Calculator is blocked or unblocked). +// +// InputStream: A class which holds the next value in an input stream +// for a Calculator to use and provides access to the stream header. +// It is the superclass of InputStreamImpl which holds implementation +// details for the framework. +// +// InputStreamImpl: All information for the input stream. +// A CalculatorNode and OutputStreamImpl has access to this information, +// but the Calculator does not. +// +// OutputStream: A class which gets the output packets from a Calculator +// and relays them to the next calculators or the framework. +// +// OutputStreamImpl: The framework level information for an OutputStream. +// A CalculatorNode has access to this information but the Calculator +// does not. +// +// CalculatorState: Data class to hold information the Calculator needs +// access to. This data persists across multiple runs of the graph, +// whereas the Calculators will be destroyed and recreated. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_FRAMEWORK_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_FRAMEWORK_H_ + +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/calculator_registry.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/input_stream.h" +#include "mediapipe/framework/output_side_packet.h" +#include "mediapipe/framework/output_stream.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_generator.h" +#include "mediapipe/framework/packet_generator_graph.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/status_handler.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/tool/validate.h" +#include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/framework/validated_graph_config.h" + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_FRAMEWORK_H_ diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc new file mode 100644 index 000000000..265c8ad60 --- /dev/null +++ b/mediapipe/framework/calculator_graph.cc @@ -0,0 +1,1306 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_graph.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/delegating_executor.h" +#include "mediapipe/framework/input_stream_manager.h" +#include "mediapipe/framework/mediapipe_profiling.h" +#include "mediapipe/framework/packet_generator.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/status_handler.h" +#include "mediapipe/framework/status_handler.pb.h" +#include "mediapipe/framework/thread_pool_executor.h" +#include "mediapipe/framework/thread_pool_executor.pb.h" +#include "mediapipe/framework/tool/fill_packet_set.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/tool/tag_map.h" +#include "mediapipe/framework/tool/validate.h" +#include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/framework/validated_graph_config.h" +#include "mediapipe/gpu/graph_support.h" +#include "mediapipe/util/cpu_util.h" +#ifndef MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + +namespace mediapipe { + +namespace { + +// Forcefully terminates the framework when the number of errors exceeds this +// threshold. +constexpr int kMaxNumAccumulatedErrors = 1000; +constexpr char kApplicationThreadExecutorType[] = "ApplicationThreadExecutor"; + +} // namespace + +void CalculatorGraph::ScheduleAllOpenableNodes() { + // This method can only be called before the scheduler_.Start() call and the + // graph input streams' SetHeader() calls because it is safe to call + // node->ReadyForOpen() only before any node or graph input stream has + // propagated header packets or generated output side packets, either of + // which may cause a downstream node to be scheduled for OpenNode(). + for (CalculatorNode& node : *nodes_) { + if (node.ReadyForOpen()) { + scheduler_.ScheduleNodeForOpen(&node); + } + } +} + +void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { + shard_.SetHeader(header); + manager_->PropagateHeader(); + manager_->LockIntroData(); +} + +void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { + // Since GraphInputStream doesn't allow SetOffset() and + // SetNextTimestampBound(), the timestamp bound to propagate is only + // determined by the timestamp of the output packets. + CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() + << "\" failed"; + manager_->PropagateUpdatesToMirrors( + shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); +} + +void CalculatorGraph::GraphInputStream::Close() { + if (!shard_.IsEmpty()) { + manager_->PropagateUpdatesToMirrors(Timestamp::Done(), &shard_); + } + manager_->Close(); +} + +CalculatorGraph::CalculatorGraph() + : profiler_(std::make_shared()), scheduler_(this) { + counter_factory_ = absl::make_unique(); +} + +CalculatorGraph::CalculatorGraph(const CalculatorGraphConfig& config) + : CalculatorGraph() { + counter_factory_ = absl::make_unique(); + MEDIAPIPE_CHECK_OK(Initialize(config)); +} + +// Defining the destructor here lets us use incomplete types in the header; +// they only need to be fully visible here, where their destructor is +// instantiated. +CalculatorGraph::~CalculatorGraph() {} + +::mediapipe::Status CalculatorGraph::InitializePacketGeneratorGraph( + const std::map& side_packets) { + // Create and initialize the output side packets. + if (!validated_graph_->OutputSidePacketInfos().empty()) { + output_side_packets_ = absl::make_unique( + validated_graph_->OutputSidePacketInfos().size()); + } + for (int index = 0; index < validated_graph_->OutputSidePacketInfos().size(); + ++index) { + const EdgeInfo& edge_info = + validated_graph_->OutputSidePacketInfos()[index]; + RETURN_IF_ERROR(output_side_packets_[index].Initialize( + edge_info.name, edge_info.packet_type)); + } + + // If use_application_thread_ is true, the default executor is a + // DelegatingExecutor. This DelegatingExecutor is tightly coupled to + // scheduler_ and therefore cannot be used by packet_generator_graph_. + Executor* default_executor = nullptr; + if (!use_application_thread_) { + default_executor = executors_[""].get(); + CHECK(default_executor); + } + // If default_executor is nullptr, then packet_generator_graph_ will create + // its own DelegatingExecutor to use the application thread. + return packet_generator_graph_.Initialize(validated_graph_.get(), + default_executor, side_packets); +} + +::mediapipe::Status CalculatorGraph::InitializeStreams() { + any_packet_type_.SetAny(); + + // Create and initialize the input streams. + input_stream_managers_ = absl::make_unique( + validated_graph_->InputStreamInfos().size()); + for (int index = 0; index < validated_graph_->InputStreamInfos().size(); + ++index) { + const EdgeInfo& edge_info = validated_graph_->InputStreamInfos()[index]; + RETURN_IF_ERROR(input_stream_managers_[index].Initialize( + edge_info.name, edge_info.packet_type, edge_info.back_edge)); + } + + // Create and initialize the output streams. + output_stream_managers_ = absl::make_unique( + validated_graph_->OutputStreamInfos().size()); + for (int index = 0; index < validated_graph_->OutputStreamInfos().size(); + ++index) { + const EdgeInfo& edge_info = validated_graph_->OutputStreamInfos()[index]; + RETURN_IF_ERROR(output_stream_managers_[index].Initialize( + edge_info.name, edge_info.packet_type)); + } + + // Initialize GraphInputStreams. + int graph_input_stream_count = 0; + ASSIGN_OR_RETURN( + auto input_tag_map, + tool::TagMap::Create(validated_graph_->Config().input_stream())); + for (const auto& stream_name : input_tag_map->Names()) { + RET_CHECK(!::mediapipe::ContainsKey(graph_input_streams_, stream_name)) + .SetNoLogging() + << "CalculatorGraph Initialization failed, graph input stream \"" + << stream_name << "\" was specified twice."; + int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); + RET_CHECK_LE(0, output_stream_index).SetNoLogging(); + const EdgeInfo& edge_info = + validated_graph_->OutputStreamInfos()[output_stream_index]; + RET_CHECK(NodeTypeInfo::NodeType::GRAPH_INPUT_STREAM == + edge_info.parent_node.type) + .SetNoLogging(); + + graph_input_streams_[stream_name] = absl::make_unique( + &output_stream_managers_[output_stream_index]); + + // Assign a virtual node ID to each graph input stream so we can treat + // these as regular nodes for throttling. + graph_input_stream_node_ids_[stream_name] = + validated_graph_->CalculatorInfos().size() + graph_input_stream_count; + ++graph_input_stream_count; + } + + // Set the default mode for graph input streams. + { + absl::MutexLock lock(&full_input_streams_mutex_); + graph_input_stream_add_mode_ = GraphInputStreamAddMode::WAIT_TILL_NOT_FULL; + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::InitializeCalculatorNodes() { + // Check if the user has specified a maximum queue size for an input stream. + max_queue_size_ = validated_graph_->Config().max_queue_size(); + max_queue_size_ = max_queue_size_ ? max_queue_size_ : 100; + + // Use a local variable to avoid needing to lock errors_. + std::vector<::mediapipe::Status> errors; + + // Create and initialize all the nodes in the graph. + nodes_ = absl::make_unique>( + validated_graph_->CalculatorInfos().size()); + for (int node_id = 0; node_id < validated_graph_->CalculatorInfos().size(); + ++node_id) { + // buffer_size_hint will be positive if one was specified in + // the graph proto. + int buffer_size_hint = 0; + const ::mediapipe::Status result = (*nodes_)[node_id].Initialize( + validated_graph_.get(), node_id, input_stream_managers_.get(), + output_stream_managers_.get(), output_side_packets_.get(), + &buffer_size_hint, profiler_); + if (buffer_size_hint > 0) { + max_queue_size_ = std::max(max_queue_size_, buffer_size_hint); + } + if (!result.ok()) { + // Collect as many errors as we can before failing. + errors.push_back(result); + } + } + if (!errors.empty()) { + return tool::CombinedStatus( + "CalculatorGraph::InitializeCalculatorNodes failed: ", errors); + } + + VLOG(2) << "Maximum input stream queue size based on graph config: " + << max_queue_size_; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::InitializeProfiler() { + profiler_->Initialize(*validated_graph_); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::InitializeExecutors() { + // If the ExecutorConfig for the default executor leaves the executor type + // unspecified, default_executor_options points to the + // ThreadPoolExecutorOptions in that ExecutorConfig. Otherwise, + // default_executor_options is null. + const ThreadPoolExecutorOptions* default_executor_options = nullptr; + bool use_application_thread = false; + for (const ExecutorConfig& executor_config : + validated_graph_->Config().executor()) { + if (::mediapipe::ContainsKey(executors_, executor_config.name())) { + if (!executor_config.type().empty()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "ExecutorConfig for \"" << executor_config.name() + << "\" has a \"type\" field but is also provided to the graph " + "with a CalculatorGraph::SetExecutor() call."; + } + continue; + } + if (executor_config.name().empty()) { + // Executor name "" refers to the default executor. + if (executor_config.type().empty()) { + // For the default executor, an unspecified type means letting the + // framework choose an appropriate executor type. + default_executor_options = &executor_config.options().GetExtension( + ThreadPoolExecutorOptions::ext); + continue; + } + if (executor_config.type() == kApplicationThreadExecutorType) { + // For the default executor, the type "ApplicationThreadExecutor" means + // running synchronously on the calling thread. + use_application_thread = true; + continue; + } + } + if (executor_config.type().empty()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "ExecutorConfig for \"" << executor_config.name() + << "\" does not have a \"type\" field. The executor \"" + << executor_config.name() + << "\" must be provided to the graph with a " + "CalculatorGraph::SetExecutor() call."; + } + // clang-format off + ASSIGN_OR_RETURN(Executor* executor, + ExecutorRegistry::CreateByNameInNamespace( + validated_graph_->Package(), + executor_config.type(), executor_config.options())); + // clang-format on + MEDIAPIPE_CHECK_OK(SetExecutorInternal( + executor_config.name(), std::shared_ptr(executor))); + } + + if (!::mediapipe::ContainsKey(executors_, "")) { + RETURN_IF_ERROR(InitializeDefaultExecutor(*default_executor_options, + use_application_thread)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::InitializeDefaultExecutor( + const ThreadPoolExecutorOptions& default_executor_options, + bool use_application_thread) { + // If specified, run synchronously on the calling thread. + if (use_application_thread) { + use_application_thread_ = true; + MEDIAPIPE_CHECK_OK(SetExecutorInternal( + "", std::make_shared( + std::bind(&internal::Scheduler::AddApplicationThreadTask, + &scheduler_, std::placeholders::_1)))); + return ::mediapipe::OkStatus(); + } + + // Check the number of threads specified in the proto. + int num_threads = default_executor_options.num_threads(); + + // If the default (0 or -1) was specified, pick a suitable number of threads + // depending on the number of processors in this system and the number of + // calculators and packet generators in the calculator graph. + if (num_threads == 0 || num_threads == -1) { + num_threads = std::min( + mediapipe::NumCPUCores(), + std::max({validated_graph_->Config().node().size(), + validated_graph_->Config().packet_generator().size(), 1})); + } + RETURN_IF_ERROR( + CreateDefaultThreadPool(default_executor_options, num_threads)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::Initialize( + std::unique_ptr validated_graph, + const std::map& side_packets) { + RET_CHECK(!initialized_).SetNoLogging() + << "CalculatorGraph can be initialized only once."; + RET_CHECK(validated_graph->Initialized()).SetNoLogging() + << "validated_graph is not initialized."; + validated_graph_ = std::move(validated_graph); + + RETURN_IF_ERROR(InitializeExecutors()); + RETURN_IF_ERROR(InitializePacketGeneratorGraph(side_packets)); + RETURN_IF_ERROR(InitializeStreams()); + RETURN_IF_ERROR(InitializeCalculatorNodes()); +#ifdef MEDIAPIPE_PROFILER_AVAILABLE + RETURN_IF_ERROR(InitializeProfiler()); +#endif + + initialized_ = true; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::Initialize( + const CalculatorGraphConfig& input_config) { + return Initialize(input_config, {}); +} + +::mediapipe::Status CalculatorGraph::Initialize( + const CalculatorGraphConfig& input_config, + const std::map& side_packets) { + auto validated_graph = absl::make_unique(); + RETURN_IF_ERROR(validated_graph->Initialize(input_config)); + return Initialize(std::move(validated_graph), side_packets); +} + +::mediapipe::Status CalculatorGraph::Initialize( + const std::vector& input_configs, + const std::vector& input_templates, + const std::map& side_packets, + const std::string& graph_type, const Subgraph::SubgraphOptions* options) { + auto validated_graph = absl::make_unique(); + RETURN_IF_ERROR(validated_graph->Initialize(input_configs, input_templates, + graph_type, options)); + return Initialize(std::move(validated_graph), side_packets); +} + +::mediapipe::Status CalculatorGraph::ObserveOutputStream( + const std::string& stream_name, + std::function<::mediapipe::Status(const Packet&)> packet_callback) { + RET_CHECK(initialized_).SetNoLogging() + << "CalculatorGraph is not initialized."; + // TODO Allow output observers to be attached by graph level + // tag/index. + int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); + if (output_stream_index < 0) { + return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) + << "Unable to attach observer to output stream \"" << stream_name + << "\" because it doesn't exist."; + } + auto observer = absl::make_unique(); + RETURN_IF_ERROR(observer->Initialize( + stream_name, &any_packet_type_, std::move(packet_callback), + &output_stream_managers_[output_stream_index])); + graph_output_streams_.push_back(std::move(observer)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::StatusOr +CalculatorGraph::AddOutputStreamPoller(const std::string& stream_name) { + RET_CHECK(initialized_).SetNoLogging() + << "CalculatorGraph is not initialized."; + int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); + if (output_stream_index < 0) { + return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) + << "Unable to attach observer to output stream \"" << stream_name + << "\" because it doesn't exist."; + } + auto internal_poller = std::make_shared(); + RETURN_IF_ERROR(internal_poller->Initialize( + stream_name, &any_packet_type_, + std::bind(&CalculatorGraph::UpdateThrottledNodes, this, + std::placeholders::_1, std::placeholders::_2), + &output_stream_managers_[output_stream_index])); + OutputStreamPoller poller(internal_poller); + graph_output_streams_.push_back(std::move(internal_poller)); + return std::move(poller); +} + +::mediapipe::StatusOr CalculatorGraph::GetOutputSidePacket( + const std::string& packet_name) { + int side_packet_index = validated_graph_->OutputSidePacketIndex(packet_name); + if (side_packet_index < 0) { + return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) + << "Unable to get the output side packet \"" << packet_name + << "\" because it doesn't exist."; + } + Packet output_packet; + if (scheduler_.IsTerminated()) { + // Side-packets from calculators can be retrieved only after the graph is + // done. + output_packet = output_side_packets_[side_packet_index].GetPacket(); + } + if (output_packet.IsEmpty()) { + // See if it exists in the base packets that come from PacketGenerators. + // TODO: Update/remove this after b/119671096 is resolved. + auto base_packets = packet_generator_graph_.BasePackets(); + auto base_packet_iter = base_packets.find(packet_name); + auto current_run_side_packet_iter = + current_run_side_packets_.find(packet_name); + if (base_packet_iter != base_packets.end() && + !base_packet_iter->second.IsEmpty()) { + output_packet = base_packet_iter->second; + } else if (current_run_side_packet_iter != + current_run_side_packets_.end() && + !current_run_side_packet_iter->second.IsEmpty()) { + output_packet = current_run_side_packet_iter->second; + } else { + return ::mediapipe::UnavailableErrorBuilder(MEDIAPIPE_LOC) + << "The output side packet \"" << packet_name + << "\" is unavailable."; + } + } + return output_packet; +} + +::mediapipe::Status CalculatorGraph::Run( + const std::map& extra_side_packets) { + RET_CHECK(graph_input_streams_.empty()).SetNoLogging() + << "When using graph input streams, call StartRun() instead of Run() so " + "that AddPacketToInputStream() and CloseInputStream() can be called."; + RETURN_IF_ERROR(StartRun(extra_side_packets, {})); + return WaitUntilDone(); +} + +::mediapipe::Status CalculatorGraph::StartRun( + const std::map& extra_side_packets, + const std::map& stream_headers) { + RET_CHECK(initialized_).SetNoLogging() + << "CalculatorGraph is not initialized."; + RETURN_IF_ERROR(PrepareForRun(extra_side_packets, stream_headers)); + RETURN_IF_ERROR(profiler_->Start(executors_[""].get())); + scheduler_.Start(); + return ::mediapipe::OkStatus(); +} + +#ifndef MEDIAPIPE_DISABLE_GPU +::mediapipe::Status CalculatorGraph::SetGpuResources( + std::shared_ptr<::mediapipe::GpuResources> resources) { + RET_CHECK(!ContainsKey(service_packets_, kGpuService.key)) + << "The GPU resources have already been configured."; + service_packets_[kGpuService.key] = + MakePacket>( + std::move(resources)); + return ::mediapipe::OkStatus(); +} + +std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() + const { + auto service_iter = service_packets_.find(kGpuService.key); + if (service_iter == service_packets_.end()) return nullptr; + return service_iter->second.Get>(); +} + +::mediapipe::StatusOr> +CalculatorGraph::PrepareGpu(const std::map& side_packets) { + std::map additional_side_packets; + bool update_sp = false; + bool uses_gpu = false; + for (const auto& node : *nodes_) { + if (node.UsesGpu()) { + uses_gpu = true; + break; + } + } + if (uses_gpu) { + auto service_iter = service_packets_.find(kGpuService.key); + bool has_service = service_iter != service_packets_.end(); + + auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); + // Workaround for b/116875321: CalculatorRunner provides an empty packet, + // instead of just leaving it unset. + bool has_legacy_sp = legacy_sp_iter != side_packets.end() && + !legacy_sp_iter->second.IsEmpty(); + + std::shared_ptr<::mediapipe::GpuResources> gpu_resources; + if (has_service) { + if (has_legacy_sp) { + LOG(WARNING) + << "::mediapipe::GpuSharedData provided as a side packet while the " + << "graph already had one; ignoring side packet"; + } + gpu_resources = service_iter->second + .Get>(); + update_sp = true; + } else { + if (has_legacy_sp) { + gpu_resources = + legacy_sp_iter->second.Get<::mediapipe::GpuSharedData*>() + ->gpu_resources; + } else { + ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create()); + update_sp = true; + } + service_packets_[kGpuService.key] = + MakePacket>(gpu_resources); + } + + // Create or replace the legacy side packet if needed. + if (update_sp) { + legacy_gpu_shared_.reset(new ::mediapipe::GpuSharedData(gpu_resources)); + additional_side_packets[kGpuSharedSidePacketName] = + MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get()); + } + + // Set up executors. + for (auto& node : *nodes_) { + if (node.UsesGpu()) { + gpu_resources->PrepareGpuNode(&node); + } + } + for (const auto& name_executor : gpu_resources->GetGpuExecutors()) { + RETURN_IF_ERROR( + SetExecutorInternal(name_executor.first, name_executor.second)); + } + } + return additional_side_packets; +} +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + +::mediapipe::Status CalculatorGraph::PrepareForRun( + const std::map& extra_side_packets, + const std::map& stream_headers) { + if (VLOG_IS_ON(1)) { + for (const auto& item : extra_side_packets) { + VLOG(1) << "Adding extra_side_packet with name: " << item.first; + } + } + + { + absl::MutexLock lock(&error_mutex_); + errors_.clear(); + has_error_ = false; + } + num_closed_graph_input_streams_ = 0; + + std::map additional_side_packets; +#ifndef MEDIAPIPE_DISABLE_GPU + ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets)); +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + + const std::map* input_side_packets; + if (!additional_side_packets.empty()) { + additional_side_packets.insert(extra_side_packets.begin(), + extra_side_packets.end()); + input_side_packets = &additional_side_packets; + } else { + input_side_packets = &extra_side_packets; + } + + current_run_side_packets_.clear(); + ::mediapipe::Status generator_status = packet_generator_graph_.RunGraphSetup( + *input_side_packets, ¤t_run_side_packets_); + + CallStatusHandlers(GraphRunState::PRE_RUN, generator_status); + + if (!generator_status.ok()) { + return generator_status; + } + + // If there was an error on the CallStatusHandlers (PRE_RUN), it was stored + // in the error list. We return immediately notifying this to the caller. + ::mediapipe::Status error_status; + if (has_error_) { + GetCombinedErrors(&error_status); + LOG(ERROR) << error_status; + return error_status; + } + + if (VLOG_IS_ON(1)) { + std::vector input_side_packet_names; + for (const auto& item : current_run_side_packets_) { + input_side_packet_names.push_back(item.first); + } + VLOG(1) << "Final input side packet names are: " + << absl::StrJoin(input_side_packet_names, ","); + } + + Executor* default_executor = nullptr; + if (!use_application_thread_) { + default_executor = executors_[""].get(); + RET_CHECK(default_executor); + } + scheduler_.Reset(); + + { + absl::MutexLock lock(&full_input_streams_mutex_); + // Initialize a count per source node to store the number of input streams + // that are full and are affected by the source node. A node is considered + // to be throttled if the count corresponding to this node is non-zero. + // i.e. there is at least one affected stream which is full. We treat the + // graph input streams as nodes because they might need to be throttled. + full_input_streams_.clear(); + full_input_streams_.resize(validated_graph_->CalculatorInfos().size() + + graph_input_streams_.size()); + } + + for (auto& item : graph_input_streams_) { + item.second->PrepareForRun( + std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1)); + } + for (int index = 0; index < validated_graph_->OutputSidePacketInfos().size(); + ++index) { + output_side_packets_[index].PrepareForRun( + std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1)); + } + for (CalculatorNode& node : *nodes_) { + InputStreamManager::QueueSizeCallback queue_size_callback = + std::bind(&CalculatorGraph::UpdateThrottledNodes, this, + std::placeholders::_1, std::placeholders::_2); + node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); + scheduler_.AssignNodeToSchedulerQueue(&node); + const ::mediapipe::Status result = node.PrepareForRun( + current_run_side_packets_, service_packets_, + std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_, + &node), + std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_, + &node), + std::bind(&internal::Scheduler::ScheduleNodeIfNotThrottled, &scheduler_, + &node, std::placeholders::_1), + std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1), + counter_factory_.get()); + if (!result.ok()) { + // Collect as many errors as we can before failing. + RecordError(result); + } + } + for (auto& graph_output_stream : graph_output_streams_) { + graph_output_stream->PrepareForRun( + [&graph_output_stream, this] { + ::mediapipe::Status status = graph_output_stream->Notify(); + if (!status.ok()) { + RecordError(status); + } + scheduler_.EmittedObservedOutput(); + }, + [this](::mediapipe::Status status) { RecordError(status); }); + } + + if (GetCombinedErrors(&error_status)) { + LOG(ERROR) << error_status; + CleanupAfterRun(&error_status); + return error_status; + } + + // Ensure that the latest value of max queue size is passed to all input + // streams. + for (auto& node : *nodes_) { + node.SetMaxInputStreamQueueSize(max_queue_size_); + } + + // Allow graph input streams to override the global max queue size. + for (const auto& name_max : graph_input_stream_max_queue_size_) { + std::unique_ptr* stream = + ::mediapipe::FindOrNull(graph_input_streams_, name_max.first); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamMaxQueueSize called on \"$0\" which is not a " + "graph input stream.", + name_max.first); + (*stream)->SetMaxQueueSize(name_max.second); + } + + for (CalculatorNode& node : *nodes_) { + if (node.IsSource()) { + scheduler_.AddUnopenedSourceNode(&node); + has_sources_ = true; + } + } + + VLOG(2) << "Opening calculators."; + // Open the calculators. + ScheduleAllOpenableNodes(); + + // Header has to be set after the above preparation, since the header is + // propagated to the connected streams. In addition, setting the header + // packet may make a node ready for OpenNode(), and we should not schedule + // OpenNode() before the ScheduleAllOpenableNodes() call. + for (auto& item : graph_input_streams_) { + auto header = stream_headers.find(item.first); + if (header != stream_headers.end()) { + item.second->SetHeader(header->second); + } else { + // SetHeader() not only sets the header but also propagates it to the + // mirrors. Propagate the header to mirrors even if the header is empty + // to inform mirrors that they can proceed. + item.second->SetHeader(Packet()); + } + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::WaitUntilIdle() { + if (has_sources_) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "WaitUntilIdle called on a graph with source nodes."; + } + RETURN_IF_ERROR(scheduler_.WaitUntilIdle()); + VLOG(2) << "Scheduler idle."; + ::mediapipe::Status status = ::mediapipe::OkStatus(); + if (GetCombinedErrors(&status)) { + LOG(ERROR) << status; + } + return status; +} + +::mediapipe::Status CalculatorGraph::WaitUntilDone() { + VLOG(2) << "Waiting for scheduler to terminate..."; + RETURN_IF_ERROR(scheduler_.WaitUntilDone()); + VLOG(2) << "Scheduler terminated."; + + return FinishRun(); +} + +::mediapipe::Status CalculatorGraph::WaitForObservedOutput() { + return scheduler_.WaitForObservedOutput(); +} + +::mediapipe::Status CalculatorGraph::AddPacketToInputStream( + const std::string& stream_name, const Packet& packet) { + return AddPacketToInputStreamInternal(stream_name, packet); +} + +::mediapipe::Status CalculatorGraph::AddPacketToInputStream( + const std::string& stream_name, Packet&& packet) { + return AddPacketToInputStreamInternal(stream_name, std::move(packet)); +} + +// We avoid having two copies of this code for AddPacketToInputStream( +// const Packet&) and AddPacketToInputStream(Packet &&) by having this +// internal-only templated version. T&& is a forwarding reference here, so +// std::forward will deduce the correct type as we pass along packet. +template +::mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( + const std::string& stream_name, T&& packet) { + std::unique_ptr* stream = + ::mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "AddPacketToInputStream called on input stream \"$0\" which is not a " + "graph input stream.", + stream_name); + int node_id = + ::mediapipe::FindOrDie(graph_input_stream_node_ids_, stream_name); + CHECK_GE(node_id, validated_graph_->CalculatorInfos().size()); + { + absl::MutexLock lock(&full_input_streams_mutex_); + if (graph_input_stream_add_mode_ == + GraphInputStreamAddMode::ADD_IF_NOT_FULL) { + if (has_error_) { + ::mediapipe::Status error_status; + GetCombinedErrors("Graph has errors: ", &error_status); + return error_status; + } + // Return with StatusUnavailable if this stream is being throttled. + if (!full_input_streams_[node_id].empty()) { + return ::mediapipe::UnavailableErrorBuilder(MEDIAPIPE_LOC) + << "Graph is throttled."; + } + } else if (graph_input_stream_add_mode_ == + GraphInputStreamAddMode::WAIT_TILL_NOT_FULL) { + // Wait until this stream is not being throttled. + // TODO: instead of checking has_error_, we could just check + // if the graph is done. That could also be indicated by returning an + // error from WaitUntilGraphInputStreamUnthrottled. + while (!has_error_ && !full_input_streams_[node_id].empty()) { + // TODO: allow waiting for a specific stream? + scheduler_.WaitUntilGraphInputStreamUnthrottled( + &full_input_streams_mutex_); + } + if (has_error_) { + ::mediapipe::Status error_status; + GetCombinedErrors("Graph has errors: ", &error_status); + return error_status; + } + } + } + + // Adding profiling info for a new packet entering the graph. + const std::string* stream_id = &(*stream)->GetManager()->Name(); + profiler_->LogEvent(TraceEvent(TraceEvent::PROCESS) + .set_is_finish(true) + .set_input_ts(packet.Timestamp()) + .set_stream_id(stream_id) + .set_packet_ts(packet.Timestamp()) + .set_packet_data_id(&packet)); + + // InputStreamManager is thread safe. GraphInputStream is not, so this method + // should not be called by multiple threads concurrently. Note that this could + // potentially lead to the max queue size being exceeded by one packet at most + // because we don't have the lock over the input stream. + (*stream)->AddPacket(std::forward(packet)); + if (has_error_) { + ::mediapipe::Status error_status; + GetCombinedErrors("Graph has errors: ", &error_status); + return error_status; + } + (*stream)->PropagateUpdatesToMirrors(); + + VLOG(2) << "Packet added directly to: " << stream_name; + // Note: one reason why we need to call the scheduler here is that we have + // re-throttled the graph input streams, and we may need to unthrottle them + // again if the graph is still idle. Unthrottling basically only lets in one + // packet at a time. TODO: add test. + scheduler_.AddedPacketToGraphInputStream(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::SetInputStreamMaxQueueSize( + const std::string& stream_name, int max_queue_size) { + // graph_input_streams_ has not been filled in yet, so we'll check this when + // it is applied when the graph is started. + graph_input_stream_max_queue_size_[stream_name] = max_queue_size; + return ::mediapipe::OkStatus(); +} + +bool CalculatorGraph::HasInputStream(const std::string& stream_name) { + return ::mediapipe::FindOrNull(graph_input_streams_, stream_name) != nullptr; +} + +::mediapipe::Status CalculatorGraph::CloseInputStream( + const std::string& stream_name) { + std::unique_ptr* stream = + ::mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "CloseInputStream called on input stream \"$0\" which is not a graph " + "input stream.", + stream_name); + // The following IsClosed() and Close() sequence is not atomic. Multiple + // threads cannot call CloseInputStream() on the same stream_name at the same + // time. + if ((*stream)->IsClosed()) { + return ::mediapipe::OkStatus(); + } + + (*stream)->Close(); + + if (++num_closed_graph_input_streams_ == graph_input_streams_.size()) { + scheduler_.ClosedAllGraphInputStreams(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::CloseAllInputStreams() { + for (auto& item : graph_input_streams_) { + item.second->Close(); + } + + num_closed_graph_input_streams_ = graph_input_streams_.size(); + scheduler_.ClosedAllGraphInputStreams(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::CloseAllPacketSources() { + for (auto& item : graph_input_streams_) { + item.second->Close(); + } + + num_closed_graph_input_streams_ = graph_input_streams_.size(); + scheduler_.ClosedAllGraphInputStreams(); + scheduler_.CloseAllSourceNodes(); + + return ::mediapipe::OkStatus(); +} + +void CalculatorGraph::RecordError(const ::mediapipe::Status& error) { + VLOG(2) << "RecordError called with " << error; + { + absl::MutexLock lock(&error_mutex_); + errors_.push_back(error); + has_error_ = true; + scheduler_.SetHasError(true); + for (const auto& stream : graph_output_streams_) { + stream->NotifyError(); + } + if (errors_.size() > kMaxNumAccumulatedErrors) { + for (const ::mediapipe::Status& error : errors_) { + LOG(ERROR) << error; + } + LOG(FATAL) << "Forcefully aborting to prevent the framework running out " + "of memory."; + } + } +} + +bool CalculatorGraph::GetCombinedErrors(::mediapipe::Status* error_status) { + return GetCombinedErrors("CalculatorGraph::Run() failed in Run: ", + error_status); +} + +bool CalculatorGraph::GetCombinedErrors(const std::string& error_prefix, + ::mediapipe::Status* error_status) { + absl::MutexLock lock(&error_mutex_); + if (!errors_.empty()) { + *error_status = tool::CombinedStatus(error_prefix, errors_); + return true; + } + return false; +} + +void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state, + const ::mediapipe::Status& status) { + for (int status_handler_index = 0; + status_handler_index < validated_graph_->Config().status_handler_size(); + ++status_handler_index) { + const auto& handler_config = + validated_graph_->Config().status_handler(status_handler_index); + const auto& handler_type = handler_config.status_handler(); + + const auto& status_handler_info = + validated_graph_->StatusHandlerInfos()[status_handler_index]; + const PacketTypeSet& packet_type_set = + status_handler_info.InputSidePacketTypes(); + ::mediapipe::StatusOr> packet_set_statusor = + tool::FillPacketSet(packet_type_set, current_run_side_packets_, + nullptr); + if (!packet_set_statusor.ok()) { + RecordError(::mediapipe::StatusBuilder( + std::move(packet_set_statusor).status(), MEDIAPIPE_LOC) + .SetPrepend() + << "Skipping run of " << handler_type << ": "); + continue; + } + ::mediapipe::StatusOr< + std::unique_ptr> + static_access_statusor = internal::StaticAccessToStatusHandlerRegistry:: + CreateByNameInNamespace(validated_graph_->Package(), handler_type); + CHECK(static_access_statusor.ok()) << handler_type << " is not registered."; + auto static_access = std::move(static_access_statusor).ValueOrDie(); + ::mediapipe::Status handler_result; + if (graph_run_state == GraphRunState::PRE_RUN) { + handler_result = static_access->HandlePreRunStatus( + handler_config.options(), *packet_set_statusor.ValueOrDie(), status); + } else { // POST_RUN + handler_result = static_access->HandleStatus( + handler_config.options(), *packet_set_statusor.ValueOrDie(), status); + } + if (!handler_result.ok()) { + ::mediapipe::StatusBuilder builder(std::move(handler_result), + MEDIAPIPE_LOC); + builder.SetPrepend() << handler_type; + if (graph_run_state == GraphRunState::PRE_RUN) { + builder << "::HandlePreRunStatus failed: "; + } else { // POST_RUN + builder << "::HandleStatus failed: "; + } + RecordError(builder); + } + } +} + +int CalculatorGraph::GetMaxInputStreamQueueSize() { return max_queue_size_; } + +void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream, + bool* stream_was_full) { + // TODO Change the throttling code to use the index directly + // rather than looking up a stream name. + int node_index = validated_graph_->OutputStreamToNode(stream->Name()); + std::unordered_set owned_set; + const std::unordered_set* upstream_nodes; + if (node_index >= validated_graph_->CalculatorInfos().size()) { + // TODO just create a NodeTypeInfo object for each virtual node. + owned_set.insert(node_index); + upstream_nodes = &owned_set; + } else { + upstream_nodes = + &validated_graph_->CalculatorInfos()[node_index].AncestorSources(); + } + CHECK(upstream_nodes); + std::vector nodes_to_schedule; + + { + absl::MutexLock lock(&full_input_streams_mutex_); + // Note that the change in stream status is recomputed here within the + // MutexLock in order to avoid interference between callbacks arriving + // out of order. + // Note that |stream_was_full| is maintained by the node throttling logic + // in this function and is guarded by full_input_streams_mutex_. + bool stream_is_full = stream->IsFull(); + if (*stream_was_full != stream_is_full) { + for (int node_id : *upstream_nodes) { + VLOG(2) << "Stream \"" << stream->Name() << "\" is " + << (stream_is_full ? "throttling" : "no longer throttling") + << " node with node ID " << node_id; + ::mediapipe::LogEvent( + profiler_.get(), + TraceEvent(stream_is_full ? TraceEvent::THROTTLED + : TraceEvent::UNTHROTTLED) + .set_stream_id(&stream->Name())); + bool was_throttled = !full_input_streams_[node_id].empty(); + if (stream_is_full) { + DCHECK_EQ(full_input_streams_[node_id].count(stream), 0); + full_input_streams_[node_id].insert(stream); + } else { + DCHECK_EQ(full_input_streams_[node_id].count(stream), 1); + full_input_streams_[node_id].erase(stream); + } + + bool is_throttled = !full_input_streams_[node_id].empty(); + bool is_graph_input_stream = + node_id >= validated_graph_->CalculatorInfos().size(); + if (is_graph_input_stream) { + // Making these calls while holding full_input_streams_mutex_ + // ensures they are correctly serialized. + // Note: !is_throttled implies was_throttled, but not vice versa. + if (!is_throttled) { + scheduler_.UnthrottledGraphInputStream(); + } else if (!was_throttled && is_throttled) { + scheduler_.ThrottledGraphInputStream(); + } + } else { + if (!is_throttled) { + CalculatorNode& node = (*nodes_)[node_id]; + // Add this node to the scheduler queue if possible. + if (node.Active() && !node.Closed()) { + nodes_to_schedule.emplace_back(&node); + } + } + } + } + } + *stream_was_full = stream_is_full; + } + + if (!nodes_to_schedule.empty()) { + scheduler_.ScheduleUnthrottledReadyNodes(nodes_to_schedule); + } +} + +bool CalculatorGraph::IsNodeThrottled(int node_id) { + absl::MutexLock lock(&full_input_streams_mutex_); + return max_queue_size_ != -1 && !full_input_streams_[node_id].empty(); +} + +bool CalculatorGraph::UnthrottleSources() { + // NOTE: We can be sure that this function will grow input streams enough + // to unthrottle at least one source node. The current stream queue sizes + // will remain unchanged until at least one source node becomes unthrottled. + // This is a sufficient because succesfully growing at least one full input + // stream during each call to UnthrottleSources will eventually resolve + // each deadlock. + std::unordered_set full_streams; + { + absl::MutexLock lock(&full_input_streams_mutex_); + for (std::unordered_set& s : full_input_streams_) { + if (!s.empty()) { + full_streams.insert(s.begin(), s.end()); + } + } + } + for (InputStreamManager* stream : full_streams) { + // The queue size of a graph output stream shouldn't change. Throttling + // should continue until the caller of the graph output stream consumes + // enough packets. + bool is_graph_output_stream = false; + for (auto& graph_output_stream : graph_output_streams_) { + if (stream == graph_output_stream->input_stream()) { + is_graph_output_stream = true; + break; + } + } + if (is_graph_output_stream) { + continue; + } + if (Config().report_deadlock()) { + RecordError(::mediapipe::UnavailableError(absl::StrCat( + "Detected a deadlock due to input throttling for: \"", stream->Name(), + "\". All calculators are idle while packet sources remain active " + "and throttled. Consider adjusting \"max_queue_size\" or " + "\"resolve_deadlock\"."))); + continue; + } + int new_size = stream->QueueSize() + 1; + stream->SetMaxQueueSize(new_size); + LOG_EVERY_N(WARNING, 100) + << "Resolved a deadlock by increasing max_queue_size of input stream: " + << stream->Name() << " to: " << new_size + << ". Consider increasing max_queue_size for better performance."; + } + return !full_streams.empty(); +} + +CalculatorGraph::GraphInputStreamAddMode +CalculatorGraph::GetGraphInputStreamAddMode() const { + absl::MutexLock lock(&full_input_streams_mutex_); + return graph_input_stream_add_mode_; +} + +void CalculatorGraph::SetGraphInputStreamAddMode(GraphInputStreamAddMode mode) { + absl::MutexLock lock(&full_input_streams_mutex_); + graph_input_stream_add_mode_ = mode; +} + +void CalculatorGraph::Cancel() { + // TODO This function should return ::mediapipe::Status. + scheduler_.Cancel(); +} + +void CalculatorGraph::Pause() { scheduler_.Pause(); } + +void CalculatorGraph::Resume() { scheduler_.Resume(); } + +::mediapipe::Status CalculatorGraph::SetServicePacket( + const GraphServiceBase& service, Packet p) { + // TODO: check that the graph has not been started! + service_packets_[service.key] = std::move(p); + return ::mediapipe::OkStatus(); +} + +Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) { + auto it = service_packets_.find(service.key); + if (it == service_packets_.end()) { + return {}; + } + return it->second; +} + +::mediapipe::Status CalculatorGraph::SetExecutorInternal( + const std::string& name, std::shared_ptr executor) { + if (!executors_.emplace(name, executor).second) { + return ::mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) + << "SetExecutor must be called only once for the executor \"" << name + << "\""; + } + if (name.empty()) { + scheduler_.SetExecutor(executor.get()); + } else { + RETURN_IF_ERROR(scheduler_.SetNonDefaultExecutor(name, executor.get())); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorGraph::SetExecutor( + const std::string& name, std::shared_ptr executor) { + RET_CHECK(!initialized_) + << "SetExecutor can only be called before Initialize()"; + if (IsReservedExecutorName(name)) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "\"" << name << "\" is a reserved executor name."; + } + return SetExecutorInternal(name, std::move(executor)); +} + +::mediapipe::Status CalculatorGraph::CreateDefaultThreadPool( + const ThreadPoolExecutorOptions& default_executor_options, + int num_threads) { + MediaPipeOptions extendable_options; + ThreadPoolExecutorOptions* options = + extendable_options.MutableExtension(ThreadPoolExecutorOptions::ext); + *options = default_executor_options; + options->set_num_threads(num_threads); + // clang-format off + ASSIGN_OR_RETURN(Executor* executor, + ThreadPoolExecutor::Create(extendable_options)); + // clang-format on + return SetExecutorInternal("", std::shared_ptr(executor)); +} + +// static +bool CalculatorGraph::IsReservedExecutorName(const std::string& name) { + return ValidatedGraphConfig::IsReservedExecutorName(name); +} + +::mediapipe::Status CalculatorGraph::FinishRun() { + // Check for any errors that may have occurred. + ::mediapipe::Status status = ::mediapipe::OkStatus(); + RETURN_IF_ERROR(profiler_->Stop()); + GetCombinedErrors(&status); + CleanupAfterRun(&status); + return status; +} + +void CalculatorGraph::CleanupAfterRun(::mediapipe::Status* status) { + for (auto& item : graph_input_streams_) { + item.second->Close(); + } + + CallStatusHandlers(GraphRunState::POST_RUN, *status); + if (has_error_) { + // Obtain the combined status again, so that it includes the new errors + // added by CallStatusHandlers. + GetCombinedErrors(status); + CHECK(!status->ok()); + } else { + MEDIAPIPE_CHECK_OK(*status); + } + + for (CalculatorNode& node : *nodes_) { + node.CleanupAfterRun(*status); + } + + for (auto& graph_output_stream : graph_output_streams_) { + graph_output_stream->input_stream()->Close(); + } + + scheduler_.CleanupAfterRun(); + + { + absl::MutexLock lock(&error_mutex_); + errors_.clear(); + has_error_ = false; + } + + { + absl::MutexLock lock(&full_input_streams_mutex_); + full_input_streams_.clear(); + } + // Note: output_side_packets_ and current_run_side_packets_ are not cleared + // in order to enable GetOutputSidePacket after WaitUntilDone. +} + +const OutputStreamManager* CalculatorGraph::FindOutputStreamManager( + const std::string& name) { + return &output_stream_managers_ + .get()[validated_graph_->OutputStreamIndex(name)]; +} + +namespace { +void PrintTimingToInfo(const std::string& label, int64 timer_value) { + const int64 total_seconds = timer_value / 1000000ll; + const int64 days = total_seconds / (3600ll * 24ll); + const int64 hours = (total_seconds / 3600ll) % 24ll; + const int64 minutes = (total_seconds / 60ll) % 60ll; + const int64 seconds = total_seconds % 60ll; + const int64 milliseconds = (timer_value / 1000ll) % 1000ll; + LOG(INFO) << label << " took " + << absl::StrFormat( + "%02lld days, %02lld:%02lld:%02lld.%03lld (total seconds: " + "%lld.%06lld)", + days, hours, minutes, seconds, milliseconds, total_seconds, + timer_value % 1000000ll); +} + +bool MetricElementComparator(const std::pair& e1, + const std::pair& e2) { + return e1.second > e2.second; +} +} // namespace + +::mediapipe::Status CalculatorGraph::GetCalculatorProfiles( + std::vector* profiles) const { + return profiler_->GetCalculatorProfiles(profiles); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h new file mode 100644 index 000000000..f8a1cb8a2 --- /dev/null +++ b/mediapipe/framework/calculator_graph.h @@ -0,0 +1,647 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Declares CalculatorGraph, which links Calculators into a directed acyclic +// graph, and allows its evaluation. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_GRAPH_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/container/fixed_array.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_node.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/executor.h" +#include "mediapipe/framework/graph_output_stream.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/mediapipe_profiling.h" +#include "mediapipe/framework/output_side_packet_impl.h" +#include "mediapipe/framework/output_stream.h" +#include "mediapipe/framework/output_stream_manager.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/output_stream_shard.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/packet_generator_graph.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/scheduler.h" +#include "mediapipe/framework/thread_pool_executor.pb.h" + +#ifndef MEDIAPIPE_DISABLE_GPU +namespace mediapipe { +class GpuResources; +class GpuSharedData; +} // namespace mediapipe +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + +namespace mediapipe { + +typedef ::mediapipe::StatusOr StatusOrPoller; + +// The class representing a DAG of calculator nodes. +// +// CalculatorGraph is the primary API for the MediaPipe Framework. +// In general, CalculatorGraph should be used if the only thing you need +// to do is run the graph (without pushing data in or extracting it as +// the graph runs). +// +// Example: +// // Build dependency "//mediapipe/framework:calculator_framework". +// +// #include "mediapipe/framework/calculator_framework.h" +// +// mediapipe::CalculatorGraphConfig config; +// RETURN_IF_ERROR(mediapipe::tool::ParseGraphFromString(THE_CONFIG, +// &config)); mediapipe::CalculatorGraph graph; +// RETURN_IF_ERROR(graph.Initialize(config)); +// +// std::map extra_side_packets; +// extra_side_packets["video_id"] = mediapipe::MakePacket( +// "3edb9503834e9b42"); +// RETURN_IF_ERROR(graph.Run(extra_side_packets)); +// +// // Run again (demonstrating the more concise initializer list syntax). +// RETURN_IF_ERROR(graph.Run( +// {{"video_id", mediapipe::MakePacket("Ex-uGhDzue4")}})); +// // See mediapipe/framework/graph_runner.h for an interface +// // to insert and extract packets from a graph as it runs. +class CalculatorGraph { + public: + // Defines possible modes for adding a packet to a graph input stream. + // WAIT_TILL_NOT_FULL can be used to control the memory usage of a graph by + // avoiding adding a new packet until all dependent input streams fall below + // the maximum queue size specified in the graph configuration. + // ADD_IF_NOT_FULL could also be used to control the latency if used in a + // real-time graph (e.g. drop camera frames if the MediaPipe graph queues are + // full). + enum class GraphInputStreamAddMode { + // Blocks and waits until none of the affected streams + // are full. Note that if max_queue_size is set to -1, the packet will be + // added regardless of queue size. + WAIT_TILL_NOT_FULL, + // Returns and does not add packet if any affected input + // stream is full. + ADD_IF_NOT_FULL + }; + + // Creates an uninitialized graph. + CalculatorGraph(); + CalculatorGraph(const CalculatorGraph&) = delete; + CalculatorGraph& operator=(const CalculatorGraph&) = delete; + + // Initializes the graph from its proto description (using Initialize()) + // and crashes if something goes wrong. + explicit CalculatorGraph(const CalculatorGraphConfig& config); + virtual ~CalculatorGraph(); + + // Initializes the graph from a its proto description. + // side_packets that are provided at this stage are common across all Run() + // invocations and could be used to execute PacketGenerators immediately. + ::mediapipe::Status Initialize( + const CalculatorGraphConfig& config, + const std::map& side_packets); + + // Convenience version which does not take side packets. + ::mediapipe::Status Initialize(const CalculatorGraphConfig& config); + + // Initializes the CalculatorGraph from the specified graph and subgraph + // configs. Template graph and subgraph configs can be specified through + // |input_templates|. Every subgraph must have its graph type specified in + // CalclatorGraphConfig.type. A subgraph can be instantiated directly by + // specifying its type in |graph_type|. A template graph can be instantiated + // directly by specifying its template arguments in |arguments|. + ::mediapipe::Status Initialize( + const std::vector& configs, + const std::vector& templates, + const std::map& side_packets = {}, + const std::string& graph_type = "", + const Subgraph::SubgraphOptions* options = nullptr); + + // Resturns the canonicalized CalculatorGraphConfig for this graph. + const CalculatorGraphConfig& Config() const { + return validated_graph_->Config(); + } + + // Observes the named output stream. packet_callback will be invoked on every + // packet emitted by the output stream. Can only be called before Run() or + // StartRun(). + // TODO: Rename to AddOutputStreamCallback. + ::mediapipe::Status ObserveOutputStream( + const std::string& stream_name, + std::function<::mediapipe::Status(const Packet&)> packet_callback); + + // Adds an OutputStreamPoller for a stream. This provides a synchronous, + // polling API for accessing a stream's output. For asynchronous output, use + // ObserveOutputStream. See also the helpers in tool/sink.h. + StatusOrPoller AddOutputStreamPoller(const std::string& stream_name); + + // Gets output side packet by name after the graph is done. However, base + // packets (generated by PacketGenerators) can be retrieved before + // graph is done. Returns error if the graph is still running (for non-base + // packets) or the output side packet is not found or empty. + ::mediapipe::StatusOr GetOutputSidePacket( + const std::string& packet_name); + + // Runs the graph after adding the given extra input side packets. All + // arguments are forgotten after Run() returns. + // Run() is a blocking call and will return when all calculators are done. + virtual ::mediapipe::Status Run( + const std::map& extra_side_packets); + + // Run the graph without adding any input side packets. + ::mediapipe::Status Run() { return Run({}); } + + // Start a run of the graph. StartRun, WaitUntilDone, HasError, + // AddPacketToInputStream, and CloseInputStream allow more control over + // the execution of the graph run. You can insert packets directly into + // a stream while the graph is running. Once StartRun has been called, + // the graph will continue to run until WaitUntilDone() is called. + // If StartRun returns an error, then the graph is not started and a + // subsequent call to StartRun can be attempted. + // + // Example: + // RETURN_IF_ERROR(graph.StartRun(...)); + // while (true) { + // if (graph.HasError() || want_to_stop) break; + // RETURN_IF_ERROR(graph.AddPacketToInputStream(...)); + // } + // for (const std::string& stream : streams) { + // RETURN_IF_ERROR(graph.CloseInputStream(stream)); + // } + // RETURN_IF_ERROR(graph.WaitUntilDone()); + ::mediapipe::Status StartRun( + const std::map& extra_side_packets) { + return StartRun(extra_side_packets, {}); + } + + // In addition to the above StartRun, add additional parameter to set the + // stream header before running. + // Note: We highly discourage the use of stream headers, this is added for the + // compatibility of existing calculators that use headers during Open(). + ::mediapipe::Status StartRun( + const std::map& extra_side_packets, + const std::map& stream_headers); + + // Wait for the current run to finish (block the current thread + // until all source calculators have returned StatusStop(), all + // graph_input_streams_ have been closed, and no more calculators can + // be run). This function can be called only after StartRun(). + ::mediapipe::Status WaitUntilDone(); + + // Wait until the running graph is in the idle mode, which is when nothing can + // be scheduled and nothing is running in the worker threads. This function + // can be called only after StartRun(). + // NOTE: The graph must not have any source nodes because source nodes prevent + // the running graph from becoming idle until the source nodes are done. + ::mediapipe::Status WaitUntilIdle(); + + // Wait until a packet is emitted on one of the observed output streams. + // Returns immediately if a packet has already been emitted since the last + // call to this function. + // Returns OutOfRangeError if the graph terminated while waiting. + ::mediapipe::Status WaitForObservedOutput(); + + // Quick non-locking means of checking if the graph has encountered an error. + bool HasError() const { return has_error_; } + + // Add a Packet to a graph input stream based on the graph input stream add + // mode. If the mode is ADD_IF_NOT_FULL, the packet will not be added if any + // queue exceeds max_queue_size specified by the graph config and will return + // StatusUnavailable. The WAIT_TILL_NOT_FULL mode (default) will block until + // the queues fall below the max_queue_size before adding the packet. If the + // mode is max_queue_size is -1, then the packet is added regardless of the + // sizes of the queues in the graph. The input stream must have been specified + // in the configuration as a graph level input_stream. On error, nothing is + // added. + ::mediapipe::Status AddPacketToInputStream(const std::string& stream_name, + const Packet& packet); + + // Same as the l-value version of this function by the same name, but moves + // the r-value referenced packet into the stream instead of copying it over. + // This allows the graph to take exclusive ownership of the packet, which may + // allow more memory optimizations. Note that, if an error is returned, the + // packet may remain valid. In particular, when using the ADD_IF_NOT_FULL + // mode with a full queue, this will return StatusUnavailable and the caller + // may try adding the packet again later. + ::mediapipe::Status AddPacketToInputStream(const std::string& stream_name, + Packet&& packet); + + // Sets the queue size of a graph input stream, overriding the graph default. + ::mediapipe::Status SetInputStreamMaxQueueSize(const std::string& stream_name, + int max_queue_size); + + // Check if an input stream exists in the graph + bool HasInputStream(const std::string& name); + + // Close a graph input stream. If the graph has any graph input streams + // then Run() will not return until all the graph input streams have + // been closed (and all packets propagate through the graph). + // Note that multiple threads cannot call CloseInputStream() on the same + // stream_name at the same time. + ::mediapipe::Status CloseInputStream(const std::string& stream_name); + + // Closes all the graph input streams. + // TODO: deprecate this function in favor of CloseAllPacketSources. + ::mediapipe::Status CloseAllInputStreams(); + + // Closes all the graph input streams and source calculator nodes. + ::mediapipe::Status CloseAllPacketSources(); + + // Returns the pointer to the stream with the given name, or dies if none + // exists. The result remains owned by the CalculatorGraph. + ABSL_DEPRECATED( + "Prefer using a Calculator to get information of all sorts out of the " + "graph.") + const OutputStreamManager* FindOutputStreamManager(const std::string& name); + + // Returns the ProfilingContext assocoaited with the CalculatorGraph. + ProfilingContext* profiler() { return profiler_.get(); } + // Collects the runtime profile for Open(), Process(), and Close() of each + // calculator in the graph. May be called at any time after the graph has been + // initialized. + ABSL_DEPRECATED("Use profiler()->GetCalculatorProfiles() instead") + ::mediapipe::Status GetCalculatorProfiles( + std::vector*) const; + + // Set the type of counter used in this graph. + void SetCounterFactory(CounterFactory* factory) { + counter_factory_.reset(factory); + } + CounterFactory* GetCounterFactory() { return counter_factory_.get(); } + + // Callback when an error is encountered. + // Adds the error to the vector of errors. + void RecordError(const ::mediapipe::Status& error) + LOCKS_EXCLUDED(error_mutex_); + + // Returns the maximum input stream queue size. + int GetMaxInputStreamQueueSize(); + + // Get the mode for adding packets to an input stream. + GraphInputStreamAddMode GetGraphInputStreamAddMode() const; + + // Set the mode for adding packets to an input stream. + void SetGraphInputStreamAddMode(GraphInputStreamAddMode mode); + + // Aborts the scheduler if the graph is not terminated; no-op otherwise. + void Cancel(); + + // Pauses the scheduler. Only used by calculator graph testing. + ABSL_DEPRECATED( + "CalculatorGraph will not allow external callers to explictly pause and " + "resume a graph.") + void Pause(); + + // Resumes the scheduler. Only used by calculator graph testing. + ABSL_DEPRECATED( + "CalculatorGraph will not allow external callers to explictly pause and " + "resume a graph.") + void Resume(); + + // Sets the executor that will run the nodes assigned to the executor + // named |name|. If |name| is empty, this sets the default executor. Must + // be called before the graph is initialized. + ::mediapipe::Status SetExecutor(const std::string& name, + std::shared_ptr executor); + + // WARNING: the following public methods are exposed to Scheduler only. + + // Return true if all the graph input streams have been closed. + bool GraphInputStreamsClosed() { + return num_closed_graph_input_streams_ == graph_input_streams_.size(); + } + + // Returns true if this node or graph input stream is connected to + // any input stream whose queue has hit maximum capacity. + bool IsNodeThrottled(int node_id) LOCKS_EXCLUDED(full_input_streams_mutex_); + + // If any active source node or graph input stream is throttled and not yet + // closed, increases the max_queue_size for each full input stream in the + // graph. + // Returns true if at least one max_queue_size has been grown. + bool UnthrottleSources() LOCKS_EXCLUDED(full_input_streams_mutex_); + + // Returns the scheduler's runtime measures for overhead measurement. + // Only meant for test purposes. + internal::SchedulerTimes GetSchedulerTimes() { + return scheduler_.GetSchedulerTimes(); + } + +#ifndef MEDIAPIPE_DISABLE_GPU + // Returns a pointer to the GpuResources in use, if any. + // Only meant for internal use. + std::shared_ptr<::mediapipe::GpuResources> GetGpuResources() const; + + ::mediapipe::Status SetGpuResources( + std::shared_ptr<::mediapipe::GpuResources> resources); + + // Helper for PrepareForRun. If it returns a non-empty map, those packets + // must be added to the existing side packets, replacing existing values + // that have the same key. + ::mediapipe::StatusOr> PrepareGpu( + const std::map& side_packets); +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + template + ::mediapipe::Status SetServiceObject(const GraphService& service, + std::shared_ptr object) { + return SetServicePacket(service, + MakePacket>(std::move(object))); + } + + template + std::shared_ptr GetServiceObject(const GraphService& service) { + Packet p = GetServicePacket(service); + if (p.IsEmpty()) return nullptr; + return p.Get>(); + } + + // Only the Java API should call this directly. + ::mediapipe::Status SetServicePacket(const GraphServiceBase& service, + Packet p); + + private: + // GraphRunState is used as a parameter in the function CallStatusHandlers. + enum class GraphRunState { + // State of the graph before the run; see status_handler.h for details. + PRE_RUN, + // State of the graph after after the run; set by CleanUpAfterRun. + POST_RUN, + }; + + // The graph input streams (which have packets added to them from + // outside the graph). Since these will be connected directly to a + // node's input streams they are implemented as "output" streams. + // Based on the assumption that all the graph input packets must be added to a + // graph input stream sequentially, a GraphInputStream object only contains + // one reusable output stream shard. + class GraphInputStream { + public: + explicit GraphInputStream(OutputStreamManager* manager) + : manager_(manager) { + shard_.SetSpec(manager_->Spec()); + } + + void PrepareForRun( + std::function error_callback) { + manager_->PrepareForRun(std::move(error_callback)); + } + + void SetMaxQueueSize(int max_queue_size) { + manager_->SetMaxQueueSize(max_queue_size); + } + + void SetHeader(const Packet& header); + + void AddPacket(const Packet& packet) { shard_.AddPacket(packet); } + + void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + + void PropagateUpdatesToMirrors(); + + void Close(); + + bool IsClosed() const { return manager_->IsClosed(); } + + OutputStreamManager* GetManager() { return manager_; } + + private: + OutputStreamManager* manager_ = nullptr; + OutputStreamShard shard_; + }; + + // Initializes the graph from a ValidatedGraphConfig object. + ::mediapipe::Status Initialize( + std::unique_ptr validated_graph, + const std::map& side_packets); + + // AddPacketToInputStreamInternal template is called by either + // AddPacketToInputStream(Packet&& packet) or + // AddPacketToInputStream(const Packet& packet). + template + ::mediapipe::Status AddPacketToInputStreamInternal( + const std::string& stream_name, T&& packet); + + // Sets the executor that will run the nodes assigned to the executor + // named |name|. If |name| is empty, this sets the default executor. + // Does not check that the graph is uninitialized and |name| is not a + // reserved executor name. + ::mediapipe::Status SetExecutorInternal(const std::string& name, + std::shared_ptr executor); + + // If the num_threads field in default_executor_options is not specified, + // assigns a reasonable value based on system configuration and the graph. + // Then, creates the default thread pool if appropriate. + // + // Only called by InitializeExecutors(). + ::mediapipe::Status InitializeDefaultExecutor( + const ThreadPoolExecutorOptions& default_executor_options, + bool use_application_thread); + + // Creates a thread pool as the default executor. The num_threads argument + // overrides the num_threads field in default_executor_options. + ::mediapipe::Status CreateDefaultThreadPool( + const ThreadPoolExecutorOptions& default_executor_options, + int num_threads); + + // Returns true if |name| is a reserved executor name. + static bool IsReservedExecutorName(const std::string& name); + + // Helper functions for Initialize(). + ::mediapipe::Status InitializeExecutors(); + ::mediapipe::Status InitializePacketGeneratorGraph( + const std::map& side_packets); + ::mediapipe::Status InitializeStreams(); + ::mediapipe::Status InitializeProfiler(); + ::mediapipe::Status InitializeCalculatorNodes(); + + // Iterates through all nodes and schedules any that can be opened. + void ScheduleAllOpenableNodes(); + + // Does the bulk of the work for StartRun but does not start the scheduler. + ::mediapipe::Status PrepareForRun( + const std::map& extra_side_packets, + const std::map& stream_headers); + + // Cleans up any remaining state after the run and returns any errors that may + // have occurred during the run. Called after the scheduler has terminated. + ::mediapipe::Status FinishRun(); + + // Cleans up any remaining state after the run. All status handlers run here + // if their requested input side packets exist. + // The original |*status| is passed to all the status handlers. If any status + // handler fails, it appends its error to errors_, and CleanupAfterRun sets + // |*status| to the new combined errors on return. + void CleanupAfterRun(::mediapipe::Status* status) + LOCKS_EXCLUDED(error_mutex_); + + // Combines errors into a status. Returns true if the vector of errors is + // non-empty. + bool GetCombinedErrors(const std::string& error_prefix, + ::mediapipe::Status* error_status); + // Convenience overload which specifies a default error prefix. + bool GetCombinedErrors(::mediapipe::Status* error_status); + + // Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one + // is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN). + // current_run_side_packets_ must be set before this function is called. + // On error, has_error_ will be set. + void CallStatusHandlers(GraphRunState graph_run_state, + const ::mediapipe::Status& status); + + // Callback function to throttle or unthrottle source nodes when a stream + // becomes full or non-full. A node is throttled (i.e. prevented being + // scheduled) if it has caused a downstream input queue to become full. Note + // that all sources (including graph input streams) that affect this stream + // will be throttled. A node is unthrottled (i.e. added to the scheduler + // queue) if all downstream input queues have become non-full. + // + // This method is invoked from an input stream when its queue becomes full or + // non-full. However, since streams are not allowed to hold any locks while + // invoking a callback, this method must re-lock the stream and query its + // status before taking any action. + void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full); + + Packet GetServicePacket(const GraphServiceBase& service); +#ifndef MEDIAPIPE_DISABLE_GPU + // Owns the legacy GpuSharedData if we need to create one for backwards + // compatibility. + std::unique_ptr<::mediapipe::GpuSharedData> legacy_gpu_shared_; +#endif // !defined(MEDIAPIPE_DISABLE_GPU) + + // True if the graph was initialized. + bool initialized_ = false; + + // A packet type that has SetAny() called on it. + PacketType any_packet_type_; + + // The ValidatedGraphConfig object defining this CalculatorGraph. + std::unique_ptr validated_graph_; + + // The PacketGeneratorGraph to use to generate all the input side packets. + PacketGeneratorGraph packet_generator_graph_; + + // True if the graph has source nodes. + bool has_sources_ = false; + + // A flat array of InputStreamManager/OutputStreamManager/ + // OutputSidePacketImpl/CalculatorNode corresponding to the input/output + // stream indexes, output side packet indexes, and calculator indexes + // respectively in validated_graph_. + // Once allocated these structures must not be reallocated since + // internal structures may point to individual entries in the array. + std::unique_ptr input_stream_managers_; + std::unique_ptr output_stream_managers_; + std::unique_ptr output_side_packets_; + std::unique_ptr> nodes_; + + // The graph output streams. + std::vector> + graph_output_streams_; + + // Maximum queue size for an input stream. This is used by the scheduler to + // restrict memory usage. + int max_queue_size_ = -1; + + // Mode for adding packets to a graph input stream. Set to block until all + // affected input streams are not full by default. + GraphInputStreamAddMode graph_input_stream_add_mode_ + GUARDED_BY(full_input_streams_mutex_); + + // For a source node or graph input stream (specified using id), + // this stores the set of dependent input streams that have hit their + // maximum capacity. Graph input streams are also treated as nodes. + // A node is scheduled only if this set is empty. Similarly, a packet + // is added to a graph input stream only if this set is empty. + // Note that this vector contains an unused entry for each non-source node. + std::vector> full_input_streams_ + GUARDED_BY(full_input_streams_mutex_); + + // Maps stream names to graph input stream objects. + std::unordered_map> + graph_input_streams_; + + // Maps graph input streams to their virtual node ids. + std::unordered_map graph_input_stream_node_ids_; + + // Maps graph input streams to their max queue size. + std::unordered_map graph_input_stream_max_queue_size_; + + // The factory for making counters associated with this graph. + std::unique_ptr counter_factory_; + + // Executors for the scheduler, keyed by the executor's name. The default + // executor's name is the empty std::string. + std::map> executors_; + + // The processed input side packet map for this run. + std::map current_run_side_packets_; + + std::map service_packets_; + + // Vector of errors encountered while running graph. Always use RecordError() + // to add an error to this vector. + std::vector<::mediapipe::Status> errors_ GUARDED_BY(error_mutex_); + + // True if the default executor uses the application thread. + bool use_application_thread_ = false; + + // Condition variable that waits until all input streams that depend on a + // graph input stream are below the maximum queue size. + absl::CondVar wait_to_add_packet_cond_var_ + GUARDED_BY(full_input_streams_mutex_); + + // Mutex for the vector of errors. + absl::Mutex error_mutex_; + + // Status variable to indicate if the graph has encountered an error. + std::atomic has_error_; + + // Mutex for full_input_streams_. + mutable absl::Mutex full_input_streams_mutex_; + + // Number of closed graph input streams. This is a separate variable because + // it is not safe to hold a lock on the scheduler while calling Close() on an + // input stream. Hence, we decouple the closing of the stream and checking its + // status. + // TODO: update this comment. + std::atomic num_closed_graph_input_streams_; + + // The graph tracing and profiling interface. It is owned by the + // CalculatorGraph using a shared_ptr in order to allow threadsafe access + // to the ProfilingContext from clients that may outlive the CalculatorGraph + // such as GlContext. It is declared here before the Scheduler so that it + // remains available during the Scheduler destructor. + std::shared_ptr profiler_; + + internal::Scheduler scheduler_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_GRAPH_H_ diff --git a/mediapipe/framework/calculator_graph_event_loop_test.cc b/mediapipe/framework/calculator_graph_event_loop_test.cc new file mode 100644 index 000000000..835a38b13 --- /dev/null +++ b/mediapipe/framework/calculator_graph_event_loop_test.cc @@ -0,0 +1,546 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/framework/tool/status_util.h" + +namespace mediapipe { + +namespace { + +class CalculatorGraphEventLoopTest : public testing::Test { + public: + void AddThreadSafeVectorSink(const Packet& packet) { + absl::WriterMutexLock lock(&output_packets_mutex_); + output_packets_.push_back(packet); + } + + protected: + std::vector output_packets_ GUARDED_BY(output_packets_mutex_); + absl::Mutex output_packets_mutex_; +}; + +// Allows blocking of the Process() call by locking the blocking_mutex passed to +// the input side packet. Used to force input stream queues to build up for +// testing. +class BlockingPassThroughCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Index(0).Set>(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + mutex_ = GetFromUniquePtr(cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + mutex_->Lock(); + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(cc->InputTimestamp())); + mutex_->Unlock(); + return ::mediapipe::OkStatus(); + } + + private: + absl::Mutex* mutex_; +}; + +REGISTER_CALCULATOR(BlockingPassThroughCalculator); + +struct SimpleHeader { + int width; + int height; +}; + +class UsingHeaderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + if (cc->Inputs().Index(0).Header().IsEmpty()) { + return ::mediapipe::UnknownError("No stream header present."); + } + + const SimpleHeader& header = + cc->Inputs().Index(0).Header().Get(); + std::unique_ptr output_header(new SimpleHeader); + output_header->width = header.width; + output_header->height = header.height; + + cc->Outputs().Index(0).SetHeader(Adopt(output_header.release())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(cc->InputTimestamp())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(UsingHeaderCalculator); + +TEST_F(CalculatorGraphEventLoopTest, WellProvisionedEventLoop) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "PassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers" + )", + &graph_config)); + + // Start MediaPipe graph. + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}})); + + // Insert 100 packets at the rate the calculator can keep up with. + for (int i = 0; i < 100; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i)))); + // Wait for all packets to be received by the sink. + while (true) { + { + absl::ReaderMutexLock lock(&output_packets_mutex_); + if (output_packets_.size() > i) { + break; + } + } + absl::SleepFor(absl::Microseconds(1)); + } + } + // Check partial results. + { + absl::ReaderMutexLock lock(&output_packets_mutex_); + ASSERT_EQ(100, output_packets_.size()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i, output_packets_[i].Get()); + } + } + + // Insert 100 more packets at rate the graph can't keep up. + for (int i = 100; i < 200; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i)))); + } + // Don't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + // Wait properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + // Check final results. + { + absl::ReaderMutexLock lock(&output_packets_mutex_); + ASSERT_EQ(200, output_packets_.size()); + for (int i = 0; i < 200; ++i) { + EXPECT_EQ(i, output_packets_[i].Get()); + } + } +} + +// Pass-Through calculator that fails upon receiving the 10th packet. +class FailingPassThroughCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + Timestamp timestamp = cc->InputTimestamp(); + if (timestamp.Value() == 9) { + return ::mediapipe::UnknownError( + "Meant to fail (magicstringincludedhere)."); + } + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(timestamp)); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(FailingPassThroughCalculator); + +TEST_F(CalculatorGraphEventLoopTest, FailingEventLoop) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "FailingPassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers")", + &graph_config)); + + // Start MediaPipe graph. + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}})); + + // Insert packets. + ::mediapipe::Status status; + for (int i = 0; true; ++i) { + status = graph.AddPacketToInputStream("input_numbers", + Adopt(new int(i)).At(Timestamp(i))); + if (!status.ok()) { + ASSERT_TRUE(graph.HasError()); // Graph failed. + ASSERT_THAT( + status.message(), + testing::HasSubstr("Meant to fail (magicstringincludedhere).")); + break; + } + } + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + status = graph.WaitUntilDone(); + ASSERT_THAT(status.message(), + testing::HasSubstr("Meant to fail (magicstringincludedhere).")); +} + +// Test the step by step mode. +TEST_F(CalculatorGraphEventLoopTest, StepByStepSchedulerLoop) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "PassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers" + )", + &graph_config)); + + // Start MediaPipe graph. + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}})); + + // Add packet one at a time, we should be able to syncrhonize the output for + // each addition in the step by step mode. + for (int i = 0; i < 100; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i)))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + absl::ReaderMutexLock lock(&output_packets_mutex_); + ASSERT_EQ(i + 1, output_packets_.size()); + } + // Don't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + // Wait properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +// Test setting the stream header. +TEST_F(CalculatorGraphEventLoopTest, SetStreamHeader) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "UsingHeaderCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers" + )", + &graph_config)); + + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}})); + + ::mediapipe::Status status = graph.WaitUntilIdle(); + // Expect to fail if header not set. + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown); + EXPECT_THAT(status.message(), + testing::HasSubstr("No stream header present.")); + + CalculatorGraph graph2(graph_config); + std::unique_ptr header(new SimpleHeader); + header->width = 320; + header->height = 240; + // With stream header set, the StartRun should succeed. + MEDIAPIPE_ASSERT_OK(graph2.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}}, + {{"input_numbers", Adopt(header.release())}})); + // Don't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph2.CloseInputStream("input_numbers")); + // Wait properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph2.WaitUntilDone()); +} + +// Test ADD_IF_NOT_FULL mode for graph input streams (by creating more packets +// than the queue will support). At least some of these attempts should fail. +TEST_F(CalculatorGraphEventLoopTest, TryToAddPacketToInputStream) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "BlockingPassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + input_side_packet: "blocking_mutex" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers" + num_threads: 2 + max_queue_size: 1 + )", + &graph_config)); + + absl::Mutex* mutex = new absl::Mutex(); + Packet mutex_side_packet = AdoptAsUniquePtr(mutex); + + CalculatorGraph graph(graph_config); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); + + // Start MediaPipe graph. + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}, + {"blocking_mutex", mutex_side_packet}})); + + constexpr int kNumInputPackets = 2; + constexpr int kMaxQueueSize = 1; + + // Lock the mutex so that the BlockingPassThroughCalculator cannot read any of + // these packets. + mutex->Lock(); + int fail_count = 0; + // Expect at least kNumInputPackets - kMaxQueueSize - 1 attempts to add + // packets to fail since the queue builds up. The -1 is because our throttling + // mechanism could be off by 1 at most due to the order of acquisition of + // locks. + for (int i = 0; i < kNumInputPackets; ++i) { + ::mediapipe::Status status = graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i))); + if (!status.ok()) { + ++fail_count; + } + } + mutex->Unlock(); + + EXPECT_GE(fail_count, kNumInputPackets - kMaxQueueSize - 1); + // Don't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + // Wait properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +// Verify that "max_queue_size: -1" disables throttling of graph-input-streams. +TEST_F(CalculatorGraphEventLoopTest, ThrottlingDisabled) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "BlockingPassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + input_side_packet: "blocking_mutex" + } + input_stream: "input_numbers" + max_queue_size: -1 + )", + &graph_config)); + + absl::Mutex* mutex = new absl::Mutex(); + Packet mutex_side_packet = AdoptAsUniquePtr(mutex); + + CalculatorGraph graph(graph_config); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); + + // Start MediaPipe graph. + MEDIAPIPE_ASSERT_OK(graph.StartRun({{"blocking_mutex", mutex_side_packet}})); + + // Lock the mutex so that the BlockingPassThroughCalculator cannot read any + // of these packets. + mutex->Lock(); + for (int i = 0; i < 10; ++i) { + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i)))); + } + mutex->Unlock(); + MEDIAPIPE_EXPECT_OK(graph.CloseInputStream("input_numbers")); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilDone()); +} + +// Verify that the graph input stream throttling code still works if we run the +// graph twice. +TEST_F(CalculatorGraphEventLoopTest, ThrottleGraphInputStreamTwice) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "BlockingPassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + input_side_packet: "blocking_mutex" + } + input_stream: "input_numbers" + max_queue_size: 1 + )", + &graph_config)); + + absl::Mutex* mutex = new absl::Mutex(); + Packet mutex_side_packet = AdoptAsUniquePtr(mutex); + + CalculatorGraph graph(graph_config); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); + + // Run the graph twice. + for (int i = 0; i < 2; ++i) { + // Start MediaPipe graph. + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"blocking_mutex", mutex_side_packet}})); + + // Lock the mutex so that the BlockingPassThroughCalculator cannot read any + // of these packets. + mutex->Lock(); + ::mediapipe::Status status = ::mediapipe::OkStatus(); + for (int i = 0; i < 10; ++i) { + status = graph.AddPacketToInputStream("input_numbers", + Adopt(new int(i)).At(Timestamp(i))); + if (!status.ok()) { + break; + } + } + mutex->Unlock(); + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnavailable); + EXPECT_THAT(status.message(), testing::HasSubstr("Graph is throttled.")); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + } +} + +// Test WAIT_TILL_NOT_FULL mode (default mode) for graph input streams (by +// creating more packets than the queue will support). All packets sent to the +// graph should be processed. +TEST_F(CalculatorGraphEventLoopTest, WaitToAddPacketToInputStream) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString( + R"( + node { + calculator: "PassThroughCalculator" + input_stream: "input_numbers" + output_stream: "output_numbers" + } + node { + calculator: "CallbackCalculator" + input_stream: "output_numbers" + input_side_packet: "CALLBACK:callback" + } + input_stream: "input_numbers" + num_threads: 2 + max_queue_size: 10 + )", + &graph_config)); + + // Start MediaPipe graph. + CalculatorGraph graph(graph_config); + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, + this, std::placeholders::_1))}})); + + constexpr int kNumInputPackets = 20; + // All of these packets should be accepted by the graph. + int fail_count = 0; + for (int i = 0; i < kNumInputPackets; ++i) { + ::mediapipe::Status status = graph.AddPacketToInputStream( + "input_numbers", Adopt(new int(i)).At(Timestamp(i))); + if (!status.ok()) { + ++fail_count; + } + } + + EXPECT_EQ(0, fail_count); + + // Don't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_numbers")); + // Wait properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + absl::ReaderMutexLock lock(&output_packets_mutex_); + ASSERT_EQ(kNumInputPackets, output_packets_.size()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_graph_stopping_test.cc b/mediapipe/framework/calculator_graph_stopping_test.cc new file mode 100644 index 000000000..37f9b3171 --- /dev/null +++ b/mediapipe/framework/calculator_graph_stopping_test.cc @@ -0,0 +1,383 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/framework/tool/status_util.h" + +namespace mediapipe {} + +namespace testing_ns { +using ::mediapipe::CalculatorBase; +using ::mediapipe::CalculatorContext; +using ::mediapipe::CalculatorContract; +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::GetFromUniquePtr; +using ::mediapipe::InputStreamShardSet; +using ::mediapipe::MakePacket; +using ::mediapipe::OutputStreamShardSet; +using ::mediapipe::Timestamp; +namespace proto_ns = ::mediapipe::proto_ns; +using ::mediapipe::CalculatorGraph; +using ::mediapipe::Packet; + +class InfiniteSequenceCalculator : public mediapipe::CalculatorBase { + public: + static ::mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) { + cc->Outputs().Tag("OUT").Set(); + cc->Outputs().Tag("EVENT").Set(); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Tag("OUT").AddPacket( + MakePacket(count_).At(Timestamp(count_))); + count_++; + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Close(CalculatorContext* cc) override { + cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(::testing_ns::InfiniteSequenceCalculator); + +class StoppingPassThroughCalculator : public mediapipe::CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { + cc->Inputs().Get("", i).SetAny(); + cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); + } + cc->Outputs().Tag("EVENT").Set(); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { + for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { + if (!cc->Inputs().Get("", i).IsEmpty()) { + cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value()); + } + } + return (++count_ <= max_count_) ? ::mediapipe::OkStatus() + : ::mediapipe::tool::StatusStop(); + } + ::mediapipe::Status Close(CalculatorContext* cc) override { + cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; + int max_count_ = 10; +}; +REGISTER_CALCULATOR(::testing_ns::StoppingPassThroughCalculator); + +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_ += amount; } + + private: + std::atomic supply_; +}; + +// A ProcessFunction that passes through all packets. +::mediapipe::Status DoProcess(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + for (int i = 0; i < inputs.NumEntries(); ++i) { + if (!inputs.Index(i).Value().IsEmpty()) { + outputs->Index(i).AddPacket(inputs.Index(i).Value()); + } + } + return ::mediapipe::OkStatus(); +} + +typedef std::function<::mediapipe::Status(const InputStreamShardSet&, + OutputStreamShardSet*)> + ProcessFunction; + +// A Calculator that delegates its Process function to a callback function. +class ProcessCallbackCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); + } + cc->InputSidePackets().Index(0).Set>(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + callback_ = + *GetFromUniquePtr(cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return callback_(cc->Inputs(), &(cc->Outputs())); + } + + private: + ProcessFunction callback_; +}; +REGISTER_CALCULATOR(::testing_ns::ProcessCallbackCalculator); + +// Tests CloseAllPacketSources. +TEST(CalculatorGraphStoppingTest, CloseAllPacketSources) { + CalculatorGraphConfig graph_config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + max_queue_size: 5 + input_stream: 'input' + node { + calculator: 'InfiniteSequenceCalculator' + output_stream: 'OUT:count' + output_stream: 'EVENT:event' + } + node { + calculator: 'StoppingPassThroughCalculator' + input_stream: 'count' + input_stream: 'input' + output_stream: 'count_out' + output_stream: 'input_out' + output_stream: 'EVENT:event_out' + } + package: 'testing_ns' + )", + &graph_config)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config, {})); + + // Observe output packets, and call CloseAllPacketSources after kNumPackets. + std::vector out_packets; + std::vector count_packets; + std::vector event_packets; + std::vector event_out_packets; + int kNumPackets = 8; + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( // + "input_out", [&](const Packet& packet) { + out_packets.push_back(packet); + if (out_packets.size() >= kNumPackets) { + MEDIAPIPE_EXPECT_OK(graph.CloseAllPacketSources()); + } + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( // + "count_out", [&](const Packet& packet) { + count_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( // + "event", [&](const Packet& packet) { + event_packets.push_back(packet.Get()); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( // + "event_out", [&](const Packet& packet) { + event_out_packets.push_back(packet.Get()); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + for (int i = 0; i < kNumPackets; ++i) { + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "input", MakePacket(i).At(Timestamp(i)))); + } + + // The graph run should complete with no error status. + MEDIAPIPE_EXPECT_OK(graph.WaitUntilDone()); + EXPECT_EQ(kNumPackets, out_packets.size()); + EXPECT_LE(kNumPackets, count_packets.size()); + std::vector expected_events = {1, 2}; + EXPECT_EQ(event_packets, expected_events); + EXPECT_EQ(event_out_packets, expected_events); +} + +// Verify that deadlock due to throttling can be reported. +TEST(CalculatorGraphStoppingTest, DeadlockReporting) { + CalculatorGraphConfig config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: 'in_1' + input_stream: 'in_2' + max_queue_size: 2 + node { + calculator: 'ProcessCallbackCalculator' + input_stream: 'in_1' + input_stream: 'in_2' + output_stream: 'out_1' + output_stream: 'out_2' + input_side_packet: 'callback_1' + } + package: 'testing_ns' + report_deadlock: true + )", + &config)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); + std::vector out_packets; + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Lambda that waits for a local semaphore. + AtomicSemaphore semaphore(0); + ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + semaphore.Acquire(1); + return DoProcess(inputs, outputs); + }; + + // Lambda that adds a packet to the calculator graph. + auto add_packet = [&graph](std::string s, int i) { + return graph.AddPacketToInputStream(s, MakePacket(i).At(Timestamp(i))); + }; + + // Start the graph. + MEDIAPIPE_ASSERT_OK(graph.StartRun({ + {"callback_1", AdoptAsUniquePtr(new auto(callback_1))}, + })); + + // Add 3 packets to "in_1" with no packets on "in_2". + // This causes throttling and deadlock with max_queue_size 2. + semaphore.Release(3); + MEDIAPIPE_EXPECT_OK(add_packet("in_1", 1)); + MEDIAPIPE_EXPECT_OK(add_packet("in_1", 2)); + EXPECT_FALSE(add_packet("in_1", 3).ok()); + + ::mediapipe::Status status = graph.WaitUntilIdle(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnavailable); + EXPECT_THAT( + status.message(), + testing::HasSubstr("Detected a deadlock due to input throttling")); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + EXPECT_FALSE(graph.WaitUntilDone().ok()); + ASSERT_EQ(0, out_packets.size()); +} + +// Verify that input streams grow due to deadlock resolution. +TEST(CalculatorGraphStoppingTest, DeadlockResolution) { + CalculatorGraphConfig config; + ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( + input_stream: 'in_1' + input_stream: 'in_2' + max_queue_size: 2 + node { + calculator: 'ProcessCallbackCalculator' + input_stream: 'in_1' + input_stream: 'in_2' + output_stream: 'out_1' + output_stream: 'out_2' + input_side_packet: 'callback_1' + } + package: 'testing_ns' + )", + &config)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); + std::vector out_packets; + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Lambda that waits for a local semaphore. + AtomicSemaphore semaphore(0); + ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + semaphore.Acquire(1); + return DoProcess(inputs, outputs); + }; + + // Lambda that adds a packet to the calculator graph. + auto add_packet = [&graph](std::string s, int i) { + return graph.AddPacketToInputStream(s, MakePacket(i).At(Timestamp(i))); + }; + + // Start the graph. + MEDIAPIPE_ASSERT_OK(graph.StartRun({ + {"callback_1", AdoptAsUniquePtr(new auto(callback_1))}, + })); + + // Add 9 packets to "in_1" with no packets on "in_2". + // This grows the input stream "in_1" to max-queue-size 10. + semaphore.Release(9); + for (int i = 1; i <= 9; ++i) { + MEDIAPIPE_EXPECT_OK(add_packet("in_1", i)); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + } + + // Advance the timestamp-bound and flush "in_1". + semaphore.Release(1); + MEDIAPIPE_EXPECT_OK(add_packet("in_2", 30)); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + // Fill up input stream "in_1", with the semaphore blocked and deadlock + // resolution disabled. + for (int i = 11; i < 23; ++i) { + MEDIAPIPE_EXPECT_OK(add_packet("in_1", i)); + } + + // Adding any more packets fails with error "Graph is throttled". + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); + EXPECT_FALSE(add_packet("in_1", 23).ok()); + + // Allow the 12 blocked calls to "callback_1" to complete. + semaphore.Release(12); + + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(21, out_packets.size()); +} + +} // namespace testing_ns diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc new file mode 100644 index 000000000..2976a1fec --- /dev/null +++ b/mediapipe/framework/calculator_graph_test.cc @@ -0,0 +1,5618 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_graph.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/time/clock.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/executor.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/lifetime_tracker.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/status_handler.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/framework/thread_pool_executor.h" +#include "mediapipe/framework/thread_pool_executor.pb.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/type_map.h" + +namespace mediapipe { + +namespace { + +// Pass packets through. Note that it calls SetOffset() in Process() +// instead of Open(). +class SetOffsetInProcessCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + // Input: arbitrary Packets. + // Output: copy of the input. + cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + cc->GetCounter("PassThrough")->Increment(); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SetOffsetInProcessCalculator); + +// A Calculator that outputs the square of its input packet (an int). +class SquareIntCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int value = cc->Inputs().Index(0).Value().Get(); + cc->Outputs().Index(0).Add(new int(value * value), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SquareIntCalculator); + +// A Calculator that selects an output stream from "OUTPUT:0", "OUTPUT:1", ..., +// using the integer value (0, 1, ...) in the packet on the "SELECT" input +// stream, and passes the packet on the "INPUT" input stream to the selected +// output stream. +// +// This calculator is called "Timed" because it sets the next timestamp bound on +// the unselected outputs. +class DemuxTimedCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); + cc->Inputs().Tag("SELECT").Set(); + PacketType* data_input = &cc->Inputs().Tag("INPUT"); + data_input->SetAny(); + for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); + id < cc->Outputs().EndId("OUTPUT"); ++id) { + cc->Outputs().Get(id).SetSameAs(data_input); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + select_input_ = cc->Inputs().GetId("SELECT", 0); + data_input_ = cc->Inputs().GetId("INPUT", 0); + output_base_ = cc->Outputs().GetId("OUTPUT", 0); + num_outputs_ = cc->Outputs().NumEntries("OUTPUT"); + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int select = cc->Inputs().Get(select_input_).Get(); + RET_CHECK(0 <= select && select < num_outputs_); + const Timestamp next_timestamp_bound = + cc->InputTimestamp().NextAllowedInStream(); + for (int i = 0; i < num_outputs_; ++i) { + if (i == select) { + cc->Outputs() + .Get(output_base_ + i) + .AddPacket(cc->Inputs().Get(data_input_).Value()); + } else { + cc->Outputs() + .Get(output_base_ + i) + .SetNextTimestampBound(next_timestamp_bound); + } + } + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId select_input_; + CollectionItemId data_input_; + CollectionItemId output_base_; + int num_outputs_ = 0; +}; + +REGISTER_CALCULATOR(DemuxTimedCalculator); + +// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., +// using the integer value (0, 1, ...) in the packet on the "SELECT" input +// stream, and passes the packet on the selected input stream to the "OUTPUT" +// output stream. +// +// This calculator is called "Timed" because it requires next timestamp bound +// propagation on the unselected inputs. +class MuxTimedCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("SELECT").Set(); + CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); + PacketType* data_input0 = &cc->Inputs().Get(data_input_id); + data_input0->SetAny(); + ++data_input_id; + for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) { + cc->Inputs().Get(data_input_id).SetSameAs(data_input0); + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); + cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + select_input_ = cc->Inputs().GetId("SELECT", 0); + data_input_base_ = cc->Inputs().GetId("INPUT", 0); + num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); + output_ = cc->Outputs().GetId("OUTPUT", 0); + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int select = cc->Inputs().Get(select_input_).Get(); + RET_CHECK(0 <= select && select < num_data_inputs_); + cc->Outputs().Get(output_).AddPacket( + cc->Inputs().Get(data_input_base_ + select).Value()); + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId select_input_; + CollectionItemId data_input_base_; + int num_data_inputs_ = 0; + CollectionItemId output_; +}; + +REGISTER_CALCULATOR(MuxTimedCalculator); + +// A Calculator that adds the integer values in the packets in all the input +// streams and outputs the sum to the output stream. +class IntAdderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int sum = 0; + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + sum += cc->Inputs().Index(i).Get(); + } + cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntAdderCalculator); + +// A Calculator that adds the float values in the packets in all the input +// streams and outputs the sum to the output stream. +class FloatAdderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + float sum = 0.0; + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + sum += cc->Inputs().Index(i).Get(); + } + cc->Outputs().Index(0).Add(new float(sum), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(FloatAdderCalculator); + +// A Calculator that multiplies the integer values in the packets in all the +// input streams and outputs the product to the output stream. +class IntMultiplierCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int product = 1; + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + product *= cc->Inputs().Index(i).Get(); + } + cc->Outputs().Index(0).Add(new int(product), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntMultiplierCalculator); + +// A Calculator that multiplies the float value in an input packet by a float +// constant scalar (specified in a side packet) and outputs the product to the +// output stream. +class FloatScalarMultiplierCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + cc->InputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + scalar_ = cc->InputSidePackets().Index(0).Get(); + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + float value = cc->Inputs().Index(0).Value().Get(); + cc->Outputs().Index(0).Add(new float(scalar_ * value), + cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + + private: + float scalar_; +}; +REGISTER_CALCULATOR(FloatScalarMultiplierCalculator); + +// A Calculator that converts an integer to a float. +class IntToFloatCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int value = cc->Inputs().Index(0).Value().Get(); + cc->Outputs().Index(0).Add(new float(static_cast(value)), + cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntToFloatCalculator); + +// A Calculator that passes an input packet through if it contains an even +// integer. +class EvenIntFilterCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int value = cc->Inputs().Index(0).Get(); + if (value % 2 == 0) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } else { + cc->Outputs().Index(0).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(EvenIntFilterCalculator); + +// A Calculator that passes packets through or not, depending on a second +// input. The first input stream's packets are only propagated if the second +// input stream carries the value true. +class ValveCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Inputs().Index(1).Set(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + if (cc->Inputs().Index(1).Get()) { + cc->GetCounter("PassThrough")->Increment(); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } else { + cc->GetCounter("Block")->Increment(); + // The next timestamp bound is the minimum timestamp that the next packet + // can have, so, if we want to inform the downstream that no packet at + // InputTimestamp() is coming, we need to set it to the next value. + // We could also just call SetOffset(TimestampDiff(0)) in Open, and then + // we would not have to call this manually. + cc->Outputs().Index(0).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(ValveCalculator); + +// A Calculator that simply passes its input Packets and header through, +// but shifts the timestamp. +class TimeShiftCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + // Input: arbitrary Packets. + // Output: copy of the input. + cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); + shift_ = cc->InputSidePackets().Index(0).Get(); + cc->SetOffset(shift_); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->GetCounter("PassThrough")->Increment(); + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(cc->InputTimestamp() + shift_)); + return ::mediapipe::OkStatus(); + } + + private: + TimestampDiff shift_; +}; +REGISTER_CALCULATOR(TimeShiftCalculator); + +template +class TypedEmptySourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).SetAny(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).Add(new OutputType(), Timestamp::PostStream()); + return tool::StatusStop(); + } +}; +typedef TypedEmptySourceCalculator StringEmptySourceCalculator; +typedef TypedEmptySourceCalculator IntEmptySourceCalculator; +REGISTER_CALCULATOR(StringEmptySourceCalculator); +REGISTER_CALCULATOR(IntEmptySourceCalculator); + +template +class TypedSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; +typedef TypedSinkCalculator StringSinkCalculator; +typedef TypedSinkCalculator IntSinkCalculator; +REGISTER_CALCULATOR(StringSinkCalculator); +REGISTER_CALCULATOR(IntSinkCalculator); + +// Output kNumOutputPackets packets, the value of each being the next +// value in the counter provided as an input side packet. An optional +// second input side packet will, if true, cause this calculator to +// output the first of the kNumOutputPackets packets during Open. +class GlobalCountSourceCalculator : public CalculatorBase { + public: + static const int kNumOutputPackets; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Index(0).Set*>(); + if (cc->InputSidePackets().NumEntries() >= 2) { + cc->InputSidePackets().Index(1).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + if (cc->InputSidePackets().NumEntries() >= 2 && + cc->InputSidePackets().Index(1).Get()) { + OutputOne(cc); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + OutputOne(cc); + if (local_count_ >= kNumOutputPackets) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + private: + void OutputOne(CalculatorContext* cc) { + std::atomic* counter = + cc->InputSidePackets().Index(0).Get*>(); + int count = counter->fetch_add(1, std::memory_order_relaxed); + cc->Outputs().Index(0).Add(new int(count), Timestamp(local_count_)); + ++local_count_; + } + + int64 local_count_ = 0; +}; +const int GlobalCountSourceCalculator::kNumOutputPackets = 5; +REGISTER_CALCULATOR(GlobalCountSourceCalculator); + +static const int kTestSequenceLength = 15; + +// Outputs the integers 0, 1, 2, 3, ..., 14, all with timestamp 0. +class TestSequence1SourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).Add(new int(count_), Timestamp(0)); + ++count_; + ++num_outputs_; + if (num_outputs_ >= kTestSequenceLength) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + private: + int count_ = 0; + int num_outputs_ = 0; +}; +REGISTER_CALCULATOR(TestSequence1SourceCalculator); + +// Outputs the integers 1, 2, 3, 4, ..., 15, with decreasing timestamps +// 100, 99, 98, 97, .... +class TestSequence2SourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).Add(new int(count_), Timestamp(timestamp_)); + ++count_; + ++num_outputs_; + --timestamp_; + if (num_outputs_ >= kTestSequenceLength) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + private: + int count_ = 1; + int num_outputs_ = 0; + int timestamp_ = 100; +}; +REGISTER_CALCULATOR(TestSequence2SourceCalculator); + +// Outputs the integers 0, 1, 2 repeatedly for a total of 15 outputs. +class Modulo3SourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).Add(new int(count_ % 3), Timestamp(count_ % 3)); + ++count_; + ++num_outputs_; + if (num_outputs_ >= kTestSequenceLength) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + private: + int count_ = 0; + int num_outputs_ = 0; +}; +REGISTER_CALCULATOR(Modulo3SourceCalculator); + +// A source calculator that outputs 100 packets all at once and stops. The +// number of output packets (100) is deliberately chosen to be equal to +// max_queue_size, which fills the input streams connected to this source +// calculator and causes the MediaPipe scheduler to throttle this source +// calculator. +class OutputAllSourceCalculator : public CalculatorBase { + public: + static const int kNumOutputPackets = 100; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + for (int i = 0; i < kNumOutputPackets; ++i) { + cc->Outputs().Index(0).Add(new int(0), Timestamp(i)); + } + return tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(OutputAllSourceCalculator); + +// A source calculator that outputs one packet at a time. The total number of +// output packets needs to be large enough to eventually fill an input stream +// connected to this source calculator and to force the MediaPipe scheduler to +// run this source calculator as a throttled source when the graph cannot make +// progress. +class OutputOneAtATimeSourceCalculator : public CalculatorBase { + public: + static const int kNumOutputPackets = 1000; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (index_ < kNumOutputPackets) { + cc->Outputs().Index(0).Add(new int(0), Timestamp(index_)); + ++index_; + return ::mediapipe::OkStatus(); + } + return tool::StatusStop(); + } + + private: + int index_ = 0; +}; +REGISTER_CALCULATOR(OutputOneAtATimeSourceCalculator); + +// A calculator that passes through one out of every 101 input packets and +// discards the rest. The decimation ratio (101) is carefully chosen to be +// greater than max_queue_size (100) so that an input stream parallel to the +// input stream connected to this calculator can become full. +class DecimatorCalculator : public CalculatorBase { + public: + static const int kDecimationRatio = 101; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + if (index_ % kDecimationRatio == 0) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } + ++index_; + return ::mediapipe::OkStatus(); + } + + private: + int index_ = 0; +}; +REGISTER_CALCULATOR(DecimatorCalculator); + +// An error will be produced in Open() if ERROR_ON_OPEN is true. Otherwise, +// this calculator simply passes its input packets through, unchanged. +class ErrorOnOpenCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { + return ::mediapipe::NotFoundError("expected error"); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(ErrorOnOpenCalculator); + +// A calculator that outputs an initial packet of value 0 at time 0 in the +// Open() method, and then delays each input packet by one time unit in the +// Process() method. The input stream and output stream have the integer type. +class UnitDelayCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).Add(new int(0), Timestamp(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + const Packet& packet = cc->Inputs().Index(0).Value(); + cc->Outputs().Index(0).AddPacket( + packet.At(packet.Timestamp().NextAllowedInStream())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(UnitDelayCalculator); + +class UnitDelayUntimedCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).Add(new int(0), Timestamp::Min()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(UnitDelayUntimedCalculator); + +class FloatUnitDelayCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).Add(new float(0.0), Timestamp(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + const Packet& packet = cc->Inputs().Index(0).Value(); + cc->Outputs().Index(0).AddPacket( + packet.At(packet.Timestamp().NextAllowedInStream())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(FloatUnitDelayCalculator); + +// A sink calculator that asserts its input stream is empty in Open() and +// discards input packets in Process(). +class AssertEmptyInputInOpenCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + RET_CHECK(cc->Inputs().Index(0).Value().IsEmpty()); + RET_CHECK_EQ(cc->Inputs().Index(0).Value().Timestamp(), Timestamp::Unset()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(AssertEmptyInputInOpenCalculator); + +// A slow sink calculator that expects 10 input integers with the values +// 0, 1, ..., 9. +class SlowCountingSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::SleepFor(absl::Milliseconds(10)); + int value = cc->Inputs().Index(0).Get(); + CHECK_EQ(value, counter_); + ++counter_; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + CHECK_EQ(10, counter_); + return ::mediapipe::OkStatus(); + } + + private: + int counter_ = 0; +}; +REGISTER_CALCULATOR(SlowCountingSinkCalculator); + +template +class TypedStatusHandler : public StatusHandler { + public: + ~TypedStatusHandler() override = 0; + static ::mediapipe::Status FillExpectations( + const MediaPipeOptions& extendable_options, + PacketTypeSet* input_side_packets) { + input_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status HandlePreRunStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const ::mediapipe::Status& pre_run_status) { + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status HandleStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const ::mediapipe::Status& run_status) { + return ::mediapipe::OkStatus(); + } +}; +typedef TypedStatusHandler StringStatusHandler; +typedef TypedStatusHandler Uint32StatusHandler; +REGISTER_STATUS_HANDLER(StringStatusHandler); +REGISTER_STATUS_HANDLER(Uint32StatusHandler); + +// A std::string generator that will succeed. +class StaticCounterStringGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + for (int i = 0; i < input_side_packets->NumEntries(); ++i) { + input_side_packets->Index(i).SetAny(); + } + output_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + output_side_packets->Index(0) = MakePacket("fixed_string"); + ++num_packets_generated_; + return ::mediapipe::OkStatus(); + } + + static int NumPacketsGenerated() { return num_packets_generated_; } + + private: + static int num_packets_generated_; +}; + +int StaticCounterStringGenerator::num_packets_generated_ = 0; + +REGISTER_PACKET_GENERATOR(StaticCounterStringGenerator); + +// A failing PacketGenerator and Calculator to verify that status handlers get +// called. Both claim to output strings but instead always fail. +class FailingPacketGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + for (int i = 0; i < input_side_packets->NumEntries(); ++i) { + input_side_packets->Index(i).SetAny(); + } + output_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + return ::mediapipe::UnknownError("this always fails."); + } +}; +REGISTER_PACKET_GENERATOR(FailingPacketGenerator); + +// Passes the integer through if it is positive. +class EnsurePositivePacketGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + for (int i = 0; i < input_side_packets->NumEntries(); ++i) { + input_side_packets->Index(i).Set(); + output_side_packets->Index(i).Set(); + } + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + for (int i = 0; i < input_side_packets.NumEntries(); ++i) { + if (input_side_packets.Index(i).Get() > 0) { + output_side_packets->Index(i) = input_side_packets.Index(i); + } else { + return ::mediapipe::UnknownError( + absl::StrCat("Integer ", i, " was not positive.")); + } + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(EnsurePositivePacketGenerator); + +// A Status handler which takes an int side packet and fails in pre run +// if that packet is FailableStatusHandler::kFailPreRun and fails post +// run if that value is FailableStatusHandler::kFailPostRun. If the +// int is any other value then no failures occur. +class FailableStatusHandler : public StatusHandler { + public: + enum { + kOk = 0, + kFailPreRun = 1, + kFailPostRun = 2, + }; + + static ::mediapipe::Status FillExpectations( + const MediaPipeOptions& extendable_options, + PacketTypeSet* input_side_packets) { + input_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + static ::mediapipe::Status HandlePreRunStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, + const ::mediapipe::Status& pre_run_status) { + if (input_side_packets.Index(0).Get() == kFailPreRun) { + return ::mediapipe::UnknownError( + "FailableStatusHandler failing pre run as intended."); + } else { + return ::mediapipe::OkStatus(); + } + } + static ::mediapipe::Status HandleStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, + const ::mediapipe::Status& run_status) { + if (input_side_packets.Index(0).Get() == kFailPostRun) { + return ::mediapipe::UnknownError( + "FailableStatusHandler failing post run as intended."); + } else { + return ::mediapipe::OkStatus(); + } + } +}; +REGISTER_STATUS_HANDLER(FailableStatusHandler); + +class FailingSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::UnknownError("this always fails."); + } +}; +REGISTER_CALCULATOR(FailingSourceCalculator); + +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_ += amount; } + + private: + std::atomic supply_; +}; + +// This calculator posts to a semaphore when it starts its Process method, +// and waits on a different semaphore before returning from Process. +// This allows a test to run code when the calculator is running Process +// without having to depend on any specific timing. +class SemaphoreCalculator : public CalculatorBase { + public: + using Semaphore = AtomicSemaphore; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Tag("POST_SEM").Set(); + cc->InputSidePackets().Tag("WAIT_SEM").Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->InputSidePackets().Tag("POST_SEM").Get()->Release(1); + cc->InputSidePackets().Tag("WAIT_SEM").Get()->Acquire(1); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SemaphoreCalculator); + +// A calculator that has no input streams and output streams, runs only once, +// and takes 20 milliseconds to run. +class OneShot20MsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::SleepFor(absl::Milliseconds(20)); + return tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(OneShot20MsCalculator); + +// A source calculator that alternates between outputting an integer (0, 1, 2, +// ..., 100) and setting the next timestamp bound. The timestamps of the output +// packets and next timestamp bounds are 0, 10, 20, 30, ... +// +// T=0 Output 0 +// T=10 Set timestamp bound +// T=20 Output 1 +// T=30 Set timestamp bound +// ... +// T=2000 Output 100 +class OutputAndBoundSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + counter_ = 0; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + Timestamp timestamp(counter_); + if (counter_ % 20 == 0) { + cc->Outputs().Index(0).AddPacket( + MakePacket(counter_ / 20).At(timestamp)); + } else { + cc->Outputs().Index(0).SetNextTimestampBound(timestamp); + } + if (counter_ == 2000) { + return tool::StatusStop(); + } + counter_ += 10; + return ::mediapipe::OkStatus(); + } + + private: + int counter_; +}; +REGISTER_CALCULATOR(OutputAndBoundSourceCalculator); + +// A calculator that outputs an initial packet of value 0 at time 0 in the +// Open() method, and then delays each input packet by 20 time units in the +// Process() method. The input stream and output stream have the integer type. +class Delay20Calculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(20)); + cc->Outputs().Index(0).AddPacket(MakePacket(0).At(Timestamp(0))); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + const Packet& packet = cc->Inputs().Index(0).Value(); + Timestamp timestamp = packet.Timestamp() + 20; + cc->Outputs().Index(0).AddPacket(packet.At(timestamp)); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(Delay20Calculator); + +// A source calculator that outputs a packet containing the return value of +// pthread_self() (the pthread id of the current thread). +class PthreadSelfSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).AddPacket( + MakePacket(pthread_self()).At(Timestamp(0))); + return tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(PthreadSelfSourceCalculator); + +// A source calculator for testing the Calculator::InputTimestamp() method. +// It outputs five int packets with timestamps 0, 1, 2, 3, 4. +class CheckInputTimestampSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source + // and non-source nodes. + ::mediapipe::Status Open(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() always returns Timestamp(0) in Process() for source + // nodes. + ::mediapipe::Status Process(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); + cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); + ++count_; + if (count_ >= 5) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + // InputTimestamp() returns Timestamp::Done() in Close() for both source + // and non-source nodes. + ::mediapipe::Status Close(CalculatorContext* cc) final { + // Must use CHECK instead of RET_CHECK in Close(), because the framework + // may call the Close() method of a source node with .IgnoreError(). + CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(CheckInputTimestampSourceCalculator); + +// A sink calculator for testing the Calculator::InputTimestamp() method. +// It expects to consume the output of a CheckInputTimestampSourceCalculator. +class CheckInputTimestampSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source + // and non-source nodes. + ::mediapipe::Status Open(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns the timestamp of input packets in Process() for + // non-source nodes. + ::mediapipe::Status Process(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), + cc->Inputs().Index(0).Value().Timestamp()); + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); + ++count_; + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Done() in Close() for both source + // and non-source nodes. + ::mediapipe::Status Close(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(CheckInputTimestampSinkCalculator); + +// A source calculator for testing the Calculator::InputTimestamp() method. +// It outputs int packets with timestamps 0, 1, 2, ... until being closed by +// the framework. +class CheckInputTimestamp2SourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source + // and non-source nodes. + ::mediapipe::Status Open(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() always returns Timestamp(0) in Process() for source + // nodes. + ::mediapipe::Status Process(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); + cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); + ++count_; + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Done() in Close() for both source + // and non-source nodes. + ::mediapipe::Status Close(CalculatorContext* cc) final { + // Must use CHECK instead of RET_CHECK in Close(), because the framework + // may call the Close() method of a source node with .IgnoreError(). + CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(CheckInputTimestamp2SourceCalculator); + +// A sink calculator for testing the Calculator::InputTimestamp() method. +// It expects to consume the output of a CheckInputTimestamp2SourceCalculator. +// It returns tool::StatusStop() after consuming five input packets. +class CheckInputTimestamp2SinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source + // and non-source nodes. + ::mediapipe::Status Open(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); + return ::mediapipe::OkStatus(); + } + + // InputTimestamp() returns the timestamp of input packets in Process() for + // non-source nodes. + ::mediapipe::Status Process(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), + cc->Inputs().Index(0).Value().Timestamp()); + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); + ++count_; + if (count_ >= 5) { + return tool::StatusStop(); + } else { + return ::mediapipe::OkStatus(); + } + } + + // InputTimestamp() returns Timestamp::Done() in Close() for both source + // and non-source nodes. + ::mediapipe::Status Close(CalculatorContext* cc) final { + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator); + +// Takes an input stream packet and passes it (with timestamp removed) as an +// output side packet. +class OutputSidePacketInProcessCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + cc->Inputs().Index(0).Value().At(Timestamp::Unset())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); + +// Takes an input stream packet and counts the number of the packets it +// receives. Outputs the total number of packets as a side packet in Close. +class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + ++count_; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + MakePacket(count_).At(Timestamp::Unset())); + return ::mediapipe::OkStatus(); + } + + int count_ = 0; +}; +REGISTER_CALCULATOR(CountAndOutputSummarySidePacketInCloseCalculator); + +// Takes an input stream packet and passes it (with timestamp intact) as an +// output side packet. This triggers an error in the graph. +class OutputSidePacketWithTimestampCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator); + +// Generates an output side packet containing the integer 1. +class IntegerOutputSidePacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set(MakePacket(1)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + LOG(FATAL) << "Not reached."; + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator); + +// Generates an output side packet containing the sum of the two integer input +// side packets. +class SidePacketAdderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Index(0).Set(); + cc->InputSidePackets().Index(1).Set(); + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + MakePacket(cc->InputSidePackets().Index(1).Get() + + cc->InputSidePackets().Index(0).Get())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + LOG(FATAL) << "Not reached."; + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SidePacketAdderCalculator); + +// Produces an output packet with the PostStream timestamp containing the +// input side packet. +class SidePacketToStreamPacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket( + cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); + cc->Outputs().Index(0).Close(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return ::mediapipe::tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(SidePacketToStreamPacketCalculator); + +// A calculator checks if either of two input streams contains a packet and +// sends the packet to the single output stream with the same timestamp. +class SimpleMuxCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); + cc->Inputs().Index(0).SetAny(); + cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + data_input_base_ = cc->Inputs().BeginId(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int select_packet_index = -1; + if (!cc->Inputs().Index(0).IsEmpty()) { + select_packet_index = 0; + } else if (!cc->Inputs().Index(1).IsEmpty()) { + select_packet_index = 1; + } + if (select_packet_index != -1) { + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Get(data_input_base_ + select_packet_index).Value()); + } + return ::mediapipe::OkStatus(); + } + + private: + CollectionItemId data_input_base_; +}; +REGISTER_CALCULATOR(SimpleMuxCalculator); + +// Mock status handler that reports the number of times HandleStatus was called +// by modifying the int in its input side packet. +class IncrementingStatusHandler : public StatusHandler { + public: + static ::mediapipe::Status FillExpectations( + const MediaPipeOptions& extendable_options, + PacketTypeSet* input_side_packets) { + input_side_packets->Tag("EXTRA").SetAny().Optional(); + input_side_packets->Tag("COUNTER1").Set>(); + input_side_packets->Tag("COUNTER2").Set>(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status HandlePreRunStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const ::mediapipe::Status& pre_run_status) { + int* counter = GetFromUniquePtr(input_side_packets.Tag("COUNTER1")); + (*counter)++; + return pre_run_status_result_; + } + + static ::mediapipe::Status HandleStatus( + const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const ::mediapipe::Status& run_status) { + int* counter = GetFromUniquePtr(input_side_packets.Tag("COUNTER2")); + (*counter)++; + return post_run_status_result_; + } + + static void SetPreRunStatusResult(const ::mediapipe::Status& status) { + pre_run_status_result_ = status; + } + + static void SetPostRunStatusResult(const ::mediapipe::Status& status) { + post_run_status_result_ = status; + } + + private: + // Return values of HandlePreRunStatus() and HandleSTatus(), respectively. + static ::mediapipe::Status pre_run_status_result_; + static ::mediapipe::Status post_run_status_result_; +}; + +::mediapipe::Status IncrementingStatusHandler::pre_run_status_result_ = + ::mediapipe::OkStatus(); +::mediapipe::Status IncrementingStatusHandler::post_run_status_result_ = + ::mediapipe::OkStatus(); + +REGISTER_STATUS_HANDLER(IncrementingStatusHandler); + +// A simple executor that runs tasks directly on the current thread. +// NOTE: If CurrentThreadExecutor is used, some CalculatorGraph methods may +// behave differently. For example, CalculatorGraph::StartRun will run the +// graph rather than returning immediately after starting the graph. +// Similarly, CalculatorGraph::AddPacketToInputStream will run the graph +// (until it's idle) rather than returning immediately after adding the packet +// to the graph input stream. +class CurrentThreadExecutor : public Executor { + public: + ~CurrentThreadExecutor() override { + CHECK(!executing_); + CHECK(tasks_.empty()); + } + + void Schedule(std::function task) override { + if (executing_) { + // Queue the task for later execution (after the currently-running task + // returns) rather than running the task immediately. This is especially + // important for a source node (which can be rescheduled immediately after + // running) to avoid an indefinitely-deep call stack. + tasks_.emplace_back(std::move(task)); + } else { + CHECK(tasks_.empty()); + executing_ = true; + task(); + while (!tasks_.empty()) { + task = tasks_.front(); + tasks_.pop_front(); + task(); + } + executing_ = false; + } + } + + private: + // True if the executor is executing tasks. + bool executing_ = false; + // The tasks to execute. + std::deque> tasks_; +}; + +// Returns a CalculatorGraphConfig used by tests. +CalculatorGraphConfig GetConfig() { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + # The graph configuration. We list the nodes in an arbitrary (not + # topologically-sorted) order to verify that CalculatorGraph can + # handle such configurations. + node { + calculator: "RangeCalculator" + output_stream: "range3" + output_stream: "range3_sum" + output_stream: "range3_mean" + input_side_packet: "node_3_converted" + } + node { + calculator: "RangeCalculator" + output_stream: "range5" + output_stream: "range5_sum" + output_stream: "range5_mean" + input_side_packet: "node_5_converted" + } + node { + calculator: "MergeCalculator" + input_stream: "range3" + input_stream: "range5_copy" + output_stream: "merge" + } + node { + calculator: "MergeCalculator" + input_stream: "range3_sum" + input_stream: "range5_sum" + output_stream: "merge_sum" + } + node { + calculator: "PassThroughCalculator" + input_stream: "range3_stddev" + input_stream: "range5_stddev" + output_stream: "range3_stddev_2" + output_stream: "range5_stddev_2" + } + node { + calculator: "PassThroughCalculator" + input_stream: "A:range3_stddev_2" + input_stream: "range5_stddev_2" + output_stream: "A:range3_stddev_3" + output_stream: "range5_stddev_3" + } + node { + calculator: "PassThroughCalculator" + input_stream: "B:range3_stddev_3" + input_stream: "B:1:range5_stddev_3" + output_stream: "B:range3_stddev_4" + output_stream: "B:1:range5_stddev_4" + } + node { + calculator: "MergeCalculator" + input_stream: "range3_stddev_4" + input_stream: "range5_stddev_4" + output_stream: "merge_stddev" + } + node { + calculator: "StdDevCalculator" + input_stream: "DATA:range3" + input_stream: "MEAN:range3_mean" + output_stream: "range3_stddev" + } + node { + calculator: "StdDevCalculator" + input_stream: "DATA:range5" + input_stream: "MEAN:range5_mean" + output_stream: "range5_stddev" + } + node { + name: "copy_range5" + calculator: "PassThroughCalculator" + input_stream: "range5" + output_stream: "range5_copy" + } + node { + calculator: "SaverCalculator" + input_stream: "merge" + output_stream: "final" + } + node { + calculator: "SaverCalculator" + input_stream: "merge_sum" + output_stream: "final_sum" + } + node { + calculator: "SaverCalculator" + input_stream: "merge_stddev" + output_stream: "final_stddev" + } + packet_generator { + packet_generator: "IntSplitterPacketGenerator" + input_side_packet: "node_3" + output_side_packet: "node_3_converted" + } + packet_generator { + packet_generator: "TaggedIntSplitterPacketGenerator" + input_side_packet: "node_5" + output_side_packet: "HIGH:unused_high" + output_side_packet: "LOW:unused_low" + output_side_packet: "PAIR:node_5_converted" + } + )"); + return config; +} + +// |graph| points to an empty CalculatorGraph object created by the default +// constructor, before CalculatorGraph::Initialize() is called. +void RunComprehensiveTest(CalculatorGraph* graph, + const CalculatorGraphConfig& the_config, + bool define_node_5) { + CalculatorGraphConfig proto(the_config); + Packet dumped_final_sum_packet; + Packet dumped_final_packet; + Packet dumped_final_stddev_packet; + tool::AddPostStreamPacketSink("final", &proto, &dumped_final_packet); + tool::AddPostStreamPacketSink("final_sum", &proto, &dumped_final_sum_packet); + tool::AddPostStreamPacketSink("final_stddev", &proto, + &dumped_final_stddev_packet); + MEDIAPIPE_ASSERT_OK(graph->Initialize(proto)); + + std::map extra_side_packets; + extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 3))); + if (define_node_5) { + extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + } + + // Call graph->Run() several times, to make sure that the appropriate + // cleanup happens between iterations. + for (int iteration = 0; iteration < 2; ++iteration) { + LOG(INFO) << "Loop iteration " << iteration; + dumped_final_sum_packet = Packet(); + dumped_final_stddev_packet = Packet(); + dumped_final_packet = Packet(); + MEDIAPIPE_ASSERT_OK(graph->Run(extra_side_packets)); + // The merger will output the timestamp and all ints output from + // the range calculators. The saver will concatenate together the + // strings with a '/' deliminator. + EXPECT_EQ( + "Timestamp(0) 300 500/" + "Timestamp(3) 301 empty/" + "Timestamp(5) empty 501/" + "Timestamp(6) 302 empty/" + "Timestamp(9) 303 empty/" + "Timestamp(10) empty 502/" + "Timestamp(12) 304 empty/" + "Timestamp(15) 305 503", + dumped_final_packet.Get()); + // Verify that the headers got set correctly. + EXPECT_EQ( + "RangeCalculator3 RangeCalculator5", + graph->FindOutputStreamManager("merge")->Header().Get()); + // Verify that sum packets get correctly processed. + // (The first is a sum of all the 3's output and the second of all + // the 5's). + EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 1815 2006"), + dumped_final_sum_packet.Get()); + EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory() + ->GetCounter("copy_range5-PassThrough") + ->Get()); + // Verify that stddev packets get correctly processed. + // The standard deviation computed as: + // sqrt(sum((x-mean(x))**2 / length(x))) + // for x = 300:305 is 1.707825 (multiplied by 100 and rounded it is 171) + // for x = 500:503 is 1.118034 (multiplied by 100 and rounded it is 112) + EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 171 112"), + dumped_final_stddev_packet.Get()); + + EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory() + ->GetCounter("copy_range5-PassThrough") + ->Get()); + } + LOG(INFO) << "After Loop Runs."; + // Verify that the graph can still run (but not successfully) when + // one of the nodes is caused to fail. + extra_side_packets.clear(); + extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 0))); + if (define_node_5) { + extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + } + dumped_final_sum_packet = Packet(); + dumped_final_stddev_packet = Packet(); + dumped_final_packet = Packet(); + LOG(INFO) << "Expect an error to be logged here."; + ASSERT_FALSE(graph->Run(extra_side_packets).ok()); + LOG(INFO) << "Error should have been logged."; +} + +TEST(CalculatorGraph, BadInitialization) { + CalculatorGraphConfig proto = GetConfig(); + CalculatorGraph graph; + // Force the config to have a missing Calculator. + proto.mutable_node(1)->clear_calculator(); + ASSERT_FALSE(graph.Initialize(proto).ok()); +} + +TEST(CalculatorGraph, BadRun) { + CalculatorGraphConfig proto = GetConfig(); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(proto)); + // Don't set the input side packets. + EXPECT_FALSE(graph.Run().ok()); +} + +TEST(CalculatorGraph, RunsCorrectly) { + CalculatorGraph graph; + CalculatorGraphConfig proto = GetConfig(); + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +TEST(CalculatorGraph, RunsCorrectlyOnApplicationThread) { + CalculatorGraph graph; + CalculatorGraphConfig proto = GetConfig(); + // Force application thread to be used. + proto.set_num_threads(0); + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +TEST(CalculatorGraph, RunsCorrectlyWithExternalExecutor) { + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("", std::make_shared(1))); + CalculatorGraphConfig proto = GetConfig(); + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +// This test verifies that the MediaPipe framework calls Executor::AddTask() +// without holding any mutex, because CurrentThreadExecutor::AddTask() may +// result in a recursive call to itself. +TEST(CalculatorGraph, RunsCorrectlyWithCurrentThreadExecutor) { + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("", std::make_shared())); + CalculatorGraphConfig proto = GetConfig(); + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +TEST(CalculatorGraph, RunsCorrectlyWithNonDefaultExecutors) { + CalculatorGraph graph; + // Add executors "second" and "third". + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("second", std::make_shared(1))); + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("third", std::make_shared(1))); + CalculatorGraphConfig proto = GetConfig(); + ExecutorConfig* executor = proto.add_executor(); + executor->set_name("second"); + executor = proto.add_executor(); + executor->set_name("third"); + for (int i = 0; i < proto.node_size(); ++i) { + switch (i % 3) { + case 0: + // Use the default executor. + break; + case 1: + proto.mutable_node(i)->set_executor("second"); + break; + case 2: + proto.mutable_node(i)->set_executor("third"); + break; + } + } + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +TEST(CalculatorGraph, RunsCorrectlyWithMultipleExecutors) { + CalculatorGraph graph; + // Add executors "second" and "third". + CalculatorGraphConfig proto = GetConfig(); + ExecutorConfig* executor = proto.add_executor(); + executor->set_name("second"); + executor->set_type("ThreadPoolExecutor"); + MediaPipeOptions* options = executor->mutable_options(); + ThreadPoolExecutorOptions* extension = + options->MutableExtension(ThreadPoolExecutorOptions::ext); + extension->set_num_threads(1); + executor = proto.add_executor(); + executor->set_name("third"); + executor->set_type("ThreadPoolExecutor"); + options = executor->mutable_options(); + extension = options->MutableExtension(ThreadPoolExecutorOptions::ext); + extension->set_num_threads(1); + for (int i = 0; i < proto.node_size(); ++i) { + switch (i % 3) { + case 0: + // Use the default executor. + break; + case 1: + proto.mutable_node(i)->set_executor("second"); + break; + case 2: + proto.mutable_node(i)->set_executor("third"); + break; + } + } + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +// Packet generator for an arbitrary unit64 packet. +class Uint64PacketGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + output_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); + +TEST(CalculatorGraph, GeneratePacket) { + CalculatorGraph graph; + CalculatorGraphConfig proto = GetConfig(); + PacketGeneratorConfig* generator = proto.add_packet_generator(); + generator->set_packet_generator("Uint64PacketGenerator"); + generator->add_output_side_packet("node_5"); + RunComprehensiveTest(&graph, proto, false); +} + +TEST(CalculatorGraph, TypeMismatch) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node = config.add_node(); + node->add_output_stream("stream_a"); + node = config.add_node(); + node->add_input_stream("stream_a"); + std::unique_ptr graph; + + // Type matches, expect success. + config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); + config.mutable_node(1)->set_calculator("StringSinkCalculator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph->Run()); + graph.reset(nullptr); + + // Type matches, expect success. + config.mutable_node(0)->set_calculator("IntEmptySourceCalculator"); + config.mutable_node(1)->set_calculator("IntSinkCalculator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph->Run()); + graph.reset(nullptr); + + // Type mismatch, expect non-crashing failure. + config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); + config.mutable_node(1)->set_calculator("IntSinkCalculator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + EXPECT_FALSE(graph->Run().ok()); + graph.reset(nullptr); + + // Type mismatch, expect non-crashing failure. + config.mutable_node(0)->set_calculator("IntEmptySourceCalculator"); + config.mutable_node(1)->set_calculator("StringSinkCalculator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + EXPECT_FALSE(graph->Run().ok()); + graph.reset(nullptr); +} + +TEST(CalculatorGraph, LayerOrdering) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node; + node = config.add_node(); + node->set_calculator("GlobalCountSourceCalculator"); + node->add_input_side_packet("global_counter"); + node->add_output_stream("count_layer_0_node_0"); + node->set_source_layer(0); + node = config.add_node(); + node->set_calculator("GlobalCountSourceCalculator"); + node->add_input_side_packet("global_counter"); + node->add_output_stream("count_layer_1_node_0"); + node->set_source_layer(1); + node = config.add_node(); + node->set_calculator("GlobalCountSourceCalculator"); + node->add_input_side_packet("global_counter"); + node->add_output_stream("count_layer_1_node_1"); + node->set_source_layer(1); + node = config.add_node(); + node->set_calculator("GlobalCountSourceCalculator"); + node->add_input_side_packet("global_counter"); + node->add_output_stream("count_layer_2_node_0"); + node->set_source_layer(2); + + // Set num threads to 1 because we rely on sequential execution for this test. + config.set_num_threads(1); + + std::vector dump_layer_0_node_0; + std::vector dump_layer_1_node_0; + std::vector dump_layer_1_node_1; + std::vector dump_layer_2_node_0; + tool::AddVectorSink("count_layer_0_node_0", &config, &dump_layer_0_node_0); + tool::AddVectorSink("count_layer_1_node_0", &config, &dump_layer_1_node_0); + tool::AddVectorSink("count_layer_1_node_1", &config, &dump_layer_1_node_1); + tool::AddVectorSink("count_layer_2_node_0", &config, &dump_layer_2_node_0); + + auto graph = absl::make_unique(); + + std::atomic global_counter(0); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph->Run(input_side_packets)); + graph.reset(nullptr); + + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + dump_layer_0_node_0.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + dump_layer_1_node_0.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + dump_layer_1_node_1.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + dump_layer_2_node_0.size()); + + // Check layer 0. + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + EXPECT_EQ(i, dump_layer_0_node_0[i].Get()); + EXPECT_EQ(Timestamp(i), dump_layer_0_node_0[i].Timestamp()); + } + // Check layer 1 is interleaved (arbitrarily). + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 == + dump_layer_1_node_0[i].Get() || + GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 == + dump_layer_1_node_0[i].Get()); + EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 == + dump_layer_1_node_1[i].Get() || + GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 == + dump_layer_1_node_1[i].Get()); + EXPECT_EQ(Timestamp(i), dump_layer_1_node_0[i].Timestamp()); + EXPECT_EQ(Timestamp(i), dump_layer_1_node_1[i].Timestamp()); + } + // Check layer 2. + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + EXPECT_EQ(3 * GlobalCountSourceCalculator::kNumOutputPackets + i, + dump_layer_2_node_0[i].Get()); + EXPECT_EQ(Timestamp(i), dump_layer_2_node_0[i].Timestamp()); + } + + EXPECT_EQ( + 20, + input_side_packets["global_counter"].Get*>()->load()); +} + +// Tests for status handler input verification. +TEST(CalculatorGraph, StatusHandlerInputVerification) { + // Status handlers with all inputs present should be OK. + auto graph = absl::make_unique(); + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + packet_generator { + packet_generator: "StaticCounterStringGenerator" + output_side_packet: "created_by_factory" + } + packet_generator { + packet_generator: "TaggedIntSplitterPacketGenerator" + input_side_packet: "a_uint64" + output_side_packet: "HIGH:generated_by_generator" + output_side_packet: "LOW:unused_low" + output_side_packet: "PAIR:unused_pair" + } + status_handler { + status_handler: "Uint32StatusHandler" + input_side_packet: "generated_by_generator" + } + status_handler { + status_handler: "StringStatusHandler" + input_side_packet: "created_by_factory" + } + status_handler { + status_handler: "StringStatusHandler" + input_side_packet: "extra_string" + } + )"); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + Packet extra_string = Adopt(new std::string("foo")); + Packet a_uint64 = Adopt(new uint64(0)); + MEDIAPIPE_EXPECT_OK( + graph->Run({{"extra_string", extra_string}, {"a_uint64", a_uint64}})); + + // Should fail verification when missing a required input. The generator is + // OK, but the StringStatusHandler is missing its input. + EXPECT_FALSE(graph->Run({{"a_uint64", a_uint64}}).ok()); + + // Should fail verification when the type of an already created packet is + // wrong. Here we give the uint64 packet instead of the std::string packet to + // the StringStatusHandler. + EXPECT_FALSE( + graph->Run({{"extra_string", a_uint64}, {"a_uint64", a_uint64}}).ok()); + + // Should fail verification when the type of a packet generated by a base + // packet factory is wrong. Everything is correct except we add a status + // handler expecting a uint32 but give it the std::string from the packet + // factory. + auto* invalid_handler = config.add_status_handler(); + invalid_handler->set_status_handler("Uint32StatusHandler"); + invalid_handler->add_input_side_packet("created_by_factory"); + graph.reset(new CalculatorGraph()); + ::mediapipe::Status status = graph->Initialize(config); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("Uint32StatusHandler"), + // The problematic input side packet. + testing::HasSubstr("created_by_factory"), + // Actual type. + testing::HasSubstr("string"), + // Expected type. + testing::HasSubstr( + MediaPipeTypeStringOrDemangled()))); + + // Should fail verification when the type of a to-be-generated packet is + // wrong. The added handler now expects a std::string but will receive the + // uint32 generated by the existing generator. + invalid_handler->set_status_handler("StringStatusHandler"); + invalid_handler->set_input_side_packet(0, "generated_by_generator"); + graph.reset(new CalculatorGraph()); + // This is caught earlier, when the type of the PacketGenerator output + // is compared to the type of the StatusHandler input. + + status = graph->Initialize(config); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("StringStatusHandler"), + // The problematic input side packet. + testing::HasSubstr("generated_by_generator"), + // Actual type. + testing::HasSubstr("string"), + // Expected type. + testing::HasSubstr( + MediaPipeTypeStringOrDemangled()))); +} + +TEST(CalculatorGraph, GenerateInInitialize) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + output_side_packet: "foo1" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "foo1" + output_side_packet: "foo2" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "input_in_run" + output_side_packet: "foo3" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "input_in_run" + input_side_packet: "foo3" + output_side_packet: "foo4" + } + )"); + int initial_count = StaticCounterStringGenerator::NumPacketsGenerated(); + MEDIAPIPE_ASSERT_OK(graph.Initialize( + config, + {{"created_by_factory", MakePacket("default string")}, + {"input_in_initialize", MakePacket(10)}})); + EXPECT_EQ(initial_count + 2, + StaticCounterStringGenerator::NumPacketsGenerated()); + MEDIAPIPE_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(11)}})); + EXPECT_EQ(initial_count + 4, + StaticCounterStringGenerator::NumPacketsGenerated()); + MEDIAPIPE_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(12)}})); + EXPECT_EQ(initial_count + 6, + StaticCounterStringGenerator::NumPacketsGenerated()); +} + +// Resets the counters in the input side packets used in the HandlersRun test. +// The value of all these counters will be set to the integer zero, as required +// at the beginning of the test. +void ResetCounters(std::map* input_side_packets) { + (*input_side_packets)["no_input_counter1"] = AdoptAsUniquePtr(new int(0)); + (*input_side_packets)["no_input_counter2"] = AdoptAsUniquePtr(new int(0)); + (*input_side_packets)["available_input_counter1"] = + AdoptAsUniquePtr(new int(0)); + (*input_side_packets)["available_input_counter2"] = + AdoptAsUniquePtr(new int(0)); + (*input_side_packets)["unavailable_input_counter1"] = + AdoptAsUniquePtr(new int(0)); + (*input_side_packets)["unavailable_input_counter2"] = + AdoptAsUniquePtr(new int(0)); +} + +// Tests that status handlers run. +// - We specify three status handlers: one taking no input side packets, one +// taking +// an input side packet that is always provided in the call to Run(), and one +// that takes the input side packet that will not be produced by the +// FailingPacketGenerator. The first two should proccess their PRE-RUN status +// but not their POST-RUN status, the third one should not process either of +// them since the graph execution fails before the PRE-RUN step. +// - We then replace the FailingPacketGenerator with a non-failing generator, +// and should have all three handlers running both PRE and POST-RUN (after the +// FailingSourceCalculator fails). +// - We test that all three status handlers run (with both status) at the end of +// a successful graph run. +// - Finally, we verify that when the status handler fails (either on PRE or +// POST run), but the calculators don't, we still receive errors from the +// calculator run. +TEST(CalculatorGraph, HandlersRun) { + std::unique_ptr graph(new CalculatorGraph()); + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + packet_generator { + packet_generator: "FailingPacketGenerator" + output_side_packet: "unavailable" + } + node { calculator: "FailingSourceCalculator" output_stream: "output" } + status_handler { + status_handler: "IncrementingStatusHandler" + input_side_packet: "COUNTER1:no_input_counter1" + input_side_packet: "COUNTER2:no_input_counter2" + } + status_handler { + status_handler: "IncrementingStatusHandler" + input_side_packet: "COUNTER1:available_input_counter1" + input_side_packet: "COUNTER2:available_input_counter2" + input_side_packet: "EXTRA:available_string" + } + status_handler { + status_handler: "IncrementingStatusHandler" + input_side_packet: "COUNTER1:unavailable_input_counter1" + input_side_packet: "COUNTER2:unavailable_input_counter2" + input_side_packet: "EXTRA:unavailable" + } + )"); + std::map input_side_packets( + {{"unused_input", AdoptAsUniquePtr(new int(0))}, + {"no_input_counter1", AdoptAsUniquePtr(new int(0))}, + {"no_input_counter2", AdoptAsUniquePtr(new int(0))}, + {"available_input_counter1", AdoptAsUniquePtr(new int(0))}, + {"available_input_counter2", AdoptAsUniquePtr(new int(0))}, + {"unavailable_input_counter1", AdoptAsUniquePtr(new int(0))}, + {"unavailable_input_counter2", AdoptAsUniquePtr(new int(0))}, + {"available_string", Adopt(new std::string("foo"))}}); + + // When the graph fails in initialize (even because of a PacketGenerator + // returning an error), status handlers should not be run. + ASSERT_THAT(graph->Initialize(config).ToString(), + testing::HasSubstr("FailingPacketGenerator")); + EXPECT_EQ(0, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(0, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); + + // Add an input side packet to the packet generator so that it doesn't + // run at initialize time. + config.mutable_packet_generator(0)->add_input_side_packet("unused_input"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + ResetCounters(&input_side_packets); + EXPECT_THAT(graph->Run(input_side_packets).ToString(), + testing::HasSubstr("FailingPacketGenerator")); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(0, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); + + // Replace the failing packet generator with something that works. All three + // status handlers should now process both the PRE and POST-RUN status. + config.mutable_packet_generator(0)->set_packet_generator( + "StaticCounterStringGenerator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + ResetCounters(&input_side_packets); + // The entire graph should still fail because of the FailingSourceCalculator. + EXPECT_THAT(graph->Run(input_side_packets).ToString(), + testing::HasSubstr("FailingSourceCalculator")); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); + + // Replace the failing calculator with something that works. All three + // status handlers should still process both PRE and POST-RUN status as part + // of the successful graph run. + config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + ResetCounters(&input_side_packets); + MEDIAPIPE_EXPECT_OK(graph->Run(input_side_packets)); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); + + ::mediapipe::Status run_status; + + // Make status handlers fail. The graph should fail. + // First, when the PRE_run fails + IncrementingStatusHandler::SetPreRunStatusResult( + ::mediapipe::InternalError("Fail at pre-run")); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + ResetCounters(&input_side_packets); + run_status = graph->Run(input_side_packets); + EXPECT_TRUE(run_status.code() == ::mediapipe::StatusCode::kInternal); + EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at pre-run")); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(0, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(0, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); + + // Second, when the POST_run fails + IncrementingStatusHandler::SetPreRunStatusResult(::mediapipe::OkStatus()); + IncrementingStatusHandler::SetPostRunStatusResult( + ::mediapipe::InternalError("Fail at post-run")); + graph.reset(new CalculatorGraph()); + MEDIAPIPE_ASSERT_OK(graph->Initialize(config)); + ResetCounters(&input_side_packets); + run_status = graph->Run(input_side_packets); + EXPECT_TRUE(run_status.code() == ::mediapipe::StatusCode::kInternal); + EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at post-run")); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); + EXPECT_EQ(1, + *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("available_input_counter2"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter1"))); + EXPECT_EQ(1, *GetFromUniquePtr( + input_side_packets.at("unavailable_input_counter2"))); +} + +// Test that calling SetOffset() in Calculator::Process() results in the +// ::mediapipe::StatusCode::kFailedPrecondition error. +TEST(CalculatorGraph, SetOffsetInProcess) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'SetOffsetInProcessCalculator' + input_stream: 'in' + output_stream: 'out' + } + )"); + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph.StartRun({})); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("in", MakePacket(0).At(Timestamp(0)))); + ::mediapipe::Status status = graph.WaitUntilIdle(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(::mediapipe::StatusCode::kFailedPrecondition, status.code()); +} + +// Test that MediaPipe releases input packets when it is done with them. +TEST(CalculatorGraph, InputPacketLifetime) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'mid' + output_stream: 'out' + } + )"); + + LifetimeTracker tracker; + Timestamp timestamp = Timestamp(0); + auto new_packet = [×tamp, &tracker] { + return Adopt(tracker.MakeObject().release()).At(++timestamp); + }; + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph.StartRun({})); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(0, tracker.live_count()); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(0, tracker.live_count()); + MEDIAPIPE_EXPECT_OK(graph.CloseInputStream("in")); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilDone()); +} + +// Test that SetNextTimestampBound propagates. +TEST(CalculatorGraph, SetNextTimestampBoundPropagation) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'gate' + node { + calculator: 'ValveCalculator' + input_stream: 'in' + input_stream: 'gate' + output_stream: 'gated' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'gated' + output_stream: 'passed' + } + node { + calculator: 'TimeShiftCalculator' + input_stream: 'passed' + output_stream: 'shifted' + input_side_packet: 'shift' + } + node { + calculator: 'MergeCalculator' + input_stream: 'in' + input_stream: 'shifted' + output_stream: 'merged' + } + node { + name: 'merged_output' + calculator: 'PassThroughCalculator' + input_stream: 'merged' + output_stream: 'out' + } + )"); + + Timestamp timestamp = Timestamp(0); + auto send_inputs = [&graph, ×tamp](int input, bool pass) { + ++timestamp; + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(input).At(timestamp))); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "gate", MakePacket(pass).At(timestamp))); + }; + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"shift", MakePacket(0)}})); + + auto pass_counter = + graph.GetCounterFactory()->GetCounter("ValveCalculator-PassThrough"); + auto block_counter = + graph.GetCounterFactory()->GetCounter("ValveCalculator-Block"); + auto merged_counter = + graph.GetCounterFactory()->GetCounter("merged_output-PassThrough"); + + send_inputs(1, true); + send_inputs(2, true); + send_inputs(3, false); + send_inputs(4, false); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + // Verify that MergeCalculator was able to run even when the gated branch + // was blocked. + EXPECT_EQ(2, pass_counter->Get()); + EXPECT_EQ(2, block_counter->Get()); + EXPECT_EQ(4, merged_counter->Get()); + + send_inputs(5, true); + send_inputs(6, false); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_EQ(3, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(6, merged_counter->Get()); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + // Now test with time shift + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"shift", MakePacket(-1)}})); + + send_inputs(7, true); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + // The merger should have run only once now, at timestamp 6, with inputs + // . If we do not respect the offset and unblock the merger for + // timestamp 7 too, then it will have run twice, with 6: and + // 7: <7, null>. + EXPECT_EQ(4, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(7, merged_counter->Get()); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(4, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(8, merged_counter->Get()); +} + +// Both input streams of the calculator node have the same next timestamp +// bound. One input stream has a packet at that timestamp. The other input +// stream is empty. We should not run the Process() method of the node in this +// case. +TEST(CalculatorGraph, NotAllInputPacketsAtNextTimestampBoundAvailable) { + // + // in0_unfiltered in1_to_be_filtered + // | | + // | V + // | +-----------------------+ + // | |EvenIntFilterCalculator| + // | +-----------------------+ + // | | + // \ / + // \ / in1_filtered + // \ / + // | | + // V V + // +------------------+ + // |IntAdderCalculator| + // +------------------+ + // | + // V + // out + // + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in0_unfiltered' + input_stream: 'in1_to_be_filtered' + node { + calculator: 'EvenIntFilterCalculator' + input_stream: 'in1_to_be_filtered' + output_stream: 'in1_filtered' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'in0_unfiltered' + input_stream: 'in1_filtered' + output_stream: 'out' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("out", &config, &packet_dump); + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + Timestamp timestamp = Timestamp(0); + + // We send an integer with timestamp 1 to the in0_unfiltered input stream of + // the IntAdderCalculator. We then send an even integer with timestamp 1 to + // the EvenIntFilterCalculator. This packet will go through and + // the IntAdderCalculator will run. The next timestamp bounds of both the + // input streams of the IntAdderCalculator will become 2. + + ++timestamp; // Timestamp 1. + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in0_unfiltered", MakePacket(1).At(timestamp))); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in1_to_be_filtered", MakePacket(2).At(timestamp))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(3, packet_dump[0].Get()); + + // We send an odd integer with timestamp 2 to the EvenIntFilterCalculator. + // This packet will be filtered out and the next timestamp bound of the + // in1_filtered input stream of the IntAdderCalculator will become 3. + + ++timestamp; // Timestamp 2. + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in1_to_be_filtered", MakePacket(3).At(timestamp))); + + // We send an integer with timestamp 3 to the in0_unfiltered input stream of + // the IntAdderCalculator. MediaPipe should propagate the next timestamp bound + // across the IntAdderCalculator but should not run its Process() method. + + ++timestamp; // Timestamp 3. + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in0_unfiltered", MakePacket(3).At(timestamp))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + + // We send an even integer with timestamp 3 to the IntAdderCalculator. This + // packet will go through and the IntAdderCalculator will run. + + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in1_to_be_filtered", MakePacket(4).At(timestamp))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(7, packet_dump[1].Get()); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(2, packet_dump.size()); +} + +// Demonstrate an if-then-else graph. +TEST(CalculatorGraph, IfThenElse) { + // This graph has an if-then-else structure. The left branch, selected by the + // select value 0, applies a double (multiply by 2) operation. The right + // branch, selected by the select value 1, applies a square operation. The + // left branch also has some no-op PassThroughCalculators to make the lengths + // of the two branches different. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'select' + node { + calculator: 'DemuxTimedCalculator' + input_stream: 'INPUT:in' + input_stream: 'SELECT:select' + output_stream: 'OUTPUT:0:left' + output_stream: 'OUTPUT:1:right' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'left' + output_stream: 'left1' + } + node { + calculator: 'DoubleIntCalculator' + input_stream: 'left1' + output_stream: 'left2' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'left2' + output_stream: 'left3' + } + node { + calculator: 'SquareIntCalculator' + input_stream: 'right' + output_stream: 'right1' + } + node { + calculator: 'MuxTimedCalculator' + input_stream: 'INPUT:0:left3' + input_stream: 'INPUT:1:right1' + input_stream: 'SELECT:select' + output_stream: 'OUTPUT:out' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("out", &config, &packet_dump); + + Timestamp timestamp = Timestamp(0); + auto send_inputs = [&graph, ×tamp](int input, int select) { + ++timestamp; + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(input).At(timestamp))); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "select", MakePacket(select).At(timestamp))); + }; + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + // If the "select" input is 0, we apply a double operation. If "select" is 1, + // we apply a square operation. To make the code easier to understand, define + // symbolic names for the select values. + const int kApplyDouble = 0; + const int kApplySquare = 1; + + send_inputs(1, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(2, packet_dump[0].Get()); + + send_inputs(2, kApplySquare); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(4, packet_dump[1].Get()); + + send_inputs(3, kApplyDouble); + send_inputs(4, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(4, packet_dump.size()); + EXPECT_EQ(6, packet_dump[2].Get()); + EXPECT_EQ(8, packet_dump[3].Get()); + + send_inputs(5, kApplySquare); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(5, packet_dump.size()); + EXPECT_EQ(25, packet_dump[4].Get()); + + send_inputs(6, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(6, packet_dump.size()); + EXPECT_EQ(12, packet_dump[5].Get()); + + send_inputs(7, kApplySquare); + send_inputs(8, kApplySquare); + send_inputs(9, kApplySquare); + send_inputs(10, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(10, packet_dump.size()); + EXPECT_EQ(49, packet_dump[6].Get()); + EXPECT_EQ(64, packet_dump[7].Get()); + EXPECT_EQ(81, packet_dump[8].Get()); + EXPECT_EQ(20, packet_dump[9].Get()); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(10, packet_dump.size()); +} + +// A simple output selecting test calculator, which omits timestamp bounds +// for the unselected outputs. +class DemuxUntimedCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); + cc->Inputs().Tag("INPUT").SetAny(); + cc->Inputs().Tag("SELECT").Set(); + for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); + id < cc->Outputs().EndId("OUTPUT"); ++id) { + cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT")); + } + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) final { + int index = cc->Inputs().Tag("SELECT").Get(); + if (!cc->Inputs().Tag("INPUT").IsEmpty()) { + cc->Outputs() + .Get("OUTPUT", index) + .AddPacket(cc->Inputs().Tag("INPUT").Value()); + } else { + cc->Outputs() + .Get("OUTPUT", index) + .SetNextTimestampBound(cc->InputTimestamp() + 1); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(DemuxUntimedCalculator); + +// Demonstrate an if-then-else graph. This test differs from the IfThenElse test +// in that it uses optional input streams instead of next timestamp bound +// propagation. +TEST(CalculatorGraph, IfThenElse2) { + // This graph has an if-then-else structure. The left branch, selected by the + // select value 0, applies a double (multiply by 2) operation. The right + // branch, selected by the select value 1, applies a square operation. The + // left branch also has some no-op PassThroughCalculators to make the lengths + // of the two branches different. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'select' + node { + calculator: 'DemuxUntimedCalculator' + input_stream: 'INPUT:in' + input_stream: 'SELECT:select' + output_stream: 'OUTPUT:0:left' + output_stream: 'OUTPUT:1:right' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'left' + output_stream: 'left1' + } + node { + calculator: 'DoubleIntCalculator' + input_stream: 'left1' + output_stream: 'left2' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'left2' + output_stream: 'left3' + } + node { + calculator: 'SquareIntCalculator' + input_stream: 'right' + output_stream: 'right1' + } + node { + calculator: 'MuxCalculator' + input_stream: 'INPUT:0:left3' + input_stream: 'INPUT:1:right1' + input_stream: 'SELECT:select' + output_stream: 'OUTPUT:out' + input_stream_handler { input_stream_handler: 'MuxInputStreamHandler' } + } + )"); + std::vector packet_dump; + tool::AddVectorSink("out", &config, &packet_dump); + + Timestamp timestamp = Timestamp(0); + auto send_inputs = [&graph, ×tamp](int input, int select) { + ++timestamp; + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(input).At(timestamp))); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "select", MakePacket(select).At(timestamp))); + }; + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + // If the "select" input is 0, we apply a double operation. If "select" is 1, + // we apply a square operation. To make the code easier to understand, define + // symbolic names for the select values. + const int kApplyDouble = 0; + const int kApplySquare = 1; + + send_inputs(1, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(2, packet_dump[0].Get()); + + send_inputs(2, kApplySquare); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(4, packet_dump[1].Get()); + + send_inputs(3, kApplyDouble); + send_inputs(4, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(4, packet_dump.size()); + EXPECT_EQ(6, packet_dump[2].Get()); + EXPECT_EQ(8, packet_dump[3].Get()); + + send_inputs(5, kApplySquare); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(5, packet_dump.size()); + EXPECT_EQ(25, packet_dump[4].Get()); + + send_inputs(6, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(6, packet_dump.size()); + EXPECT_EQ(12, packet_dump[5].Get()); + + send_inputs(7, kApplySquare); + send_inputs(8, kApplySquare); + send_inputs(9, kApplySquare); + send_inputs(10, kApplyDouble); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(10, packet_dump.size()); + EXPECT_EQ(49, packet_dump[6].Get()); + EXPECT_EQ(64, packet_dump[7].Get()); + EXPECT_EQ(81, packet_dump[8].Get()); + EXPECT_EQ(20, packet_dump[9].Get()); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(10, packet_dump.size()); +} + +// A regression test for bug 28321551. The scheduler should be able to run +// the calculator graph to completion without hanging. The test merely checks +// that CalculatorGraph::Run() returns. +TEST(CalculatorGraph, ClosedSourceNodeShouldNotBeUnthrottled) { + // This calculator graph has two source nodes. The first source node, + // OutputAllSourceCalculator, outputs a lot of packets in one shot and stops. + // The second source node, OutputOneAtATimeSourceCalculator, outputs one + // packet at a time. But it is connected to a node, DecimatorCalculator, + // that discards most of its input packets and only rarely outputs a packet. + // The sink node, MergeCalculator, receives three input streams, two from + // the two source nodes and one from DecimatorCalculator. The two input + // streams connected to the two source nodes will become full, and the + // MediaPipe scheduler will throttle the source nodes. + // + // The MediaPipe scheduler should not schedule a closed source node, even if + // the source node filled an input stream and the input stream changes from + // being "full" to "not full". + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + num_threads: 1 + max_queue_size: 100 + node { + calculator: 'OutputAllSourceCalculator' + output_stream: 'first_stream' + } + node { + calculator: 'OutputOneAtATimeSourceCalculator' + output_stream: 'second_stream' + } + node { + calculator: 'DecimatorCalculator' + input_stream: 'second_stream' + output_stream: 'decimated_second_stream' + } + node { + calculator: 'MergeCalculator' + input_stream: 'first_stream' + input_stream: 'second_stream' + input_stream: 'decimated_second_stream' + output_stream: 'output' + } + )"); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); +} + +// Tests that a calculator can output a packet in the Open() method. +// +// The initial output packet generated by UnitDelayCalculator::Open() causes +// the following to happen before the scheduler starts to run: +// - The downstream PassThroughCalculator becomes ready and is added to the +// scheduler queue. +// - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is +// throttled. +// The scheduler should be able to run the graph from this initial state. +TEST(CalculatorGraph, OutputPacketInOpen) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + max_queue_size: 1 + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'UnitDelayCalculator' + input_stream: 'integers' + output_stream: 'delayed_integers' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'delayed_integers' + output_stream: 'output' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("output", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1, + packet_dump.size()); + EXPECT_EQ(0, packet_dump[0].Get()); + EXPECT_EQ(Timestamp(0), packet_dump[0].Timestamp()); + for (int i = 1; i <= GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + EXPECT_EQ(i, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); + } +} + +// Tests that a calculator can output a packet in the Open() method. +// +// The initial output packet generated by UnitDelayCalculator::Open() causes +// the following to happen before the scheduler starts to run: +// - The downstream MergeCalculator does not become ready because its second +// input stream has no packet. +// - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is +// throttled. +// The scheduler must schedule a throttled source node from the beginning. +TEST(CalculatorGraph, OutputPacketInOpen2) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + max_queue_size: 1 + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'UnitDelayCalculator' + input_stream: 'integers' + output_stream: 'delayed_integers' + } + node { + calculator: 'MergeCalculator' + input_stream: 'delayed_integers' + input_stream: 'integers' + output_stream: 'output' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("output", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1, + packet_dump.size()); + int i; + for (i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + std::string expected = + absl::Substitute("Timestamp($0) $1 $2", + packet_dump[i].Timestamp().DebugString(), i, i + 1); + EXPECT_EQ(expected, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); + } + std::string expected = absl::Substitute( + "Timestamp($0) $1 empty", packet_dump[i].Timestamp().DebugString(), i); + EXPECT_EQ(expected, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); +} + +// Tests that no packets are available on input streams in Open(), even if the +// upstream calculator outputs a packet in Open(). +TEST(CalculatorGraph, EmptyInputInOpen) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + max_queue_size: 1 + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + # UnitDelayCalculator outputs a packet during Open(). + node { + calculator: 'UnitDelayCalculator' + input_stream: 'integers' + output_stream: 'delayed_integers' + } + node { + calculator: 'AssertEmptyInputInOpenCalculator' + input_stream: 'delayed_integers' + } + node { + calculator: 'AssertEmptyInputInOpenCalculator' + input_stream: 'integers' + } + )"); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph.Run(input_side_packets)); +} + +// Test for b/33568859. +TEST(CalculatorGraph, UnthrottleRespectsLayers) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + max_queue_size: 1 + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers0' + source_layer: 0 + } + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + input_side_packet: 'output_in_open' + output_stream: 'integers1' + source_layer: 1 + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'integers1' + output_stream: 'integers1passthrough' + } + )"); + + std::vector layer0_packets; + std::vector layer1_packets; + tool::AddVectorSink("integers0", &config, &layer0_packets); + tool::AddVectorSink("integers1passthrough", &config, &layer1_packets); + + std::atomic global_counter(0); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + // TODO: Set this value to true. When the calculator outputs a + // packet in Open, it will trigget b/33568859, and the test will fail. Use + // this test to verify that b/33568859 is fixed. + constexpr bool kOutputInOpen = true; + input_side_packets["output_in_open"] = MakePacket(kOutputInOpen); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + layer0_packets.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, + layer1_packets.size()); + // Check that we ran things in the expected order. + int count = 0; + if (kOutputInOpen) { + EXPECT_EQ(count, layer1_packets[0].Get()); + EXPECT_EQ(Timestamp(0), layer1_packets[0].Timestamp()); + ++count; + } + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; + ++i, ++count) { + EXPECT_EQ(count, layer0_packets[i].Get()); + EXPECT_EQ(Timestamp(i), layer0_packets[i].Timestamp()); + } + for (int i = kOutputInOpen ? 1 : 0; + i < GlobalCountSourceCalculator::kNumOutputPackets; ++i, ++count) { + EXPECT_EQ(count, layer1_packets[i].Get()); + EXPECT_EQ(Timestamp(i), layer1_packets[i].Timestamp()); + } +} + +// The graph calculates the sum of all the integers output by the source node +// so far. The graph has one cycle. +TEST(CalculatorGraph, Cycle) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'integers' + input_stream: 'old_sum' + input_stream_info: { + tag_index: ':1' # 'old_sum' + back_edge: true + } + output_stream: 'sum' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'UnitDelayCalculator' + input_stream: 'sum' + output_stream: 'old_sum' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("sum", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); + int sum = 0; + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + sum += i + 1; + EXPECT_EQ(sum, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); + } +} + +// The graph calculates the sum of all the integers output by the source node +// so far. The graph has one cycle. +// +// The difference from the "Cycle" test is that the graph is scheduled with +// packet timestamps ignored. +TEST(CalculatorGraph, CycleUntimed) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream_handler { + input_stream_handler: 'BarrierInputStreamHandler' + } + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'integers' + input_stream: 'old_sum' + input_stream_info: { + tag_index: ':1' # 'old_sum' + back_edge: true + } + output_stream: 'sum' + } + node { + calculator: 'UnitDelayUntimedCalculator' + input_stream: 'sum' + output_stream: 'old_sum' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("sum", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); + int sum = 0; + for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { + sum += i + 1; + EXPECT_EQ(sum, packet_dump[i].Get()); + } +} + +// This unit test is a direct form I implementation of Example 6.2 of +// Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.6. The system +// function of the linear time-invariant (LTI) system is +// H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2) +// The graph has two cycles. +TEST(CalculatorGraph, DirectFormI) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'IntToFloatCalculator' + input_stream: 'integers' + output_stream: 'x' + } + node { + calculator: 'FloatUnitDelayCalculator' + input_stream: 'x' + output_stream: 'a' + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'a' + output_stream: 'b' + input_side_packet: 'b1' + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'x' + input_stream: 'b' + output_stream: 'c' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'c' + input_stream: 'f' + input_stream_info: { + tag_index: ':1' # 'f' + back_edge: true + } + output_stream: 'y' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'FloatUnitDelayCalculator' + input_stream: 'y' + output_stream: 'd' + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'd' + output_stream: 'e' + input_side_packet: 'a1' + } + node { + calculator: 'FloatUnitDelayCalculator' + input_stream: 'd' + output_stream: 'g' + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'g' + output_stream: 'h' + input_side_packet: 'a2' + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'e' + input_stream: 'h' + output_stream: 'f' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + )"); + std::vector packet_dump; + tool::AddVectorSink("y", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + input_side_packets["a2"] = Adopt(new float(-0.9)); + input_side_packets["a1"] = Adopt(new float(1.5)); + input_side_packets["b1"] = Adopt(new float(2.0)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5); + EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get()); + EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get()); + EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get()); + EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get()); + EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get()); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); + } +} + +// This unit test is a direct form II implementation of Example 6.2 of +// Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.7. The system +// function of the linear time-invariant (LTI) system is +// H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2) +// The graph has two cycles. +TEST(CalculatorGraph, DirectFormII) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'GlobalCountSourceCalculator' + input_side_packet: 'global_counter' + output_stream: 'integers' + } + node { + calculator: 'IntToFloatCalculator' + input_stream: 'integers' + output_stream: 'x' + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'x' + input_stream: 'f' + input_stream_info: { + tag_index: ':1' # 'f' + back_edge: true + } + output_stream: 'a' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'FloatUnitDelayCalculator' + input_stream: 'a' + output_stream: 'b' + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'b' + output_stream: 'd' + input_side_packet: 'a1' + } + node { + calculator: 'FloatUnitDelayCalculator' + input_stream: 'b' + output_stream: 'c' + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'c' + output_stream: 'e' + input_side_packet: 'a2' + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'd' + input_stream: 'e' + output_stream: 'f' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'FloatScalarMultiplierCalculator' + input_stream: 'b' + output_stream: 'g' + input_side_packet: 'b1' + } + node { + calculator: 'FloatAdderCalculator' + input_stream: 'a' + input_stream: 'g' + output_stream: 'y' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + )"); + std::vector packet_dump; + tool::AddVectorSink("y", &config, &packet_dump); + + std::atomic global_counter(1); + std::map input_side_packets; + input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); + input_side_packets["a2"] = Adopt(new float(-0.9)); + input_side_packets["a1"] = Adopt(new float(1.5)); + input_side_packets["b1"] = Adopt(new float(2.0)); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run(input_side_packets)); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); + ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5); + EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get()); + EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get()); + EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get()); + EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get()); + EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get()); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); + } +} + +// Calculates the dot products of two streams of three-dimensional vectors. +TEST(CalculatorGraph, DotProduct) { + // The use of BarrierInputStreamHandler in this graph aligns the input + // packets to a calculator by arrival order rather than by timestamp. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream_handler { + input_stream_handler: 'BarrierInputStreamHandler' + } + node { + calculator: 'TestSequence1SourceCalculator' + output_stream: 'test_sequence_1' + } + node { + calculator: 'TestSequence2SourceCalculator' + output_stream: 'test_sequence_2' + } + node { + calculator: 'Modulo3SourceCalculator' + output_stream: 'select_0_1_2' + } + node { + calculator: 'DemuxUntimedCalculator' + input_stream: 'INPUT:test_sequence_1' + input_stream: 'SELECT:select_0_1_2' + output_stream: 'OUTPUT:0:x_1' + output_stream: 'OUTPUT:1:y_1' + output_stream: 'OUTPUT:2:z_1' + } + node { + calculator: 'DemuxUntimedCalculator' + input_stream: 'INPUT:test_sequence_2' + input_stream: 'SELECT:select_0_1_2' + output_stream: 'OUTPUT:0:x_2' + output_stream: 'OUTPUT:1:y_2' + output_stream: 'OUTPUT:2:z_2' + } + node { + calculator: 'IntMultiplierCalculator' + input_stream: 'x_1' + input_stream: 'x_2' + output_stream: 'x_product' + } + node { + calculator: 'IntMultiplierCalculator' + input_stream: 'y_1' + input_stream: 'y_2' + output_stream: 'y_product' + } + node { + calculator: 'IntMultiplierCalculator' + input_stream: 'z_1' + input_stream: 'z_2' + output_stream: 'z_product' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'x_product' + input_stream: 'y_product' + input_stream: 'z_product' + output_stream: 'dot_product' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("dot_product", &config, &packet_dump); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); + + // The calculator graph performs the following computation: + // test_sequence_1 is split into x_1, y_1, z_1. + // test_sequence_2 is split into x_2, y_2, z_2. + // x_product = x_1 * x_2 + // y_product = y_1 * y_2 + // z_product = z_1 * z_2 + // dot_product = x_product + y_product + z_product + // + // The values in these streams are: + // test_sequence_1: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + // test_sequence_2: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // x_1: 0, 3, 6, 9, 12 + // x_2: 1, 4, 7, 10, 13 + // x_product: 0, 12, 42, 90, 156 + // y_1: 1, 4, 7, 10, 13 + // y_2: 2, 5, 8, 11, 14 + // y_product: 2, 20, 56, 110, 182 + // z_1: 2, 5, 8, 11, 14 + // z_2: 3, 6, 9, 12, 15 + // z_product: 6, 30, 72, 132, 210 + // dot_product: 8, 62, 170, 332, 548 + + ASSERT_EQ(kTestSequenceLength / 3, packet_dump.size()); + const int expected[] = {8, 62, 170, 332, 548}; + for (int i = 0; i < packet_dump.size(); ++i) { + EXPECT_EQ(expected[i], packet_dump[i].Get()); + } +} + +TEST(CalculatorGraph, TerminatesOnCancelWithOpenGraphInputStreams) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'PassThroughCalculator' + input_stream: 'in_a' + input_stream: 'in_b' + output_stream: 'out_a' + output_stream: 'out_b' + } + input_stream: 'in_a' + input_stream: 'in_b' + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in_a", MakePacket(1).At(Timestamp(1)))); + MEDIAPIPE_EXPECT_OK(graph.CloseInputStream("in_a")); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in_b", MakePacket(2).At(Timestamp(2)))); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); + graph.Cancel(); + // This tests that the graph doesn't deadlock on WaitUntilDone (because + // the scheduler thread is sleeping). + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kCancelled); +} + +TEST(CalculatorGraph, TerminatesOnCancelAfterPause) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'out' + } + input_stream: 'in' + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + graph.Pause(); + // Make the PassThroughCalculator runnable while the scheduler is paused. + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("in", MakePacket(1).At(Timestamp(1)))); + // Now cancel the graph run. A non-empty scheduler queue should not prevent + // the scheduler from terminating. + graph.Cancel(); + // Any attempt to pause the scheduler after the graph run is cancelled should + // be ignored. + graph.Pause(); + // This tests that the graph doesn't deadlock on WaitUntilDone (because + // the scheduler thread is sleeping). + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kCancelled); +} + +// A PacketGenerator that simply passes its input Packets through +// unchanged. The inputs may be specified by tag or index. The outputs +// must match the inputs exactly. Any options may be specified and will +// also be ignored. +class PassThroughGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs, + PacketTypeSet* outputs) { + if (!inputs->TagMap()->SameAs(*outputs->TagMap())) { + return ::mediapipe::InvalidArgumentError( + "Input and outputs to PassThroughGenerator must use the same tags " + "and indexes."); + } + for (CollectionItemId id = inputs->BeginId(); id < inputs->EndId(); ++id) { + inputs->Get(id).SetAny(); + outputs->Get(id).SetSameAs(&inputs->Get(id)); + } + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + for (CollectionItemId id = input_side_packets.BeginId(); + id < input_side_packets.EndId(); ++id) { + output_side_packets->Get(id) = input_side_packets.Get(id); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(PassThroughGenerator); + +TEST(CalculatorGraph, SharePacketGeneratorGraph) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count1' + input_side_packet: 'MAX_COUNT:max_count1' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count2' + input_side_packet: 'MAX_COUNT:max_count2' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count3' + input_side_packet: 'MAX_COUNT:max_count3' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count4' + input_side_packet: 'MAX_COUNT:max_count4' + } + node { + calculator: 'PassThroughCalculator' + input_side_packet: 'MAX_COUNT:max_count5' + output_side_packet: 'MAX_COUNT:max_count6' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count5' + input_side_packet: 'MAX_COUNT:max_count6' + } + packet_generator { + packet_generator: 'PassThroughGenerator' + input_side_packet: 'max_count1' + output_side_packet: 'max_count2' + } + packet_generator { + packet_generator: 'PassThroughGenerator' + input_side_packet: 'max_count4' + output_side_packet: 'max_count5' + } + )"); + + // At this point config is a standard config which specifies both + // calculators and packet_factories/packet_genators. The following + // code is an example of reusing side packets across a number of + // CalculatorGraphs. It is particularly informative to note how each + // side packet is created. + // + // max_count1 is set for all graphs by a PacketFactory in the config. + // The side packet is created by generator_graph.InitializeGraph(). + // + // max_count2 is set for all graphs by a PacketGenerator in the config. + // The side packet is created by generator_graph.InitializeGraph() + // because max_count1 is available at that time. + // + // max_count3 is set for all graphs by directly being specified as an + // argument to generator_graph.InitializeGraph(). + // + // max_count4 is set per graph because it is directly specified as an + // argument to generator_graph.ProcessGraph(). + // + // max_count5 is set per graph by a PacketGenerator which is run when + // generator_graph.ProcessGraph() is run (because max_count4 isn't + // available until then). + + // Before anything else, split the graph config into two parts, one + // with the PacketFactory and PacketGenerator config and the other + // with the Calculator config. + CalculatorGraphConfig calculator_config = config; + calculator_config.clear_packet_factory(); + calculator_config.clear_packet_generator(); + CalculatorGraphConfig generator_config = config; + generator_config.clear_node(); + + // Next, create a ValidatedGraphConfig for both configs. + ValidatedGraphConfig validated_calculator_config; + MEDIAPIPE_ASSERT_OK( + validated_calculator_config.Initialize(calculator_config)); + ValidatedGraphConfig validated_generator_config; + MEDIAPIPE_ASSERT_OK(validated_generator_config.Initialize(generator_config)); + + // Create a PacketGeneratorGraph. Side packets max_count1, max_count2, + // and max_count3 are created upon initialization. + // Note that validated_generator_config must outlive generator_graph. + PacketGeneratorGraph generator_graph; + MEDIAPIPE_ASSERT_OK( + generator_graph.Initialize(&validated_generator_config, nullptr, + {{"max_count1", MakePacket(10)}, + {"max_count3", MakePacket(20)}})); + ASSERT_THAT(generator_graph.BasePackets(), + testing::ElementsAre(testing::Key("max_count1"), + testing::Key("max_count2"), + testing::Key("max_count3"))); + + // Create a bunch of graphs. + std::vector> graphs; + for (int i = 0; i < 100; ++i) { + graphs.emplace_back(absl::make_unique()); + // Do not pass extra side packets here. + // Note that validated_calculator_config must outlive the graph. + MEDIAPIPE_ASSERT_OK(graphs.back()->Initialize(calculator_config, {})); + } + // Run a bunch of graphs, reusing side packets max_count1, max_count2, + // and max_count3. The side packet max_count4 is added per run, + // and triggers the execution of a packet generator which generates + // max_count5. + for (int i = 0; i < 100; ++i) { + std::map all_side_packets; + // Creates max_count4 and max_count5. + MEDIAPIPE_ASSERT_OK(generator_graph.RunGraphSetup( + {{"max_count4", MakePacket(30 + i)}}, &all_side_packets)); + ASSERT_THAT(all_side_packets, + testing::ElementsAre( + testing::Key("max_count1"), testing::Key("max_count2"), + testing::Key("max_count3"), testing::Key("max_count4"), + testing::Key("max_count5"))); + // Pass all the side packets prepared by generator_graph here. + MEDIAPIPE_ASSERT_OK(graphs[i]->Run(all_side_packets)); + // TODO Verify the actual output. + } + + // Destroy all the graphs. + graphs.clear(); +} + +TEST(CalculatorGraph, RecoverAfterRunError) { + PacketGeneratorGraph generator_graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + name: 'calculator1' + calculator: 'CountingSourceCalculator' + output_stream: 'count1' + input_side_packet: 'MAX_COUNT:max_count2' + input_side_packet: 'ERROR_COUNT:max_error2' + } + packet_generator { + packet_generator: 'EnsurePositivePacketGenerator' + input_side_packet: 'max_count1' + output_side_packet: 'max_count2' + input_side_packet: 'max_error1' + output_side_packet: 'max_error2' + } + status_handler { + status_handler: 'FailableStatusHandler' + input_side_packet: 'status_handler_command' + } + )"); + + int packet_count = 0; + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config, {})); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "count1", [&packet_count](const Packet& packet) { + ++packet_count; + return ::mediapipe::OkStatus(); + })); + // Set ERROR_COUNT higher than MAX_COUNT and hence the calculator will + // finish successfully. + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + // Fail in PacketGenerator::Generate(). + // Negative max_count1 will cause EnsurePositivePacketGenerator to fail. + ASSERT_FALSE(graph + .Run({{"max_count1", MakePacket(-1)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + // Fail in PacketGenerator::Generate() also fail in StatusHandler. + ASSERT_FALSE(graph + .Run({{"max_count1", MakePacket(-1)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kFailPreRun)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + ASSERT_FALSE( + graph + .Run({{"max_count1", MakePacket(-1)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kFailPostRun)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + // Fail in Calculator::Process(). + ASSERT_FALSE(graph + .Run({{"max_count1", MakePacket(1000)}, + {"max_error1", MakePacket(10)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + // Fail in Calculator::Process() also fail in StatusHandler. + ASSERT_FALSE(graph + .Run({{"max_count1", MakePacket(1000)}, + {"max_error1", MakePacket(10)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kFailPreRun)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); + ASSERT_FALSE( + graph + .Run({{"max_count1", MakePacket(1000)}, + {"max_error1", MakePacket(10)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kFailPostRun)}}) + .ok()); + packet_count = 0; + MEDIAPIPE_ASSERT_OK( + graph.Run({{"max_count1", MakePacket(10)}, + {"max_error1", MakePacket(20)}, + {"status_handler_command", + MakePacket(FailableStatusHandler::kOk)}})); + EXPECT_EQ(packet_count, 10); +} + +TEST(CalculatorGraph, SetInputStreamMaxQueueSizeWorksSlowCalculator) { + using Semaphore = SemaphoreCalculator::Semaphore; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'SemaphoreCalculator' + input_stream: 'in' + output_stream: 'out' + input_side_packet: 'POST_SEM:post_sem' + input_side_packet: 'WAIT_SEM:wait_sem' + } + node { + calculator: 'SemaphoreCalculator' + input_stream: 'in_2' + output_stream: 'out_2' + input_side_packet: 'POST_SEM:post_sem_busy' + input_side_packet: 'WAIT_SEM:wait_sem_busy' + } + input_stream: 'in' + input_stream: 'in_2' + max_queue_size: 100 + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); + MEDIAPIPE_ASSERT_OK(graph.SetInputStreamMaxQueueSize("in", 1)); + + Semaphore calc_entered_process(0); + Semaphore calc_can_exit_process(0); + Semaphore calc_entered_process_busy(0); + Semaphore calc_can_exit_process_busy(0); + MEDIAPIPE_ASSERT_OK(graph.StartRun({ + {"post_sem", MakePacket(&calc_entered_process)}, + {"wait_sem", MakePacket(&calc_can_exit_process)}, + {"post_sem_busy", MakePacket(&calc_entered_process_busy)}, + {"wait_sem_busy", MakePacket(&calc_can_exit_process_busy)}, + })); + + Timestamp timestamp(0); + // Prevent deadlock resolution by running the "busy" SemaphoreCalculator + // for the duration of the test. + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("in_2", MakePacket(0).At(timestamp))); + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("in", MakePacket(0).At(timestamp++))); + for (int i = 1; i < 20; ++i, ++timestamp) { + // Wait for the calculator to begin its Process call. + calc_entered_process.Acquire(1); + // Now the calculator is stuck processing a packet. We can queue up + // another one. + MEDIAPIPE_EXPECT_OK( + graph.AddPacketToInputStream("in", MakePacket(i).At(timestamp))); + // We should be prevented from adding another, since the queue is now full. + ::mediapipe::Status status = graph.AddPacketToInputStream( + "in", MakePacket(i).At(timestamp + 1)); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnavailable); + // Allow calculator to complete its Process call. + calc_can_exit_process.Release(1); + } + // Allow the final Process call to complete. + calc_can_exit_process.Release(1); + calc_can_exit_process_busy.Release(1); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +// Verify the scheduler unthrottles the graph input stream to avoid a deadlock, +// and won't enter a busy loop. +TEST(CalculatorGraph, AddPacketNoBusyLoop) { + // The DecimatorCalculator ouputs 1 out of every 101 input packets and drops + // the rest, without setting the next timestamp bound on its output. As a + // result, the MergeCalculator is not runnable in between and packets on its + // "in" input stream will be queued and exceed the max queue size. + // + // in + // | + // / \ + // / \ + // / \ + // | \ + // v | + // +---------+ | + // 101:1 |Decimator| | <== Packet buildup + // +---------+ | + // | | + // v v + // +----------+ + // | Merge | + // +----------+ + // | + // v + // out + // + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + max_queue_size: 1 + node { + calculator: 'DecimatorCalculator' + input_stream: 'in' + output_stream: 'decimated_in' + } + node { + calculator: 'MergeCalculator' + input_stream: 'decimated_in' + input_stream: 'in' + output_stream: 'out' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + graph.SetGraphInputStreamAddMode( + CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); + std::vector out_packets; // Packets from the output stream "out". + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + const int kDecimationRatio = DecimatorCalculator::kDecimationRatio; + // To leave the graph input stream "in" in the throttled state, kNumPackets + // can be any value other than a multiple of kDecimationRatio plus one. + const int kNumPackets = 2 * kDecimationRatio; + for (int i = 0; i < kNumPackets; ++i) { + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(i).At(Timestamp(i)))); + } + + // The graph input stream "in" is throttled. Wait until the graph is idle. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + // Check that Pause() does not block forever trying to acquire a mutex. + // This is a regression test for an old bug. + graph.Pause(); + graph.Resume(); + + MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + // The expected output packets are: + // "Timestamp(0) 0 0" + // "Timestamp(1) empty 1" + // ... + // "Timestamp(100) empty 100" + // "Timestamp(101) 101 101" + // "Timestamp(102) empty 102" + // ... + // "Timestamp(201) empty 201" + ASSERT_EQ(kNumPackets, out_packets.size()); + for (int i = 0; i < out_packets.size(); ++i) { + std::string format = (i % kDecimationRatio == 0) ? "Timestamp($0) $0 $0" + : "Timestamp($0) empty $0"; + std::string expected = absl::Substitute(format, i); + EXPECT_EQ(expected, out_packets[i].Get()); + EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); + } +} + +namespace nested_ns { + +typedef std::function<::mediapipe::Status(const InputStreamShardSet&, + OutputStreamShardSet*)> + ProcessFunction; + +// A Calculator that delegates its Process function to a callback function. +class ProcessCallbackCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); + } + cc->InputSidePackets().Index(0).Set>(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + callback_ = + *GetFromUniquePtr(cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return callback_(cc->Inputs(), &(cc->Outputs())); + } + + private: + ProcessFunction callback_; +}; +REGISTER_CALCULATOR(::mediapipe::nested_ns::ProcessCallbackCalculator); + +} // namespace nested_ns + +TEST(CalculatorGraph, CalculatorInNamepsace) { + CalculatorGraphConfig config; + CHECK(proto_ns::TextFormat::ParseFromString(R"( + input_stream: 'in_a' + node { + calculator: 'mediapipe.nested_ns.ProcessCallbackCalculator' + input_stream: 'in_a' + output_stream: 'out_a' + input_side_packet: 'callback_1' + } + )", + &config)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + nested_ns::ProcessFunction callback_1; + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"callback_1", AdoptAsUniquePtr(new auto(callback_1))}})); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); +} + +// A ProcessFunction that passes through all packets. +::mediapipe::Status DoProcess(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { + for (int i = 0; i < inputs.NumEntries(); ++i) { + if (!inputs.Index(i).Value().IsEmpty()) { + outputs->Index(i).AddPacket(inputs.Index(i).Value()); + } + } + return ::mediapipe::OkStatus(); +} + +TEST(CalculatorGraph, ObserveOutputStream) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count' + input_side_packet: 'MAX_COUNT:max_count' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'count' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'mid' + output_stream: 'out' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + // Observe the internal output stream "count" and the unconnected output + // stream "out". + std::vector count_packets; // Packets from the output stream "count". + std::vector out_packets; // Packets from the output stream "out". + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "count", [&count_packets](const Packet& packet) { + count_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.Run()); + ASSERT_EQ(max_count, count_packets.size()); + for (int i = 0; i < count_packets.size(); ++i) { + EXPECT_EQ(i, count_packets[i].Get()); + EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp()); + } + ASSERT_EQ(max_count, out_packets.size()); + for (int i = 0; i < out_packets.size(); ++i) { + EXPECT_EQ(i, out_packets[i].Get()); + EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); + } +} + +class PassThroughSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'INPUT:input' + output_stream: 'OUTPUT:output' + node { + calculator: 'PassThroughCalculator' + input_stream: 'input' + output_stream: 'output' + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(PassThroughSubgraph); + +TEST(CalculatorGraph, ObserveOutputStreamSubgraph) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count' + input_side_packet: 'MAX_COUNT:max_count' + } + node { + calculator: 'PassThroughSubgraph' + input_stream: 'INPUT:count' + output_stream: 'OUTPUT:out' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + // Observe the unconnected output stream "out". + std::vector out_packets; // Packets from the output stream "out". + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.Run()); + ASSERT_EQ(max_count, out_packets.size()); + for (int i = 0; i < out_packets.size(); ++i) { + EXPECT_EQ(i, out_packets[i].Get()); + EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); + } +} + +TEST(CalculatorGraph, ObserveOutputStreamError) { + const int max_count = 10; + const int fail_count = 6; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count' + input_side_packet: 'MAX_COUNT:max_count' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'count' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'mid' + output_stream: 'out' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + // Observe the internal output stream "count" and the unconnected output + // stream "out". + std::vector count_packets; // Packets from the output stream "count". + std::vector out_packets; // Packets from the output stream "out". + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "count", [&count_packets](const Packet& packet) { + count_packets.push_back(packet); + if (count_packets.size() >= fail_count) { + return ::mediapipe::UnknownError("Expected. MagicString-eatnhuea"); + } else { + return ::mediapipe::OkStatus(); + } + })); + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { + out_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + ::mediapipe::Status status = graph.Run(); + ASSERT_THAT(status.message(), testing::HasSubstr("MagicString-eatnhuea")); + ASSERT_EQ(fail_count, count_packets.size()); + for (int i = 0; i < count_packets.size(); ++i) { + EXPECT_EQ(i, count_packets[i].Get()); + EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp()); + } +} + +TEST(CalculatorGraph, ObserveOutputStreamNonexistent) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count' + input_side_packet: 'MAX_COUNT:max_count' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'count' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'mid' + output_stream: 'out' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + // Observe the internal output stream "count". + std::vector count_packets; // Packets from the output stream "count". + ::mediapipe::Status status = graph.ObserveOutputStream( + "not_found", [&count_packets](const Packet& packet) { + count_packets.push_back(packet); + return ::mediapipe::OkStatus(); + }); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kNotFound); + EXPECT_THAT(status.message(), testing::HasSubstr("not_found")); +} + +// Verify that after a fast source node is closed, a slow sink node can +// consume all the accumulated input packets. In other words, closing an +// output stream still allows its mirrors to process all the received packets. +TEST(CalculatorGraph, FastSourceSlowSink) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + num_threads: 2 + max_queue_size: 100 + node { + calculator: 'CountingSourceCalculator' + output_stream: 'out' + input_side_packet: 'MAX_COUNT:max_count' + } + node { calculator: 'SlowCountingSinkCalculator' input_stream: 'out' } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + MEDIAPIPE_EXPECT_OK(graph.Run()); +} + +TEST(CalculatorGraph, GraphFinishesWhilePaused) { + // The graph contains only one node, and the node runs only once. This test + // sets up the following sequence of events (all times in milliseconds): + // + // Application thread Worker thread + // + // T=0 graph.StartRun OneShot20MsCalculator::Process starts + // T=10 graph.Pause + // T=20 OneShot20MsCalculator::Process ends. + // So graph finishes running while paused. + // T=30 graph.Resume + // + // graph.WaitUntilDone must not block forever. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { calculator: 'OneShot20MsCalculator' } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_EXPECT_OK(graph.StartRun({})); + absl::SleepFor(absl::Milliseconds(10)); + graph.Pause(); + absl::SleepFor(absl::Milliseconds(20)); + graph.Resume(); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilDone()); +} + +// There should be no memory leaks, no error messages (requires manual +// inspection of the test log), etc. +TEST(CalculatorGraph, ConstructAndDestruct) { CalculatorGraph graph; } + +// A regression test for b/36364314. UnitDelayCalculator outputs a packet in +// Open(). ErrorOnOpenCalculator fails in Open() if ERROR_ON_OPEN is true. +TEST(CalculatorGraph, RecoverAfterPreviousFailInOpen) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'a' + input_side_packet: 'MAX_COUNT:max_count' + } + node { + calculator: 'UnitDelayCalculator' + input_stream: 'a' + output_stream: 'b' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'b' + output_stream: 'c' + } + node { + calculator: 'ErrorOnOpenCalculator' + input_stream: 'c' + output_stream: 'd' + input_side_packet: 'ERROR_ON_OPEN:fail' + } + node { calculator: 'IntSinkCalculator' input_stream: 'd' } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); + for (int i = 0; i < 2; ++i) { + EXPECT_FALSE(graph.Run({{"fail", MakePacket(true)}}).ok()); + MEDIAPIPE_EXPECT_OK(graph.Run({{"fail", MakePacket(false)}})); + } +} + +TEST(CalculatorGraph, PropagateBoundLoop) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'OutputAndBoundSourceCalculator' + output_stream: 'integers' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'integers' + input_stream: 'old_sum' + input_stream_info: { + tag_index: ':1' # 'old_sum' + back_edge: true + } + output_stream: 'sum' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'Delay20Calculator' + input_stream: 'sum' + output_stream: 'old_sum' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("sum", &config, &packet_dump); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); + ASSERT_EQ(101, packet_dump.size()); + int sum = 0; + for (int i = 0; i < 101; ++i) { + sum += i; + EXPECT_EQ(sum, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i * 20), packet_dump[i].Timestamp()); + } +} + +TEST(CalculatorGraph, ReuseValidatedGraphConfig) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + output_side_packet: "foo1" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "foo1" + output_side_packet: "foo2" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "input_in_run" + output_side_packet: "foo3" + } + packet_generator { + packet_generator: "StaticCounterStringGenerator" + input_side_packet: "created_by_factory" + input_side_packet: "input_in_initialize" + input_side_packet: "input_in_run" + input_side_packet: "foo3" + output_side_packet: "foo4" + } + node { + calculator: "GlobalCountSourceCalculator" + input_side_packet: "global_counter" + output_stream: "unused" + } + )"); + ValidatedGraphConfig validated_graph; + MEDIAPIPE_ASSERT_OK(validated_graph.Initialize(config)); + + std::atomic global_counter(0); + Packet global_counter_packet = Adopt(new auto(&global_counter)); + + absl::FixedArray graphs(30); + for (int i = 0; i < graphs.size(); ++i) { + CalculatorGraph& graph = graphs[i]; + int initial_generator_count = + StaticCounterStringGenerator::NumPacketsGenerated(); + int initial_calculator_count = global_counter.load(); + MEDIAPIPE_ASSERT_OK(graph.Initialize( + config, + {{"created_by_factory", MakePacket("default string")}, + {"input_in_initialize", MakePacket(10)}, + {"global_counter", global_counter_packet}})); + EXPECT_EQ(initial_generator_count + 2, + StaticCounterStringGenerator::NumPacketsGenerated()); + EXPECT_EQ(initial_calculator_count, global_counter.load()); + } + for (int k = 0; k < 10; ++k) { + for (int i = 0; i < graphs.size(); ++i) { + CalculatorGraph& graph = graphs[i]; + int initial_generator_count = + StaticCounterStringGenerator::NumPacketsGenerated(); + int initial_calculator_count = global_counter.load(); + MEDIAPIPE_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(11)}})); + EXPECT_EQ(initial_generator_count + 2, + StaticCounterStringGenerator::NumPacketsGenerated()); + EXPECT_EQ(initial_calculator_count + + GlobalCountSourceCalculator::kNumOutputPackets, + global_counter.load()); + } + } +} + +class TestRangeStdDevSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_side_packet: 'node_converted' + output_stream: 'DATA:range' + output_stream: 'SUM:range_sum' + output_stream: 'MEAN:range_mean' + output_stream: 'STDDEV:range_stddev' + node { + calculator: 'RangeCalculator' + output_stream: 'range' + output_stream: 'range_sum' + output_stream: 'range_mean' + input_side_packet: 'node_converted' + } + node { + calculator: 'StdDevCalculator' + input_stream: 'DATA:range' + input_stream: 'MEAN:range_mean' + output_stream: 'range_stddev' + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(TestRangeStdDevSubgraph); + +class TestMergeSaverSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'DATA1:range1' + input_stream: 'DATA2:range2' + output_stream: 'MERGE:merge' + output_stream: 'FINAL:final' + node { + name: 'merger' + calculator: 'MergeCalculator' + input_stream: 'range1' + input_stream: 'range2' + output_stream: 'merge' + } + node { + calculator: 'SaverCalculator' + input_stream: 'merge' + output_stream: 'final' + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(TestMergeSaverSubgraph); + +CalculatorGraphConfig GetConfigWithSubgraphs() { + CalculatorGraphConfig proto = + ::mediapipe::ParseTextProtoOrDie(R"( + # Ensure stream name for FindOutputStreamManager + output_stream: 'MERGE:merge' + packet_generator { + packet_generator: 'IntSplitterPacketGenerator' + input_side_packet: 'node_3' + output_side_packet: 'node_3_converted' + } + packet_generator { + packet_generator: 'TaggedIntSplitterPacketGenerator' + input_side_packet: 'node_5' + output_side_packet: 'HIGH:unused_high' + output_side_packet: 'LOW:unused_low' + output_side_packet: 'PAIR:node_5_converted' + } + node { + calculator: 'TestRangeStdDevSubgraph' + input_side_packet: 'node_3_converted' + output_stream: 'DATA:range3' + output_stream: 'SUM:range3_sum' + output_stream: 'MEAN:range3_mean' + output_stream: 'STDDEV:range3_stddev' + } + node { + calculator: 'TestRangeStdDevSubgraph' + input_side_packet: 'node_5_converted' + output_stream: 'DATA:range5' + output_stream: 'SUM:range5_sum' + output_stream: 'MEAN:range5_mean' + output_stream: 'STDDEV:range5_stddev' + } + node { + name: 'copy_range5' + calculator: 'PassThroughCalculator' + input_stream: 'range5' + output_stream: 'range5_copy' + } + node { + calculator: 'TestMergeSaverSubgraph' + input_stream: 'DATA1:range3' + input_stream: 'DATA2:range5_copy' + output_stream: 'MERGE:merge' + output_stream: 'FINAL:final' + } + node { + calculator: 'TestMergeSaverSubgraph' + input_stream: 'DATA1:range3_sum' + input_stream: 'DATA2:range5_sum' + output_stream: 'FINAL:final_sum' + } + node { + calculator: 'TestMergeSaverSubgraph' + input_stream: 'DATA1:range3_stddev' + input_stream: 'DATA2:range5_stddev' + output_stream: 'FINAL:final_stddev' + } + )"); + return proto; +} + +TEST(CalculatorGraph, RunsCorrectlyWithSubgraphs) { + CalculatorGraph graph; + CalculatorGraphConfig proto = GetConfigWithSubgraphs(); + RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); +} + +TEST(CalculatorGraph, SetExecutorTwice) { + // SetExecutor must not be called more than once for the same executor name. + CalculatorGraph graph; + MEDIAPIPE_EXPECT_OK( + graph.SetExecutor("xyz", std::make_shared(1))); + MEDIAPIPE_EXPECT_OK( + graph.SetExecutor("abc", std::make_shared(1))); + ::mediapipe::Status status = + graph.SetExecutor("xyz", std::make_shared(1)); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kAlreadyExists); + EXPECT_THAT(status.message(), testing::HasSubstr("xyz")); +} + +TEST(CalculatorGraph, ReservedNameSetExecutor) { + // A reserved executor name such as "__gpu" must not be used. + CalculatorGraph graph; + ::mediapipe::Status status = + graph.SetExecutor("__gpu", std::make_shared(1)); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), + testing::HasSubstr("reserved"))); +} + +TEST(CalculatorGraph, ReservedNameExecutorConfig) { + // A reserved executor name such as "__gpu" must not be used. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + executor { + name: '__gpu' + type: 'ThreadPoolExecutor' + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), + testing::HasSubstr("reserved"))); +} + +TEST(CalculatorGraph, ReservedNameNodeExecutor) { + // A reserved executor name such as "__gpu" must not be used. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PassThroughCalculator' + executor: '__gpu' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), + testing::HasSubstr("reserved"))); +} + +TEST(CalculatorGraph, NonExistentExecutor) { + // Any executor used by a calculator node must either be created by the + // graph (which requires an ExecutorConfig with a "type" field) or be + // provided to the graph with a CalculatorGraph::SetExecutor() call. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("xyz"), + testing::HasSubstr("not declared"))); +} + +TEST(CalculatorGraph, UndeclaredExecutor) { + // Any executor used by a calculator node must be declared in an + // ExecutorConfig, even if the executor is provided to the graph with a + // CalculatorGraph::SetExecutor() call. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("xyz", std::make_shared(1))); + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("xyz"), + testing::HasSubstr("not declared"))); +} + +TEST(CalculatorGraph, UntypedExecutorDeclaredButNotSet) { + // If an executor is declared without a "type" field, it must be provided to + // the graph with a CalculatorGraph::SetExecutor() call. + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + executor { name: 'xyz' } + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("xyz"), + testing::HasSubstr("SetExecutor"))); +} + +TEST(CalculatorGraph, DuplicateExecutorConfig) { + // More than one ExecutorConfig cannot have the same name. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("xyz", std::make_shared(1))); + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + executor { name: 'xyz' } + executor { name: 'xyz' } + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("xyz"), + testing::HasSubstr("duplicate"))); +} + +TEST(CalculatorGraph, TypedExecutorDeclaredAndSet) { + // If an executor is declared with a "type" field, it must not be provided + // to the graph with a CalculatorGraph::SetExecutor() call. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK( + graph.SetExecutor("xyz", std::make_shared(1))); + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + executor { + name: 'xyz' + type: 'ThreadPoolExecutor' + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'in' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("xyz"), + testing::HasSubstr("SetExecutor"))); +} + +// The graph-level num_threads field and the ExecutorConfig for the default +// executor must not both be specified. +TEST(CalculatorGraph, NumThreadsAndDefaultExecutorConfig) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + num_threads: 1 + executor { + type: 'ThreadPoolExecutor' + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'mid' + output_stream: 'out' + } + )"); + ::mediapipe::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("num_threads"), + testing::HasSubstr("default executor"))); +} + +// The graph-level num_threads field and the ExecutorConfig for a non-default +// executor may coexist. +TEST(CalculatorGraph, NumThreadsAndNonDefaultExecutorConfig) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + num_threads: 1 + executor { + name: 'xyz' + type: 'ThreadPoolExecutor' + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'in' + output_stream: 'mid' + } + node { + calculator: 'PassThroughCalculator' + executor: 'xyz' + input_stream: 'mid' + output_stream: 'out' + } + )"); + MEDIAPIPE_EXPECT_OK(graph.Initialize(config)); +} + +// Verifies that the application thread is used only when +// "ApplicationThreadExecutor" is specified. In this test +// "ApplicationThreadExecutor" is specified in the ExecutorConfig for the +// default executor. +TEST(CalculatorGraph, RunWithNumThreadsInExecutorConfig) { + const struct { + std::string executor_type; + int num_threads; + bool use_app_thread_is_expected; + } cases[] = {{"ApplicationThreadExecutor", 0, true}, + {"", 0, false}, + {"ThreadPoolExecutor", 1, false}}; + + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + executor { + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 0 } + } + } + node { calculator: 'PthreadSelfSourceCalculator' output_stream: 'out' } + )"); + ThreadPoolExecutorOptions* default_executor_options = + config.mutable_executor(0)->mutable_options()->MutableExtension( + ThreadPoolExecutorOptions::ext); + for (int i = 0; i < ABSL_ARRAYSIZE(cases); ++i) { + default_executor_options->set_num_threads(cases[i].num_threads); + config.mutable_executor(0)->clear_type(); + if (cases[i].executor_type != "") { + config.mutable_executor(0)->set_type(cases[i].executor_type); + } + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + Packet out_packet; + MEDIAPIPE_ASSERT_OK( + graph.ObserveOutputStream("out", [&out_packet](const Packet& packet) { + out_packet = packet; + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.Run()); + EXPECT_EQ(cases[i].use_app_thread_is_expected, + out_packet.Get() == pthread_self()) + << "for case " << i; + } +} + +TEST(CalculatorGraph, CalculatorGraphNotInitialized) { + CalculatorGraph graph; + EXPECT_FALSE(graph.Run().ok()); +} + +TEST(CalculatorGraph, SimulateAssertFailure) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + num_threads: 2 + node { + calculator: 'PassThroughCalculator' + input_stream: 'in_a' + input_stream: 'in_b' + output_stream: 'out_a' + output_stream: 'out_b' + } + input_stream: 'in_a' + input_stream: 'in_b' + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle()); + + // End the test here to simulate an ASSERT_ failure, which will skip the + // rest of the test and exit the test function immediately. The test should + // not hang in the CalculatorGraph destructor. +} + +// Verifies Calculator::InputTimestamp() returns the expected value in Open(), +// Process(), and Close() for both source and non-source nodes. In this test +// the source node stops the graph. +TEST(CalculatorGraph, CheckInputTimestamp) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CheckInputTimestampSourceCalculator' + output_stream: 'integer' + } + node { + calculator: 'CheckInputTimestampSinkCalculator' + input_stream: 'integer' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); +} + +// Verifies Calculator::InputTimestamp() returns the expected value in Open(), +// Process(), and Close() for both source and non-source nodes. In this test +// the sink node stops the graph, which causes the framework to close the +// source node. +TEST(CalculatorGraph, CheckInputTimestamp2) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CheckInputTimestamp2SourceCalculator' + output_stream: 'integer' + } + node { + calculator: 'CheckInputTimestamp2SinkCalculator' + input_stream: 'integer' + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); +} + +TEST(CalculatorGraph, CheckBatchProcessingBoundPropagation) { + // The timestamp bound sent by OutputAndBoundSourceCalculator shouldn't be + // directly propagated to the output stream when PassThroughCalculator has + // anything in its default calculator context for batch processing. Otherwise, + // the sink calculator's input stream should report packet timestamp + // mismatches. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'OutputAndBoundSourceCalculator' + output_stream: 'integers' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'integers' + output_stream: 'output' + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + options: { + [mediapipe.DefaultInputStreamHandlerOptions.ext]: { + batch_size: 10 + } + } + } + } + node { calculator: 'IntSinkCalculator' input_stream: 'output' } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.Run()); +} + +TEST(CalculatorGraph, OutputSidePacketInProcess) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + output_stream: "output" + input_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Run the graph twice. + for (int run = 0; run < 2; ++run) { + output_packets.clear(); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("offset")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(offset, output_packets[0].Get().Value()); + } +} + +TEST(CalculatorGraph, OutputSidePacketAlreadySet) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + // Send two input packets to cause OutputSidePacketInProcessCalculator to + // set the output side packet twice. + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(1)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("offset")); + + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kAlreadyExists); + EXPECT_THAT(status.message(), testing::HasSubstr("was already set.")); +} + +TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketWithTimestampCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + // The OutputSidePacketWithTimestampCalculator neglects to clear the + // timestamp in the input packet when it copies the input packet to the + // output side packet. The timestamp value should appear in the error + // message. + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(237)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("offset")); + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), testing::HasSubstr("has a timestamp 237.")); +} + +TEST(CalculatorGraph, OutputSidePacketConsumedBySourceNode) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "max_count" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "max_count" + output_side_packet: "max_count" + } + node { + calculator: "CountingSourceCalculator" + output_stream: "count" + input_side_packet: "MAX_COUNT:max_count" + } + node { + calculator: "PassThroughCalculator" + input_stream: "count" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + // Wait until the graph is idle so that + // Scheduler::TryToScheduleNextSourceLayer() gets called. + // Scheduler::TryToScheduleNextSourceLayer() should not activate source + // nodes that haven't been opened. We can't call graph.WaitUntilIdle() + // because the graph has a source node. + absl::SleepFor(absl::Milliseconds(10)); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "max_count", MakePacket(max_count).At(Timestamp(0)))); + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("max_count")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(max_count, output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + EXPECT_EQ(i, output_packets[i].Get()); + EXPECT_EQ(Timestamp(i), output_packets[i].Timestamp()); + } +} + +TEST(CalculatorGraph, GraphInputStreamWithTag) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "VIDEO_METADATA:video_metadata" + input_stream: "max_count" + node { + calculator: "PassThroughCalculator" + input_stream: "FIRST_INPUT:video_metadata" + input_stream: "max_count" + output_stream: "FIRST_INPUT:output_0" + output_stream: "output_1" + } + )"); + std::vector packet_dump; + tool::AddVectorSink("output_0", &config, &packet_dump); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + for (int i = 0; i < 5; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "video_metadata", MakePacket(i).At(Timestamp(i)))); + } + MEDIAPIPE_ASSERT_OK(graph.CloseAllPacketSources()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(5, packet_dump.size()); +} + +// Returns the first packet of the input stream. +class FirstPacketFilterCalculator : public CalculatorBase { + public: + FirstPacketFilterCalculator() {} + ~FirstPacketFilterCalculator() override {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (!seen_first_packet_) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + cc->Outputs().Index(0).Close(); + seen_first_packet_ = true; + } + return ::mediapipe::OkStatus(); + } + + private: + bool seen_first_packet_ = false; +}; +REGISTER_CALCULATOR(FirstPacketFilterCalculator); + +TEST(CalculatorGraph, SourceLayerInversion) { + // There are three CountingSourceCalculators, indexed 0, 1, and 2. Each of + // them outputs 10 packets. + // + // CountingSourceCalculator 0 should output 0, 1, 2, 3, ..., 9. + // CountingSourceCalculator 1 should output 100, 101, 102, 103, ..., 109. + // CountingSourceCalculator 2 should output 0, 100, 200, 300, ..., 900. + // However, there is a source layer inversion. + // CountingSourceCalculator 0 is in source layer 0. + // CountingSourceCalculator 1 is in source layer 1. + // CountingSourceCalculator 2 is in source layer 0, but consumes an output + // side packet generated by a downstream calculator of + // CountingSourceCalculator 1. + // + // This graph will deadlock when CountingSourceCalculator 0 runs to + // completion and CountingSourceCalculator 1 cannot be activated because + // CountingSourceCalculator 2 cannot be opened. + + const int max_count = 10; + const int initial_value1 = 100; + // Set num_threads to 1 to force sequential execution for deterministic + // outputs. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + num_threads: 1 + node { + calculator: "CountingSourceCalculator" + output_stream: "count0" + input_side_packet: "MAX_COUNT:max_count" + source_layer: 0 + } + + node { + calculator: "CountingSourceCalculator" + output_stream: "count1" + input_side_packet: "MAX_COUNT:max_count" + input_side_packet: "INITIAL_VALUE:initial_value1" + source_layer: 1 + } + node { + calculator: "FirstPacketFilterCalculator" + input_stream: "count1" + output_stream: "first_count1" + } + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "first_count1" + output_side_packet: "increment2" + } + + node { + calculator: "CountingSourceCalculator" + output_stream: "count2" + input_side_packet: "MAX_COUNT:max_count" + input_side_packet: "INCREMENT:increment2" + source_layer: 0 + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize( + config, {{"max_count", MakePacket(max_count)}, + {"initial_value1", MakePacket(initial_value1)}})); + ::mediapipe::Status status = graph.Run(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown); + EXPECT_THAT(status.message(), testing::HasSubstr("deadlock")); +} + +// Tests a graph of packet-generator-like calculators, which have no input +// streams and no output streams. +TEST(CalculatorGraph, PacketGeneratorLikeCalculators) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "one" + } + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "another_one" + } + node { + calculator: "SidePacketAdderCalculator" + input_side_packet: "one" + input_side_packet: "another_one" + output_side_packet: "two" + } + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "yet_another_one" + } + node { + calculator: "SidePacketAdderCalculator" + input_side_packet: "two" + input_side_packet: "yet_another_one" + output_side_packet: "three" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + input_side_packet: "three" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.Run()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(3, output_packets[0].Get()); + EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); +} + +TEST(CalculatorGraph, OutputSummarySidePacketInClose) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_packets" + node { + calculator: "CountAndOutputSummarySidePacketInCloseCalculator" + input_stream: "input_packets" + output_side_packet: "num_of_packets" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + input_side_packet: "num_of_packets" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Run the graph twice. + int max_count = 100; + for (int run = 0; run < 1; ++run) { + output_packets.clear(); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + for (int i = 0; i < max_count; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "input_packets", MakePacket(i).At(Timestamp(i)))); + } + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_packets")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(max_count, output_packets[0].Get()); + EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); + } +} + +TEST(CalculatorGraph, GetOutputSidePacket) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_packets" + node { + calculator: "CountAndOutputSummarySidePacketInCloseCalculator" + input_stream: "input_packets" + output_side_packet: "num_of_packets" + } + packet_generator { + packet_generator: "Uint64PacketGenerator" + output_side_packet: "output_uint64" + } + packet_generator { + packet_generator: "IntSplitterPacketGenerator" + input_side_packet: "input_uint64" + output_side_packet: "output_uint32_pair" + } + )"); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + // Check a packet generated by the PacketGenerator, which is available after + // graph initialization, can be fetched before graph starts. + ::mediapipe::StatusOr status_or_packet = + graph.GetOutputSidePacket("output_uint64"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // IntSplitterPacketGenerator is missing its input side packet and we + // won't be able to get its output side packet now. + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, + status_or_packet.status().code()); + // Run the graph twice. + int max_count = 100; + std::map extra_side_packets; + extra_side_packets.insert({"input_uint64", MakePacket(1123)}); + for (int run = 0; run < 1; ++run) { + MEDIAPIPE_ASSERT_OK(graph.StartRun(extra_side_packets)); + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + for (int i = 0; i < max_count; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "input_packets", MakePacket(i).At(Timestamp(i)))); + } + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input_packets")); + + // Should return NOT_FOUND for invalid side packets. + status_or_packet = graph.GetOutputSidePacket("unknown"); + EXPECT_FALSE(status_or_packet.ok()); + EXPECT_EQ(::mediapipe::StatusCode::kNotFound, + status_or_packet.status().code()); + // Should return UNAVAILABLE before graph is done for valid non-base + // packets. + status_or_packet = graph.GetOutputSidePacket("num_of_packets"); + EXPECT_FALSE(status_or_packet.ok()); + EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, + status_or_packet.status().code()); + // Should stil return a base even before graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint64"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + // Check packets are available after graph is done. + status_or_packet = graph.GetOutputSidePacket("num_of_packets"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(max_count, status_or_packet.ValueOrDie().Get()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // Should still return a base packet after graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint64"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // Should still return a non-base packet after graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + MEDIAPIPE_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + } +} + +constexpr int kDefaultMaxCount = 1000; + +TEST(CalculatorGraph, TestPollPacket) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator("CountingSourceCalculator"); + node->add_output_stream("output"); + node->add_input_side_packet("MAX_COUNT:max_count"); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller = graph.AddOutputStreamPoller("output"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet; + int num_packets = 0; + while (poller.Next(&packet)) { + EXPECT_EQ(num_packets, packet.Get()); + ++num_packets; + } + MEDIAPIPE_ASSERT_OK(graph.CloseAllPacketSources()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller.Next(&packet)); + EXPECT_EQ(kDefaultMaxCount, num_packets); +} + +TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator("CountingSourceCalculator"); + node->add_output_stream("output"); + node->add_input_side_packet("MAX_COUNT:max_count"); + + for (int queue_size = 1; queue_size < 10; ++queue_size) { + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller = graph.AddOutputStreamPoller("output"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + poller.SetMaxQueueSize(queue_size); + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet; + int num_packets = 0; + while (poller.Next(&packet)) { + EXPECT_EQ(num_packets, packet.Get()); + ++num_packets; + } + MEDIAPIPE_ASSERT_OK(graph.CloseAllPacketSources()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller.Next(&packet)); + EXPECT_EQ(kDefaultMaxCount, num_packets); + } +} + +TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node1 = config.add_node(); + node1->set_calculator("CountingSourceCalculator"); + node1->add_output_stream("stream1"); + node1->add_input_side_packet("MAX_COUNT:max_count"); + CalculatorGraphConfig::Node* node2 = config.add_node(); + node2->set_calculator("PassThroughCalculator"); + node2->add_input_stream("stream1"); + node2->add_output_stream("stream2"); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller1 = graph.AddOutputStreamPoller("stream1"); + ASSERT_TRUE(status_or_poller1.ok()); + OutputStreamPoller poller1 = std::move(status_or_poller1.ValueOrDie()); + auto status_or_poller2 = graph.AddOutputStreamPoller("stream2"); + ASSERT_TRUE(status_or_poller2.ok()); + OutputStreamPoller poller2 = std::move(status_or_poller2.ValueOrDie()); + MEDIAPIPE_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet1; + Packet packet2; + int num_packets1 = 0; + int num_packets2 = 0; + int running_pollers = 2; + while (running_pollers > 0) { + if (poller1.Next(&packet1)) { + EXPECT_EQ(num_packets1++, packet1.Get()); + } else { + --running_pollers; + } + if (poller2.Next(&packet2)) { + EXPECT_EQ(num_packets2++, packet2.Get()); + } else { + --running_pollers; + } + } + MEDIAPIPE_ASSERT_OK(graph.CloseAllPacketSources()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller1.Next(&packet1)); + EXPECT_FALSE(poller2.Next(&packet2)); + EXPECT_EQ(kDefaultMaxCount, num_packets1); + EXPECT_EQ(kDefaultMaxCount, num_packets2); +} + +// Ensure that when a custom input stream handler is used to handle packets from +// input streams, an error message is outputted with the appropriate link to +// resolve the issue when the calculator doesn't handle inputs in monotonically +// increasing order of timestamps. +TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'input0' + input_stream: 'input1' + node { + calculator: 'SimpleMuxCalculator' + input_stream: 'input0' + input_stream: 'input1' + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + output_stream: 'output' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("output", &config, &packet_dump); + + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + // Send packets to input stream "input0" at timestamps 0 and 1 consecutively. + Timestamp input0_timestamp = Timestamp(0); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(1).At(input0_timestamp))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(1, packet_dump[0].Get()); + + ++input0_timestamp; + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(3).At(input0_timestamp))); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(3, packet_dump[1].Get()); + + // Send a packet to input stream "input1" at timestamp 0 after sending two + // packets at timestamps 0 and 1 to input stream "input0". This will result + // in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle + // inputs from all streams in monotonically increasing order of timestamps. + Timestamp input1_timestamp = Timestamp(0); + MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(2).At(input1_timestamp))); + ::mediapipe::Status run_status = graph.WaitUntilIdle(); + EXPECT_THAT( + run_status.ToString(), + testing::AllOf( + // The core problem. + testing::HasSubstr("timestamp mismatch on a calculator"), + testing::HasSubstr( + "timestamps that are not strictly monotonically increasing"), + // Link to the possible solution. + testing::HasSubstr("ImmediateInputStreamHandler class comment"))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc new file mode 100644 index 000000000..42fd7e06a --- /dev/null +++ b/mediapipe/framework/calculator_node.cc @@ -0,0 +1,825 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_node.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_registry_util.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/input_stream_manager.h" +#include "mediapipe/framework/mediapipe_profiling.h" +#include "mediapipe/framework/output_stream_manager.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/tool/tag_map.h" +#include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/gpu/graph_support.h" + +namespace mediapipe { + +namespace { + +const PacketType* GetPacketType(const PacketTypeSet& packet_type_set, + const std::string& tag, const int index) { + CollectionItemId id; + if (tag.empty()) { + id = packet_type_set.GetId("", index); + } else { + id = packet_type_set.GetId(tag, 0); + } + CHECK(id.IsValid()) << "Internal mediapipe error."; + return &packet_type_set.Get(id); +} + +} // namespace + +CalculatorNode::CalculatorNode() {} + +Timestamp CalculatorNode::SourceProcessOrder( + const CalculatorContext* cc) const { + return calculator_->SourceProcessOrder(cc); +} + +::mediapipe::Status CalculatorNode::Initialize( + const ValidatedGraphConfig* validated_graph, int node_id, + InputStreamManager* input_stream_managers, + OutputStreamManager* output_stream_managers, + OutputSidePacketImpl* output_side_packets, int* buffer_size_hint, + std::shared_ptr profiling_context) { + RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL"; + node_id_ = node_id; + validated_graph_ = validated_graph; + profiling_context_ = profiling_context; + + const CalculatorGraphConfig::Node& node_config = + validated_graph_->Config().node(node_id_); + name_ = CanonicalNodeName(validated_graph_->Config(), node_id_); + + max_in_flight_ = node_config.max_in_flight(); + max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1; + if (!node_config.executor().empty()) { + executor_ = node_config.executor(); + } + source_layer_ = node_config.source_layer(); + + const NodeTypeInfo& node_type_info = + validated_graph_->CalculatorInfos()[node_id_]; + + uses_gpu_ = + node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) || + ContainsKey(node_type_info.Contract().ServiceRequests(), kGpuService.key); + + // TODO Propagate types between calculators when SetAny is used. + + RETURN_IF_ERROR(InitializeOutputSidePackets( + node_type_info.OutputSidePacketTypes(), output_side_packets)); + + RETURN_IF_ERROR(InitializeInputSidePackets(output_side_packets)); + + RETURN_IF_ERROR(InitializeOutputStreamHandler( + node_config.output_stream_handler(), node_type_info.OutputStreamTypes())); + RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers)); + + calculator_state_ = absl::make_unique( + name_, node_id_, node_config.calculator(), node_config, + profiling_context_); + + // Inform the scheduler that this node has buffering behavior and that the + // maximum input queue size should be adjusted accordingly. + *buffer_size_hint = node_config.buffer_size_hint(); + + calculator_context_manager_.Initialize( + calculator_state_.get(), node_type_info.InputStreamTypes().TagMap(), + node_type_info.OutputStreamTypes().TagMap(), + /*calculator_run_in_parallel=*/max_in_flight_ > 1); + + // The graph specified InputStreamHandler takes priority. + const bool graph_specified = + node_config.input_stream_handler().has_input_stream_handler(); + const bool calc_specified = !(node_type_info.GetInputStreamHandler().empty()); + + // Only use calculator ISH if available, and if the graph ISH is not set. + InputStreamHandlerConfig handler_config; + const bool use_calc_specified = calc_specified && !graph_specified; + if (use_calc_specified) { + *(handler_config.mutable_input_stream_handler()) = + node_type_info.GetInputStreamHandler(); + *(handler_config.mutable_options()) = + node_type_info.GetInputStreamHandlerOptions(); + } + + // Use calculator or graph specified InputStreamHandler, or the default ISH + // already set from graph. + RETURN_IF_ERROR(InitializeInputStreamHandler( + use_calc_specified ? handler_config : node_config.input_stream_handler(), + node_type_info.InputStreamTypes())); + + return InitializeInputStreams(input_stream_managers, output_stream_managers); +} + +::mediapipe::Status CalculatorNode::InitializeOutputSidePackets( + const PacketTypeSet& output_side_packet_types, + OutputSidePacketImpl* output_side_packets) { + output_side_packets_ = + absl::make_unique(output_side_packet_types.TagMap()); + const NodeTypeInfo& node_type_info = + validated_graph_->CalculatorInfos()[node_id_]; + int base_index = node_type_info.OutputSidePacketBaseIndex(); + RET_CHECK_LE(0, base_index); + for (CollectionItemId id = output_side_packets_->BeginId(); + id < output_side_packets_->EndId(); ++id) { + output_side_packets_->GetPtr(id) = + &output_side_packets[base_index + id.value()]; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::InitializeInputSidePackets( + OutputSidePacketImpl* output_side_packets) { + const NodeTypeInfo& node_type_info = + validated_graph_->CalculatorInfos()[node_id_]; + int base_index = node_type_info.InputSidePacketBaseIndex(); + RET_CHECK_LE(0, base_index); + // Set all the mirrors. + for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId(); + id < node_type_info.InputSidePacketTypes().EndId(); ++id) { + int output_side_packet_index = + validated_graph_->InputSidePacketInfos()[base_index + id.value()] + .upstream; + if (output_side_packet_index < 0) { + // Not generated by a graph node. Comes from an extra side packet + // provided to the graph. + continue; + } + OutputSidePacketImpl* origin_output_side_packet = + &output_side_packets[output_side_packet_index]; + VLOG(2) << "Adding mirror for input side packet with id " << id.value() + << " and flat index " << base_index + id.value() + << " which will be connected to output side packet with flat index " + << output_side_packet_index; + origin_output_side_packet->AddMirror(&input_side_packet_handler_, id); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::InitializeOutputStreams( + OutputStreamManager* output_stream_managers) { + RET_CHECK(output_stream_managers) << "output_stream_managers is NULL"; + const NodeTypeInfo& node_type_info = + validated_graph_->CalculatorInfos()[node_id_]; + RET_CHECK_LE(0, node_type_info.OutputStreamBaseIndex()); + OutputStreamManager* current_output_stream_managers = + &output_stream_managers[node_type_info.OutputStreamBaseIndex()]; + return output_stream_handler_->InitializeOutputStreamManagers( + current_output_stream_managers); +} + +::mediapipe::Status CalculatorNode::InitializeInputStreams( + InputStreamManager* input_stream_managers, + OutputStreamManager* output_stream_managers) { + RET_CHECK(input_stream_managers) << "input_stream_managers is NULL"; + RET_CHECK(output_stream_managers) << "output_stream_managers is NULL"; + const NodeTypeInfo& node_type_info = + validated_graph_->CalculatorInfos()[node_id_]; + RET_CHECK_LE(0, node_type_info.InputStreamBaseIndex()); + InputStreamManager* current_input_stream_managers = + &input_stream_managers[node_type_info.InputStreamBaseIndex()]; + RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers( + current_input_stream_managers)); + + // Set all the mirrors. + for (CollectionItemId id = node_type_info.InputStreamTypes().BeginId(); + id < node_type_info.InputStreamTypes().EndId(); ++id) { + int output_stream_index = + validated_graph_ + ->InputStreamInfos()[node_type_info.InputStreamBaseIndex() + + id.value()] + .upstream; + RET_CHECK_LE(0, output_stream_index); + OutputStreamManager* origin_output_stream_manager = + &output_stream_managers[output_stream_index]; + VLOG(2) << "Adding mirror for input stream with id " << id.value() + << " and flat index " + << node_type_info.InputStreamBaseIndex() + id.value() + << " which will be connected to output stream with flat index " + << output_stream_index; + origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::InitializeInputStreamHandler( + const InputStreamHandlerConfig& handler_config, + const PacketTypeSet& input_stream_types) { + const ProtoString& input_stream_handler_name = + handler_config.input_stream_handler(); + RET_CHECK(!input_stream_handler_name.empty()); + ASSIGN_OR_RETURN(input_stream_handler_, + InputStreamHandlerRegistry::CreateByNameInNamespace( + validated_graph_->Package(), input_stream_handler_name, + input_stream_types.TagMap(), + &calculator_context_manager_, handler_config.options(), + /*calculator_run_in_parallel=*/max_in_flight_ > 1), + _ << "\"" << input_stream_handler_name + << "\" is not a registered input stream handler."); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::InitializeOutputStreamHandler( + const OutputStreamHandlerConfig& handler_config, + const PacketTypeSet& output_stream_types) { + const ProtoString& output_stream_handler_name = + handler_config.output_stream_handler(); + RET_CHECK(!output_stream_handler_name.empty()); + ASSIGN_OR_RETURN(output_stream_handler_, + OutputStreamHandlerRegistry::CreateByNameInNamespace( + validated_graph_->Package(), output_stream_handler_name, + output_stream_types.TagMap(), + &calculator_context_manager_, handler_config.options(), + /*calculator_run_in_parallel=*/max_in_flight_ > 1), + _ << "\"" << output_stream_handler_name + << "\" is not a registered output stream handler."); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::ConnectShardsToStreams( + CalculatorContext* calculator_context) { + RET_CHECK(calculator_context); + RETURN_IF_ERROR( + input_stream_handler_->SetupInputShards(&calculator_context->Inputs())); + return output_stream_handler_->SetupOutputShards( + &calculator_context->Outputs()); +} + +void CalculatorNode::SetExecutor(const std::string& executor) { + absl::MutexLock status_lock(&status_mutex_); + CHECK_LT(status_, kStateOpened); + executor_ = executor; +} + +bool CalculatorNode::Prepared() const { + absl::MutexLock status_lock(&status_mutex_); + return status_ >= kStatePrepared; +} + +bool CalculatorNode::Opened() const { + absl::MutexLock status_lock(&status_mutex_); + return status_ >= kStateOpened; +} + +bool CalculatorNode::Active() const { + absl::MutexLock status_lock(&status_mutex_); + return status_ >= kStateActive; +} + +bool CalculatorNode::Closed() const { + absl::MutexLock status_lock(&status_mutex_); + return status_ >= kStateClosed; +} + +void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) { + CHECK(input_stream_handler_); + input_stream_handler_->SetMaxQueueSize(max_queue_size); +} + +::mediapipe::Status CalculatorNode::PrepareForRun( + const std::map& all_side_packets, + const std::map& service_packets, + std::function ready_for_open_callback, + std::function source_node_opened_callback, + std::function schedule_callback, + std::function error_callback, + CounterFactory* counter_factory) { + RET_CHECK(ready_for_open_callback) << "ready_for_open_callback is NULL"; + RET_CHECK(schedule_callback) << "schedule_callback is NULL"; + RET_CHECK(error_callback) << "error_callback is NULL"; + calculator_state_->ResetBetweenRuns(); + + ready_for_open_callback_ = std::move(ready_for_open_callback); + source_node_opened_callback_ = std::move(source_node_opened_callback); + input_stream_handler_->PrepareForRun( + [this]() { CalculatorNode::InputStreamHeadersReady(); }, + [this]() { CalculatorNode::CheckIfBecameReady(); }, + std::move(schedule_callback), error_callback); + output_stream_handler_->PrepareForRun(error_callback); + + const PacketTypeSet* input_side_packet_types = + &validated_graph_->CalculatorInfos()[node_id_].InputSidePacketTypes(); + RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun( + input_side_packet_types, all_side_packets, + [this]() { CalculatorNode::InputSidePacketsReady(); }, + std::move(error_callback))); + calculator_state_->SetInputSidePackets( + &input_side_packet_handler_.InputSidePackets()); + calculator_state_->SetOutputSidePackets(output_side_packets_.get()); + calculator_state_->SetCounterFactory(counter_factory); + + const auto& contract = + validated_graph_->CalculatorInfos()[node_id_].Contract(); + for (const auto& svc_req : contract.ServiceRequests()) { + const auto& req = svc_req.second; + std::string key{req.Service().key}; + auto it = service_packets.find(key); + if (it == service_packets.end()) { + RET_CHECK(req.IsOptional()) + << "required service '" << key << "' was not provided"; + } else { + calculator_state_->SetServicePacket(key, it->second); + } + } + + RETURN_IF_ERROR(calculator_context_manager_.PrepareForRun(std::bind( + &CalculatorNode::ConnectShardsToStreams, this, std::placeholders::_1))); + + auto calculator_statusor = CreateCalculator( + input_stream_handler_->InputTagMap(), + output_stream_handler_->OutputTagMap(), validated_graph_->Package(), + calculator_state_.get(), + calculator_context_manager_.GetDefaultCalculatorContext()); + if (!calculator_statusor.ok()) { + return calculator_statusor.status(); + } + calculator_ = std::move(calculator_statusor).ValueOrDie(); + + needs_to_close_ = false; + + { + absl::MutexLock status_lock(&status_mutex_); + status_ = kStatePrepared; + scheduling_state_ = kIdle; + current_in_flight_ = 0; + input_stream_headers_ready_called_ = false; + input_side_packets_ready_called_ = false; + input_stream_headers_ready_ = + (input_stream_handler_->UnsetHeaderCount() == 0); + input_side_packets_ready_ = + (input_side_packet_handler_.MissingInputSidePacketCount() == 0); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorNode::OpenNode() { + VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName(); + + CalculatorContext* default_context = + calculator_context_manager_.GetDefaultCalculatorContext(); + InputStreamShardSet* inputs = &default_context->Inputs(); + // The upstream calculators may set the headers in the output streams during + // Calculator::Open(), needs to update the header packets in input stream + // shards. + input_stream_handler_->UpdateInputShardHeaders(inputs); + OutputStreamShardSet* outputs = &default_context->Outputs(); + output_stream_handler_->PrepareOutputs(Timestamp::Unstarted(), outputs); + calculator_context_manager_.PushInputTimestampToContext( + default_context, Timestamp::Unstarted()); + + ::mediapipe::Status result; + + { + MEDIAPIPE_PROFILING(OPEN, default_context); + LegacyCalculatorSupport::Scoped s(default_context); + result = calculator_->Open(default_context); + } + + calculator_context_manager_.PopInputTimestampFromContext(default_context); + if (IsSource()) { + // A source node has a dummy input timestamp of 0 for Process(). This input + // timestamp is not popped until Close() is called. + calculator_context_manager_.PushInputTimestampToContext(default_context, + Timestamp(0)); + } + + LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( + "Open() on node \"$0\" returned tool::StatusStop() which should only be " + "used to signal that a source node is done producing data.", + DebugName()); + RETURN_IF_ERROR(result).SetPrepend() << absl::Substitute( + "Calculator::Open() for node \"$0\" failed: ", DebugName()); + needs_to_close_ = true; + + output_stream_handler_->Open(outputs); + + { + absl::MutexLock status_lock(&status_mutex_); + status_ = kStateOpened; + } + + return ::mediapipe::OkStatus(); +} + +void CalculatorNode::ActivateNode() { + absl::MutexLock status_lock(&status_mutex_); + CHECK_EQ(status_, kStateOpened) << DebugName(); + status_ = kStateActive; +} + +void CalculatorNode::CloseInputStreams() { + { + absl::MutexLock status_lock(&status_mutex_); + if (status_ == kStateClosed) { + return; + } + } + VLOG(2) << "Closing node " << DebugName() << " input streams."; + + // Clear the input queues and prevent the upstream nodes from filling them + // back in. We may still get ProcessNode called on us after this. + input_stream_handler_->Close(); +} + +void CalculatorNode::CloseOutputStreams(OutputStreamShardSet* outputs) { + { + absl::MutexLock status_lock(&status_mutex_); + if (status_ == kStateClosed) { + return; + } + } + VLOG(2) << "Closing node " << DebugName() << " output streams."; + output_stream_handler_->Close(outputs); +} + +::mediapipe::Status CalculatorNode::CloseNode( + const ::mediapipe::Status& graph_status, bool graph_run_ended) { + { + absl::MutexLock status_lock(&status_mutex_); + RET_CHECK_NE(status_, kStateClosed) + << "CloseNode() must only be called once."; + } + + CloseInputStreams(); + CalculatorContext* default_context = + calculator_context_manager_.GetDefaultCalculatorContext(); + OutputStreamShardSet* outputs = &default_context->Outputs(); + output_stream_handler_->PrepareOutputs(Timestamp::Done(), outputs); + if (IsSource()) { + calculator_context_manager_.PopInputTimestampFromContext(default_context); + calculator_context_manager_.PushInputTimestampToContext(default_context, + Timestamp::Done()); + } + calculator_context_manager_.SetGraphStatusInContext(default_context, + graph_status); + + ::mediapipe::Status result; + + { + MEDIAPIPE_PROFILING(CLOSE, default_context); + LegacyCalculatorSupport::Scoped s(default_context); + result = calculator_->Close(default_context); + } + needs_to_close_ = false; + + LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( + "Close() on node \"$0\" returned tool::StatusStop() which should only be " + "used to signal that a source node is done producing data.", + DebugName()); + + // If the graph run has ended, we are cleaning up after the run and don't + // need to propagate updates to mirrors, so we can skip this + // CloseOutputStreams() call. CleanupAfterRun() will close the output + // streams. + if (!graph_run_ended) { + CloseOutputStreams(outputs); + } + + { + absl::MutexLock status_lock(&status_mutex_); + status_ = kStateClosed; + } + + RETURN_IF_ERROR(result).SetPrepend() << absl::Substitute( + "Calculator::Close() for node \"$0\" failed: ", DebugName()); + + VLOG(2) << "Closed node " << DebugName(); + return ::mediapipe::OkStatus(); +} + +void CalculatorNode::CleanupAfterRun(const ::mediapipe::Status& graph_status) { + if (needs_to_close_) { + calculator_context_manager_.PushInputTimestampToContext( + calculator_context_manager_.GetDefaultCalculatorContext(), + Timestamp::Done()); + CloseNode(graph_status, /*graph_run_ended=*/true).IgnoreError(); + } + calculator_ = nullptr; + // All pending output packets are automatically dropped when calculator + // context manager destroys all calculator context objects. + calculator_context_manager_.CleanupAfterRun(); + + CloseInputStreams(); + // All output stream shards have been destroyed by calculator context manager. + CloseOutputStreams(/*outputs=*/nullptr); + + { + absl::MutexLock lock(&status_mutex_); + status_ = kStateUninitialized; + scheduling_state_ = kIdle; + current_in_flight_ = 0; + } +} + +void CalculatorNode::SchedulingLoop() { + int max_allowance = 0; + { + absl::MutexLock lock(&status_mutex_); + if (status_ == kStateClosed) { + scheduling_state_ = kIdle; + return; + } + max_allowance = max_in_flight_ - current_in_flight_; + } + while (true) { + Timestamp input_bound; + // input_bound is set to a meaningful value iff the latest readiness of the + // node is kNotReady when ScheduleInvocations() returns. + input_stream_handler_->ScheduleInvocations(max_allowance, &input_bound); + if (input_bound != Timestamp::Unset()) { + // Updates the minimum timestamp for which a new packet could possibly + // arrive. + output_stream_handler_->UpdateTaskTimestampBound(input_bound); + } + + { + absl::MutexLock lock(&status_mutex_); + if (scheduling_state_ == kSchedulingPending && + current_in_flight_ < max_in_flight_) { + max_allowance = max_in_flight_ - current_in_flight_; + scheduling_state_ = kScheduling; + } else { + scheduling_state_ = kIdle; + break; + } + } + } +} + +bool CalculatorNode::ReadyForOpen() const { + absl::MutexLock lock(&status_mutex_); + return input_stream_headers_ready_ && input_side_packets_ready_; +} + +void CalculatorNode::InputStreamHeadersReady() { + bool ready_for_open = false; + { + absl::MutexLock lock(&status_mutex_); + CHECK_EQ(status_, kStatePrepared) << DebugName(); + CHECK(!input_stream_headers_ready_called_); + input_stream_headers_ready_called_ = true; + input_stream_headers_ready_ = true; + ready_for_open = input_side_packets_ready_; + } + if (ready_for_open) { + ready_for_open_callback_(); + } +} + +void CalculatorNode::InputSidePacketsReady() { + bool ready_for_open = false; + { + absl::MutexLock lock(&status_mutex_); + CHECK_EQ(status_, kStatePrepared) << DebugName(); + CHECK(!input_side_packets_ready_called_); + input_side_packets_ready_called_ = true; + input_side_packets_ready_ = true; + ready_for_open = input_stream_headers_ready_; + } + if (ready_for_open) { + ready_for_open_callback_(); + } +} + +void CalculatorNode::CheckIfBecameReady() { + { + absl::MutexLock lock(&status_mutex_); + // Doesn't check if status_ is kStateActive since the function can only be + // invoked by non-source nodes. + if (status_ != kStateOpened) { + return; + } + if (scheduling_state_ == kIdle && current_in_flight_ < max_in_flight_) { + scheduling_state_ = kScheduling; + } else { + if (scheduling_state_ == kScheduling) { + // Changes the state to scheduling pending if another thread is doing + // the scheduling. + scheduling_state_ = kSchedulingPending; + } + return; + } + } + SchedulingLoop(); +} + +void CalculatorNode::NodeOpened() { + if (IsSource()) { + source_node_opened_callback_(); + } else if (input_stream_handler_->NumInputStreams() != 0) { + // A node with input streams may have received input packets generated by + // the upstreams nodes' Open() or Process() methods. Check if the node is + // ready to run. + CheckIfBecameReady(); + } +} + +void CalculatorNode::EndScheduling() { + { + absl::MutexLock lock(&status_mutex_); + if (status_ != kStateOpened && status_ != kStateActive) { + return; + } + --current_in_flight_; + CHECK_GE(current_in_flight_, 0); + + if (scheduling_state_ == kScheduling) { + // Changes the state to scheduling pending if another thread is doing the + // scheduling. + scheduling_state_ = kSchedulingPending; + return; + } else if (scheduling_state_ == kSchedulingPending) { + // Quits when another thread is already doing the scheduling. + return; + } + scheduling_state_ = kScheduling; + } + SchedulingLoop(); +} + +bool CalculatorNode::TryToBeginScheduling() { + absl::MutexLock lock(&status_mutex_); + if (current_in_flight_ < max_in_flight_) { + ++current_in_flight_; + return true; + } + return false; +} + +std::string CalculatorNode::DebugInputStreamNames() const { + return input_stream_handler_->DebugStreamNames(); +} + +std::string CalculatorNode::DebugName() const { + DCHECK(calculator_state_); + + const std::string first_output_stream_name = + output_stream_handler_->FirstStreamName(); + if (!first_output_stream_name.empty()) { + // A calculator is unique by its output streams (one of them is + // sufficient) unless it is a sink. For readability, its type name is + // included. + return absl::Substitute( + "[$0, $1 with output stream: $2]", calculator_state_->NodeName(), + calculator_state_->CalculatorType(), first_output_stream_name); + } + // If it is a sink, its full node spec is returned. + return absl::Substitute( + "[$0, $1 with node ID: $2 and $3]", calculator_state_->NodeName(), + calculator_state_->CalculatorType(), node_id_, DebugInputStreamNames()); +} + +// TODO: Split this function. +::mediapipe::Status CalculatorNode::ProcessNode( + CalculatorContext* calculator_context) { + if (IsSource()) { + // This is a source Calculator. + if (Closed()) { + return ::mediapipe::OkStatus(); + } + + const Timestamp input_timestamp = calculator_context->InputTimestamp(); + + OutputStreamShardSet* outputs = &calculator_context->Outputs(); + output_stream_handler_->PrepareOutputs(input_timestamp, outputs); + + VLOG(2) << "Calling Calculator::Process() for node: " << DebugName(); + ::mediapipe::Status result; + + { + MEDIAPIPE_PROFILING(PROCESS, calculator_context); + LegacyCalculatorSupport::Scoped s(calculator_context); + result = calculator_->Process(calculator_context); + } + + bool node_stopped = false; + if (!result.ok()) { + if (result == tool::StatusStop()) { + // Needs to call CloseNode(). + node_stopped = true; + } else { + return ::mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend() + << absl::Substitute( + "Calculator::Process() for node \"$0\" failed: ", + DebugName()); + } + } + output_stream_handler_->PostProcess(input_timestamp); + if (node_stopped) { + RETURN_IF_ERROR( + CloseNode(::mediapipe::OkStatus(), /*graph_run_ended=*/false)); + } + return ::mediapipe::OkStatus(); + } else { + // This is not a source Calculator. + InputStreamShardSet* const inputs = &calculator_context->Inputs(); + OutputStreamShardSet* const outputs = &calculator_context->Outputs(); + ::mediapipe::Status result = + ::mediapipe::InternalError("Calculator context has no input packets."); + + int num_invocations = calculator_context_manager_.NumberOfContextTimestamps( + *calculator_context); + RET_CHECK(num_invocations <= 1 || max_in_flight_ <= 1) + << "num_invocations:" << num_invocations + << ", max_in_flight_:" << max_in_flight_; + for (int i = 0; i < num_invocations; ++i) { + const Timestamp input_timestamp = calculator_context->InputTimestamp(); + // The node is ready for Process(). + if (input_timestamp.IsAllowedInStream()) { + input_stream_handler_->FinalizeInputSet(input_timestamp, inputs); + output_stream_handler_->PrepareOutputs(input_timestamp, outputs); + + VLOG(2) << "Calling Calculator::Process() for node: " << DebugName(); + + { + MEDIAPIPE_PROFILING(PROCESS, calculator_context); + LegacyCalculatorSupport::Scoped s( + calculator_context); + result = calculator_->Process(calculator_context); + } + + // Removes one packet from each shard and progresses to the next input + // timestamp. + input_stream_handler_->ClearCurrentInputs(calculator_context); + + // Nodes are allowed to return StatusStop() to cause the termination + // of the graph. This is different from an error in that it will + // ensure that all sources will be closed and that packets in input + // streams will be processed before the graph is terminated. + if (!result.ok() && result != tool::StatusStop()) { + return ::mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend() + << absl::Substitute( + "Calculator::Process() for node \"$0\" failed: ", + DebugName()); + } + output_stream_handler_->PostProcess(input_timestamp); + if (result == tool::StatusStop()) { + return result; + } + } else if (input_timestamp == Timestamp::Done()) { + // Some or all the input streams are closed and there are not enough + // open input streams for Process(). So this node needs to be closed + // too. + // If the streams are closed, there shouldn't be more input. + CHECK_EQ(calculator_context_manager_.NumberOfContextTimestamps( + *calculator_context), + 1); + return CloseNode(::mediapipe::OkStatus(), /*graph_run_ended=*/false); + } else { + RET_CHECK_FAIL() + << "Invalid input timestamp in ProcessNode(). timestamp: " + << input_timestamp; + } + } + return result; + } +} + +void CalculatorNode::SetQueueSizeCallbacks( + InputStreamManager::QueueSizeCallback becomes_full_callback, + InputStreamManager::QueueSizeCallback becomes_not_full_callback) { + CHECK(input_stream_handler_); + input_stream_handler_->SetQueueSizeCallbacks( + std::move(becomes_full_callback), std::move(becomes_not_full_callback)); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_node.h b/mediapipe/framework/calculator_node.h new file mode 100644 index 000000000..fd17d4ada --- /dev/null +++ b/mediapipe/framework/calculator_node.h @@ -0,0 +1,373 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Declares CalculatorNode which is internally used by the Calculator framework +// (in particular, CalculatorGraph and Calculator) to perform the computations. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_NODE_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_NODE_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/input_side_packet_handler.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/legacy_calculator_support.h" +#include "mediapipe/framework/output_side_packet_impl.h" +#include "mediapipe/framework/output_stream_handler.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/stream_handler.pb.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/framework/validated_graph_config.h" + +namespace mediapipe { + +class CounterFactory; +class InputStreamManager; +class OutputStreamManager; + +namespace internal { +class SchedulerQueue; +} // namespace internal + +class CalculatorNode { + public: + // Handy typedef for a map from the name of an output stream to the set of ids + // of upstream sources that affect it. + typedef std::unordered_map> + OutputStreamToSourcesMap; + + CalculatorNode(); + CalculatorNode(const CalculatorNode&) = delete; + CalculatorNode& operator=(const CalculatorNode&) = delete; + int Id() const { return node_id_; } + + // Returns a value according to which the scheduler queue determines the + // relative priority between runnable source nodes; a smaller value means + // running first. If a node is not a source, this method is not called. + Timestamp SourceProcessOrder(const CalculatorContext* cc) const; + + // Retrieves a std::string name for the node. If the node's name was set in + // the calculator graph config, it will be returned. Otherwise, a + // human-readable std::string that uniquely identifies the node is returned, + // e.g. + // "[FooBarCalculator with first output stream \"foo_bar_output\"]" for + // non-sink nodes and "[FooBarCalculator with node ID: 42 and input streams: + // \"foo_bar_input\"]" for sink nodes. This name should be used in error + // messages where more context info is helpful. + std::string DebugName() const; + + // Name of the executor which the node will execute on. If empty, the node + // will execute on the default executor. + const std::string& Executor() const { return executor_; } + + // Changes the executor a node is assigned to. + void SetExecutor(const std::string& executor); + + // Calls Process() on the Calculator corresponding to this node. + ::mediapipe::Status ProcessNode(CalculatorContext* calculator_context); + + // Initializes the node. The buffer_size_hint argument is + // set to the value specified in the graph proto for this field. + // input_stream_managers/output_stream_managers is expected to point to + // a contiguous flat array with Input/OutputStreamManagers corresponding + // to the input/output stream indexes in validated_graph. + // output_side_packets is expected to point to a contiguous flat array with + // OutputSidePacketImpls corresponding to the output side packet indexes in + // validated_graph. + ::mediapipe::Status Initialize( + const ValidatedGraphConfig* validated_graph, int node_id, + InputStreamManager* input_stream_managers, + OutputStreamManager* output_stream_managers, + OutputSidePacketImpl* output_side_packets, int* buffer_size_hint, + std::shared_ptr profiling_context); + + // Sets up the node at the beginning of CalculatorGraph::Run(). This + // method is executed before any OpenNode() calls to the nodes + // within a CalculatorGraph. Creates a Calculator, and clears the + // input queues. Sets the callback to run when the node wants to + // schedule itself for later processing (in the order determined by + // the priority queue). ready_for_open_callback is called when OpenNode() + // can be scheduled. source_node_opened_callback is called when a source + // node is opened. schedule_callback is passed to the InputStreamHandler + // and is called each time a new invocation can be scheduled. + ::mediapipe::Status PrepareForRun( + const std::map& all_side_packets, + const std::map& service_packets, + std::function ready_for_open_callback, + std::function source_node_opened_callback, + std::function schedule_callback, + std::function error_callback, + CounterFactory* counter_factory) LOCKS_EXCLUDED(status_mutex_); + // Opens the node. + ::mediapipe::Status OpenNode() LOCKS_EXCLUDED(status_mutex_); + // Called when a source node's layer becomes active. + void ActivateNode() LOCKS_EXCLUDED(status_mutex_); + // Cleans up the node after the CalculatorGraph has been run. Deletes + // the Calculator managed by this node. graph_status is the status of + // the graph run. + void CleanupAfterRun(const ::mediapipe::Status& graph_status) + LOCKS_EXCLUDED(status_mutex_); + + // Returns true iff PrepareForRun() has been called (and types verified). + bool Prepared() const LOCKS_EXCLUDED(status_mutex_); + // Returns true iff Open() has been called on the calculator. + bool Opened() const LOCKS_EXCLUDED(status_mutex_); + // Returns true iff a source calculator's layer is active. + bool Active() const LOCKS_EXCLUDED(status_mutex_); + // Returns true iff Close() has been called on the calculator. + bool Closed() const LOCKS_EXCLUDED(status_mutex_); + + // Returns true iff this is a source node. + // + // A source node has no input streams but has at least one output stream. A + // node with no input streams and no output streams is essentially a packet + // generator and is not a source node. + bool IsSource() const { + return input_stream_handler_->NumInputStreams() == 0 && + output_stream_handler_->NumOutputStreams() != 0; + } + + int source_layer() const { return source_layer_; } + + // Checks if the node can be scheduled; if so, increases current_in_flight_ + // and returns true; otherwise, returns false. + // If true is returned, the scheduler must commit to executing the node, and + // then call EndScheduling when finished running it. + // If false is returned, the scheduler must not execute the node. + // This method is thread-safe. + bool TryToBeginScheduling() LOCKS_EXCLUDED(status_mutex_); + + // Subtracts one from current_in_flight_ to allow a new invocation to be + // scheduled. Then, it checks scheduling_state_ and invokes SchedulingLoop() + // if necessary. This method is thread-safe. + // TODO: this could be done implicitly by the call to ProcessNode + // or CloseNode. + void EndScheduling() LOCKS_EXCLUDED(status_mutex_); + + // Returns true if OpenNode() can be scheduled. + bool ReadyForOpen() const LOCKS_EXCLUDED(status_mutex_); + + // Called by the InputStreamHandler when all the input stream headers + // become available. + void InputStreamHeadersReady() LOCKS_EXCLUDED(status_mutex_); + + // Called by the InputSidePacketHandler when all the input side packets + // become available. + void InputSidePacketsReady() LOCKS_EXCLUDED(status_mutex_); + + // Checks scheduling_state_, and then invokes SchedulingLoop() if necessary. + // This method is thread-safe. + void CheckIfBecameReady() LOCKS_EXCLUDED(status_mutex_); + + // Called by SchedulerQueue when a node is opened. + void NodeOpened() LOCKS_EXCLUDED(status_mutex_); + + // Returns whether this is a GPU calculator node. + bool UsesGpu() const { return uses_gpu_; } + + // Returns the scheduler queue the node is assigned to. + internal::SchedulerQueue* GetSchedulerQueue() const { + return scheduler_queue_; + } + // Sets the scheduler queue the node is assigned to. + void SetSchedulerQueue(internal::SchedulerQueue* queue) { + scheduler_queue_ = queue; + } + + // Sets callbacks in the scheduler that should be invoked when an input queue + // becomes full/non-full. + void SetQueueSizeCallbacks( + InputStreamManager::QueueSizeCallback becomes_full_callback, + InputStreamManager::QueueSizeCallback becomes_not_full_callback); + + // Sets each of this node's input streams to use the specified + // max_queue_size to trigger callbacks. + void SetMaxInputStreamQueueSize(int max_queue_size); + + // Closes the node's calculator and input and output streams. + // graph_status is the current status of the graph run. graph_run_ended + // indicates whether the graph run has ended. + ::mediapipe::Status CloseNode(const ::mediapipe::Status& graph_status, + bool graph_run_ended) + LOCKS_EXCLUDED(status_mutex_); + + // Returns a pointer to the default calculator context that is used for + // sequential execution. A source node should always reuse its default + // calculator context. + CalculatorContext* GetDefaultCalculatorContext() const { + return calculator_context_manager_.GetDefaultCalculatorContext(); + } + + const CalculatorState& GetCalculatorState() const { + return *calculator_state_; + } + + private: + // Sets up the output side packets from the master flat array. + ::mediapipe::Status InitializeOutputSidePackets( + const PacketTypeSet& output_side_packet_types, + OutputSidePacketImpl* output_side_packets); + // Connects the input side packets as mirrors on the output side packets. + // Output side packets are looked up in the master flat array which is + // provided. + ::mediapipe::Status InitializeInputSidePackets( + OutputSidePacketImpl* output_side_packets); + // Sets up the output streams from the master flat array. + ::mediapipe::Status InitializeOutputStreams( + OutputStreamManager* output_stream_managers); + // Sets up the input streams and connects them as mirrors on the + // output streams. Both input streams and output streams are looked + // up in the master flat arrays which are provided. + ::mediapipe::Status InitializeInputStreams( + InputStreamManager* input_stream_managers, + OutputStreamManager* output_stream_managers); + + ::mediapipe::Status InitializeInputStreamHandler( + const InputStreamHandlerConfig& handler_config, + const PacketTypeSet& input_stream_types); + ::mediapipe::Status InitializeOutputStreamHandler( + const OutputStreamHandlerConfig& handler_config, + const PacketTypeSet& output_stream_types); + + // Connects the input/output stream shards in the given calculator context to + // the input/output streams of the node. + ::mediapipe::Status ConnectShardsToStreams( + CalculatorContext* calculator_context); + + // The general scheduling logic shared by EndScheduling() and + // CheckIfBecameReady(). + // Inside the function, a while loop keeps preparing CalculatorContexts and + // scheduling the node until 1) the node becomes not ready or 2) the max + // number of in flight invocations is reached. It also attempts to propagate + // the latest input timestamp bound if no invocations can be scheduled. + void SchedulingLoop(); + + // Closes the input and output streams. + void CloseInputStreams() LOCKS_EXCLUDED(status_mutex_); + void CloseOutputStreams(OutputStreamShardSet* outputs) + LOCKS_EXCLUDED(status_mutex_); + // Get a std::string describing the input streams. + std::string DebugInputStreamNames() const; + + // The calculator. + std::unique_ptr calculator_; + // Keeps data which a Calculator subclass needs access to. + std::unique_ptr calculator_state_; + + int node_id_ = -1; + std::string name_; // Optional user-defined name + // Name of the executor which the node will execute on. If empty, the node + // will execute on the default executor. + std::string executor_; + // The layer a source calculator operates on. + int source_layer_ = 0; + // The status of the current Calculator that this CalculatorNode + // is wrapping. kStateActive is currently used only for source nodes. + enum NodeStatus { + kStateUninitialized = 0, + kStatePrepared = 1, + kStateOpened = 2, + kStateActive = 3, + kStateClosed = 4 + }; + NodeStatus status_ GUARDED_BY(status_mutex_){kStateUninitialized}; + + // The max number of invocations that can be scheduled in parallel. + int max_in_flight_ = 1; + // The following two variables are used for the concurrency control of node + // scheduling. + // + // The number of invocations that are scheduled but not finished. + int current_in_flight_ GUARDED_BY(status_mutex_) = 0; + // SchedulingState incidates the current state of the node scheduling process. + // There are four possible transitions: + // (a) From kIdle to kScheduling. + // Any thread that makes this transition becomes the scheduling thread and + // will be responsible for preparing and scheduling all possible invocations. + // (b) From kScheduling to kSchedulingPending. + // Any thread, except the scheduling thread, can make this transition. + // kSchedulingPending indicates that some recent changes require the + // scheduling thread to recheck the node readiness after current scheduling + // iteration. + // (c) From kSchedulingPending to kScheduling. + // Made by the scheduling thread to indicate that it has already caught up + // with all the recent changes that can affect node readiness. + // (d) From kScheduling to kIdle. Made by the scheduling thread when there is + // no more scheduling work to be done. + enum SchedulingState { + kIdle = 0, // + kScheduling = 1, // + kSchedulingPending = 2 + }; + SchedulingState scheduling_state_ GUARDED_BY(status_mutex_) = kIdle; + + std::function ready_for_open_callback_; + std::function source_node_opened_callback_; + bool input_stream_headers_ready_called_ GUARDED_BY(status_mutex_) = false; + bool input_side_packets_ready_called_ GUARDED_BY(status_mutex_) = false; + bool input_stream_headers_ready_ GUARDED_BY(status_mutex_) = false; + bool input_side_packets_ready_ GUARDED_BY(status_mutex_) = false; + + // Owns and manages all CalculatorContext objects. + CalculatorContextManager calculator_context_manager_; + + std::shared_ptr profiling_context_; + + // Mutex for node status. + mutable absl::Mutex status_mutex_; + + // Manages the set of input side packets. + InputSidePacketHandler input_side_packet_handler_; + + // Collection of all OutputSidePacket objects. + std::unique_ptr output_side_packets_; + + std::unique_ptr input_stream_handler_; + + std::unique_ptr output_stream_handler_; + + // Whether this is a GPU calculator. + bool uses_gpu_ = false; + + // True if CleanupAfterRun() needs to call CloseNode(). + bool needs_to_close_ = false; + + internal::SchedulerQueue* scheduler_queue_ = nullptr; + + const ValidatedGraphConfig* validated_graph_ = nullptr; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_NODE_H_ diff --git a/mediapipe/framework/calculator_node_test.cc b/mediapipe/framework/calculator_node_test.cc new file mode 100644 index 000000000..899436ba8 --- /dev/null +++ b/mediapipe/framework/calculator_node_test.cc @@ -0,0 +1,575 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_node.h" + +#include + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +class CountCalculator : public CalculatorBase { + public: + CountCalculator() { ++num_constructed_; } + ~CountCalculator() override { ++num_destroyed_; } + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + ++num_fill_expectations_; + cc->Inputs().Get(cc->Inputs().BeginId()).Set(); + cc->Outputs().Get(cc->Outputs().BeginId()).Set(); + cc->InputSidePackets().Get(cc->InputSidePackets().BeginId()).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + ++num_open_; + // Simulate doing nontrivial work to ensure that the time spent in the + // method will register on streamz each time it is called. + usleep(100); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + ++num_process_; + int input_stream_int = cc->Inputs().Get(cc->Inputs().BeginId()).Get(); + int side_packet_int = + cc->InputSidePackets().Get(cc->InputSidePackets().BeginId()).Get(); + cc->Outputs() + .Get(cc->Outputs().BeginId()) + .AddPacket(MakePacket(input_stream_int + side_packet_int) + .At(cc->InputTimestamp())); + // Simulate doing nontrivial work to ensure that the time spent in the + // method will register on streamz each time it is called. + usleep(100); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) override { + ++num_close_; + // Simulate doing nontrivial work to ensure that the time spent in the + // method will register on streamz each time it is called. + usleep(100); + return ::mediapipe::OkStatus(); + } + + static int num_constructed_; + static int num_fill_expectations_; + static int num_open_; + static int num_process_; + static int num_close_; + static int num_destroyed_; +}; +REGISTER_CALCULATOR(CountCalculator); + +int CountCalculator::num_constructed_ = 0; +int CountCalculator::num_fill_expectations_ = 0; +int CountCalculator::num_open_ = 0; +int CountCalculator::num_process_ = 0; +int CountCalculator::num_close_ = 0; +int CountCalculator::num_destroyed_ = 0; + +void SourceNodeOpenedNoOp() {} + +void CheckFail(const ::mediapipe::Status& status) { + LOG(FATAL) << "The test triggered the error callback with status: " << status; +} + +class CalculatorNodeTest : public ::testing::Test { + public: + void ReadyForOpen(int* count) { ++(*count); } + + void Notification(CalculatorContext* cc, int* count) { + CHECK(cc); + cc_ = cc; + ++(*count); + } + + protected: + void InitializeEnvironment(bool use_tags) { + CountCalculator::num_constructed_ = 0; + CountCalculator::num_fill_expectations_ = 0; + CountCalculator::num_open_ = 0; + CountCalculator::num_process_ = 0; + CountCalculator::num_close_ = 0; + CountCalculator::num_destroyed_ = 0; + + std::string first_two_nodes_string = + "node {\n" // Node index 0 + " calculator: \"SidePacketsToStreamsCalculator\"\n" + " input_side_packet: \"input_b\"\n" // Input side packet index 0 + " output_stream: \"unused_stream\"\n" // Output stream 0 + "}\n" + "node {\n" // Node index 1 + " calculator: \"PassThroughCalculator\"\n" + " input_stream: \"unused_stream\"\n" // Input stream index 0 + " output_stream: \"stream_a\"\n" // Output stream index 1 + " input_side_packet: \"input_a\"\n" // Input side packet index 1 + " input_side_packet: \"input_b\"\n" // Input side packet index 2 + "}\n"; + CalculatorGraphConfig graph_config; + // Add the test for the node under test. + if (use_tags) { + graph_config = ::mediapipe::ParseTextProtoOrDie( + first_two_nodes_string + + "node {\n" // Node index 2 + " calculator: \"CountCalculator\"\n" + " input_stream: \"INPUT_TAG:stream_a\"\n" // Input stream index 1 + " output_stream: \"OUTPUT_TAG:stream_b\"\n" // Output stream index 2 + // Input side packet index 3 + " input_side_packet: \"INPUT_SIDE_PACKET_TAG:input_a\"\n" + "}\n"); + } else { + graph_config = ::mediapipe::ParseTextProtoOrDie( + first_two_nodes_string + + "node {\n" // Node index 2 + " calculator: \"CountCalculator\"\n" + " input_stream: \"stream_a\"\n" // Input stream index 1 + " output_stream: \"stream_b\"\n" // Output stream index 2 + " input_side_packet: \"input_a\"\n" // Input side packet index 3 + "}\n"); + } + MEDIAPIPE_CHECK_OK(validated_graph_.Initialize(graph_config)); + MEDIAPIPE_CHECK_OK(InitializeStreams()); + + input_side_packets_.emplace("input_a", Adopt(new int(42))); + input_side_packets_.emplace("input_b", Adopt(new int(42))); + + node_.reset(new CalculatorNode()); + MEDIAPIPE_ASSERT_OK(node_->Initialize( + &validated_graph_, 2, input_stream_managers_.get(), + output_stream_managers_.get(), output_side_packets_.get(), + &buffer_size_hint_, graph_profiler_)); + } + + ::mediapipe::Status PrepareNodeForRun() { + return node_->PrepareForRun( // + input_side_packets_, // + service_packets_, // + std::bind(&CalculatorNodeTest::ReadyForOpen, // + this, // + &ready_for_open_count_), // + SourceNodeOpenedNoOp, // + std::bind(&CalculatorNodeTest::Notification, // + this, std::placeholders::_1, // + &schedule_count_), // + CheckFail, // + nullptr); + } + + ::mediapipe::Status InitializeStreams() { + // START OF: code is copied from + // CalculatorGraph::InitializePacketGeneratorGraph. + // Create and initialize the output side packets. + output_side_packets_ = absl::make_unique( + validated_graph_.OutputSidePacketInfos().size()); + for (int index = 0; index < validated_graph_.OutputSidePacketInfos().size(); + ++index) { + const EdgeInfo& edge_info = + validated_graph_.OutputSidePacketInfos()[index]; + RETURN_IF_ERROR(output_side_packets_[index].Initialize( + edge_info.name, edge_info.packet_type)); + } + // END OF: code is copied from + // CalculatorGraph::InitializePacketGeneratorGraph. + + // START OF: code is copied from CalculatorGraph::InitializeStreams. + // Create and initialize the input streams. + input_stream_managers_.reset( + new InputStreamManager[validated_graph_.InputStreamInfos().size()]); + for (int index = 0; index < validated_graph_.InputStreamInfos().size(); + ++index) { + const EdgeInfo& edge_info = validated_graph_.InputStreamInfos()[index]; + RETURN_IF_ERROR(input_stream_managers_[index].Initialize( + edge_info.name, edge_info.packet_type, edge_info.back_edge)); + } + + // Create and initialize the output streams. + output_stream_managers_.reset( + new OutputStreamManager[validated_graph_.OutputStreamInfos().size()]); + for (int index = 0; index < validated_graph_.OutputStreamInfos().size(); + ++index) { + const EdgeInfo& edge_info = validated_graph_.OutputStreamInfos()[index]; + RETURN_IF_ERROR(output_stream_managers_[index].Initialize( + edge_info.name, edge_info.packet_type)); + } + // END OF: code is copied from CalculatorGraph::InitializeStreams. + + stream_a_manager_ = &output_stream_managers_[1]; + stream_b_manager_ = &output_stream_managers_[2]; + return ::mediapipe::OkStatus(); + } + + virtual void SimulateParentOpenNode() { stream_a_manager_->LockIntroData(); } + + virtual void TestCleanupAfterRunTwice(); + + std::map input_side_packets_; + std::map service_packets_; + + std::unique_ptr input_stream_managers_; + std::unique_ptr output_stream_managers_; + std::unique_ptr output_side_packets_; + + // A pointer to the output stream manager for stream_a. + // An alias for &output_stream_managers_[1]. + OutputStreamManager* stream_a_manager_; + // A pointer to the output stream manager for stream_b. + // An alias for &output_stream_managers_[2]. + OutputStreamManager* stream_b_manager_; + + std::unique_ptr node_; + + ValidatedGraphConfig validated_graph_; + std::shared_ptr graph_profiler_ = + std::make_shared(); + + int ready_for_open_count_ = 0; + int schedule_count_ = 0; + + int buffer_size_hint_ = -1; + // Stores the CalculatorContext passed to the ready_callback_ of node_, and we + // pass this to node_->ProcessNode(). + CalculatorContext* cc_; +}; + +TEST_F(CalculatorNodeTest, Initialize) { + InitializeEnvironment(/*use_tags=*/false); + EXPECT_EQ(2, node_->Id()); + EXPECT_THAT(node_->DebugName(), + ::testing::AllOf(::testing::HasSubstr("CountCalculator"), + ::testing::HasSubstr("stream_b"))); + + EXPECT_FALSE(node_->Prepared()); + EXPECT_FALSE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(0, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(0, CountCalculator::num_open_); + EXPECT_EQ(0, CountCalculator::num_process_); + EXPECT_EQ(0, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, PrepareForRun) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + EXPECT_TRUE(node_->Prepared()); + EXPECT_FALSE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(0, ready_for_open_count_); + EXPECT_EQ(0, schedule_count_); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(0, CountCalculator::num_open_); + EXPECT_EQ(0, CountCalculator::num_process_); + EXPECT_EQ(0, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, Open) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + EXPECT_EQ(0, ready_for_open_count_); + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + + EXPECT_TRUE(node_->Prepared()); + EXPECT_TRUE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + // Nodes are not immediately scheduled upon opening. + EXPECT_EQ(0, schedule_count_); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(1, CountCalculator::num_open_); + EXPECT_EQ(0, CountCalculator::num_process_); + EXPECT_EQ(0, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, Process) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + + OutputStreamShard stream_a_shard; + stream_a_shard.SetSpec(stream_a_manager_->Spec()); + stream_a_shard.Add(new int(1), Timestamp(1)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(2), &stream_a_shard); + EXPECT_EQ(1, schedule_count_); + // Expects that a CalculatorContext has been prepared. + EXPECT_NE(nullptr, cc_); + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + + cc_ = nullptr; + node_->EndScheduling(); + EXPECT_EQ(1, schedule_count_); + // Expects that no CalculatorContext is prepared by EndScheduling(). + EXPECT_EQ(nullptr, cc_); + + EXPECT_TRUE(node_->Prepared()); + EXPECT_TRUE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(1, schedule_count_); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(1, CountCalculator::num_open_); + EXPECT_EQ(1, CountCalculator::num_process_); + EXPECT_EQ(0, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, ProcessSeveral) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + + OutputStreamShard stream_a_shard; + stream_a_shard.SetSpec(stream_a_manager_->Spec()); + stream_a_shard.Add(new int(1), Timestamp(1)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(2), &stream_a_shard); + + EXPECT_EQ(1, schedule_count_); + EXPECT_TRUE(node_->TryToBeginScheduling()); + EXPECT_NE(nullptr, cc_); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + EXPECT_EQ(1, schedule_count_); + + stream_a_manager_->ResetShard(&stream_a_shard); + stream_a_shard.Add(new int(2), Timestamp(4)); + stream_a_shard.Add(new int(3), Timestamp(8)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(9), &stream_a_shard); + // The packet at Timestamp 8 is left in the input queue. + + EXPECT_EQ(2, schedule_count_); + EXPECT_TRUE(node_->TryToBeginScheduling()); + // Expects that a CalculatorContext has been prepared. + EXPECT_NE(nullptr, cc_); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + EXPECT_EQ(3, schedule_count_); + EXPECT_TRUE(node_->TryToBeginScheduling()); + + stream_a_manager_->ResetShard(&stream_a_shard); + stream_a_shard.Add(new int(4), Timestamp(16)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(17), &stream_a_shard); + // The packet at Timestamp 16 is left in the input queue. + + EXPECT_EQ(3, schedule_count_); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + EXPECT_NE(nullptr, cc_); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + EXPECT_EQ(4, schedule_count_); + EXPECT_TRUE(node_->TryToBeginScheduling()); + + EXPECT_NE(nullptr, cc_); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + + cc_ = nullptr; + node_->EndScheduling(); + // Expects that no CalculatorContext is prepared by EndScheduling(). + EXPECT_EQ(nullptr, cc_); + EXPECT_EQ(4, schedule_count_); + + EXPECT_TRUE(node_->Prepared()); + EXPECT_TRUE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(1, CountCalculator::num_open_); + EXPECT_EQ(4, CountCalculator::num_process_); + EXPECT_EQ(0, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, Close) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + + OutputStreamShard stream_a_shard; + stream_a_shard.SetSpec(stream_a_manager_->Spec()); + stream_a_shard.Add(new int(1), Timestamp(1)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(2), &stream_a_shard); + EXPECT_TRUE(node_->TryToBeginScheduling()); + stream_a_manager_->Close(); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + EXPECT_TRUE(node_->Closed()); + EXPECT_EQ(2, schedule_count_); + + node_->EndScheduling(); + + EXPECT_TRUE(node_->Prepared()); + EXPECT_TRUE(node_->Opened()); + EXPECT_TRUE(node_->Closed()); + + EXPECT_EQ(2, schedule_count_); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(1, CountCalculator::num_open_); + EXPECT_EQ(1, CountCalculator::num_process_); + EXPECT_EQ(1, CountCalculator::num_close_); + EXPECT_EQ(0, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, CleanupAfterRun) { + InitializeEnvironment(/*use_tags=*/false); + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + OutputStreamShard stream_a_shard; + stream_a_shard.SetSpec(stream_a_manager_->Spec()); + stream_a_shard.Add(new int(1), Timestamp(1)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(2), &stream_a_shard); + EXPECT_TRUE(node_->TryToBeginScheduling()); + stream_a_manager_->Close(); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + // Call ProcessNode again for the node to see the end of the stream. + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + node_->CleanupAfterRun(::mediapipe::OkStatus()); + + EXPECT_FALSE(node_->Prepared()); + EXPECT_FALSE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(2, schedule_count_); + + EXPECT_EQ(1, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(1, CountCalculator::num_open_); + EXPECT_EQ(1, CountCalculator::num_process_); + EXPECT_EQ(1, CountCalculator::num_close_); + EXPECT_EQ(1, CountCalculator::num_destroyed_); +} + +void CalculatorNodeTest::TestCleanupAfterRunTwice() { + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + OutputStreamShard stream_a_shard; + stream_a_shard.SetSpec(stream_a_manager_->Spec()); + stream_a_shard.Add(new int(1), Timestamp(1)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(2), &stream_a_shard); + EXPECT_TRUE(node_->TryToBeginScheduling()); + stream_a_manager_->Close(); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + // We should get Timestamp::Done here. + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + node_->CleanupAfterRun(::mediapipe::OkStatus()); + + stream_a_manager_->PrepareForRun(nullptr); + + MEDIAPIPE_ASSERT_OK(PrepareNodeForRun()); + + SimulateParentOpenNode(); + MEDIAPIPE_EXPECT_OK(node_->OpenNode()); + stream_a_manager_->ResetShard(&stream_a_shard); + stream_a_shard.Add(new int(2), Timestamp(4)); + stream_a_shard.Add(new int(3), Timestamp(8)); + stream_a_manager_->PropagateUpdatesToMirrors(Timestamp(9), &stream_a_shard); + EXPECT_TRUE(node_->TryToBeginScheduling()); + stream_a_manager_->Close(); + EXPECT_FALSE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + // We should get Timestamp::Done here. + EXPECT_TRUE(node_->TryToBeginScheduling()); + MEDIAPIPE_EXPECT_OK(node_->ProcessNode(cc_)); + node_->EndScheduling(); + // The max parallelism is already reached. + EXPECT_FALSE(node_->TryToBeginScheduling()); + node_->CleanupAfterRun(::mediapipe::OkStatus()); + + EXPECT_FALSE(node_->Prepared()); + EXPECT_FALSE(node_->Opened()); + EXPECT_FALSE(node_->Closed()); + + EXPECT_EQ(5, schedule_count_); + + EXPECT_EQ(2, CountCalculator::num_constructed_); + EXPECT_EQ(1, CountCalculator::num_fill_expectations_); + EXPECT_EQ(2, CountCalculator::num_open_); + EXPECT_EQ(3, CountCalculator::num_process_); + EXPECT_EQ(2, CountCalculator::num_close_); + EXPECT_EQ(2, CountCalculator::num_destroyed_); +} + +TEST_F(CalculatorNodeTest, CleanupAfterRunTwice) { + InitializeEnvironment(/*use_tags=*/false); + TestCleanupAfterRunTwice(); +} + +TEST_F(CalculatorNodeTest, CleanupAfterRunTwiceWithTags) { + InitializeEnvironment(/*use_tags=*/true); + TestCleanupAfterRunTwice(); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_options.proto b/mediapipe/framework/calculator_options.proto new file mode 100644 index 000000000..5680dd9ed --- /dev/null +++ b/mediapipe/framework/calculator_options.proto @@ -0,0 +1,44 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/calculator.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. + +syntax = "proto2"; + +package mediapipe; + +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "CalculatorOptionsProto"; + +// Options for Calculators. Each Calculator implementation should +// have its own options proto, which should look like this: +// +// message MyCalculatorOptions { +// extend CalculatorOptions { +// optional MyCalculatorOptions ext = ; +// } +// optional string field_needed_by_my_calculator = 1; +// optional int32 another_field = 2; +// // etc +// } +message CalculatorOptions { + // If true, this proto specifies a subset of field values, + // which should override corresponding field values. + // Deprecated in cl/228195782. + optional bool merge_fields = 1 [deprecated = true]; + + extensions 20000 to max; +} diff --git a/mediapipe/framework/calculator_parallel_execution_test.cc b/mediapipe/framework/calculator_parallel_execution_test.cc new file mode 100644 index 000000000..c97b066ee --- /dev/null +++ b/mediapipe/framework/calculator_parallel_execution_test.cc @@ -0,0 +1,157 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Verifies the correctness of parallel execution. +// $ bazel build -c opt \ +// mediapipe/framework/calculator_parallel_execution_test \ +// --runs_per_test=100 +// +// TODO: Add more tests to verify the correctness of parallel execution. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +using RandomEngine = std::mt19937_64; + +inline void BusySleep(absl::Duration duration) { + absl::Time start_time = absl::Now(); + while (absl::Now() - start_time < duration) { + } +} + +class SlowPlusOneCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(mediapipe::TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (cc->InputTimestamp().Value() % 4 == 0) { + return ::mediapipe::OkStatus(); + } + + RandomEngine random(testing::UnitTest::GetInstance()->random_seed()); + std::uniform_int_distribution<> uniform_dist(0, 10); + BusySleep(absl::Milliseconds(90 + uniform_dist(random))); + cc->Outputs().Index(0).Add(new int(cc->Inputs().Index(0).Get() + 1), + cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(SlowPlusOneCalculator); + +class ParallelExecutionTest : public testing::Test { + public: + void AddThreadSafeVectorSink(const Packet& packet) { + absl::WriterMutexLock lock(&output_packets_mutex_); + output_packets_.push_back(packet); + } + + protected: + std::vector output_packets_ GUARDED_BY(output_packets_mutex_); + absl::Mutex output_packets_mutex_; +}; + +TEST_F(ParallelExecutionTest, SlowPlusOneCalculatorsTest) { + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input" + node { + calculator: "SlowPlusOneCalculator" + input_stream: "input" + output_stream: "first_calculator_output" + max_in_flight: 5 + } + node { + calculator: "SlowPlusOneCalculator" + input_stream: "first_calculator_output" + output_stream: "output" + max_in_flight: 5 + } + node { + calculator: "CallbackCalculator" + input_stream: "output" + input_side_packet: "CALLBACK:callback" + } + num_threads: 5 + )"); + + // Starts MediaPipe graph. + CalculatorGraph graph(graph_config); + // Runs the graph twice. + for (int i = 0; i < 2; ++i) { + MEDIAPIPE_ASSERT_OK(graph.StartRun( + {{"callback", MakePacket>(std::bind( + &ParallelExecutionTest::AddThreadSafeVectorSink, this, + std::placeholders::_1))}})); + const int kTotalNums = 100; + int fail_count = 0; + for (int i = 0; i < kTotalNums; ++i) { + ::mediapipe::Status status = graph.AddPacketToInputStream( + "input", Adopt(new int(i)).At(Timestamp(i))); + if (!status.ok()) { + ++fail_count; + } + } + + EXPECT_EQ(0, fail_count); + + // Doesn't wait but just close the input stream. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("input")); + // Waits properly via the API until the graph is done. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + absl::ReaderMutexLock lock(&output_packets_mutex_); + ASSERT_EQ(kTotalNums - kTotalNums / 4, output_packets_.size()); + int index = 1; + for (const Packet& packet : output_packets_) { + MEDIAPIPE_ASSERT_OK(packet.ValidateAsType()); + EXPECT_EQ(index + 2, packet.Get()); + EXPECT_EQ(Timestamp(index), packet.Timestamp()); + if (++index % 4 == 0) { + ++index; + } + } + output_packets_.clear(); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto new file mode 100644 index 000000000..d90d8a3df --- /dev/null +++ b/mediapipe/framework/calculator_profile.proto @@ -0,0 +1,191 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/calculator_profile.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "CalculatorProfileProto"; + +// Stores the profiling information. +// +// It is the responsibility of the user of this message to make sure the 'total' +// field and the interval information (num, size and count) are in a valid +// state and all get updated together. +// +// Each interval of the histogram is closed on the lower range and open on the +// higher end. An example histogram with interval_size=1000 and num_interval=3 +// will have the following intervals: +// - First interval = [0, 1000) +// - Second interval = [1000, 2000) +// - Third interval = [2000, +inf) +// +// IMPORTANT: If You add any new field, update CalculatorProfiler::Reset() +// accordingly. +message TimeHistogram { + // Total time (in microseconds). + optional int64 total = 1 [default = 0]; + + // Size of the runtimes histogram intervals (in microseconds) to generate the + // histogram of the Process() time. The last interval extends to +inf. + optional int64 interval_size_usec = 2 [default = 1000000 /* 1 sec */]; + + // Number of intervals to generate the histogram of the Process() runtime. + optional int64 num_intervals = 3 [default = 1]; + + // Number of calls in each interval. + repeated int64 count = 4; +} + +// Stores the profiling information of a stream. +message StreamProfile { + // Stream name. + optional string name = 1; + + // If true, than this is a back edge input stream and won't be profiled. + optional bool back_edge = 2 [default = false]; + + // Total and histogram of the time that this stream took. + optional TimeHistogram latency = 3; +} + +// Stores the profiling information for a calculator node. +// All the times are in microseconds. +message CalculatorProfile { + // The calculator name. + optional string name = 1; + + // Total time the calculator spent on Open (in microseconds). + optional int64 open_runtime = 2 [default = 0]; + + // Total time the calculator spent on Close (in microseconds). + optional int64 close_runtime = 3 [default = 0]; + + // Total and histogram of the time that the calculator spent on the Process() + // (in microseconds). + optional TimeHistogram process_runtime = 4; + + // Total and histogram of the time that the input latency, ie. difference + // between input timestamp and process call time. + // (in microseconds). + optional TimeHistogram process_input_latency = 5; + + // Total and histogram of the time that the output latency, ie. difference + // between input timestamp and process finished time. + optional TimeHistogram process_output_latency = 6; + + // Total and histogram of the time that input streams of this calculator took. + repeated StreamProfile input_stream_profiles = 7; +} + +// Latency timing for recent mediapipe packets. +message GraphTrace { + // The timing for one packet across one packet stream. + message StreamTrace { + // The time at which the packet entered the stream. + optional int64 start_time = 1; + + // The time at which the packet exited the stream. + optional int64 finish_time = 2; + + // The identifying timetamp of the packet. + optional int64 packet_timestamp = 3; + + // The index of the stream in the stream_name list. + optional int32 stream_id = 4; + + // The address of the packet contents. + optional int64 packet_id = 5; + } + + // The kind of event recorded. + enum EventType { + UNKNOWN = 0; + OPEN = 1; + PROCESS = 2; + CLOSE = 3; + NOT_READY = 4; + READY_FOR_PROCESS = 5; + READY_FOR_CLOSE = 6; + THROTTLED = 7; + UNTHROTTLED = 8; + CPU_TASK_USER = 9; + CPU_TASK_SYSTEM = 10; + GPU_TASK = 11; + DSP_TASK = 12; + TPU_TASK = 13; + GPU_CALIBRATION = 14; + } + + // The timing for one packet set being processed at one caclulator node. + message CalculatorTrace { + // The index of the calculator node in the calculator_name list. + optional int32 node_id = 1; + + // The input timestamp during Open, Process, or Close. + optional int64 input_timestamp = 2; + + // The kind of event, 1=Open, 2=Process, 3=Close, etc. + optional EventType event_type = 3; + + // The time at which the packets entered the caclulator node. + optional int64 start_time = 4; + + // The time at which the packets exited the caclulator node. + optional int64 finish_time = 5; + + // The timing data for each input packet. + repeated StreamTrace input_trace = 6; + + // The identifying timetamp and stream_id for each output packet. + repeated StreamTrace output_trace = 7; + + // An identifier for the current process thread. + optional int32 thread_id = 8; + } + + // The time represented as 0 in the trace. + optional int64 base_time = 1; + + // The timestamp represented as 0 in the trace. + optional int64 base_timestamp = 2; + + // The list of calculator node names indexed by node id. + repeated string calculator_name = 3; + + // The list of stream names indexed by stream id. + repeated string stream_name = 4; + + // Recent packet timing informtion about each calculator node and stream. + repeated CalculatorTrace calculator_trace = 5; +} + +// Latency events and summaries for recent mediapipe packets. +message GraphProfile { + // Recent packet timing informtion about each calculator node and stream. + repeated GraphTrace graph_trace = 1; + + // Aggregated latency information about each calculator node. + repeated CalculatorProfile calculator_profiles = 2; + + // The canonicalized calculator graph that is traced. + optional CalculatorGraphConfig config = 3; +} diff --git a/mediapipe/framework/calculator_registry.h b/mediapipe/framework/calculator_registry.h new file mode 100644 index 000000000..f3d88fa18 --- /dev/null +++ b/mediapipe/framework/calculator_registry.h @@ -0,0 +1,32 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Calculator registration. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_H_ + +#include "mediapipe/framework/calculator_base.h" + +#define REGISTER_CALCULATOR(name) \ + REGISTER_FACTORY_FUNCTION_QUALIFIED(::mediapipe::CalculatorBaseRegistry, \ + calculator_registration, name, \ + absl::make_unique); \ + REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + ::mediapipe::internal::StaticAccessToCalculatorBaseRegistry, \ + access_registration, name, \ + absl::make_unique< \ + ::mediapipe::internal::StaticAccessToCalculatorBaseTyped>) + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_H_ diff --git a/mediapipe/framework/calculator_registry_util.cc b/mediapipe/framework/calculator_registry_util.cc new file mode 100644 index 000000000..b1c2110b5 --- /dev/null +++ b/mediapipe/framework/calculator_registry_util.cc @@ -0,0 +1,61 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_registry_util.h" + +#include +#include + +#include "mediapipe/framework/collection.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +bool IsLegacyCalculator(const std::string& package_name, + const std::string& node_class) { + return false; +} + +::mediapipe::Status VerifyCalculatorWithContract( + const std::string& package_name, const std::string& node_class, + CalculatorContract* contract) { + // A number of calculators use the non-CC methods on GlCalculatorHelper + // even though they are CalculatorBase-based. + ASSIGN_OR_RETURN( + auto static_access_to_calculator_base, + internal::StaticAccessToCalculatorBaseRegistry::CreateByNameInNamespace( + package_name, node_class), + _ << "Unable to find Calculator \"" << node_class << "\""); + RETURN_IF_ERROR(static_access_to_calculator_base->GetContract(contract)) + .SetPrepend() + << node_class << ": "; + return ::mediapipe::OkStatus(); +} + +::mediapipe::StatusOr> CreateCalculator( + const std::shared_ptr& input_tag_map, + const std::shared_ptr& output_tag_map, + const std::string& package_name, CalculatorState* calculator_state, + CalculatorContext* calculator_context) { + std::unique_ptr calculator; + ASSIGN_OR_RETURN(calculator, + CalculatorBaseRegistry::CreateByNameInNamespace( + package_name, calculator_state->CalculatorType())); + return std::move(calculator); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_registry_util.h b/mediapipe/framework/calculator_registry_util.h new file mode 100644 index 000000000..6d218c1c0 --- /dev/null +++ b/mediapipe/framework/calculator_registry_util.h @@ -0,0 +1,46 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_ + +#include + +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/framework/tool/tag_map.h" + +// Calculator registry util functions that supports both legacy Calculator API +// and CalculatorBase. +namespace mediapipe { + +bool IsLegacyCalculator(const std::string& package_name, + const std::string& node_class); + +::mediapipe::Status VerifyCalculatorWithContract( + const std::string& package_name, const std::string& node_class, + CalculatorContract* contract); + +::mediapipe::StatusOr> CreateCalculator( + const std::shared_ptr& input_tag_map, + const std::shared_ptr& output_tag_map, + const std::string& package_name, CalculatorState* calculator_state, + CalculatorContext* calculator_context); + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_ diff --git a/mediapipe/framework/calculator_runner.cc b/mediapipe/framework/calculator_runner.cc new file mode 100644 index 000000000..157afaacf --- /dev/null +++ b/mediapipe/framework/calculator_runner.cc @@ -0,0 +1,354 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for CalculatorRunner. + +#include "mediapipe/framework/calculator_runner.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +const char CalculatorRunner::kSourcePrefix[] = "source_for_"; +const char CalculatorRunner::kSinkPrefix[] = "sink_for_"; + +namespace { + +// Calculator generating a stream with the given contents. +// Inputs: none +// Outputs: 1, with the contents provided via the input side packet. +// Input side packets: 1, pointing to CalculatorRunner::StreamContents. +class CalculatorRunnerSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets() + .Index(0) + .Set(); + cc->Outputs().Index(0).SetAny(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const auto* contents = cc->InputSidePackets() + .Index(0) + .Get(); + // Set the header and packets of the output stream. + cc->Outputs().Index(0).SetHeader(contents->header); + for (const Packet& packet : contents->packets) { + cc->Outputs().Index(0).AddPacket(packet); + } + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { + return tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(CalculatorRunnerSourceCalculator); + +// Calculator recording the contents of a stream. +// Inputs: 1, with the contents written to the input side packet. +// Outputs: none +// Input side packets: 1, pointing to CalculatorRunner::StreamContents. +class CalculatorRunnerSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->InputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + contents_ = cc->InputSidePackets() + .Index(0) + .Get(); + contents_->header = cc->Inputs().Index(0).Header(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + contents_->packets.push_back(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } + + private: + CalculatorRunner::StreamContents* contents_ = nullptr; +}; +REGISTER_CALCULATOR(CalculatorRunnerSinkCalculator); + +} // namespace + +CalculatorRunner::CalculatorRunner( + const CalculatorGraphConfig::Node& node_config) { + MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); +} + +::mediapipe::Status CalculatorRunner::InitializeFromNodeConfig( + const CalculatorGraphConfig::Node& node_config) { + node_config_ = node_config; + + if (node_config_.external_input_size() > 0) { + RET_CHECK_EQ(0, node_config_.input_side_packet_size()) + << "Only one of input_side_packet or (deprecated) external_input can " + "be set."; + node_config_.mutable_external_input()->Swap( + node_config_.mutable_input_side_packet()); + } + + ASSIGN_OR_RETURN(auto input_map, + tool::TagMap::Create(node_config_.input_stream())); + inputs_ = absl::make_unique(input_map); + + ASSIGN_OR_RETURN(auto output_map, + tool::TagMap::Create(node_config_.output_stream())); + outputs_ = absl::make_unique(output_map); + + ASSIGN_OR_RETURN(auto input_side_map, + tool::TagMap::Create(node_config_.input_side_packet())); + input_side_packets_ = absl::make_unique(input_side_map); + + ASSIGN_OR_RETURN(auto output_side_map, + tool::TagMap::Create(node_config_.output_side_packet())); + output_side_packets_ = absl::make_unique(output_side_map); + + return ::mediapipe::OkStatus(); +} + +CalculatorRunner::CalculatorRunner(const std::string& calculator_type, + const CalculatorOptions& options) { + node_config_.set_calculator(calculator_type); + *node_config_.mutable_options() = options; + log_calculator_proto_ = true; +} + +#if !defined(MEDIAPIPE_PROTO_LITE) +CalculatorRunner::CalculatorRunner(const std::string& node_config_string) { + CalculatorGraphConfig::Node node_config; + CHECK( + proto_ns::TextFormat::ParseFromString(node_config_string, &node_config)); + MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); +} + +CalculatorRunner::CalculatorRunner(const std::string& calculator_type, + const std::string& options_string, + int num_inputs, int num_outputs, + int num_side_packets) { + node_config_.set_calculator(calculator_type); + CHECK(proto_ns::TextFormat::ParseFromString(options_string, + node_config_.mutable_options())); + SetNumInputs(num_inputs); + SetNumOutputs(num_outputs); + SetNumInputSidePackets(num_side_packets); + // Reset log_calculator_proto to false, since it was set to true by + // SetNum*() calls above. This constructor is not deprecated but is + // currently implemented in terms of deprecated functions. + log_calculator_proto_ = false; +} +#endif + +CalculatorRunner::~CalculatorRunner() {} + +void CalculatorRunner::SetNumInputs(int n) { + tool::TagAndNameInfo info; + for (int i = 0; i < n; ++i) { + info.names.push_back(absl::StrCat("input_", i)); + } + InitializeInputs(info); +} + +void CalculatorRunner::SetNumOutputs(int n) { + tool::TagAndNameInfo info; + for (int i = 0; i < n; ++i) { + info.names.push_back(absl::StrCat("output_", i)); + } + InitializeOutputs(info); +} + +void CalculatorRunner::SetNumInputSidePackets(int n) { + tool::TagAndNameInfo info; + for (int i = 0; i < n; ++i) { + info.names.push_back(absl::StrCat("side_packet_", i)); + } + InitializeInputSidePackets(info); +} + +void CalculatorRunner::InitializeInputs(const tool::TagAndNameInfo& info) { + CHECK(graph_ == nullptr); + MEDIAPIPE_CHECK_OK( + tool::SetFromTagAndNameInfo(info, node_config_.mutable_input_stream())); + inputs_.reset(new StreamContentsSet(info)); + log_calculator_proto_ = true; +} + +void CalculatorRunner::InitializeOutputs(const tool::TagAndNameInfo& info) { + CHECK(graph_ == nullptr); + MEDIAPIPE_CHECK_OK( + tool::SetFromTagAndNameInfo(info, node_config_.mutable_output_stream())); + outputs_.reset(new StreamContentsSet(info)); + log_calculator_proto_ = true; +} + +void CalculatorRunner::InitializeInputSidePackets( + const tool::TagAndNameInfo& info) { + CHECK(graph_ == nullptr); + MEDIAPIPE_CHECK_OK(tool::SetFromTagAndNameInfo( + info, node_config_.mutable_input_side_packet())); + input_side_packets_.reset(new PacketSet(info)); + log_calculator_proto_ = true; +} + +mediapipe::Counter* CalculatorRunner::GetCounter(const std::string& name) { + return graph_->GetCounterFactory()->GetCounter(name); +} + +::mediapipe::Status CalculatorRunner::BuildGraph() { + if (graph_ != nullptr) { + // The graph was already built. + return ::mediapipe::OkStatus(); + } + RET_CHECK(inputs_) << "The inputs were not initialized."; + RET_CHECK(outputs_) << "The outputs were not initialized."; + RET_CHECK(input_side_packets_) + << "The input side packets were not initialized."; + + CalculatorGraphConfig config; + // Add the calculator node. + *(config.add_node()) = node_config_; + + for (int i = 0; i < node_config_.input_stream_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i), &tag, + &index, &name)); + // Add a source for each input stream. + auto* node = config.add_node(); + node->set_calculator("CalculatorRunnerSourceCalculator"); + node->add_output_stream(name); + node->add_input_side_packet(absl::StrCat(kSourcePrefix, name)); + } + for (int i = 0; i < node_config_.output_stream_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i), &tag, + &index, &name)); + // Add a sink for each output stream. + auto* node = config.add_node(); + node->set_calculator("CalculatorRunnerSinkCalculator"); + node->add_input_stream(name); + node->add_input_side_packet(absl::StrCat(kSinkPrefix, name)); + } + config.set_num_threads(1); + + if (log_calculator_proto_) { +#if defined(MEDIAPIPE_PROTO_LITE) + LOG(INFO) << "Please initialize CalculatorRunner using the recommended " + "constructor:\n CalculatorRunner runner(node_config);"; +#else + std::string config_string; + proto_ns::TextFormat::Printer printer; + printer.SetInitialIndentLevel(4); + printer.PrintToString(node_config_, &config_string); + LOG(INFO) << "Please initialize CalculatorRunner using the recommended " + "constructor:\n CalculatorRunner runner(R\"(\n" + << config_string << "\n )\");"; +#endif + } + + graph_ = absl::make_unique(); + RETURN_IF_ERROR(graph_->Initialize(config)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CalculatorRunner::Run() { + RETURN_IF_ERROR(BuildGraph()); + // Set the input side packets for the sources. + std::map input_side_packets; + int positional_index = -1; + for (int i = 0; i < node_config_.input_stream_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i), &tag, + &index, &name)); + const CalculatorRunner::StreamContents* contents; + if (index == -1) { + // positional_index considers the case when the tag is empty, which is + // always the case when index == -1. If we ever support indices for + // non-empty tags ("ABC:input1" and "ABC:input2" with automatic indices), + // this should be changed to use a map insted. + contents = &inputs_->Get(tag, ++positional_index); + } else { + contents = &inputs_->Get(tag, index); + } + input_side_packets.emplace(absl::StrCat(kSourcePrefix, name), + Adopt(new auto(contents))); + } + // Set the input side packets for the calculator. + positional_index = -1; + for (int i = 0; i < node_config_.input_side_packet_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_side_packet(i), + &tag, &index, &name)); + const Packet* packet; + if (index == -1) { + packet = &input_side_packets_->Get(tag, ++positional_index); + } else { + packet = &input_side_packets_->Get(tag, index); + } + input_side_packets.emplace(name, *packet); + } + // Set the input side packets for the sinks. + positional_index = -1; + for (int i = 0; i < node_config_.output_stream_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i), &tag, + &index, &name)); + CalculatorRunner::StreamContents* contents; + if (index == -1) { + contents = &outputs_->Get(tag, ++positional_index); + } else { + contents = &outputs_->Get(tag, index); + } + // Clear |contents| because Run() may be called multiple times. + *contents = CalculatorRunner::StreamContents(); + input_side_packets.emplace(absl::StrCat(kSinkPrefix, name), + Adopt(new auto(contents))); + } + RETURN_IF_ERROR(graph_->Run(input_side_packets)); + + positional_index = -1; + for (int i = 0; i < node_config_.output_side_packet_size(); ++i) { + std::string name; + std::string tag; + int index; + RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_side_packet(i), + &tag, &index, &name)); + Packet& contents = output_side_packets_->Get( + tag, (index == -1) ? ++positional_index : index); + ASSIGN_OR_RETURN(contents, graph_->GetOutputSidePacket(name)); + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_runner.h b/mediapipe/framework/calculator_runner.h new file mode 100644 index 000000000..59b20c548 --- /dev/null +++ b/mediapipe/framework/calculator_runner.h @@ -0,0 +1,157 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines CalculatorRunner which can be used to run a Calculator in +// isolation. This is useful for testing. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_RUNNER_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_RUNNER_H_ + +#include +#include +#include + +#include "absl/base/macros.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +class CalculatorGraph; + +// The class for running the Calculator with given inputs and examining outputs. +class CalculatorRunner { + public: + // A representation of input or output stream contents. + struct StreamContents { + // The Packets in the stream. + std::vector packets; + // Stream header. + Packet header; + }; + // A collection of StreamContents by either index or tag. + typedef internal::Collection StreamContentsSet; + + // Preferred constructor. + // All the needed information comes from the node config. + // Example: + // CalculatorRunner runner(R"( + // calculator: "ScaleImageCalculator" + // input_stream: "ycbcr_frames" + // output_stream: "FRAMES:srgb_frames" + // output_stream: "VIDEO_HEADER:srgb_frames_header" + // options { + // [mediapipe.ScaleImageCalculatorOptions.ext] { + // target_height: 10 + // preserve_aspect_ratio: true + // output_format: SRGB + // algorithm: AREA + // } + // } + // )"); + explicit CalculatorRunner(const CalculatorGraphConfig::Node& node_config); +#if !defined(MEDIAPIPE_PROTO_LITE) + // Convenience constructor which takes a node_config std::string directly. + explicit CalculatorRunner(const std::string& node_config_string); + // Convenience constructor to initialize a calculator which uses indexes + // (not tags) for all its fields. + // NOTE: This constructor calls proto_ns::TextFormat::ParseFromString(), which + // is not available when using lite protos. + CalculatorRunner(const std::string& calculator_type, + const std::string& options_string, int num_inputs, + int num_outputs, int num_side_packets); +#endif + // Minimal constructor which requires additional calls to define inputs, + // outputs, and input side packets. Prefer using another constructor. + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + CalculatorRunner(const std::string& calculator_type, + const CalculatorOptions& options); + + CalculatorRunner(const CalculatorRunner&) = delete; + CalculatorRunner& operator=(const CalculatorRunner&) = delete; + + ~CalculatorRunner(); + + // Sets the number of input streams, output streams, or input side packets, + // respectively. May not be called after Run() has been called. + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void SetNumInputs(int n); + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void SetNumOutputs(int n); + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void SetNumInputSidePackets(int n); + + // Initializes the inputs, outputs, or side packets using a + // TagAndNameInfo. This sets the corresponding section of node_config_. + // May not be called after Run() has been called. + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void InitializeInputs(const tool::TagAndNameInfo& info); + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void InitializeOutputs(const tool::TagAndNameInfo& info); + ABSL_DEPRECATED("Initialize CalculatorRunner with a proto instead.") + void InitializeInputSidePackets(const tool::TagAndNameInfo& info); + + // Returns mutable access to the input stream contents. + StreamContentsSet* MutableInputs() { return inputs_.get(); } + // Returns mutable access to the input side packets. + PacketSet* MutableSidePackets() { return input_side_packets_.get(); } + + // Runs the calculator, by calling Open(), Process() with the + // inputs provided via mutable_inputs(), and Close(). Returns the + // ::mediapipe::Status from CalculatorGraph::Run(). Internally, Run() + // constructs a CalculatorGraph in the first call, and calls + // CalculatorGraph::Run(). A single instance of CalculatorRunner + // uses the same instance of CalculatorGraph for all runs. + ::mediapipe::Status Run(); + + // Returns the vector of contents of the output streams. The .header + // field contains the stream header and the .packets field contains + // the Packets from the stream, unless SetOutputPacketCallback() + // has been called with non-nullptr, in which case .packets will be empty. + const StreamContentsSet& Outputs() const { return *outputs_; } + + // Returns the access to the output side packets. + const PacketSet& OutputSidePackets() { return *output_side_packets_.get(); } + + // Returns a graph counter. + mediapipe::Counter* GetCounter(const std::string& name); + + private: + static const char kSourcePrefix[]; + static const char kSinkPrefix[]; + + // Initialize using a node config (does the constructor's work). + ::mediapipe::Status InitializeFromNodeConfig( + const CalculatorGraphConfig::Node& node_config); + + // Builds the graph if one does not already exist. + ::mediapipe::Status BuildGraph(); + + CalculatorGraphConfig::Node node_config_; + + // Log the calculator proto after it is created from the provided + // parameters. This aids users in migrating to the recommended + // constructor. + bool log_calculator_proto_ = false; + + std::unique_ptr inputs_; + std::unique_ptr outputs_; + std::unique_ptr input_side_packets_; + std::unique_ptr output_side_packets_; + std::unique_ptr graph_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_RUNNER_H_ diff --git a/mediapipe/framework/calculator_runner_test.cc b/mediapipe/framework/calculator_runner_test.cc new file mode 100644 index 000000000..692a9a554 --- /dev/null +++ b/mediapipe/framework/calculator_runner_test.cc @@ -0,0 +1,239 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Tests CalculatorRunner. + +#include "mediapipe/framework/calculator_runner.h" + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_registry.h" +#include "mediapipe/framework/input_stream.h" +#include "mediapipe/framework/output_stream.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace { + +// Inputs: 2 streams with ints. Headers are strings. +// Input side packets: 1. +// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from +// corresponding input streams, #2 will contain replicas of the input side +// packet +// at InputTimestamp. The headers are strings. +class CalculatorRunnerTestCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Inputs().Index(1).Set(); + cc->Outputs().Index(0).Set(); + cc->Outputs().Index(1).Set(); + cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0)); + cc->InputSidePackets().Index(0).SetAny(); + cc->OutputSidePackets() + .Tag("SIDE_OUTPUT") + .SetSameAs(&cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + std::string input_header_string = + absl::StrCat(cc->Inputs().Index(0).Header().Get(), + cc->Inputs().Index(1).Header().Get()); + for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { + // Set the header to the concatenation of the input headers and + // the index of the output stream. + cc->Outputs().Index(i).SetHeader( + Adopt(new std::string(absl::StrCat(input_header_string, i)))); + } + cc->OutputSidePackets() + .Tag("SIDE_OUTPUT") + .Set(cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + for (int index = 0; index < 2; ++index) { + cc->Outputs().Index(index).Add( + new int(-cc->Inputs().Index(index).Get()), cc->InputTimestamp()); + } + cc->Outputs().Index(2).AddPacket( + cc->InputSidePackets().Index(0).At(cc->InputTimestamp())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(CalculatorRunnerTestCalculator); + +// Inputs: Any number of streams of integer, with any tags. +// Outputs: For each tag name (possibly including the empty tag), outputs a +// a single stream with the sum of the integers belonging to streams +// with the same tag name (and any index). +class CalculatorRunnerMultiTagTestCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (const std::string& tag : cc->Inputs().GetTags()) { + for (CollectionItemId item_id = cc->Inputs().BeginId(tag); + item_id < cc->Inputs().EndId(tag); ++item_id) { + cc->Inputs().Get(item_id).Set(); + } + cc->Outputs().Get(tag, 0).Set(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + for (const std::string& tag : cc->Inputs().GetTags()) { + auto sum = absl::make_unique(0); + for (CollectionItemId item_id = cc->Inputs().BeginId(tag); + item_id < cc->Inputs().EndId(tag); ++item_id) { + if (!cc->Inputs().Get(item_id).IsEmpty()) { + *sum += cc->Inputs().Get(item_id).Get(); + } + } + cc->Outputs().Get(tag, 0).Add(sum.release(), cc->InputTimestamp()); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(CalculatorRunnerMultiTagTestCalculator); + +TEST(CalculatorRunner, RunsCalculator) { + CalculatorRunner runner(R"( + calculator: "CalculatorRunnerTestCalculator" + input_stream: "input_0" + input_stream: "input_1" + output_stream: "output_0" + output_stream: "output_1" + output_stream: "output_2" + input_side_packet: "input_side_packet_0" + output_side_packet: "SIDE_OUTPUT:output_side_packet_0" + options { + } + )"); + + // Run CalculatorRunner::Run() several times, with different inputs. This + // tests that a CalculatorRunner instance can be reused. + for (int iter = 0; iter < 3; ++iter) { + LOG(INFO) << "iter: " << iter; + const int length = iter; + // Generate the inputs at timestamps 0 ... length-1, at timestamp t having + // values t and t*2 for the two streams, respectively. + const std::string kHeaderPrefix = "header"; + for (int index = 0; index < 2; ++index) { + runner.MutableInputs()->Index(index).packets.clear(); + for (int t = 0; t < length; ++t) { + runner.MutableInputs()->Index(index).packets.push_back( + Adopt(new int(t * (index + 1))).At(Timestamp(t))); + } + // Set the header to the concatenation of kHeaderPrefix and the index of + // the input stream. + runner.MutableInputs()->Index(index).header = + Adopt(new std::string(absl::StrCat(kHeaderPrefix, index))); + } + const int input_side_packet_content = 10 + iter; + runner.MutableSidePackets()->Index(0) = + Adopt(new int(input_side_packet_content)); + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(input_side_packet_content, + runner.OutputSidePackets().Tag("SIDE_OUTPUT").Get()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(3, outputs.NumEntries()); + + // Check the output headers and the number of Packets. + for (int index = 0; index < outputs.NumEntries(); ++index) { + // The header should be the concatenation of the input headers + // and the index of the output stream. + EXPECT_EQ(absl::StrCat(kHeaderPrefix, 0, kHeaderPrefix, 1, index), + outputs.Index(index).header.Get()); + // Check the packets. + const std::vector& packets = outputs.Index(index).packets; + EXPECT_EQ(length, packets.size()); + for (int t = 0; t < length; ++t) { + EXPECT_EQ(Timestamp(t), packets[t].Timestamp()); + // The first two output streams are negations of the inputs, the last + // contains copies of the input side packet. + if (index < 2) { + EXPECT_EQ(-t * (index + 1), packets[t].Get()); + } else { + EXPECT_EQ(input_side_packet_content, packets[t].Get()); + } + } + } + } +} + +TEST(CalculatorRunner, MultiTagTestCalculatorOk) { + CalculatorRunner runner(R"( + calculator: "CalculatorRunnerMultiTagTestCalculator" + input_stream: "A:0:full_0" + input_stream: "A:1:full_1" + input_stream: "A:2:full_2" + input_stream: "B:no_index_0" + input_stream: "no_tag_or_index_0" + input_stream: "no_tag_or_index_1" + output_stream: "A:output_a" + output_stream: "B:output_b" + output_stream: "output_c" + )"); + + for (int ts = 0; ts < 5; ++ts) { + for (int i = 0; i < 3; ++i) { + runner.MutableInputs()->Get("A", i).packets.push_back( + Adopt(new int(10 * ts + i)).At(Timestamp(ts))); + } + runner.MutableInputs()->Get("B", 0).packets.push_back( + Adopt(new int(100)).At(Timestamp(ts))); + runner.MutableInputs() + ->Get("", ts % 2) + .packets.push_back(Adopt(new int(ts)).At(Timestamp(ts))); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const auto& outputs = runner.Outputs(); + ASSERT_EQ(3, outputs.NumEntries()); + for (int ts = 0; ts < 5; ++ts) { + const std::vector& a_packets = outputs.Tag("A").packets; + const std::vector& b_packets = outputs.Tag("B").packets; + const std::vector& c_packets = outputs.Tag("").packets; + EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp()); + EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp()); + EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp()); + + EXPECT_EQ(10 * 3 * ts + 3, a_packets[ts].Get()); + EXPECT_EQ(100, b_packets[ts].Get()); + EXPECT_EQ(ts, c_packets[ts].Get()); + } +} + +TEST(CalculatorRunner, MultiTagTestInvalidStreamTagCrashes) { + const std::string graph_config = R"( + calculator: "CalculatorRunnerMultiTagTestCalculator" + input_stream: "A:0:a_0" + input_stream: "A:a_1" + input_stream: "A:2:a_2" + output_stream: "A:output_a" + )"; + EXPECT_DEATH(CalculatorRunner runner(graph_config), + ".*tag \"A\" index 0 already had a name " + "\"a_0\" but is being reassigned a name \"a_1\""); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_state.cc b/mediapipe/framework/calculator_state.cc new file mode 100644 index 000000000..6b4088872 --- /dev/null +++ b/mediapipe/framework/calculator_state.cc @@ -0,0 +1,82 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions for CalculatorNode. + +#include "mediapipe/framework/calculator_state.h" + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +CalculatorState::CalculatorState( + const std::string& node_name, int node_id, + const std::string& calculator_type, + const CalculatorGraphConfig::Node& node_config, + std::shared_ptr profiling_context) + : node_name_(node_name), + node_id_(node_id), + calculator_type_(calculator_type), + node_config_(node_config), + profiling_context_(profiling_context), + input_streams_(nullptr), + output_streams_(nullptr), + counter_factory_(nullptr) { + options_.Initialize(node_config); + ResetBetweenRuns(); +} + +CalculatorState::~CalculatorState() {} + +void CalculatorState::SetInputStreamSet(InputStreamSet* input_stream_set) { + CHECK(input_stream_set); + input_streams_ = input_stream_set; +} + +void CalculatorState::SetOutputStreamSet(OutputStreamSet* output_stream_set) { + CHECK(output_stream_set); + output_streams_ = output_stream_set; +} + +void CalculatorState::ResetBetweenRuns() { + input_side_packets_ = nullptr; + input_streams_ = nullptr; + output_streams_ = nullptr; + counter_factory_ = nullptr; +} + +void CalculatorState::SetInputSidePackets(const PacketSet* input_side_packets) { + CHECK(input_side_packets); + input_side_packets_ = input_side_packets; +} + +void CalculatorState::SetOutputSidePackets( + OutputSidePacketSet* output_side_packets) { + CHECK(output_side_packets); + output_side_packets_ = output_side_packets; +} + +Counter* CalculatorState::GetCounter(const std::string& name) { + CHECK(counter_factory_); + return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name)); +} + +void CalculatorState::SetServicePacket(const std::string& key, Packet packet) { + service_packets_[key] = std::move(packet); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_state.h b/mediapipe/framework/calculator_state.h new file mode 100644 index 000000000..b50362170 --- /dev/null +++ b/mediapipe/framework/calculator_state.h @@ -0,0 +1,160 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defines CalculatorState. + +#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_ +#define MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_ + +#include +#include +#include + +// TODO: Move protos in another CL after the C++ code migration. +#include "absl/base/macros.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/counter.h" +#include "mediapipe/framework/counter_factory.h" +#include "mediapipe/framework/graph_service.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/tool/options_util.h" + +namespace mediapipe { + +class ProfilingContext; +// Holds data that the Calculator needs access to. This data is not +// stored in Calculator directly since Calculator will be destroyed after +// every CalculatorGraph::Run() . It is not stored in CalculatorNode +// because Calculator should not depend on CalculatorNode. All +// information conveyed in this class is flowing from the CalculatorNode +// to the Calculator. +class CalculatorState { + public: + CalculatorState(const std::string& node_name, int node_id, + const std::string& calculator_type, + const CalculatorGraphConfig::Node& node_config, + std::shared_ptr profiling_context); + CalculatorState(const CalculatorState&) = delete; + CalculatorState& operator=(const CalculatorState&) = delete; + ~CalculatorState(); + + // Sets the pointer to the InputStreamSet. The function is invoked by + // CalculatorNode::PrepareForRun. + void SetInputStreamSet(InputStreamSet* input_stream_set); + + // Sets the pointer to the OutputStreamSet. The function is invoked by + // CalculatorNode::PrepareForRun. + void SetOutputStreamSet(OutputStreamSet* output_stream_set); + + // Called before every call to Calculator::Open() (during the PrepareForRun + // phase). + void ResetBetweenRuns(); + + const std::string& CalculatorType() const { return calculator_type_; } + const CalculatorOptions& Options() const { return node_config_.options(); } + // Returns the options given to this calculator. Template argument T must + // be the type of the protobuf extension message or the protobuf::Any + // message containing the options. + template + const T& Options() const { + return options_.Get(); + } + const std::string& NodeName() const { return node_name_; } + const int& NodeId() const { return node_id_; } + + //////////////////////////////////////// + // Interface for Calculator. + //////////////////////////////////////// + const InputStreamSet& InputStreams() const { return *input_streams_; } + const OutputStreamSet& OutputStreams() const { return *output_streams_; } + const PacketSet& InputSidePackets() const { return *input_side_packets_; } + OutputSidePacketSet& OutputSidePackets() { return *output_side_packets_; } + + // Returns a counter using the graph's counter factory. The counter's + // name is the passed-in name, prefixed by the calculator NodeName. + Counter* GetCounter(const std::string& name); + + std::shared_ptr GetSharedProfilingContext() const { + return profiling_context_; + } + + //////////////////////////////////////// + // Interface for CalculatorNode. + //////////////////////////////////////// + // Sets the input side packets. + void SetInputSidePackets(const PacketSet* input_side_packets); + // Sets the output side packets. + void SetOutputSidePackets(OutputSidePacketSet* output_side_packets); + // Sets the counter factory. + void SetCounterFactory(CounterFactory* counter_factory) { + counter_factory_ = counter_factory; + } + + void SetServicePacket(const std::string& key, Packet packet); + + bool IsServiceAvailable(const GraphServiceBase& service) { + return ContainsKey(service_packets_, service.key); + } + + template + T& GetServiceObject(const GraphService& service) { + auto it = service_packets_.find(service.key); + CHECK(it != service_packets_.end()); + return *it->second.template Get>(); + } + + private: + //////////////////////////////////////// + // Persistent variables that are not cleared by ResetBetweenRuns(). + //////////////////////////////////////// + // The name associated with this calculator's node. + const std::string node_name_; + // The ID associated with this calculator's node. + const int node_id_; + // The registered type name of the Calculator. + const std::string calculator_type_; + // The Node protobuf containing the options for the calculator. + const CalculatorGraphConfig::Node node_config_; + // The unpacked protobuf options for the calculator. + tool::OptionsMap options_; + // The graph tracing and profiling interface. + std::shared_ptr profiling_context_; + + std::map service_packets_; + + //////////////////////////////////////// + // Variables which ARE cleared by ResetBetweenRuns(). + //////////////////////////////////////// + // The InputStreamSet object is owned by the CalculatorNode. + // CalculatorState obtains its pointer in CalculatorNode::PrepareForRun. + InputStreamSet* input_streams_; + // The OutputStreamSet object is owned by the CalculatorNode. + // CalculatorState obtains its pointer in CalculatorNode::PrepareForRun. + OutputStreamSet* output_streams_; + // The set of input side packets set by CalculatorNode::PrepareForRun(). + // ResetBetweenRuns() clears this PacketSet pointer. + const PacketSet* input_side_packets_; + // The OutputSidePacketSet object is owned by the CalculatorNode. + // CalculatorState obtains its pointer in CalculatorNode::PrepareForRun. + OutputSidePacketSet* output_side_packets_; + + CounterFactory* counter_factory_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_ diff --git a/mediapipe/framework/camera_intrinsics.h b/mediapipe/framework/camera_intrinsics.h new file mode 100644 index 000000000..69a30881f --- /dev/null +++ b/mediapipe/framework/camera_intrinsics.h @@ -0,0 +1,53 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_CAMERA_INTRINSICS_H_ +#define MEDIAPIPE_FRAMEWORK_CAMERA_INTRINSICS_H_ + +class CameraIntrinsics { + public: + CameraIntrinsics(float fx, float fy, float cx, float cy, float width, + float height) + : fx_(fx), fy_(fy), cx_(cx), cy_(cy), width_(width), height_(height) {} + CameraIntrinsics(float fx, float fy, float cx, float cy) + : CameraIntrinsics(fx, fy, cx, cy, -1, -1) {} + + float fx() const { return fx_; } + float fy() const { return fy_; } + float cx() const { return cx_; } + float cy() const { return cy_; } + float width() const { return width_; } + float height() const { return height_; } + + private: + // Lens focal length along the x-axis, in pixels. + const float fx_; + + // Lens focal length along the y-axis, in pixels. + const float fy_; + + // Principal point, x-coordinate on the image, in pixels. + const float cx_; + + // Principal point, y-coordinate on the image, in pixels. + const float cy_; + + // Image width, in pixels. + const float width_; + + // Image height, in pixels. + const float height_; +}; + +#endif // MEDIAPIPE_FRAMEWORK_CAMERA_INTRINSICS_H_ diff --git a/mediapipe/framework/collection.h b/mediapipe/framework/collection.h new file mode 100644 index 000000000..b3f972b0a --- /dev/null +++ b/mediapipe/framework/collection.h @@ -0,0 +1,563 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_COLLECTION_H_ +#define MEDIAPIPE_FRAMEWORK_COLLECTION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/tool/tag_map.h" +#include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/framework/type_map.h" + +namespace mediapipe { +namespace internal { + +// A class to handle errors that occur in Collection. For most +// collections, these errors should be fatal. However, for a collection +// more like PacketTypeSet, the errors should be deferred and handled +// later. +// +// This class is thread compatible. +template +struct CollectionErrorHandlerFatal { + // An error occurred during object lookup for the provided tag and + // index. The returned object reference will be provided instead. + // + // Since there isn't any state and we're not returning anything, we + // get away with only one version of this function (which is const + // but returns a non-const reference). + T& GetFallback(const std::string& tag, int index) const { + LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index; + std::abort(); + } +}; + +enum class CollectionStorage { kStoreValue = 0, kStorePointer }; + +// A collection of objects of type T. +// +// If storage == kStorePointer then T* will be stored instead of T, but +// the accessor functions will still return T types. The T objects must +// be owned elsewhere and remain alive as long as the collection is used. +// To set the pointers use the GetPtr() function. +// +// The ErrorHandler object allows errors to be deferred to a later time. +// +// This class is thread compatible as long as the ErrorHandler object is also +// thread compatible. +template > +class Collection { + private: + template + class DoubleDerefIterator; + + public: + using value_type = T; + + // The iterator is over value_type, requiring a double dereference if + // storage == kStorePointer. + using iterator = + typename std::conditional, + value_type*>::type; + using const_iterator = + typename std::conditional, + const value_type*>::type; + using difference_type = ptrdiff_t; + using size_type = size_t; + using pointer = value_type*; + using reference = value_type&; + + // The type that is stored by data_; + using stored_type = + typename std::conditional::type; + + // Collection must be initialized on construction. + Collection() = delete; + Collection(const Collection&) = delete; + Collection& operator=(const Collection&) = delete; + // Makes a Collection using the given TagMap (which should be shared + // between collections). + // Refer to mediapipe::tool::CreateTagMap for examples of how to construct a + // collection from a vector of "TAG::name" strings, or from an integer + // number of indexes, etc. + explicit Collection(std::shared_ptr tag_map); + // Makes a Collection using the information in the TagAndNameInfo. + ABSL_DEPRECATED("Use Collection(tool::TagMap)") + explicit Collection(const tool::TagAndNameInfo& info); + // Convenience constructor which initializes a collection to use + // indexes and have num_entries inputs. + ABSL_DEPRECATED("Use Collection(tool::TagMap)") + explicit Collection(int num_entries); + // Convenience constructor which initializes a collection to use tags + // with the given names. + // Note: initializer_list constructor should not be marked explicit. + ABSL_DEPRECATED("Use Collection(tool::TagMap)") + Collection(const std::initializer_list& tag_names); + + // Access the data at a given CollectionItemId. This is the most efficient + // way to access data within the collection. + // + // Do not assume that Index(2) == Get(collection.TagMap()->BeginId() + 2). + value_type& Get(CollectionItemId id); + const value_type& Get(CollectionItemId id) const; + + // Convenience functions. + value_type& Get(const std::string& tag, int index); + const value_type& Get(const std::string& tag, int index) const; + + // Equivalent to Get("", index); + value_type& Index(int index); + const value_type& Index(int index) const; + + // Equivalent to Get(tag, 0); + value_type& Tag(const std::string& tag); + const value_type& Tag(const std::string& tag) const; + + // These functions only exist for collections with storage == + // kStorePointer. GetPtr returns the stored ptr value rather than + // the value_type. The non-const version returns a reference so that + // the pointer can be set. + value_type*& GetPtr(CollectionItemId id); + // Const version returns a pointer to a const value (a const-ref to + // a pointer wouldn't be useful in this context). + const value_type* GetPtr(CollectionItemId id) const; + + // Returns true if the collection has a tag other than "". + // TODO Deprecate and remove this function. + bool UsesTags() const; + + // Returns a description of the collection. + std::string DebugString() const; + + // Return the tag_map. + const std::shared_ptr& TagMap() const; + + // Iteration functions for use of the collection in a range based + // for loop. The items are provided in sorted tag order with indexes + // sequential within tags. + iterator begin(); + iterator end(); + const_iterator begin() const; + const_iterator end() const; + + // Returns the error handler object. + const ErrorHandler& GetErrorHandler() const { return error_handler_; } + + //////////////////////////////////////// + // The remaining public functions directly call their equivalent + // in tool::TagMap. They are guaranteed to be equivalent for any + // Collection initialized using an equivalent tool::TagMap. + //////////////////////////////////////// + + // Returns true if the provided tag is available (not necessarily set yet). + bool HasTag(const std::string& tag) const { return tag_map_->HasTag(tag); } + + // Returns the number of entries in this collection. + int NumEntries() const { return tag_map_->NumEntries(); } + + // Returns the number of entries with the provided tag. + int NumEntries(const std::string& tag) const { + return tag_map_->NumEntries(tag); + } + + // Get the id for the tag and index. This id is guaranteed valid for + // any Collection which was initialized with an equivalent tool::TagMap. + // If the tag or index are invalid then an invalid CollectionItemId + // is returned (with id.IsValid() == false). + // + // The id for indexes within the same tag are guaranteed to + // be sequential. Meaning, if tag "BLAH" has 3 indexes, then + // ++GetId("BLAH", 1) == GetId("BLAH", 2) + // However, be careful in using this fact, as it circumvents the + // validity checks in GetId() (i.e. ++GetId("BLAH", 2) looks like it + // is valid, while GetId("BLAH", 3) is not valid). + CollectionItemId GetId(const std::string& tag, int index) const { + return tag_map_->GetId(tag, index); + } + + // Returns the names of the tags in this collection. + std::set GetTags() const { return tag_map_->GetTags(); } + + // Get a tag and index for the specified id. If the id is not valid, + // then {"", -1} will be returned. + std::pair TagAndIndexFromId(CollectionItemId id) const { + return tag_map_->TagAndIndexFromId(id); + } + + // The CollectionItemId corresponding to the first element in the collection. + // Looping over all elements can be done as follows. + // for (CollectionItemId id = collection.BeginId(); + // id < collection.EndId(); ++id) { + // } + // However, if only one collection is involved, prefer using a range + // based for loop. + // for (Packet packet : Inputs()) { + // } + CollectionItemId BeginId() const { return tag_map_->BeginId(); } + // The CollectionItemId corresponding to an element immediately after + // the last element of the collection. + CollectionItemId EndId() const { return tag_map_->EndId(); } + + // Same as BeginId()/EndId() but for only one tag. If the tag doesn't + // exist then an invalid CollectionItemId is returned. It is guaranteed + // that a loop constructed in this way will successfully not be entered + // for invalid tags. + // for (CollectionItemId id = collection.BeginId(tag); + // id < collection.EndId(tag); ++id) { + // } + CollectionItemId BeginId(const std::string& tag) const { + return tag_map_->BeginId(tag); + } + CollectionItemId EndId(const std::string& tag) const { + return tag_map_->EndId(tag); + } + + private: + // An iterator which is identical to ItType** except that the + // dereference operator (operator*) does a double dereference and + // returns an ItType. + // + // This class is thread compatible. + template + class DoubleDerefIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = ItType; + using difference_type = std::ptrdiff_t; + using pointer = ItType*; + using reference = ItType&; + + DoubleDerefIterator() : ptr_(nullptr) {} + + reference operator*() { return **ptr_; } + + pointer operator->() { return *ptr_; } + + reference operator[](difference_type d) { return **(ptr_ + d); } + + // Member operators. + DoubleDerefIterator& operator++() { + ++ptr_; + return *this; + } + DoubleDerefIterator operator++(int) { + DoubleDerefIterator output(ptr_); + ++ptr_; + return output; + } + DoubleDerefIterator& operator--() { + --ptr_; + return *this; + } + DoubleDerefIterator operator--(int) { + DoubleDerefIterator output(ptr_); + --ptr_; + return output; + } + DoubleDerefIterator& operator+=(difference_type d) { + ptr_ += d; + return *this; + } + DoubleDerefIterator& operator-=(difference_type d) { + ptr_ -= d; + return *this; + } + + // Non-member binary operators. + friend bool operator==(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ == rhs.ptr_; + } + friend bool operator!=(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ != rhs.ptr_; + } + friend bool operator<(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ < rhs.ptr_; + } + friend bool operator<=(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ <= rhs.ptr_; + } + friend bool operator>(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ > rhs.ptr_; + } + friend bool operator>=(DoubleDerefIterator lhs, DoubleDerefIterator rhs) { + return lhs.ptr_ >= rhs.ptr_; + } + + friend DoubleDerefIterator operator+(DoubleDerefIterator lhs, + difference_type d) { + return lhs.ptr_ + d; + } + friend DoubleDerefIterator operator+(difference_type d, + DoubleDerefIterator rhs) { + return rhs.ptr_ + d; + } + friend DoubleDerefIterator& operator-(DoubleDerefIterator lhs, + difference_type d) { + return lhs.ptr_ - d; + } + friend difference_type operator-(DoubleDerefIterator lhs, + DoubleDerefIterator rhs) { + return lhs.ptr_ - rhs.ptr_; + } + + private: + explicit DoubleDerefIterator(ItType* const* data) : ptr_(data) {} + + ItType* const* ptr_; + + friend class Collection; + }; + + // TagMap for the collection. + std::shared_ptr tag_map_; + + // Indexed by Id. Use an array directly so that the type does not + // have to be copy constructable. The array has tag_map_->NumEntries() + // elements. + std::unique_ptr data_; + + // A class which allows errors to be reported flexibly. The default + // instantiation performs a LOG(FATAL) and does not have any member + // variables (zero size). + ErrorHandler error_handler_; +}; + +// Definitions of templated functions for Collection. + +template +Collection::Collection( + std::shared_ptr tag_map) + : tag_map_(std::move(tag_map)) { + if (tag_map_->NumEntries() != 0) { + data_ = absl::make_unique(tag_map_->NumEntries()); + } +} + +template +Collection::Collection( + const tool::TagAndNameInfo& info) { + tag_map_ = std::move(tool::TagMap::Create(info).ValueOrDie()); + if (tag_map_->NumEntries() != 0) { + data_ = absl::make_unique(tag_map_->NumEntries()); + } +} + +template +Collection::Collection(const int num_entries) { + proto_ns::RepeatedPtrField fields; + for (int i = 0; i < num_entries; ++i) { + *fields.Add() = absl::StrCat("name", i); + } + tag_map_ = std::move(tool::TagMap::Create(fields).ValueOrDie()); + if (tag_map_->NumEntries() != 0) { + data_ = absl::make_unique(tag_map_->NumEntries()); + } +} + +template +Collection::Collection( + const std::initializer_list& tag_names) { + proto_ns::RepeatedPtrField fields; + int i = 0; + for (const std::string& name : tag_names) { + *fields.Add() = absl::StrCat(name, ":name", i); + ++i; + } + tag_map_ = std::move(tool::TagMap::Create(fields).ValueOrDie()); + if (tag_map_->NumEntries() != 0) { + data_ = absl::make_unique(tag_map_->NumEntries()); + } +} + +template +bool Collection::UsesTags() const { + auto& mapping = tag_map_->Mapping(); + if (mapping.size() > 1) { + // At least one tag is not "". + return true; + } + if (mapping.empty()) { + // The mapping is empty, it doesn't use tags. + return false; + } + // If the one tag present is non-empty then we are using tags. + return mapping.begin()->first != ""; +} + +template +typename Collection::value_type& +Collection::Get(CollectionItemId id) { + CHECK_LE(BeginId(), id); + CHECK_LT(id, EndId()); + return begin()[id.value()]; +} + +template +const typename Collection::value_type& +Collection::Get(CollectionItemId id) const { + CHECK_LE(BeginId(), id); + CHECK_LT(id, EndId()); + return begin()[id.value()]; +} + +template +typename Collection::value_type*& +Collection::GetPtr(CollectionItemId id) { + static_assert(storage == CollectionStorage::kStorePointer, + "::mediapipe::internal::Collection::GetPtr() is only " + "available for collections that were defined with template " + "argument storage == CollectionStorage::kStorePointer."); + CHECK_LE(BeginId(), id); + CHECK_LT(id, EndId()); + return data_[id.value()]; +} + +template +const typename Collection::value_type* +Collection::GetPtr(CollectionItemId id) const { + static_assert(storage == CollectionStorage::kStorePointer, + "::mediapipe::internal::Collection::GetPtr() is only " + "available for collections that were defined with template " + "argument storage == CollectionStorage::kStorePointer."); + CHECK_LE(BeginId(), id); + CHECK_LT(id, EndId()); + return data_[id.value()]; +} + +template +typename Collection::value_type& +Collection::Get(const std::string& tag, int index) { + CollectionItemId id = GetId(tag, index); + if (!id.IsValid()) { + return error_handler_.GetFallback(tag, index); + } + return begin()[id.value()]; +} + +template +const typename Collection::value_type& +Collection::Get(const std::string& tag, + int index) const { + CollectionItemId id = GetId(tag, index); + if (!id.IsValid()) { + return error_handler_.GetFallback(tag, index); + } + return begin()[id.value()]; +} + +template +typename Collection::value_type& +Collection::Index(int index) { + return Get("", index); +} + +template +const typename Collection::value_type& +Collection::Index(int index) const { + return Get("", index); +} + +template +typename Collection::value_type& +Collection::Tag(const std::string& tag) { + return Get(tag, 0); +} + +template +const typename Collection::value_type& +Collection::Tag(const std::string& tag) const { + return Get(tag, 0); +} + +template +std::string Collection::DebugString() const { + std::string output = + absl::StrCat("Collection of \"", MediaPipeTypeStringOrDemangled(), + "\" with\n", tag_map_->DebugString()); + return output; +} + +template +const std::shared_ptr& +Collection::TagMap() const { + return tag_map_; +} + +template +typename Collection::iterator +Collection::begin() { + return iterator(data_.get()); +} + +template +typename Collection::iterator +Collection::end() { + return iterator(data_.get() + tag_map_->NumEntries()); +} + +template +typename Collection::const_iterator +Collection::begin() const { + return const_iterator(data_.get()); +} + +template +typename Collection::const_iterator +Collection::end() const { + return const_iterator(data_.get() + tag_map_->NumEntries()); +} + +} // namespace internal + +// Returns c.HasTag(tag) && !Tag(tag)->IsEmpty() (just for convenience). +// This version is used with Calculator. +template +bool HasTagValue(const internal::Collection& c, const std::string& tag) { + return c.HasTag(tag) && !c.Tag(tag)->IsEmpty(); +} + +// Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience). +// This version is used with CalculatorBase. +template +bool HasTagValue(const internal::Collection& c, const std::string& tag) { + return c.HasTag(tag) && !c.Tag(tag).IsEmpty(); +} + +// Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience). +// This version is used with Calculator or CalculatorBase. +template +bool HasTagValue(const C& c, const std::string& tag) { + return HasTagValue(c->Inputs(), tag); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_COLLECTION_H_ diff --git a/mediapipe/framework/collection_item_id.cc b/mediapipe/framework/collection_item_id.cc new file mode 100644 index 000000000..a1603eede --- /dev/null +++ b/mediapipe/framework/collection_item_id.cc @@ -0,0 +1,27 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/collection_item_id.h" + +namespace mediapipe { + +std::ostream& operator<<(std::ostream& os, CollectionItemId arg) { + return os << arg.value(); +} + +CollectionItemId operator+(int lhs, CollectionItemId rhs) { return rhs + lhs; } +CollectionItemId operator-(int lhs, CollectionItemId rhs) { return -rhs + lhs; } +CollectionItemId operator*(int lhs, CollectionItemId rhs) { return rhs * lhs; } + +} // namespace mediapipe diff --git a/mediapipe/framework/collection_item_id.h b/mediapipe/framework/collection_item_id.h new file mode 100644 index 000000000..4d87eb060 --- /dev/null +++ b/mediapipe/framework/collection_item_id.h @@ -0,0 +1,177 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_COLLECTION_ITEM_ID_H_ +#define MEDIAPIPE_FRAMEWORK_COLLECTION_ITEM_ID_H_ + +#include "mediapipe/framework/deps/strong_int.h" + +namespace mediapipe { + +namespace tool { +class TagMap; +} // namespace tool + +// TagMap allows access to a collection using a tag and index value. +// The underlying data in the collection is stored in a flat array. +// CollectionItemId is the index into that array. Although this type is +// conceptually an int we don't allow implicit type conversion so as to +// avoid confusion where a user accidentally forgets to query the TagMap +// to get an actual CollectionItemId. +// For example, accidentally using Inputs().Get(2) when Inputs().Index(2) +// was meant will cause a type error. +class CollectionItemId { + public: + // Static function to return an invalid id. + static const CollectionItemId GetInvalid() { return CollectionItemId(); } + + // Construct an invalid CollectionItemId. + constexpr CollectionItemId() : value_(-1) {} + + // Use the default copy constructor, assignment, and destructor. + CollectionItemId(const CollectionItemId&) = default; + ~CollectionItemId() = default; + CollectionItemId& operator=(const CollectionItemId&) = default; + + bool IsValid() const { return value_ >= 0; } + // Accesses the raw value. + constexpr int value() const { return value_; } + + // Unary operators. + bool operator!() const { return value_ == 0; } + const CollectionItemId operator+() const { return CollectionItemId(value_); } + const CollectionItemId operator-() const { return CollectionItemId(-value_); } + + // Increment and decrement operators. + CollectionItemId& operator++() { // ++x + ++value_; + return *this; + } + const CollectionItemId operator++(int postfix_flag) { // x++ + CollectionItemId temp(*this); + ++value_; + return temp; + } + CollectionItemId& operator--() { // --x + --value_; + return *this; + } + const CollectionItemId operator--(int postfix_flag) { // x-- + CollectionItemId temp(*this); + --value_; + return temp; + } + + // Action-Assignment operators. + CollectionItemId& operator+=(CollectionItemId arg) { + value_ += arg.value_; + return *this; + } + CollectionItemId operator+(CollectionItemId arg) const { + return CollectionItemId(value_ + arg.value_); + } + template + CollectionItemId operator+(ArgType arg) const { + return CollectionItemId(value_ + arg); + } + + CollectionItemId& operator-=(CollectionItemId arg) { + value_ -= arg.value_; + return *this; + } + CollectionItemId operator-(CollectionItemId arg) const { + return CollectionItemId(value_ - arg.value_); + } + template + CollectionItemId operator-(ArgType arg) const { + return CollectionItemId(value_ - arg); + } + + template + CollectionItemId& operator*=(ArgType arg) { + value_ *= arg; + return *this; + } + CollectionItemId operator*(CollectionItemId arg) const { + return CollectionItemId(value_ * arg.value_); + } + template + CollectionItemId operator*(ArgType arg) const { + return CollectionItemId(value_ * arg); + } + + template + CollectionItemId& operator/=(ArgType arg) { + value_ /= arg; + return *this; + } + CollectionItemId operator/(CollectionItemId arg) const { + return CollectionItemId(value_ / arg.value_); + } + template + CollectionItemId operator/(ArgType arg) const { + return CollectionItemId(value_ / arg); + } + + template + CollectionItemId& operator%=(ArgType arg) { + value_ %= arg; + return *this; + } + CollectionItemId operator%(CollectionItemId arg) const { + return CollectionItemId(value_ % arg.value_); + } + template + CollectionItemId operator%(ArgType arg) const { + return CollectionItemId(value_ % arg); + } + + inline bool operator>(CollectionItemId rhs) const { + return value_ > rhs.value_; + } + inline bool operator>=(CollectionItemId rhs) const { + return value_ >= rhs.value_; + } + inline bool operator<(CollectionItemId rhs) const { + return value_ < rhs.value_; + } + inline bool operator<=(CollectionItemId rhs) const { + return value_ <= rhs.value_; + } + inline bool operator==(CollectionItemId rhs) const { + return value_ == rhs.value_; + } + inline bool operator!=(CollectionItemId rhs) const { + return value_ != rhs.value_; + } + + private: + friend class ::mediapipe::tool::TagMap; + + // Initialization from a value. + explicit constexpr CollectionItemId(int init_value) : value_(init_value) {} + + // The integer value of type int. + int value_; +}; + +std::ostream& operator<<(std::ostream& os, CollectionItemId arg); + +CollectionItemId operator+(int lhs, CollectionItemId rhs); +CollectionItemId operator-(int lhs, CollectionItemId rhs); +CollectionItemId operator*(int lhs, CollectionItemId rhs); + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_COLLECTION_ITEM_ID_H_ diff --git a/mediapipe/framework/collection_test.cc b/mediapipe/framework/collection_test.cc new file mode 100644 index 000000000..0a617b3a6 --- /dev/null +++ b/mediapipe/framework/collection_test.cc @@ -0,0 +1,495 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/collection.h" + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/tag_map_helper.h" + +namespace mediapipe { +namespace { + +TEST(CollectionTest, BasicByIndex) { + tool::TagAndNameInfo info; + info.names.push_back("name_1"); + info.names.push_back("name_0"); + info.names.push_back("name_2"); + internal::Collection collection(info); + collection.Index(1) = 101; + collection.Index(0) = 100; + collection.Index(2) = 102; + + // Test the stored values. + EXPECT_EQ(100, collection.Index(0)); + EXPECT_EQ(101, collection.Index(1)); + EXPECT_EQ(102, collection.Index(2)); + // Test access using a range based for. + int i = 0; + for (int num : collection) { + EXPECT_EQ(100 + i, num); + ++i; + } +} + +TEST(CollectionTest, BasicByTag) { + tool::TagAndNameInfo info; + info.names.push_back("name_1"); + info.tags.push_back("TAG_1"); + info.names.push_back("name_0"); + info.tags.push_back("TAG_0"); + info.names.push_back("name_2"); + info.tags.push_back("TAG_2"); + internal::Collection collection(info); + collection.Tag("TAG_1") = 101; + collection.Tag("TAG_0") = 100; + collection.Tag("TAG_2") = 102; + + // Test the stored values. + EXPECT_EQ(100, collection.Tag("TAG_0")); + EXPECT_EQ(101, collection.Tag("TAG_1")); + EXPECT_EQ(102, collection.Tag("TAG_2")); + // Test access using a range based for. + int i = 0; + for (int num : collection) { + // Numbers are in sorted order by tag. + EXPECT_EQ(100 + i, num); + ++i; + } +} + +TEST(CollectionTest, MixedTagAndIndexUsage) { + auto tags_statusor = + tool::CreateTagMap({"TAG_A:a", "TAG_B:1:b", "TAG_A:2:c", "TAG_B:d", + "TAG_C:0:e", "TAG_A:1:f"}); + MEDIAPIPE_ASSERT_OK(tags_statusor); + + internal::Collection collection1(std::move(tags_statusor.ValueOrDie())); + collection1.Get("TAG_A", 0) = 100; + collection1.Get("TAG_A", 1) = 101; + collection1.Get("TAG_A", 2) = 102; + collection1.Get("TAG_B", 0) = 103; + collection1.Get("TAG_B", 1) = 104; + collection1.Get("TAG_C", 0) = 105; + + // Test access using a range based for. + int i = 0; + for (int num : collection1) { + // Numbers are in sorted order by tag and then index. + EXPECT_EQ(100 + i, num); + ++i; + } + EXPECT_EQ(6, i); + // Initialize the values of another collection while iterating through + // the entries of the first. This is testing that two collections + // can be looped through in lock step. + internal::Collection collection2(collection1.TagMap()); + i = 0; + for (CollectionItemId id = collection1.BeginId(); id < collection1.EndId(); + ++id) { + // Numbers are in sorted order by tag and then index. + EXPECT_EQ(100 + i, collection1.Get(id)); + // Initialize the entries of the second collection. + collection2.Get(id) = 'a' + i; + ++i; + } + EXPECT_EQ(6, i); + + // Check the second collection. + EXPECT_EQ(6, collection2.NumEntries()); + EXPECT_EQ('a', collection2.Get("TAG_A", 0)); + EXPECT_EQ('b', collection2.Get("TAG_A", 1)); + EXPECT_EQ('c', collection2.Get("TAG_A", 2)); + EXPECT_EQ('d', collection2.Get("TAG_B", 0)); + EXPECT_EQ('e', collection2.Get("TAG_B", 1)); + EXPECT_EQ('f', collection2.Get("TAG_C", 0)); + // And check it again with a loop. + i = 0; + for (int num : collection2) { + EXPECT_EQ('a' + i, num); + ++i; + } + EXPECT_EQ(6, i); + + // Initialize the values of another collection by iterating over + // each tag. + internal::Collection collection3(collection1.TagMap()); + i = 0; + for (const std::string& tag : collection1.GetTags()) { + int index_in_tag = 0; + for (CollectionItemId id = collection1.BeginId(tag); + id < collection1.EndId(tag); ++id) { + VLOG(1) << "tag: " << tag << " index_in_tag: " << index_in_tag + << " collection index: " << i; + // Numbers are in sorted order by tag and then index. + EXPECT_EQ(100 + i, collection1.Get(id)); + // Initialize the entries of the second collection. + collection3.Get(id) = absl::StrCat(i, " ", tag, " ", index_in_tag); + ++i; + ++index_in_tag; + } + } + EXPECT_EQ(6, i); + + for (CollectionItemId id = collection1.BeginId("TAG_D"); + id < collection1.EndId("TAG_D"); ++id) { + EXPECT_FALSE(true) << "iteration through non-existent tag found element."; + } + + // Check the second collection. + EXPECT_EQ(6, collection3.NumEntries()); + EXPECT_EQ("0 TAG_A 0", collection3.Get("TAG_A", 0)); + EXPECT_EQ("1 TAG_A 1", collection3.Get("TAG_A", 1)); + EXPECT_EQ("2 TAG_A 2", collection3.Get("TAG_A", 2)); + EXPECT_EQ("3 TAG_B 0", collection3.Get("TAG_B", 0)); + EXPECT_EQ("4 TAG_B 1", collection3.Get("TAG_B", 1)); + EXPECT_EQ("5 TAG_C 0", collection3.Get("TAG_C", 0)); +} + +TEST(CollectionTest, StaticEmptyCollectionHeapCheck) { + // Ensure that static collections play nicely with the heap checker. + // "new T[0]" returns a non-null pointer which the heap checker has + // issues in tracking. Additionally, allocating of empty arrays is + // also inefficient as it invokes heap management routines. + static auto* collection1 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + // Heap check issues are most triggered when zero length and non-zero + // length allocations are interleaved. Additionally, this heap check + // wasn't triggered by "char", so a more complex type (Packet) is used. + static auto* collection2 = + new PacketSet(tool::CreateTagMap({"TAG:name"}).ValueOrDie()); + static auto* collection3 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + static auto* collection4 = + new PacketSet(tool::CreateTagMap({"TAG:name"}).ValueOrDie()); + static auto* collection5 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + EXPECT_EQ(0, collection1->NumEntries()); + EXPECT_EQ(1, collection2->NumEntries()); + EXPECT_EQ(0, collection3->NumEntries()); + EXPECT_EQ(1, collection4->NumEntries()); + EXPECT_EQ(0, collection5->NumEntries()); +} + +template +::mediapipe::Status TestCollectionWithPointers( + const std::vector& original_values, const T& inject1, const T& inject2) { + std::shared_ptr tag_map = + tool::CreateTagMap({"TAG_A:a", "TAG_B:1:b", "TAG_A:2:c", "TAG_B:d", + "TAG_C:0:e", "TAG_A:1:f"}) + .ValueOrDie(); + + { + // Test a regular collection. + std::vector values = original_values; + internal::Collection collection(tag_map); + collection.Get("TAG_A", 0) = values[0]; + collection.Get("TAG_A", 1) = values[1]; + collection.Get("TAG_A", 2) = values[2]; + collection.Get("TAG_B", 0) = values[3]; + collection.Get("TAG_B", 1) = values[4]; + collection.Get("TAG_C", 0) = values[5]; + + const auto* collection_ptr = &collection; + + EXPECT_EQ(values[0], collection.Get("TAG_A", 0)); + EXPECT_EQ(values[1], collection.Get("TAG_A", 1)); + EXPECT_EQ(values[2], collection.Get("TAG_A", 2)); + EXPECT_EQ(values[3], collection.Get("TAG_B", 0)); + EXPECT_EQ(values[4], collection.Get("TAG_B", 1)); + EXPECT_EQ(values[5], collection.Get("TAG_C", 0)); + + EXPECT_EQ(values[0], collection_ptr->Get("TAG_A", 0)); + EXPECT_EQ(values[1], collection_ptr->Get("TAG_A", 1)); + EXPECT_EQ(values[2], collection_ptr->Get("TAG_A", 2)); + EXPECT_EQ(values[3], collection_ptr->Get("TAG_B", 0)); + EXPECT_EQ(values[4], collection_ptr->Get("TAG_B", 1)); + EXPECT_EQ(values[5], collection_ptr->Get("TAG_C", 0)); + + // Test const-ness. + EXPECT_EQ(false, std::is_const::type>::value); + EXPECT_EQ(true, std::is_constGet("TAG_A", 0))>::type>::value); + + // Test access using a range based for. + int i = 0; + for (auto& value : *collection_ptr) { + EXPECT_EQ(values[i], value); + EXPECT_EQ( + true, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + i = 0; + for (auto& value : collection) { + EXPECT_EQ(values[i], value); + EXPECT_EQ( + false, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + // Test the random access operator in the iterator. + // the operator[] should not generally be used. + EXPECT_EQ(values[2], collection_ptr->begin()[2]); + collection.begin()[2] = inject2; + EXPECT_EQ(inject2, collection_ptr->Get("TAG_A", 2)); + } + + { + // Pointer Collection type with dereference_content set to true. + std::vector values = original_values; + internal::Collection + collection(tag_map); + collection.GetPtr(collection.GetId("TAG_A", 0)) = &values[0]; + collection.GetPtr(collection.GetId("TAG_A", 1)) = &values[1]; + collection.GetPtr(collection.GetId("TAG_A", 2)) = &values[2]; + collection.GetPtr(collection.GetId("TAG_B", 0)) = &values[3]; + collection.GetPtr(collection.GetId("TAG_B", 1)) = &values[4]; + collection.GetPtr(collection.GetId("TAG_C", 0)) = &values[5]; + + const auto* collection_ptr = &collection; + + EXPECT_EQ(values[0], collection.Get("TAG_A", 0)); + EXPECT_EQ(values[1], collection.Get("TAG_A", 1)); + EXPECT_EQ(values[2], collection.Get("TAG_A", 2)); + EXPECT_EQ(values[3], collection.Get("TAG_B", 0)); + EXPECT_EQ(values[4], collection.Get("TAG_B", 1)); + EXPECT_EQ(values[5], collection.Get("TAG_C", 0)); + + EXPECT_EQ(values[0], collection_ptr->Get("TAG_A", 0)); + EXPECT_EQ(values[1], collection_ptr->Get("TAG_A", 1)); + EXPECT_EQ(values[2], collection_ptr->Get("TAG_A", 2)); + EXPECT_EQ(values[3], collection_ptr->Get("TAG_B", 0)); + EXPECT_EQ(values[4], collection_ptr->Get("TAG_B", 1)); + EXPECT_EQ(values[5], collection_ptr->Get("TAG_C", 0)); + + // Test const-ness. + EXPECT_EQ(false, std::is_const::type>::value); + EXPECT_EQ(true, std::is_constGet("TAG_A", 0))>::type>::value); + + // Test access using a range based for. + int i = 0; + for (auto& value : *collection_ptr) { + EXPECT_EQ(values[i], value); + EXPECT_EQ( + true, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + i = 0; + for (auto& value : collection) { + EXPECT_EQ(values[i], value); + EXPECT_EQ( + false, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + i = 0; + for (CollectionItemId id = collection_ptr->BeginId(); + id < collection_ptr->EndId(); ++id) { + // TODO Test that GetPtr() does not exist for + // storage == kStoreValue. + EXPECT_EQ(&values[i], collection_ptr->GetPtr(id)); + EXPECT_EQ(values[i], *collection_ptr->GetPtr(id)); + EXPECT_EQ(false, std::is_const::type>::value); + EXPECT_EQ(true, std::is_constGetPtr(id))>::type>::value); + ++i; + } + + T injected = inject1; + collection.GetPtr(collection_ptr->GetId("TAG_A", 2)) = &injected; + EXPECT_EQ(&injected, + collection_ptr->GetPtr(collection_ptr->GetId("TAG_A", 2))); + EXPECT_EQ(injected, + *collection_ptr->GetPtr(collection_ptr->GetId("TAG_A", 2))); + EXPECT_EQ(injected, collection_ptr->Get("TAG_A", 2)); + // Test the random access operator in the iterator. + // the operator[] should not generally be used. + EXPECT_EQ( + injected, + collection_ptr->begin()[collection_ptr->GetId("TAG_A", 2).value()]); + collection.begin()[collection_ptr->GetId("TAG_A", 2).value()] = inject2; + EXPECT_EQ(inject2, injected); + + // Test access using a range based for. + i = 0; + for (const T& value : *collection_ptr) { + if (i != collection_ptr->GetId("TAG_A", 2).value()) { + EXPECT_EQ(values[i], value); + } else { + EXPECT_EQ(injected, value); + } + ++i; + } + } + + { + // Pointer Collection type with dereference_content set to false. + std::vector values = original_values; + internal::Collection + collection(tag_map); + collection.Get("TAG_A", 0) = &values[0]; + collection.Get("TAG_A", 1) = &values[1]; + collection.Get("TAG_A", 2) = &values[2]; + collection.Get("TAG_B", 0) = &values[3]; + collection.Get("TAG_B", 1) = &values[4]; + collection.Get("TAG_C", 0) = &values[5]; + + const auto* collection_ptr = &collection; + + EXPECT_EQ(values[0], *collection.Get("TAG_A", 0)); + EXPECT_EQ(values[1], *collection.Get("TAG_A", 1)); + EXPECT_EQ(values[2], *collection.Get("TAG_A", 2)); + EXPECT_EQ(values[3], *collection.Get("TAG_B", 0)); + EXPECT_EQ(values[4], *collection.Get("TAG_B", 1)); + EXPECT_EQ(values[5], *collection.Get("TAG_C", 0)); + + EXPECT_EQ(&values[0], collection.Get("TAG_A", 0)); + EXPECT_EQ(&values[1], collection.Get("TAG_A", 1)); + EXPECT_EQ(&values[2], collection.Get("TAG_A", 2)); + EXPECT_EQ(&values[3], collection.Get("TAG_B", 0)); + EXPECT_EQ(&values[4], collection.Get("TAG_B", 1)); + EXPECT_EQ(&values[5], collection.Get("TAG_C", 0)); + + EXPECT_EQ(values[0], *collection_ptr->Get("TAG_A", 0)); + EXPECT_EQ(values[1], *collection_ptr->Get("TAG_A", 1)); + EXPECT_EQ(values[2], *collection_ptr->Get("TAG_A", 2)); + EXPECT_EQ(values[3], *collection_ptr->Get("TAG_B", 0)); + EXPECT_EQ(values[4], *collection_ptr->Get("TAG_B", 1)); + EXPECT_EQ(values[5], *collection_ptr->Get("TAG_C", 0)); + + EXPECT_EQ(&values[0], collection_ptr->Get("TAG_A", 0)); + EXPECT_EQ(&values[1], collection_ptr->Get("TAG_A", 1)); + EXPECT_EQ(&values[2], collection_ptr->Get("TAG_A", 2)); + EXPECT_EQ(&values[3], collection_ptr->Get("TAG_B", 0)); + EXPECT_EQ(&values[4], collection_ptr->Get("TAG_B", 1)); + EXPECT_EQ(&values[5], collection_ptr->Get("TAG_C", 0)); + + // Test const-ness. + EXPECT_EQ(false, std::is_const::type>::value); + EXPECT_EQ(true, std::is_constGet("TAG_A", 0))>::type>::value); + + // Test access using a range based for. + int i = 0; + for (auto& value : *collection_ptr) { + EXPECT_EQ(&values[i], value); + EXPECT_EQ(values[i], *value); + EXPECT_EQ( + true, + std::is_const< + typename std::remove_reference::type>::value); + // In const collections of pointers it's just the (stored) pointer + // which is const, not the underlying data. + EXPECT_EQ( + false, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + i = 0; + for (auto& value : collection) { + EXPECT_EQ(&values[i], value); + EXPECT_EQ(values[i], *value); + EXPECT_EQ( + false, + std::is_const< + typename std::remove_reference::type>::value); + EXPECT_EQ( + false, + std::is_const< + typename std::remove_reference::type>::value); + ++i; + } + + T injected = inject1; + collection.Get("TAG_A", 2) = &injected; + EXPECT_EQ(&injected, collection_ptr->Get("TAG_A", 2)); + EXPECT_EQ(injected, *collection_ptr->Get("TAG_A", 2)); + // Test the random access operator in the iterator. + // the operator[] should not generally be used. + EXPECT_EQ( + &injected, + collection_ptr->begin()[collection_ptr->GetId("TAG_A", 2).value()]); + *collection.begin()[collection_ptr->GetId("TAG_A", 2).value()] = inject2; + EXPECT_EQ(inject2, injected); + + // Test access using a range based for. + i = 0; + for (const T* value : *collection_ptr) { + if (i != collection_ptr->GetId("TAG_A", 2).value()) { + EXPECT_EQ(&values[i], value); + EXPECT_EQ(values[i], *value); + } else { + EXPECT_EQ(&injected, value); + EXPECT_EQ(injected, *value); + } + ++i; + } + } + return ::mediapipe::OkStatus(); +} + +TEST(CollectionTest, TestCollectionWithPointersIntAndString) { + MEDIAPIPE_ASSERT_OK( + TestCollectionWithPointers({3, 7, -2, 0, 4, -3}, 17, 10)); + MEDIAPIPE_ASSERT_OK(TestCollectionWithPointers( + {"a0", "a1", "a2", "b0", "b1", "c0"}, "inject1", "inject2")); +} + +TEST(CollectionTest, TestIteratorFunctions) { + std::shared_ptr tag_map = + tool::CreateTagMap({"TAG_A:a", "TAG_B:1:b", "TAG_A:2:c", "TAG_B:d", + "TAG_C:0:e", "TAG_A:1:f"}) + .ValueOrDie(); + + std::vector values = {"a0", "a1", "a2", "b0", "b1", "c0"}; + internal::Collection + collection(tag_map); + collection.GetPtr(collection.GetId("TAG_A", 0)) = &values[0]; + collection.GetPtr(collection.GetId("TAG_A", 1)) = &values[1]; + collection.GetPtr(collection.GetId("TAG_A", 2)) = &values[2]; + collection.GetPtr(collection.GetId("TAG_B", 0)) = &values[3]; + collection.GetPtr(collection.GetId("TAG_B", 1)) = &values[4]; + collection.GetPtr(collection.GetId("TAG_C", 0)) = &values[5]; + + EXPECT_EQ(false, std::is_const::type>::value); + EXPECT_EQ(values[0], *collection.begin()); + EXPECT_EQ(false, collection.begin()->empty()); + EXPECT_EQ(false, (*collection.begin()).empty()); + collection.begin()->assign("inject3"); + EXPECT_EQ(values[0], "inject3"); + + const auto* collection_ptr = &collection; + + EXPECT_EQ(true, std::is_constbegin())>::type>::value); + EXPECT_EQ(values[0], *collection_ptr->begin()); + EXPECT_EQ(false, collection_ptr->begin()->empty()); + EXPECT_EQ(false, (*collection_ptr->begin()).empty()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/counter.h b/mediapipe/framework/counter.h new file mode 100644 index 000000000..2fce5b5af --- /dev/null +++ b/mediapipe/framework/counter.h @@ -0,0 +1,36 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The abstract class of counter. + +#ifndef MEDIAPIPE_FRAMEWORK_COUNTER_H_ +#define MEDIAPIPE_FRAMEWORK_COUNTER_H_ + +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe { + +class Counter { + public: + Counter() {} + virtual ~Counter() {} + + virtual void Increment() = 0; + virtual void IncrementBy(int amount) = 0; + virtual int64 Get() = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_COUNTER_H_ diff --git a/mediapipe/framework/counter_factory.cc b/mediapipe/framework/counter_factory.cc new file mode 100644 index 000000000..0eac61777 --- /dev/null +++ b/mediapipe/framework/counter_factory.cc @@ -0,0 +1,80 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/counter_factory.h" + +#include + +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" + +namespace mediapipe { +namespace { + +// Counter implementation when we're not using Flume. +// TODO: Consider using Dax atomic counters instead of this. +// This class is thread safe. +class BasicCounter : public Counter { + public: + explicit BasicCounter(const std::string& name) : value_(0) {} + + void Increment() LOCKS_EXCLUDED(mu_) override { + absl::WriterMutexLock lock(&mu_); + ++value_; + } + + void IncrementBy(int amount) LOCKS_EXCLUDED(mu_) override { + absl::WriterMutexLock lock(&mu_); + value_ += amount; + } + + int64 Get() LOCKS_EXCLUDED(mu_) override { + absl::ReaderMutexLock lock(&mu_); + return value_; + } + + private: + absl::Mutex mu_; + int64 value_ GUARDED_BY(mu_); +}; + +} // namespace + +CounterSet::CounterSet() {} + +CounterSet::~CounterSet() LOCKS_EXCLUDED(mu_) { PublishCounters(); } + +void CounterSet::PublishCounters() LOCKS_EXCLUDED(mu_) {} + +void CounterSet::PrintCounters() LOCKS_EXCLUDED(mu_) { + absl::ReaderMutexLock lock(&mu_); + LOG_IF(INFO, !counters_.empty()) << "MediaPipe Counters:"; + for (const auto& counter : counters_) { + LOG(INFO) << counter.first << ": " << counter.second->Get(); + } +} + +Counter* CounterSet::Get(const std::string& name) LOCKS_EXCLUDED(mu_) { + absl::ReaderMutexLock lock(&mu_); + if (!::mediapipe::ContainsKey(counters_, name)) { + return nullptr; + } + return counters_[name].get(); +} + +Counter* BasicCounterFactory::GetCounter(const std::string& name) { + return counter_set_.Emplace(name, name); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/counter_factory.h b/mediapipe/framework/counter_factory.h new file mode 100644 index 000000000..d1d09df60 --- /dev/null +++ b/mediapipe/framework/counter_factory.h @@ -0,0 +1,93 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_COUNTER_FACTORY_H_ +#define MEDIAPIPE_FRAMEWORK_COUNTER_FACTORY_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "mediapipe/framework/counter.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/map_util.h" + +namespace mediapipe { + +// Holds a map of counter names to counter unique_ptrs. +// This class is thread safe. +class CounterSet { + public: + CounterSet(); + + // In builds with streamz export enabled, this will synchronously export + // the final counter values. + ~CounterSet(); + // Prints the values of all the counters. + // A call to PublishCounters will reset all counters. + void PrintCounters(); + // Publishes the vales of all the counters for monitoring and resets + // all internal counters. + void PublishCounters(); + + // Adds a counter of the given type by constructing the counter in place. + // Returns a pointer to the new counter or if the counter already exists + // to the existing pointer. + template + Counter* Emplace(const std::string& name, Args&&... args) + LOCKS_EXCLUDED(mu_) { + absl::WriterMutexLock lock(&mu_); + std::unique_ptr* existing_counter = FindOrNull(counters_, name); + if (existing_counter) { + return existing_counter->get(); + } + Counter* counter = new CounterType(std::forward(args)...); + counters_[name].reset(counter); + return counter; + } + // Retrieves the counter with the given name; return nullptr if it doesn't + // exist. + Counter* Get(const std::string& name); + + private: + absl::Mutex mu_; + std::map> counters_ GUARDED_BY(mu_); +}; + +// Generic counter factory +class CounterFactory { + public: + virtual ~CounterFactory() {} + virtual Counter* GetCounter(const std::string& name) = 0; + CounterSet* GetCounterSet() { return &counter_set_; } + + protected: + CounterSet counter_set_; +}; + +// Counter factory that makes the counters be our own basic counters. +class BasicCounterFactory : public CounterFactory { + public: + ~BasicCounterFactory() override {} + Counter* GetCounter(const std::string& name) override; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_COUNTER_FACTORY_H_ diff --git a/mediapipe/framework/delegating_executor.cc b/mediapipe/framework/delegating_executor.cc new file mode 100644 index 000000000..627d51751 --- /dev/null +++ b/mediapipe/framework/delegating_executor.cc @@ -0,0 +1,27 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/delegating_executor.h" + +#include + +namespace mediapipe { +namespace internal { + +void DelegatingExecutor::Schedule(std::function task) { + callback_(std::move(task)); +} + +} // namespace internal +} // namespace mediapipe diff --git a/mediapipe/framework/delegating_executor.h b/mediapipe/framework/delegating_executor.h new file mode 100644 index 000000000..e0e4b4025 --- /dev/null +++ b/mediapipe/framework/delegating_executor.h @@ -0,0 +1,38 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_DELEGATING_EXECUTOR_H_ +#define MEDIAPIPE_FRAMEWORK_DELEGATING_EXECUTOR_H_ + +#include "mediapipe/framework/executor.h" + +namespace mediapipe { +namespace internal { + +// An executor that delegates the running of tasks using a callback. +class DelegatingExecutor : public Executor { + public: + explicit DelegatingExecutor( + std::function)> callback) + : callback_(std::move(callback)) {} + void Schedule(std::function task) override; + + private: + std::function)> callback_; +}; + +} // namespace internal +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_DELEGATING_EXECUTOR_H_ diff --git a/mediapipe/framework/demangle.h b/mediapipe/framework/demangle.h new file mode 100644 index 000000000..e9624c5ac --- /dev/null +++ b/mediapipe/framework/demangle.h @@ -0,0 +1,83 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_DEMANGLE_H_ +#define MEDIAPIPE_FRAMEWORK_DEMANGLE_H_ + +// We only support some compilers that support __cxa_demangle. +// TODO: Checks if Android NDK has fixed this issue or not. +#if defined(__ANDROID__) && (defined(__i386__) || defined(__x86_64__)) +#define HAS_CXA_DEMANGLE 0 +#elif (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && \ + !defined(__mips__) +#define HAS_CXA_DEMANGLE 1 +#elif defined(__clang__) && !defined(_MSC_VER) +#define HAS_CXA_DEMANGLE 1 +#else +#define HAS_CXA_DEMANGLE 0 +#endif + +#include + +#include +#if HAS_CXA_DEMANGLE +#include +#endif + +namespace mediapipe { + +// Demangle a mangled symbol name and return the demangled name. +// If 'mangled' isn't mangled in the first place, this function +// simply returns 'mangled' as is. +// +// This function is used for demangling mangled symbol names such as +// '_Z3bazifdPv'. It uses abi::__cxa_demangle() if your compiler has +// the API. Otherwise, this function simply returns 'mangled' as is. +// +// Currently, we support only GCC 3.4.x or later for the following +// reasons. +// +// - GCC 2.95.3 doesn't have cxxabi.h +// - GCC 3.3.5 and ICC 9.0 have a bug. Their abi::__cxa_demangle() +// returns junk values for non-mangled symbol names (ex. function +// names in C linkage). For example, +// abi::__cxa_demangle("main", 0, 0, &status) +// returns "unsigned long" and the status code is 0 (successful). +// +// Also, +// +// - MIPS is not supported because abi::__cxa_demangle() is not defined. +// - Android x86 is not supported because STLs don't define __cxa_demangle +// +// Prefer using MediaPipeTypeStringOrDemangled() when possible (defined +// in type_map.h). +inline std::string Demangle(const char* mangled) { + int status = 0; + char* demangled = nullptr; +#if HAS_CXA_DEMANGLE + demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status); +#endif + std::string out; + if (status == 0 && demangled != nullptr) { // Demangling succeeeded. + out.append(demangled); + free(demangled); + } else { + out.append(mangled); + } + return out; +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_DEMANGLE_H_ diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD new file mode 100644 index 000000000..f3f66eaa8 --- /dev/null +++ b/mediapipe/framework/deps/BUILD @@ -0,0 +1,450 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# The dependencies of mediapipe. + +licenses(["notice"]) # Apache 2.0 + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_py_proto_library") + +package(default_visibility = ["//visibility:private"]) + +proto_library( + name = "proto_descriptor_proto", + srcs = ["proto_descriptor.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], +) + +mediapipe_cc_proto_library( + name = "proto_descriptor_cc_proto", + srcs = ["proto_descriptor.proto"], + visibility = ["//visibility:public"], + deps = [":proto_descriptor_proto"], +) + +cc_library( + name = "aligned_malloc_and_free", + hdrs = ["aligned_malloc_and_free.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cleanup", + hdrs = ["cleanup.h"], + visibility = ["//visibility:public"], + deps = ["@com_google_absl//absl/base:core_headers"], +) + +cc_library( + name = "clock", + srcs = [ + "clock.cc", + "monotonic_clock.cc", + ], + hdrs = [ + "clock.h", + "monotonic_clock.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "message_matchers", + testonly = True, + hdrs = ["message_matchers.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_library( + name = "file_path", + srcs = ["file_path.cc"], + hdrs = ["file_path.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "file_helpers", + srcs = ["file_helpers.cc"], + hdrs = ["file_helpers.h"], + visibility = ["//visibility:public"], + deps = [ + ":file_path", + ":status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "intops", + hdrs = [ + "safe_int.h", + "strong_int.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "image_resizer", + hdrs = ["image_resizer.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:opencv_imgproc", + ], +) + +cc_library( + name = "map_util", + hdrs = ["map_util.h"], + # Use this library through "mediapipe/framework/port:map_util". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = ["//mediapipe/framework/port:logging"], +) + +cc_library( + name = "mathutil", + hdrs = ["mathutil.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + ], +) + +cc_library( + name = "numbers", + hdrs = ["numbers.h"], + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "no_destructor", + hdrs = ["no_destructor.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "point", + hdrs = ["point2.h"], + # Use this library through "mediapipe/framework/port:point". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":mathutil", + ":vector", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + ], +) + +cc_library( + name = "random", + hdrs = ["random_base.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "rectangle", + hdrs = ["rectangle.h"], + # Use this library through "mediapipe/framework/port:rectangle". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":point", + ":vector", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + ], +) + +cc_library( + name = "registration_token", + srcs = ["registration_token.cc"], + hdrs = ["registration_token.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "registration", + srcs = ["registration.cc"], + hdrs = ["registration.h"], + visibility = ["//visibility:public"], + deps = [ + ":registration_token", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "singleton", + hdrs = ["singleton.h"], + # Use this library through "mediapipe/framework/port:singleton". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "source_location", + hdrs = ["source_location.h"], + # Use this library through "mediapipe/framework/port:source_location". + visibility = ["//mediapipe/framework/port:__pkg__"], +) + +cc_library( + name = "status", + srcs = [ + "status.cc", + "status_builder.cc", + ], + hdrs = [ + "canonical_errors.h", + "status.h", + "status_builder.h", + "status_macros.h", + ], + # Use this library through "mediapipe/framework/port:status". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":source_location", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "statusor", + srcs = ["statusor.cc"], + hdrs = [ + "statusor.h", + "statusor_internals.h", + ], + # Use this library through "mediapipe/framework/port:statusor". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":status", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "status_matchers", + testonly = 1, + hdrs = ["status_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":status", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "ret_check", + srcs = ["ret_check.cc"], + hdrs = ["ret_check.h"], + # Use this library through "mediapipe/framework/port:ret_check". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":status", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "thread_options", + hdrs = ["thread_options.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "threadpool", + srcs = ["threadpool.cc"], + hdrs = ["threadpool.h"], + # Use this library through "mediapipe/framework/port:threadpool". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + ":thread_options", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "topologicalsorter", + srcs = ["topologicalsorter.cc"], + hdrs = ["topologicalsorter.h"], + # Use this library through "mediapipe/framework/port:topologicalsorter". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + "//mediapipe/framework/port:logging", + ], +) + +cc_library( + name = "vector", + hdrs = ["vector.h"], + # Use this library through "mediapipe/framework/port:vector". + visibility = ["//mediapipe/framework/port:__pkg__"], + deps = [ + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/utility", + ], +) + +cc_test( + name = "mathutil_unittest", + srcs = ["mathutil_unittest.cc"], + visibility = ["//visibility:public"], + deps = [ + ":mathutil", + "//mediapipe/framework/port:benchmark", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "registration_token_test", + srcs = ["registration_token_test.cc"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":registration_token", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "safe_int_test", + size = "small", + timeout = "long", + srcs = ["safe_int_test.cc"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":intops", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_test( + name = "monotonic_clock_test", + srcs = ["monotonic_clock_test.cc"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":clock", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework/tool:simulation_clock", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "status_builder_test", + size = "small", + srcs = ["status_builder_test.cc"], + linkstatic = 1, + deps = [ + ":status", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "status_test", + size = "small", + srcs = ["status_test.cc"], + linkstatic = 1, + deps = [ + ":status", + ":status_matchers", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "statusor_test", + size = "small", + srcs = ["statusor_test.cc"], + linkstatic = 1, + deps = [ + ":status", + ":statusor", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "topologicalsorter_test", + srcs = ["topologicalsorter_test.cc"], + linkstatic = 1, + deps = [ + ":topologicalsorter", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "threadpool_test", + srcs = ["threadpool_test.cc"], + linkstatic = 1, + deps = [ + ":threadpool", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/synchronization", + ], +) diff --git a/mediapipe/framework/deps/aligned_malloc_and_free.h b/mediapipe/framework/deps/aligned_malloc_and_free.h new file mode 100644 index 000000000..94fd60a2e --- /dev/null +++ b/mediapipe/framework/deps/aligned_malloc_and_free.h @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_ALIGNED_MALLOC_AND_FREE_H_ +#define MEDIAPIPE_DEPS_ALIGNED_MALLOC_AND_FREE_H_ + +#include // for free(), aligned_alloc(), + +#if defined(__ANDROID__) +#include // for memalign() +#endif + +inline void *aligned_malloc(size_t size, int minimum_alignment) { +#if defined(__ANDROID__) || defined(OS_ANDROID) + return memalign(minimum_alignment, size); +#else // !__ANDROID__ && !OS_ANDROID + void *ptr = nullptr; + // posix_memalign requires that the requested alignment be at least + // sizeof(void*). In this case, fall back on malloc which should return memory + // aligned to at least the size of a pointer. + const int required_alignment = sizeof(void *); + if (minimum_alignment < required_alignment) return malloc(size); + if (posix_memalign(&ptr, static_cast(minimum_alignment), size) != 0) + return nullptr; + else + return ptr; +#endif +} + +inline void aligned_free(void *aligned_memory) { free(aligned_memory); } + +#endif // MEDIAPIPE_DEPS_ALIGNED_MALLOC_AND_FREE_H_ diff --git a/mediapipe/framework/deps/canonical_errors.h b/mediapipe/framework/deps/canonical_errors.h new file mode 100644 index 000000000..b5a956da2 --- /dev/null +++ b/mediapipe/framework/deps/canonical_errors.h @@ -0,0 +1,86 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_CANONICAL_ERRORS_H_ +#define MEDIAPIPE_DEPS_CANONICAL_ERRORS_H_ + +#include "mediapipe/framework/deps/status.h" + +namespace mediapipe { + +// Each of the functions below creates a canonical error with the given +// message. The error code of the returned status object matches the name of +// the function. +inline ::mediapipe::Status AlreadyExistsError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kAlreadyExists, message); +} + +inline ::mediapipe::Status CancelledError() { + return ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, ""); +} + +inline ::mediapipe::Status CancelledError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, message); +} + +inline ::mediapipe::Status InternalError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kInternal, message); +} + +inline ::mediapipe::Status InvalidArgumentError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kInvalidArgument, + message); +} + +inline ::mediapipe::Status FailedPreconditionError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kFailedPrecondition, + message); +} + +inline ::mediapipe::Status NotFoundError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound, message); +} + +inline ::mediapipe::Status OutOfRangeError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kOutOfRange, message); +} + +inline ::mediapipe::Status PermissionDeniedError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kPermissionDenied, + message); +} + +inline ::mediapipe::Status UnimplementedError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kUnimplemented, message); +} + +inline ::mediapipe::Status UnknownError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kUnknown, message); +} + +inline ::mediapipe::Status UnavailableError(absl::string_view message) { + return ::mediapipe::Status(::mediapipe::StatusCode::kUnavailable, message); +} + +inline bool IsCancelled(const ::mediapipe::Status& status) { + return status.code() == ::mediapipe::StatusCode::kCancelled; +} + +inline bool IsNotFound(const ::mediapipe::Status& status) { + return status.code() == ::mediapipe::StatusCode::kNotFound; +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_CANONICAL_ERRORS_H_ diff --git a/mediapipe/framework/deps/cleanup.h b/mediapipe/framework/deps/cleanup.h new file mode 100644 index 000000000..141e71c6c --- /dev/null +++ b/mediapipe/framework/deps/cleanup.h @@ -0,0 +1,105 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() {} +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = ::mediapipe::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } + +#ifndef MEDIAPIPE_DEPS_CLEANUP_H_ +#define MEDIAPIPE_DEPS_CLEANUP_H_ + +#include +#include + +#include "absl/base/attributes.h" + +namespace mediapipe { + +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it + // and making a new one in its place, analogous to unique_ptr + // semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +ABSL_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { + static_assert(sizeof...(ExplicitParameterBarrier) == 0, + "No explicit template arguments."); + return Cleanup(std::forward(f)); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_CLEANUP_H_ diff --git a/mediapipe/framework/deps/clock.cc b/mediapipe/framework/deps/clock.cc new file mode 100644 index 000000000..f68143862 --- /dev/null +++ b/mediapipe/framework/deps/clock.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/clock.h" + +#include "absl/time/clock.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +namespace { + +// ----------------------------------------------------------------- +// RealTimeClock +// +// This class is thread-safe. +class RealTimeClock : public Clock { + public: + virtual ~RealTimeClock() { + LOG(FATAL) << "RealTimeClock should never be destroyed"; + } + + absl::Time TimeNow() override { return absl::Now(); } + + void Sleep(absl::Duration d) override { absl::SleepFor(d); } + + void SleepUntil(absl::Time wakeup_time) override { + absl::Duration d = wakeup_time - TimeNow(); + if (d > absl::ZeroDuration()) { + Sleep(d); + } + } +}; + +} // namespace + +Clock::~Clock() {} + +Clock* Clock::RealClock() { + static RealTimeClock* rtclock = new RealTimeClock; + return rtclock; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/clock.h b/mediapipe/framework/deps/clock.h new file mode 100644 index 000000000..28d37b4df --- /dev/null +++ b/mediapipe/framework/deps/clock.h @@ -0,0 +1,68 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_CLOCK_H_ +#define MEDIAPIPE_DEPS_CLOCK_H_ + +#include "absl/time/time.h" + +namespace mediapipe { + +// An abstract interface representing a Clock, which is an object that can +// tell you the current time, and sleep. +// +// This interface allows decoupling code that uses time from the code that +// creates a point in time. You can use this to your advantage by injecting +// Clocks into interfaces rather than having implementations call absl::Now() +// directly. +// +// The Clock::RealClock() function returns a pointer (that you do not own) +// to the global realtime clock. +// +// Example: +// +// bool IsWeekend(Clock* clock) { +// absl::Time now = clock->TimeNow(); +// // ... code to check if 'now' is a weekend. +// } +// +// // Production code. +// IsWeekend(Clock::RealClock()); +// +// // Test code: +// MyTestClock test_clock(SATURDAY); +// IsWeekend(&test_clock); +// +class Clock { + public: + // Returns a pointer to the global realtime clock. The caller does not + // own the returned pointer and should not delete it. The returned clock + // is thread-safe. + static Clock* RealClock(); + + virtual ~Clock(); + + // Returns the current time. + virtual absl::Time TimeNow() = 0; + + // Sleeps for the specified duration. + virtual void Sleep(absl::Duration d) = 0; + + // Sleeps until the specified time. + virtual void SleepUntil(absl::Time wakeup_time) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_CLOCK_H_ diff --git a/mediapipe/framework/deps/file_helpers.cc b/mediapipe/framework/deps/file_helpers.cc new file mode 100644 index 000000000..3c38182a5 --- /dev/null +++ b/mediapipe/framework/deps/file_helpers.cc @@ -0,0 +1,122 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/file_helpers.h" + +#include +#include +#include + +#include + +#include "mediapipe/framework/deps/canonical_errors.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/deps/status_builder.h" + +namespace mediapipe { +namespace file { +::mediapipe::Status GetContents(absl::string_view file_name, + std::string* output) { + FILE* fp = fopen(file_name.data(), "r"); + if (fp == NULL) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Can't find file: " << file_name; + } + + output->clear(); + while (!feof(fp)) { + char buf[4096]; + size_t ret = fread(buf, 1, 4096, fp); + if (ret == 0 && ferror(fp)) { + return ::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) + << "Error while reading file: " << file_name; + } + output->append(std::string(buf, ret)); + } + fclose(fp); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SetContents(absl::string_view file_name, + absl::string_view content) { + FILE* fp = fopen(file_name.data(), "w"); + if (fp == NULL) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Can't open file: " << file_name; + } + + fwrite(content.data(), sizeof(char), content.size(), fp); + size_t ret = fclose(fp); + if (ret == 0 && ferror(fp)) { + return ::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC) + << "Error while writing file: " << file_name; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatchInTopSubdirectories( + const std::string& parent_directory, const std::string& file_name, + std::vector* results) { + DIR* dir = opendir(parent_directory.c_str()); + CHECK(dir); + // Iterates through the parent direcotry. + while (true) { + struct dirent* dir_ent = readdir(dir); + if (dir_ent == nullptr) { + break; + } + if (std::string(dir_ent->d_name) == "." || + std::string(dir_ent->d_name) == "..") { + continue; + } + std::string subpath = + JoinPath(parent_directory, std::string(dir_ent->d_name)); + DIR* sub_dir = opendir(subpath.c_str()); + // Iterates through the subdirecotry to find file matches. + while (true) { + struct dirent* dir_ent_2 = readdir(sub_dir); + if (dir_ent_2 == nullptr) { + break; + } + if (std::string(dir_ent_2->d_name) == "." || + std::string(dir_ent_2->d_name) == "..") { + continue; + } + if (absl::EndsWith(std::string(dir_ent_2->d_name), file_name)) { + results->push_back(JoinPath(subpath, std::string(dir_ent_2->d_name))); + } + } + closedir(sub_dir); + } + closedir(dir); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Exists(absl::string_view file_name) { + struct stat buffer; + int status; + status = stat(file_name.data(), &buffer); + if (status == 0) { + return ::mediapipe::OkStatus(); + } + switch (errno) { + case EACCES: + return ::mediapipe::PermissionDeniedError("Insufficient permissions."); + default: + return ::mediapipe::NotFoundError("The path does not exist."); + } +} + +} // namespace file +} // namespace mediapipe diff --git a/mediapipe/framework/deps/file_helpers.h b/mediapipe/framework/deps/file_helpers.h new file mode 100644 index 000000000..516447d6b --- /dev/null +++ b/mediapipe/framework/deps/file_helpers.h @@ -0,0 +1,38 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_FILE_HELPERS_H_ +#define MEDIAPIPE_DEPS_FILE_HELPERS_H_ + +#include "absl/strings/match.h" +#include "mediapipe/framework/deps/status.h" + +namespace mediapipe { +namespace file { +::mediapipe::Status GetContents(absl::string_view file_name, + std::string* output); + +::mediapipe::Status SetContents(absl::string_view file_name, + absl::string_view content); + +::mediapipe::Status MatchInTopSubdirectories( + const std::string& parent_directory, const std::string& file_name, + std::vector* results); + +::mediapipe::Status Exists(absl::string_view file_name); + +} // namespace file +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_FILE_HELPERS_H_ diff --git a/mediapipe/framework/deps/file_path.cc b/mediapipe/framework/deps/file_path.cc new file mode 100644 index 000000000..19ebbe500 --- /dev/null +++ b/mediapipe/framework/deps/file_path.cc @@ -0,0 +1,120 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/file_path.h" + +#include "absl/strings/str_cat.h" + +namespace mediapipe { +namespace file { + +// 40% of the time in JoinPath() is from calls with 2 arguments, so we +// specialize that case. +std::string JoinPath(absl::string_view path1, absl::string_view path2) { + if (path1.empty()) return std::string(path2); + if (path2.empty()) return std::string(path1); + if (path1.back() == '/') { + if (path2.front() == '/') + return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); + } else { + if (path2.front() != '/') return absl::StrCat(path1, "/", path2); + } + return absl::StrCat(path1, path2); +} + +namespace internal { + +// Given a collection of file paths, append them all together, +// ensuring that the proper path separators are inserted between them. +std::string JoinPathImpl(bool honor_abs, + std::initializer_list paths) { + std::string result; + + if (paths.size() != 0) { + // This size calculation is worst-case: it assumes one extra "/" for every + // path other than the first. + size_t total_size = paths.size() - 1; + for (const absl::string_view path : paths) total_size += path.size(); + result.resize(total_size); + + auto begin = result.begin(); + auto out = begin; + bool trailing_slash = false; + for (absl::string_view path : paths) { + if (path.empty()) continue; + if (path.front() == '/') { + if (honor_abs) { + out = begin; // wipe out whatever we've built up so far. + } else if (trailing_slash) { + path.remove_prefix(1); + } + } else { + if (!trailing_slash && out != begin) *out++ = '/'; + } + const size_t this_size = path.size(); + memcpy(&*out, path.data(), this_size); + out += this_size; + trailing_slash = out[-1] == '/'; + } + result.erase(out - begin); + } + return result; +} + +// Return the parts of the basename of path, split on the final ".". +// If there is no "." in the basename or "." is the final character in the +// basename, the second value will be empty. +std::pair SplitBasename( + absl::string_view path) { + path = Basename(path); + + absl::string_view::size_type pos = path.find_last_of('.'); + if (pos == absl::string_view::npos) + return std::make_pair(path, absl::ClippedSubstr(path, path.size(), 0)); + return std::make_pair(path.substr(0, pos), + absl::ClippedSubstr(path, pos + 1)); +} + +} // namespace internal + +absl::string_view Dirname(absl::string_view path) { + return SplitPath(path).first; +} + +absl::string_view Basename(absl::string_view path) { + return SplitPath(path).second; +} + +std::pair SplitPath( + absl::string_view path) { + absl::string_view::size_type pos = path.find_last_of('/'); + + // Handle the case with no '/' in 'path'. + if (pos == absl::string_view::npos) + return std::make_pair(path.substr(0, 0), path); + + // Handle the case with a single leading '/' in 'path'. + if (pos == 0) + return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1)); + + return std::make_pair(path.substr(0, pos), + absl::ClippedSubstr(path, pos + 1)); +} + +absl::string_view Extension(absl::string_view path) { + return internal::SplitBasename(path).second; +} + +} // namespace file +} // namespace mediapipe diff --git a/mediapipe/framework/deps/file_path.h b/mediapipe/framework/deps/file_path.h new file mode 100644 index 000000000..4c1c15153 --- /dev/null +++ b/mediapipe/framework/deps/file_path.h @@ -0,0 +1,96 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_FILE_PATH_H_ +#define MEDIAPIPE_DEPS_FILE_PATH_H_ + +#include +#include + +#include "absl/strings/string_view.h" + +// A set of file pathname manipulation routines. +namespace mediapipe { +namespace file { +namespace internal { + +// Not part of the public API. +std::string JoinPathImpl(bool honor_abs, + std::initializer_list paths); + +} // namespace internal + +// Join multiple paths together. +// JoinPath and JoinPathRespectAbsolute have slightly different semantics. +// JoinPath unconditionally joins all paths together, whereas +// JoinPathRespectAbsolute ignores any segments prior to the last absolute +// path. For example: +// +// Arguments | JoinPath | JoinPathRespectAbsolute +// ---------------------------+---------------------+----------------------- +// '/foo', 'bar' | /foo/bar | /foo/bar +// '/foo/', 'bar' | /foo/bar | /foo/bar +// '/foo', '/bar' | /foo/bar | /bar +// '/foo', '/bar', '/baz' | /foo/bar/baz | /baz +// +// All paths will be treated as relative paths, regardless of whether or not +// they start with a leading '/'. That is, all paths will be concatenated +// together, with the appropriate path separator inserted in between. +// Arguments must be convertible to absl::string_view. +// +// Usage: +// std::string path = file::JoinPath("/cns", dirname, filename); +// std::string path = file::JoinPath("./", filename); +// +// 0, 1, 2-path specializations exist to optimize common cases. +inline std::string JoinPath() { return std::string(); } +inline std::string JoinPath(absl::string_view path) { + return std::string(path.data(), path.size()); +} +std::string JoinPath(absl::string_view path1, absl::string_view path2); +template +inline std::string JoinPath(absl::string_view path1, absl::string_view path2, + absl::string_view path3, const T&... args) { + return internal::JoinPathImpl(false, {path1, path2, path3, args...}); +} + +// Returns the part of the path before the final "/", EXCEPT: +// * If there is a single leading "/" in the path, the result will be the +// leading "/". +// * If there is no "/" in the path, the result is the empty prefix of the +// input std::string. +absl::string_view Dirname(absl::string_view path); + +// Return the parts of the path, split on the final "/". If there is no +// "/" in the path, the first part of the output is empty and the second +// is the input. If the only "/" in the path is the first character, it is +// the first part of the output. +std::pair SplitPath( + absl::string_view path); + +// Returns the part of the path after the final "/". If there is no +// "/" in the path, the result is the same as the input. +// Note that this function's behavior differs from the Unix basename +// command if path ends with "/". For such paths, this function returns the +// empty std::string. +absl::string_view Basename(absl::string_view path); + +// Returns the part of the basename of path after the final ".". If +// there is no "." in the basename, the result is empty. +absl::string_view Extension(absl::string_view path); + +} // namespace file +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_FILE_PATH_H_ diff --git a/mediapipe/framework/deps/image_resizer.h b/mediapipe/framework/deps/image_resizer.h new file mode 100644 index 000000000..6e1215a69 --- /dev/null +++ b/mediapipe/framework/deps/image_resizer.h @@ -0,0 +1,36 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ +#define MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ + +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" + +namespace mediapipe { + +class ImageResizer { + public: + ImageResizer(double sharpen_coeff) {} + + bool Resize(const cv::Mat& input_mat, cv::Mat* output_mat) { + cv::resize(input_mat, *output_mat, output_mat->size(), 0, 0, + cv::INTER_AREA); + return true; + } +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ diff --git a/mediapipe/framework/deps/map_util.h b/mediapipe/framework/deps/map_util.h new file mode 100644 index 000000000..05d47b7e7 --- /dev/null +++ b/mediapipe/framework/deps/map_util.h @@ -0,0 +1,152 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file provides utility functions for use with STL map-like data +// structures, such as std::map and hash_map. Some functions will also work with +// sets, such as ContainsKey(). + +#ifndef MEDIAPIPE_DEPS_MAP_UTIL_H_ +#define MEDIAPIPE_DEPS_MAP_UTIL_H_ + +#include + +#include +#include +#include +#include +#include + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +// A note on terminology: `m` and `M` represent a map and its type. +// +// Returns a const reference to the value associated with the given key if it +// exists. Crashes otherwise. +// +// This is intended as a replacement for operator[] as an rvalue (for reading) +// when the key is guaranteed to exist. +// +// operator[] for lookup is discouraged for several reasons (note that these +// reasons may apply to only some map types): +// * It has a side-effect of inserting missing keys +// * It is not thread-safe (even when it is not inserting, it can still +// choose to resize the underlying storage) +// * It invalidates iterators (when it chooses to resize) +// * It default constructs a value object even if it doesn't need to +// +// This version assumes the key is printable, and includes it in the fatal log +// message. +template +const typename M::value_type::second_type& FindOrDie( + const M& m, const typename M::value_type::first_type& key) { + auto it = m.find(key); + CHECK(it != m.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename M::value_type::second_type& FindOrDie( + M& m, // NOLINT + const typename M::value_type::first_type& key) { + auto it = m.find(key); + CHECK(it != m.end()) << "Map key not found: " << key; + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with std::string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a std::string (not std::string&). +template +const typename M::value_type::second_type& FindWithDefault( + const M& m, const typename M::value_type::first_type& key, + const typename M::value_type::second_type& value) { + auto it = m.find(key); + if (it != m.end()) { + return it->second; + } + return value; +} + +// Returns a pointer to the const value associated with the given key if it +// exists, or null otherwise. +template +const typename M::value_type::second_type* FindOrNull( + const M& m, const typename M::value_type::first_type& key) { + auto it = m.find(key); + if (it == m.end()) { + return nullptr; + } + return &it->second; +} + +// Returns a pointer to the non-const value associated with the given key if it +// exists, or null otherwise. +template +typename M::value_type::second_type* FindOrNull( + M& m, // NOLINT + const typename M::value_type::first_type& key) { + auto it = m.find(key); + if (it == m.end()) { + return nullptr; + } + return &it->second; +} + +// Returns true if and only if the given m contains the given key. +template +bool ContainsKey(const M& m, const Key& key) { + return m.find(key) != m.end(); +} + +// Inserts the given key and value into the given m if and only if the +// given key did NOT already exist in the m. If the key previously +// existed in the m, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template +bool InsertIfNotPresent(M* m, const typename M::value_type& vt) { + return m->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template +bool InsertIfNotPresent(M* m, const typename M::value_type::first_type& key, + const typename M::value_type::second_type& value) { + return InsertIfNotPresent(m, {key, value}); +} + +// Saves the reverse mapping into reverse. Returns true if values could all be +// inserted. +template +bool ReverseMap(const M& m, ReverseM* reverse) { + CHECK(reverse != nullptr); + for (const auto& kv : m) { + if (!InsertIfNotPresent(reverse, kv.second, kv.first)) { + return false; + } + } + return true; +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_MAP_UTIL_H_ diff --git a/mediapipe/framework/deps/mathutil.h b/mediapipe/framework/deps/mathutil.h new file mode 100644 index 000000000..315b78c42 --- /dev/null +++ b/mediapipe/framework/deps/mathutil.h @@ -0,0 +1,406 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class is intended to contain a collection of useful (static) +// mathematical functions, properly coded (by consulting numerical +// recipes or another authoritative source first). + +#ifndef MEDIAPIPE_DEPS_MATHUTIL_H_ +#define MEDIAPIPE_DEPS_MATHUTIL_H_ + +#include +#include +#include + +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +// ========================================================================= // + +class MathUtil { + public: + // -------------------------------------------------------------------- + // Round + // This function rounds a floating-point number to an integer. It + // works for positive or negative numbers. + // + // Values that are halfway between two integers may be rounded up or + // down, for example Round(0.5) == 0 and Round(1.5) == 2. + // This allows the function to be implemented efficiently on multiple + // hardware platforms (see the template specializations at the bottom + // of this file). You should not use this function if you care about which + // way such half-integers are rounded. + // + // Example usage: + // double y, z; + // int x = Round(y + 3.7); + // int64 b = Round(0.3 * z); + // + // Note that the floating-point template parameter is typically inferred + // from the argument type, i.e. there is no need to specify it explicitly. + // -------------------------------------------------------------------- + template + static IntOut Round(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + + // We don't use sgn(x) below because there is no need to distinguish the + // (x == 0) case. Also note that there are specialized faster versions + // of this function for Intel, ARM and PPC processors at the bottom + // of this file. + if (x > -0.5 && x < 0.5) { + // This case is special, because for largest floating point number + // below 0.5, the addition of 0.5 yields 1 and this would lead + // to incorrect result. + return static_cast(0); + } + return static_cast(x < 0 ? (x - 0.5) : (x + 0.5)); + } + + // Convert a floating-point number to an integer. For all inputs x where + // static_cast(x) is legal according to the C++ standard, the result + // is identical to that cast (i.e. the result is x with its fractional part + // truncated whenever that is representable as IntOut). + // + // static_cast would cause undefined behavior for the following cases, which + // have well-defined behavior for this function: + // + // 1. If x is NaN, the result is zero. + // + // 2. If the truncated form of x is above the representable range of IntOut, + // the result is std::numeric_limits::max(). + // + // 3. If the truncated form of x is below the representable range of IntOut, + // the result is std::numeric_limits::min(). + // + // Note that cases #2 and #3 cover infinities as well as finite numbers. + // + // The range of FloatIn must include the range of IntOut, otherwise + // the results are undefined. + template + static IntOut SafeCast(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + static_assert(std::numeric_limits::radix == 2, "IntOut is base 2"); + + // Special case NaN, for which the logic below doesn't work. + if (std::isnan(x)) { + return 0; + } + + // Negative values all clip to zero for unsigned results. + if (!std::numeric_limits::is_signed && x < 0) { + return 0; + } + + // Handle infinities. + if (std::isinf(x)) { + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + + // Set exp such that x == f * 2^exp for some f with |f| in [0.5, 1.0), + // unless x is zero in which case exp == 0. Note that this implies that the + // magnitude of x is strictly less than 2^exp. + int exp = 0; + std::frexp(x, &exp); + + // Let N be the number of non-sign bits in the representation of IntOut. If + // the magnitude of x is strictly less than 2^N, the truncated version of x + // is representable as IntOut. The only representable integer for which this + // is not the case is std::numeric_limits::min() for signed types (i.e. + // -2^N), but that is covered by the fall-through below. + if (exp <= std::numeric_limits::digits) { + return x; + } + + // Handle numbers with magnitude >= 2^N. + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + + // -------------------------------------------------------------------- + // SafeRound + // These functions round a floating-point number to an integer. + // Results are identical to Round, except in cases where + // the argument is NaN, or when the rounded value would overflow the + // return type. In those cases, Round has undefined + // behavior. SafeRound returns 0 when the argument is + // NaN, and returns the closest possible integer value otherwise (i.e. + // std::numeric_limits::max() for large positive values, and + // std::numeric_limits::min() for large negative values). + // The range of FloatIn must include the range of IntOut, otherwise + // the results are undefined. + // -------------------------------------------------------------------- + template + static IntOut SafeRound(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + + if (std::isnan(x)) { + return 0; + } else { + return SafeCast((x < 0.) ? (x - 0.5) : (x + 0.5)); + } + } + + // -------------------------------------------------------------------- + // FastIntRound, FastInt64Round + // Fast routines for converting floating-point numbers to integers. + // + // These routines are approximately 6 times faster than the default + // implementation of Round on Intel processors (12 times faster on + // the Pentium 3). They are also more than 5 times faster than simply + // casting a "double" to an "int" using static_cast. This is + // because casts are defined to truncate towards zero, which on Intel + // processors requires changing the rounding mode and flushing the + // floating-point pipeline (unless programs are compiled specifically + // for the Pentium 4, which has a new instruction to avoid this). + // + // Numbers that are halfway between two integers may be rounded up or + // down. This is because the conversion is done using the default + // rounding mode, which rounds towards the closest even number in case + // of ties. So for example, FastIntRound(0.5) == 0, but + // FastIntRound(1.5) == 2. These functions should only be used with + // applications that don't care about which way such half-integers are + // rounded. + // + // There are template specializations of Round() which call these + // functions (for "int" and "int64" only), but it's safer to call them + // directly. + // -------------------------------------------------------------------- + + static int32 FastIntRound(double x) { +#if defined __GNUC__ && (defined __i386__ || defined __SSE2__ || \ + defined __aarch64__ || defined __powerpc64__) +#if defined __AVX__ + // AVX. + int32 result; + __asm__ __volatile__( + "vcvtsd2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __SSE2__ + // SSE2. + int32 result; + __asm__ __volatile__( + "cvtsd2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __i386__ + // FPU stack. Adapted from /usr/include/bits/mathinline.h. + int32 result; + __asm__ __volatile__("fistpl %0" + : "=m"(result) // Output operand is a memory location + : "t"(x) // Input operand is top of FP stack + : "st"); // Clobbers (pops) top of FP stack + return result; +#elif defined __aarch64__ + int64 result; + __asm__ __volatile__("fcvtns %d0, %d1" + : "=w"(result) // Vector floating point register + : "w"(x) // Vector floating point register + : /* No clobbers */); + return static_cast(result); +#elif defined __powerpc64__ + int64 result; + __asm__ __volatile__("fctid %0, %1" + : "=d"(result) + : "d"(x) + : /* No clobbers */); + return result; +#endif // defined __powerpc64__ +#else + return Round(x); +#endif // if defined __GNUC__ && ... + } + + static int32 FastIntRound(float x) { +#if defined __GNUC__ && (defined __i386__ || defined __SSE2__ || \ + defined __aarch64__ || defined __powerpc64__) +#if defined __AVX__ + // AVX. + int32 result; + __asm__ __volatile__( + "vcvtss2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __SSE2__ + // SSE2. + int32 result; + __asm__ __volatile__( + "cvtss2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __i386__ + // FPU stack. Adapted from /usr/include/bits/mathinline.h. + int32 result; + __asm__ __volatile__("fistpl %0" + : "=m"(result) // Output operand is a memory location + : "t"(x) // Input operand is top of FP stack + : "st"); // Clobbers (pops) top of FP stack + return result; +#elif defined __aarch64__ + int64 result; + __asm__ __volatile__("fcvtns %s0, %s1" + : "=w"(result) // Vector floating point register + : "w"(x) // Vector floating point register + : /* No clobbers */); + return static_cast(result); +#elif defined __powerpc64__ + uint64 output; + __asm__ __volatile__("fctiw %0, %1" + : "=d"(output) + : "f"(x) + : /* No clobbers */); + return bit_cast(static_cast(output >> 32)); +#endif // defined __powerpc64__ +#else + return Round(x); +#endif // if defined __GNUC__ && ... + } + + static int64 FastInt64Round(double x) { +#if defined __GNUC__ && (defined __i386__ || defined __x86_64__ || \ + defined __aarch64__ || defined __powerpc64__) +#if defined __AVX__ + // AVX. + int64 result; + __asm__ __volatile__( + "vcvtsd2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __x86_64__ + // SSE2. + int64 result; + __asm__ __volatile__( + "cvtsd2si %1, %0" + : "=r"(result) // Output operand is a register + : "xm"(x)); // Input operand is an xmm register or memory + return result; +#elif defined __i386__ + // There is no CVTSD2SI in i386 to produce a 64 bit int, even with SSE2. + // FPU stack. Adapted from /usr/include/bits/mathinline.h. + int64 result; + __asm__ __volatile__("fistpll %0" + : "=m"(result) // Output operand is a memory location + : "t"(x) // Input operand is top of FP stack + : "st"); // Clobbers (pops) top of FP stack + return result; +#elif defined __aarch64__ + // Floating-point convert to signed integer, + // rounding to nearest with ties to even. + int64 result; + __asm__ __volatile__("fcvtns %d0, %d1" + : "=w"(result) + : "w"(x) + : /* No clobbers */); + return result; +#elif defined __powerpc64__ + int64 result; + __asm__ __volatile__("fctid %0, %1" + : "=d"(result) + : "d"(x) + : /* No clobbers */); + return result; +#endif // if defined __powerpc64__ +#else + return Round(x); +#endif // if defined __GNUC__ && ... + } + + static int64 FastInt64Round(float x) { + return FastInt64Round(static_cast(x)); + } + + static int32 FastIntRound(long double x) { return Round(x); } + + static int64 FastInt64Round(long double x) { return Round(x); } + + // Absolute value of the difference between two numbers. + // Works correctly for signed types and special floating point values. + template + static typename std::make_unsigned::type AbsDiff(const T x, const T y) { + // Carries out arithmetic as unsigned to avoid overflow. + typedef typename std::make_unsigned::type R; + return x > y ? R(x) - R(y) : R(y) - R(x); + } + + // Clamps value to the range [low, high]. Requires low <= high. + template // T models LessThanComparable. + static const T& Clamp(const T& low, const T& high, const T& value) { + // Prevents errors in ordering the arguments. + DCHECK(!(high < low)); + if (high < value) return high; + if (value < low) return low; + return value; + } + + // If two (usually floating point) numbers are within a certain + // absolute margin of error. + template + static bool WithinMargin(const T x, const T y, const T margin) { + DCHECK_GE(margin, 0); + return (std::abs(x) <= std::abs(y) + margin) && + (std::abs(x) >= std::abs(y) - margin); + } +}; + +// ========================================================================= // + +#if defined __GNUC__ && (defined __i386__ || defined __x86_64__ || \ + defined __aarch64__ || defined __powerpc64__) + +// We define template specializations of Round() to get the more efficient +// Intel versions when possible. Note that gcc does not currently support +// partial specialization of templatized functions. + +template <> +inline int32 MathUtil::Round(float x) { + return FastIntRound(x); +} + +template <> +inline int32 MathUtil::Round(double x) { + return FastIntRound(x); +} + +template <> +inline int64 MathUtil::Round(float x) { + return FastInt64Round(x); +} + +template <> +inline int64 MathUtil::Round(double x) { + return FastInt64Round(x); +} + +#endif + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_MATHUTIL_H_ diff --git a/mediapipe/framework/deps/mathutil_unittest.cc b/mediapipe/framework/deps/mathutil_unittest.cc new file mode 100644 index 000000000..640e75c6e --- /dev/null +++ b/mediapipe/framework/deps/mathutil_unittest.cc @@ -0,0 +1,879 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test functions in MathUtil. + +#include "mediapipe/framework/deps/mathutil.h" + +#include + +#include +#include +#include +#include + +#include "mediapipe/framework/port/benchmark.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace { + +TEST(MathUtil, Round) { + // test float rounding + EXPECT_EQ(mediapipe::MathUtil::FastIntRound(0.7f), 1); + EXPECT_EQ(mediapipe::MathUtil::FastIntRound(5.7f), 6); + EXPECT_EQ(mediapipe::MathUtil::FastIntRound(6.3f), 6); + EXPECT_EQ(mediapipe::MathUtil::FastIntRound(1000000.7f), 1000001); + + // test that largest representable number below 0.5 rounds to zero. + // this is important because naive implementation of round: + // static_cast(r + 0.5f) is 1 due to implicit rounding in operator+ + float rf = std::nextafter(0.5f, .0f); + EXPECT_LT(rf, 0.5f); + EXPECT_EQ(mediapipe::MathUtil::Round(rf), 0); + + // same test for double + double rd = std::nextafter(0.5, 0.0); + EXPECT_LT(rd, 0.5); + EXPECT_EQ(mediapipe::MathUtil::Round(rd), 0); + + // same test for long double + long double rl = std::nextafter(0.5l, 0.0l); + EXPECT_LT(rl, 0.5l); + EXPECT_EQ(mediapipe::MathUtil::Round(rl), 0); +} + +static void BM_IntCast(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_IntCast); + +static void BM_Int64Cast(benchmark::State& state) { + double x = 0.1; + int64 sum = 0; + for (auto _ : state) { + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + sum += static_cast(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_Int64Cast); + +static void BM_IntRound(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_IntRound); + +static void BM_FastIntRound(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::FastIntRound(x); + x += 0.1; + sum += mediapipe::MathUtil::FastIntRound(x); + x += 0.1; + sum += mediapipe::MathUtil::FastIntRound(x); + x += 0.1; + sum += mediapipe::MathUtil::FastIntRound(x); + x += 0.1; + sum += mediapipe::MathUtil::FastIntRound(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_FastIntRound); + +static void BM_Int64Round(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_Int64Round); + +static void BM_UintRound(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + sum += mediapipe::MathUtil::Round(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_UintRound); + +static void BM_SafeIntCast(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_SafeIntCast); + +static void BM_SafeInt64Cast(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeCast(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_SafeInt64Cast); + +static void BM_SafeIntRound(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_SafeIntRound); + +static void BM_SafeInt64Round(benchmark::State& state) { + double x = 0.1; + int sum = 0; + for (auto _ : state) { + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + sum += mediapipe::MathUtil::SafeRound(x); + x += 0.1; + } + EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. +} +BENCHMARK(BM_SafeInt64Round); + +TEST(MathUtil, IntRound) { + EXPECT_EQ(mediapipe::MathUtil::Round(0.0), 0); + EXPECT_EQ(mediapipe::MathUtil::Round(0.49), 0); + EXPECT_EQ(mediapipe::MathUtil::Round(1.49), 1); + EXPECT_EQ(mediapipe::MathUtil::Round(-0.49), 0); + EXPECT_EQ(mediapipe::MathUtil::Round(-1.49), -1); + + // Either adjacent integer is an acceptable result. + EXPECT_EQ(fabs(mediapipe::MathUtil::Round(0.5) - 0.5), 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::Round(1.5) - 1.5), 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::Round(-0.5) + 0.5), 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::Round(-1.5) + 1.5), 0.5); + + EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(0x76543210)), + 0x76543210); + + // A double-precision number has a 53-bit mantissa (52 fraction bits), + // so the following value can be represented exactly. + int64 value64 = GG_ULONGLONG(0x1234567890abcd00); + EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(value64)), + value64); +} + +template +F NextAfter(F x, F y); + +template <> +float NextAfter(float x, float y) { + return nextafterf(x, y); +} + +template <> +double NextAfter(double x, double y) { + return nextafter(x, y); +} + +template +class SafeCastTester { + public: + static void Run() { + const IntOut imax = std::numeric_limits::max(); + EXPECT_GT(imax, 0); + const IntOut imin = std::numeric_limits::min(); + const bool s = std::numeric_limits::is_signed; + if (s) { + EXPECT_LT(imin, 0); + } else { + EXPECT_EQ(0, imin); + } + + // Some basic tests. + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(0.0)), + 0); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(-0.0)), + 0); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(0.99)), + 0); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(1.0)), + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(1.01)), + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(1.99)), + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(2.0)), + 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(2.01)), + 2); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-0.99)), 0); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(-1.0)), + s ? -1 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-1.01)), + s ? -1 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-1.99)), + s ? -1 : 0); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(-2.0)), + s ? -2 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-2.01)), + s ? -2 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(117.9)), + 117); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(118.0)), + 118); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(118.1)), + 118); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-117.9)), + s ? -117 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-118.0)), + s ? -118 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-118.1)), + s ? -118 : 0); + + // Some edge cases. + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + std::numeric_limits::max()), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + -std::numeric_limits::max()), + imin); + const FloatIn inf_val = std::numeric_limits::infinity(); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(inf_val), imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(-inf_val), imin); + const FloatIn nan_val = inf_val - inf_val; + EXPECT_TRUE(std::isnan(nan_val)); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(nan_val), 0); + + // Some larger numbers. + if (sizeof(IntOut) >= 32) { + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(0x76543210)), + 0x76543210); + } + + if (sizeof(FloatIn) >= 64) { + // A double-precision number has a 53-bit mantissa (52 fraction bits), + // so the following value can be represented exactly by a double. + int64 value64 = GG_ULONGLONG(0x1234567890abcd00); + const IntOut expected = + (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(value64)), + expected); + } + + // Check values near imin and imax + static const int kLoopCount = 10; + + { + // Values greater than or equal to imax should convert to imax + FloatIn v = static_cast(imax); + for (int i = 0; i < kLoopCount; i++) { + EXPECT_EQ(mediapipe::MathUtil::SafeCast(v), imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(v + 10000.)), + imax); + v = NextAfter(v, std::numeric_limits::max()); + } + } + + { + // Values less than or equal to imin should convert to imin + FloatIn v = static_cast(imin); + for (int i = 0; i < kLoopCount; i++) { + EXPECT_EQ(mediapipe::MathUtil::SafeCast(v), imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(v - 10000.)), + imin); + v = NextAfter(v, -std::numeric_limits::max()); + } + } + + { + // Values slightly less than imax which can be exactly represented as a + // FloatIn should convert exactly to themselves. + IntOut v = imax; + for (int i = 0; i < kLoopCount; i++) { + v = std::min(v - 1, + NextAfter(static_cast(v), + -std::numeric_limits::max())); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(v)), v); + } + } + + { + // Values slightly greater than imin which can be exactly represented as a + // FloatIn should convert exactly to themselves. + IntOut v = imin; + for (int i = 0; i < kLoopCount; i++) { + v = std::max(v + 1, + NextAfter(static_cast(v), + std::numeric_limits::max())); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(v)), v); + } + } + + // When FloatIn is wider than IntOut, we can test that fractional conversion + // near imax works as expected. + if (sizeof(FloatIn) > sizeof(IntOut)) { + { + // Values slightly less than imax should convert to imax - 1 + FloatIn v = static_cast(imax); + for (int i = 0; i < kLoopCount; i++) { + v = NextAfter(static_cast(v), + -std::numeric_limits::max()); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(v)), + imax - 1); + } + } + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) + 0.1)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) + 0.99)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) + 1.0)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) + 1.99)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) + 2.0)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 0.1)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 0.99)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 1.0)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 1.01)), + imax - 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 1.99)), + imax - 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 2.0)), + imax - 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imax) - 2.01)), + imax - 3); + } + // When FloatIn is wider than IntOut, and IntOut is signed, we can test + // that fractional conversion near imin works as expected. + if (s && (sizeof(FloatIn) > sizeof(IntOut))) { + { + // Values just over imin should convert to imin + 1 + FloatIn v = static_cast(imin); + for (int i = 0; i < kLoopCount; i++) { + v = NextAfter(static_cast(v), + std::numeric_limits::max()); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(v)), + imin + 1); + } + } + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) - 0.1)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) - 0.99)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) - 1.0)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) - 0.99)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) - 2.0)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 0.1)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 0.99)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 1.0)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 1.01)), + imin + 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 1.99)), + imin + 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 2.0)), + imin + 2); + EXPECT_EQ(mediapipe::MathUtil::SafeCast( + static_cast(static_cast(imin) + 2.01)), + imin + 3); + } + } +}; + +TEST(MathUtil, SafeCast) { + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + + // Spot-check SafeCast + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(12345.678)), + 12345); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(12345.4321)), + 12345); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(-12345.678)), + -12345); + EXPECT_EQ( + mediapipe::MathUtil::SafeCast(static_cast(-12345.4321)), + -12345); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(1E47), 2147483647); + EXPECT_EQ(mediapipe::MathUtil::SafeCast(-1E47), + GG_LONGLONG(-2147483648)); +} + +template +class SafeRoundTester { + public: + static void Run() { + const IntOut imax = std::numeric_limits::max(); + EXPECT_GT(imax, 0); + const IntOut imin = std::numeric_limits::min(); + const bool s = std::numeric_limits::is_signed; + if (s) { + EXPECT_LT(imin, 0); + } else { + EXPECT_EQ(0, imin); + } + + // Some basic tests. + EXPECT_EQ(mediapipe::MathUtil::SafeRound(static_cast(0.0)), + 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-0.0)), 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(0.49)), 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(0.51)), 1); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(1.49)), 1); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(1.51)), 2); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-0.49)), 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-0.51)), + s ? -1 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-1.49)), + s ? -1 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-1.51)), + s ? -2 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(117.4)), + 117); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(117.6)), + 118); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-117.4)), + s ? -117 : 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-117.6)), + s ? -118 : 0); + + // At the midpoint between ints, either adjacent int is an acceptable + // result. + EXPECT_EQ( + fabs(mediapipe::MathUtil::SafeRound(static_cast(0.5)) - + 0.5), + 0.5); + EXPECT_EQ( + fabs(mediapipe::MathUtil::SafeRound(static_cast(1.5)) - + 1.5), + 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::SafeRound( + static_cast(117.5)) - + 117.5), + 0.5); + if (s) { + EXPECT_EQ(fabs(mediapipe::MathUtil::SafeRound( + static_cast(-0.5)) + + 0.5), + 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::SafeRound( + static_cast(-1.5)) + + 1.5), + 0.5); + EXPECT_EQ(fabs(mediapipe::MathUtil::SafeRound( + static_cast(-117.5)) + + 117.5), + 0.5); + } else { + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-0.5)), + 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-1.5)), + 0); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-117.5)), + 0); + } + + // Some edge cases. + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + std::numeric_limits::max()), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + -std::numeric_limits::max()), + imin); + const FloatIn inf_val = std::numeric_limits::infinity(); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(inf_val), imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(-inf_val), imin); + const FloatIn nan_val = inf_val - inf_val; + EXPECT_TRUE(std::isnan(nan_val)); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(nan_val), 0); + + // Some larger numbers. + if (sizeof(IntOut) >= 32) { + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(0x76543210)), + 0x76543210); + } + + if (sizeof(FloatIn) >= 64) { + // A double-precision number has a 53-bit mantissa (52 fraction bits), + // so the following value can be represented exactly by a double. + int64 value64 = GG_ULONGLONG(0x1234567890abcd00); + const IntOut expected = + (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(value64)), + expected); + } + + // Check values near imin and imax + static const int kLoopCount = 10; + + { + // Values greater than or equal to imax should round to imax + FloatIn v = static_cast(imax); + for (int i = 0; i < kLoopCount; i++) { + EXPECT_EQ(mediapipe::MathUtil::SafeRound(v), imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(v + 10000.)), + imax); + v = NextAfter(v, std::numeric_limits::max()); + } + } + + { + // Values less than or equal to imin should round to imin + FloatIn v = static_cast(imin); + for (int i = 0; i < kLoopCount; i++) { + EXPECT_EQ(mediapipe::MathUtil::SafeRound(v), imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(v - 10000.)), + imin); + v = NextAfter(v, -std::numeric_limits::max()); + } + } + + { + // Values slightly less than imax which can be exactly represented as a + // FloatIn should round exactly to themselves. + IntOut v = imax; + for (int i = 0; i < kLoopCount; i++) { + v = std::min(v - 1, + NextAfter(static_cast(v), + -std::numeric_limits::max())); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(v)), v); + } + } + + { + // Values slightly greater than imin which can be exactly represented as a + // FloatIn should round exactly to themselves. + IntOut v = imin; + for (int i = 0; i < kLoopCount; i++) { + v = std::max(v + 1, + NextAfter(static_cast(v), + std::numeric_limits::max())); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(v)), v); + } + } + + // When FloatIn is wider than IntOut, we can test that fractional rounding + // near imax works as expected. + if (sizeof(FloatIn) > sizeof(IntOut)) { + { + // Values slightly less than imax should round to imax + FloatIn v = static_cast(imax); + for (int i = 0; i < kLoopCount; i++) { + v = NextAfter(static_cast(v), + -std::numeric_limits::max()); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(v)), + imax); + } + } + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) + 0.1)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) + 0.49)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) + 0.5)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) + 0.51)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) + 0.99)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 0.1)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 0.49)), + imax); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 0.51)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 0.99)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 1.49)), + imax - 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imax) - 1.51)), + imax - 2); + } + // When FloatIn is wider than IntOut, or if IntOut is unsigned, we can test + // that fractional rounding near imin works as expected. + if (!s || (sizeof(FloatIn) > sizeof(IntOut))) { + { + // Values slightly greater than imin should round to imin + FloatIn v = static_cast(imin); + for (int i = 0; i < kLoopCount; i++) { + v = NextAfter(static_cast(v), + std::numeric_limits::max()); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(v)), + imin); + } + } + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) - 0.1)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) - 0.49)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) - 0.5)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) - 0.51)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) - 0.99)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 0.1)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 0.49)), + imin); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 0.51)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 0.99)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 1.49)), + imin + 1); + EXPECT_EQ(mediapipe::MathUtil::SafeRound( + static_cast(static_cast(imin) + 1.51)), + imin + 2); + } + } +}; + +TEST(MathUtil, SafeRound) { + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + + // Spot-check SafeRound + EXPECT_EQ(mediapipe::MathUtil::SafeRound(static_cast(12345.678)), + 12346); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(static_cast(12345.4321)), + 12345); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-12345.678)), + -12346); + EXPECT_EQ( + mediapipe::MathUtil::SafeRound(static_cast(-12345.4321)), + -12345); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(1E47), 2147483647); + EXPECT_EQ(mediapipe::MathUtil::SafeRound(-1E47), + GG_LONGLONG(-2147483648)); +} + +} // namespace diff --git a/mediapipe/framework/deps/message_matchers.h b/mediapipe/framework/deps/message_matchers.h new file mode 100644 index 000000000..7ffcbca1d --- /dev/null +++ b/mediapipe/framework/deps/message_matchers.h @@ -0,0 +1,64 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_ +#define MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_ + +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/gmock.h" + +namespace mediapipe { + +namespace internal { +bool EqualsMessage(const proto_ns::MessageLite& m_1, + const proto_ns::MessageLite& m_2) { + std::string s_1, s_2; + m_1.SerializeToString(&s_1); + m_2.SerializeToString(&s_2); + return s_1 == s_2; +} +} // namespace internal + +template +class ProtoMatcher : public testing::MatcherInterface { + using MatchResultListener = testing::MatchResultListener; + + public: + explicit ProtoMatcher(const MessageType& message) : message_(message) {} + virtual bool MatchAndExplain(MessageType m, MatchResultListener*) const { + return internal::EqualsMessage(message_, m); + } + + virtual void DescribeTo(::std::ostream* os) const { +#if defined(MEDIAPIPE_PROTO_LITE) + *os << "Protobuf messages have identical serializations."; +#else + *os << message_.DebugString(); +#endif + } + + private: + const MessageType message_; +}; + +template +inline testing::PolymorphicMatcher> EqualsProto( + const MessageType& message) { + return testing::PolymorphicMatcher>( + ProtoMatcher(message)); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_ diff --git a/mediapipe/framework/deps/monotonic_clock.cc b/mediapipe/framework/deps/monotonic_clock.cc new file mode 100644 index 000000000..0d7a9b7da --- /dev/null +++ b/mediapipe/framework/deps/monotonic_clock.cc @@ -0,0 +1,229 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/monotonic_clock.h" + +#include "absl/base/macros.h" +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +// This state, which contains the "guts" of MonotonicClockImpl, is separate +// from the class instance so that it can be shared to implement a +// SynchronizedMonotonicClock. (The per-instance state of MonotonicClock is +// just for frills like the correction metrics and callback.) It lives in this +// private namespace so that test code can use it without exposing it to the +// world. +struct MonotonicClock::State { + // The clock whose time is being corrected. + Clock* raw_clock; + absl::Mutex lock; + // The largest time ever returned by Now(). + absl::Time max_time GUARDED_BY(lock); + explicit State(Clock* clock) + : raw_clock(clock), max_time(absl::UnixEpoch()) {} +}; + +using State = MonotonicClock::State; + +class MonotonicClockImpl : public MonotonicClock { + public: + // By default, MonotonicClockImpl owns the state_. ReleaseState(), below, + // can be used to prevent the MCI destructor from deleting a shared state_. + explicit MonotonicClockImpl(State* state) + : state_(state), + state_owned_(true), + last_raw_time_(absl::UnixEpoch()), + correction_count_(0), + max_correction_(absl::ZeroDuration()) {} + + MonotonicClockImpl(const MonotonicClockImpl&) = delete; + MonotonicClockImpl& operator=(const MonotonicClockImpl&) = delete; + + virtual ~MonotonicClockImpl() { + if (state_owned_) delete state_; + } + + // Absolve this object of responsibility for state_. + void ReleaseState() { + CHECK(state_owned_); + state_owned_ = false; + } + + // + // The Clock interface (see util/time/clock.h). + // + + // The logic in TimeNow() is based on GFS_NowMS(). + virtual absl::Time TimeNow() { + // These variables save some state from the critical section below. + absl::Time raw_time; + absl::Time local_max_time; + absl::Time local_last_raw_time; + + // As there are several early exits from this function, use absl::MutexLock. + { + absl::MutexLock m(&state_->lock); + + // Check consistency of internal data with state_. + CHECK_LE(last_raw_time_, state_->max_time) + << "non-monotonic behavior: last_raw_time_=" << last_raw_time_ + << ", max_time=" << state_->max_time; + + raw_time = state_->raw_clock->TimeNow(); + + // Normal case: time is advancing. Update state and return the raw time. + if (raw_time >= state_->max_time) { + last_raw_time_ = raw_time; + state_->max_time = raw_time; + return raw_time; + } + + // Exceptional case: Raw time is within a window of a previous backward + // jump. We do not run any callbacks or update metrics here since we + // already did that when the backward jump was detected. + if (raw_time >= last_raw_time_) { + last_raw_time_ = raw_time; + return state_->max_time; + } + + // Exceptional case: Raw time jumped backward. Remainder of function + // handles this case. + // + // First, update correction metrics. + ++correction_count_; + absl::Duration delta = state_->max_time - raw_time; + CHECK_LT(absl::ZeroDuration(), delta); + if (delta > max_correction_) { + max_correction_ = delta; + } + + // Copy state into local vars before updating last_raw_time_ and leaving + // the critical section. + local_max_time = state_->max_time; + local_last_raw_time = last_raw_time_; + last_raw_time_ = raw_time; + } // absl::MutexLock + + // Return the saved maximum time. + return local_max_time; + } + + // The strategy of Sleep and SleepUntil is K.I.S.S.: set an alarm on the + // raw_clock for the desired wakeup_time, and then snooze the alarm if we wake + // up too soon. This guarantees that the caller won't wake up too soon (which + // would require us to advance monotonic time simply by the act of waking up), + // however the caller may sleep for much longer (in monotonic time) if + // monotonic time jumps far into the future. Whether or not this happens + // depends on the behavior of the raw clock. + virtual void Sleep(absl::Duration d) { + absl::Time wakeup_time = TimeNow() + d; + SleepUntil(wakeup_time); + } + + virtual void SleepUntil(absl::Time wakeup_time) { + while (TimeNow() < wakeup_time) { + state_->raw_clock->SleepUntil(wakeup_time); + } + } + + // + // End of Clock interface. + // + + private: + // Get metrics about time corrections. + virtual void GetCorrectionMetrics(int* correction_count, + double* max_correction) { + absl::MutexLock l(&state_->lock); + if (correction_count != nullptr) *correction_count = correction_count_; + if (max_correction != nullptr) + *max_correction = absl::FDivDuration(max_correction_, absl::Seconds(1)); + } + + // Reset values returned by GetCorrectionMetrics(). + virtual void ResetCorrectionMetrics() { + absl::MutexLock l(&state_->lock); + correction_count_ = 0; + max_correction_ = absl::ZeroDuration(); + } + + // The guts of the monotonic clock. Caution: this may point to a static + // object. + State* state_; + // If true, this object owns state_ and is responsible for deallocating it. + bool state_owned_; + + // last_raw_time_ remembers the last value obtained from raw_clock_. + // It prevents spurious calls to ReportCorrection when time moves + // forward by a smaller amount than a prior backward jump. + absl::Time last_raw_time_ GUARDED_BY(state_->lock); + + // Variables that keep track of time corrections made by this instance of + // MonotonicClock. (All such metrics are instance-local for reasons + // described earlier.) + int correction_count_ GUARDED_BY(state_->lock); + absl::Duration max_correction_ GUARDED_BY(state_->lock); +}; + +// Factory methods. +MonotonicClock* MonotonicClock::CreateMonotonicClock(Clock* clock) { + State* state = new State(clock); + // MonotonicClockImpl takes ownership of state. + return new MonotonicClockImpl(state); +} + +namespace { +State* GlobalSyncState() { + static State* sync_state = new State(Clock::RealClock()); + return sync_state; +} +} // namespace + +// The reason that SynchronizedMonotonicClock is not implemented as a singleton +// is so that different code bases can handle clock corrections their own way. +MonotonicClock* MonotonicClock::CreateSynchronizedMonotonicClock() { + MonotonicClockImpl* clock = new MonotonicClockImpl(GlobalSyncState()); + // Release ownership of sync_state. + clock->ReleaseState(); + return clock; +} + +// Test access methods. +void MonotonicClockAccess::SynchronizedMonotonicClockReset() { + LOG(INFO) << "Resetting SynchronizedMonotonicClock"; + State* sync_state = GlobalSyncState(); + absl::MutexLock m(&sync_state->lock); + sync_state->max_time = absl::UnixEpoch(); +} + +State* MonotonicClockAccess::CreateMonotonicClockState(Clock* raw_clock) { + return new State(raw_clock); +} + +void MonotonicClockAccess::DeleteMonotonicClockState(State* state) { + delete state; +} + +MonotonicClock* MonotonicClockAccess::CreateMonotonicClock(State* state) { + MonotonicClockImpl* clock = new MonotonicClockImpl(state); + // Release ownership of sync_state. + clock->ReleaseState(); + return clock; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/monotonic_clock.h b/mediapipe/framework/deps/monotonic_clock.h new file mode 100644 index 000000000..586a75cd1 --- /dev/null +++ b/mediapipe/framework/deps/monotonic_clock.h @@ -0,0 +1,100 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_MONOTONIC_CLOCK_H_ +#define MEDIAPIPE_DEPS_MONOTONIC_CLOCK_H_ + +#include "absl/time/time.h" +#include "mediapipe/framework/deps/clock.h" + +namespace mediapipe { + +// MonotonicClock is an interface for a Clock that never goes backward. +// Successive returned values from Now() are guaranteed to be monotonically +// non-decreasing, although they may not be monotonic with respect to values +// returned from other instances of MonotonicClock. +// +// You can wrap any Clock object in a MonotonicClock using the +// CreateMonotonicClock() factory method, including Clock::RealClock(). +// However, if you want a monotonic version of real time, it is strongly +// recommended that you use the CreateSynchronizedMonotonicClock() factory +// method, which wraps Clock::RealClock() and guarantees that values returned +// from Now() are monotonic ACROSS instances of the class that are created by +// CreateSynchronizedMonotonicClock(). +// +// All methods support concurrent access. +class MonotonicClock : public Clock { + public: + // The MonotonicClock state, which may be shared between MonotonicClocks. + struct State; + + ~MonotonicClock() override {} + + // The Clock interface (see util/time/clock.h). + // + // Return a monotonically non-decreasing time. + absl::Time TimeNow() override = 0; + // Sleep and SleepUntil guarantee only that the caller will sleep for at + // least as long as specified in monotonic time. The caller may sleep for + // much longer (in monotonic time) if monotonic time jumps far into the + // future. Whether or not this happens depends on the behavior of the raw + // clock. + void Sleep(absl::Duration d) override = 0; + void SleepUntil(absl::Time wakeup_time) override = 0; + + // Get metrics about time corrections. + virtual void GetCorrectionMetrics(int* correction_count, + double* max_correction) = 0; + // Reset values returned by GetCorrectionMetrics(). + virtual void ResetCorrectionMetrics() = 0; + + // Factory methods. + // + // Create a MonotonicClock based on the given raw_clock. This clock will + // return monotonically non-decreasing values from Now(), but may not behave + // monotonically with respect to other instances created by this function, + // even if they are based on the same raw_clock. Caller owns raw_clock. + static MonotonicClock* CreateMonotonicClock(Clock* raw_clock); + + // Create an instance of MonotonicClock that is based on Clock::RealClock(). + // All such instance are synced with each other such that return values from + // Now() are monotonic across instances. This allows independently developed + // code bases to have private instances of the synchronized MonotonicClock + // and know that they will never see time anomalies when calling from one + // code base to another. Each instance can have its own correction callback. + // Unlike Clock::RealClock(), caller owns this object and should delete it + // when no longer needed. + static MonotonicClock* CreateSynchronizedMonotonicClock(); +}; + +class MonotonicClockTest; + +// Provides access to MonotonicClock::State for unit-testing. +class MonotonicClockAccess { + private: + using State = MonotonicClock::State; + + // Reset internal global state. Should only be called by test code. + static void SynchronizedMonotonicClockReset(); + static State* CreateMonotonicClockState(Clock* raw_clock); + static void DeleteMonotonicClockState(State* state); + // Create a monotonic clock based on the given state. Caller owns state + // so that multiple such clocks can be created from the same state. + static MonotonicClock* CreateMonotonicClock(State* state); + friend class ::mediapipe::MonotonicClockTest; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_MONOTONIC_CLOCK_H_ diff --git a/mediapipe/framework/deps/monotonic_clock_test.cc b/mediapipe/framework/deps/monotonic_clock_test.cc new file mode 100644 index 000000000..ebd081057 --- /dev/null +++ b/mediapipe/framework/deps/monotonic_clock_test.cc @@ -0,0 +1,539 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/monotonic_clock.h" + +#include + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/threadpool.h" +#include "mediapipe/framework/tool/simulation_clock.h" + +namespace mediapipe { + +using RandomEngine = std::mt19937_64; +using State = MonotonicClock::State; + +// absl::Now() recomputes clock drift approx. every 2 seconds, so run real +// clock tests for at least that long. +static const absl::Duration kDefaultRealTest = absl::Seconds(2.5); + +class MonotonicClockTest : public testing::Test { + protected: + MonotonicClockTest() {} + virtual ~MonotonicClockTest() {} + + void SetUp() override { + MonotonicClockAccess::SynchronizedMonotonicClockReset(); + } + + void VerifyCorrectionMetrics(MonotonicClock* clock, + int num_corrections_expect, + double max_correction_expect) { + int clock_num_corrections; + double clock_max_correction; + clock->GetCorrectionMetrics(&clock_num_corrections, &clock_max_correction); + ASSERT_EQ(num_corrections_expect, clock_num_corrections); + ASSERT_DOUBLE_EQ(max_correction_expect, clock_max_correction); + } + + // This test produces no time corrections. + void TestSimulatedForwardTime(SimulationClock* sim_clock, + MonotonicClock* mono_clock) { + absl::Time base_time = sim_clock->TimeNow(); + ASSERT_EQ(base_time, mono_clock->TimeNow()); + sim_clock->Sleep(absl::Seconds(10)); + ASSERT_EQ(base_time + absl::Seconds(10), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(10), mono_clock->TimeNow()); + sim_clock->Sleep(absl::Seconds(10)); + ASSERT_EQ(base_time + absl::Seconds(20), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(20), mono_clock->TimeNow()); + sim_clock->Sleep(absl::Seconds(5)); + ASSERT_EQ(base_time + absl::Seconds(25), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(25), mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 0, 0.0); + } + + // This test produces three corrections: one with arguments + // (50, 100, 100), one with (80, 90, 100), and one with (60, 105, 105). + void TestSimulatedBackwardTime(SimulationClock* sim_clock, + MonotonicClock* mono_clock) { + absl::Time base_time = sim_clock->TimeNow(); + sim_clock->Sleep(absl::Seconds(100)); + ASSERT_EQ(base_time + absl::Seconds(100), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(100), mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 0, 0.0); + // Time moves backward -- expect a correction. + sim_clock->Sleep(absl::Seconds(-50)); + ASSERT_EQ(base_time + absl::Seconds(50), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(100), // correction + mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 1, 50.0); + // Time moves forward, but not enough to exceed the last value returned by + // TimeNow(). No correction in this case. + sim_clock->Sleep(absl::Seconds(20)); + ASSERT_EQ(base_time + absl::Seconds(70), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(100), mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 1, 50.0); + sim_clock->Sleep(absl::Seconds(20)); + ASSERT_EQ(base_time + absl::Seconds(90), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(100), mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 1, 50.0); + // Time moves backwards again -- expect a correction. + sim_clock->Sleep(absl::Seconds(-10)); + ASSERT_EQ(base_time + absl::Seconds(80), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(100), // correction + mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 2, 50.0); + // Time moves forward enough to advance monotonic time. + sim_clock->Sleep(absl::Seconds(25)); + ASSERT_EQ(base_time + absl::Seconds(105), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(105), mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 2, 50.0); + // Time moves backward again. + sim_clock->Sleep(absl::Seconds(-45)); + ASSERT_EQ(base_time + absl::Seconds(60), sim_clock->TimeNow()); + ASSERT_EQ(base_time + absl::Seconds(105), // correction + mono_clock->TimeNow()); + VerifyCorrectionMetrics(mono_clock, 3, 50.0); + + // Reset metrics and re-verify. + mono_clock->ResetCorrectionMetrics(); + VerifyCorrectionMetrics(mono_clock, 0, 0.0); + } + + // Test that the Sleep/SleepUntil calls do not return until monotonic time + // passes the requested wakeup time. + void TestRandomSleep(MonotonicClock* mono_clock) { + RandomEngine random(testing::UnitTest::GetInstance()->random_seed()); + const int kNumSamples = 5; + + // Sleep. + for (int i = 0; i < kNumSamples; i++) { + absl::Duration sleep_time = absl::Seconds( + std::uniform_real_distribution(0.0f, 0.2f)(random)); + absl::Time before = mono_clock->TimeNow(); + absl::Time wakeup_time = before + sleep_time; + mono_clock->Sleep(sleep_time); + absl::Time after = mono_clock->TimeNow(); + ASSERT_LE(wakeup_time, after); + } + + // SleepUntil. + for (int i = 0; i < kNumSamples; i++) { + absl::Duration sleep_time = absl::Seconds( + std::uniform_real_distribution(0.0f, 0.2f)(random)); + absl::Time before = mono_clock->TimeNow(); + absl::Time wakeup_time = before + sleep_time; + mono_clock->SleepUntil(wakeup_time); + absl::Time after = mono_clock->TimeNow(); + ASSERT_LE(wakeup_time, after); + } + } + + static State* CreateMonotonicClockState(Clock* raw_clock) { + return MonotonicClockAccess::CreateMonotonicClockState(raw_clock); + } + + static MonotonicClock* CreateMonotonicClock(State* state) { + return MonotonicClockAccess::CreateMonotonicClock(state); + } + + static void DeleteMonotonicClockState(State* state) { + MonotonicClockAccess::DeleteMonotonicClockState(state); + } +}; + +// Time moves forward only -- there should be no time corrections. +TEST_F(MonotonicClockTest, SimulatedForwardTime) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + TestSimulatedForwardTime(&sim_clock, mono_clock); + sim_clock.ThreadFinish(); + delete mono_clock; +} + +// Time moves forward and backward. +TEST_F(MonotonicClockTest, SimulatedBackwardTime) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + TestSimulatedBackwardTime(&sim_clock, mono_clock); + sim_clock.ThreadFinish(); + delete mono_clock; +} + +// Time moves forward and backward. +TEST_F(MonotonicClockTest, SimulatedTime) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + TestSimulatedBackwardTime(&sim_clock, mono_clock); + absl::Time mono_time = mono_clock->TimeNow(); + sim_clock.Sleep(absl::Seconds(-1)); + ASSERT_EQ(mono_time, mono_clock->TimeNow()); + sim_clock.ThreadFinish(); + delete mono_clock; +} + +// Take a random walk through time. +TEST_F(MonotonicClockTest, SimulatedRandomWalk) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + sim_clock.Sleep(absl::Now() - sim_clock.TimeNow()); + ASSERT_EQ(sim_clock.TimeNow(), mono_clock->TimeNow()); + + // Generate kNumSamples random clock adjustments. + const int kNumSamples = 5; + RandomEngine random(testing::UnitTest::GetInstance()->random_seed()); + // Keep track of maximum time on clock and corrections. + absl::Time max_time = sim_clock.TimeNow(); + int num_corrections = 0; + absl::Duration max_correction = absl::ZeroDuration(); + for (int i = 0; i < kNumSamples; i++) { + absl::Duration jump = + absl::Seconds(std::uniform_real_distribution(-0.5, 0.5)(random)); + sim_clock.Sleep(jump); + absl::Time sim_time = sim_clock.TimeNow(); + if (jump < absl::ZeroDuration()) { + ASSERT_LT(sim_time, max_time); + absl::Duration correction = max_time - sim_time; + if (correction > max_correction) { + max_correction = correction; + } + ++num_corrections; + } + if (sim_clock.TimeNow() > max_time) { + max_time = sim_clock.TimeNow(); + } + ASSERT_EQ(max_time, mono_clock->TimeNow()); + } + VerifyCorrectionMetrics(mono_clock, num_corrections, + absl::FDivDuration(max_correction, absl::Seconds(1))); + sim_clock.ThreadFinish(); + delete mono_clock; +} + +TEST_F(MonotonicClockTest, RealTime) { + MonotonicClock* mono_clock = + MonotonicClock::CreateMonotonicClock(Clock::RealClock()); + // Call mono_clock->Now() continuously for FLAGS_real_test_secs seconds. + absl::Time start = absl::Now(); + absl::Time time = start; + int64 num_calls = 0; + do { + absl::Time last_time = time; + time = mono_clock->TimeNow(); + ASSERT_LE(last_time, time); + ++num_calls; + } while (time - start < kDefaultRealTest); + // Just out of curiousity -- did real clock go backwards? + int clock_num_corrections; + mono_clock->GetCorrectionMetrics(&clock_num_corrections, NULL); + LOG(INFO) << clock_num_corrections << " corrections in " << num_calls + << " calls to mono_clock->Now()"; + delete mono_clock; +} + +// Test the Sleep interface using a MonotonicClock. +TEST_F(MonotonicClockTest, RandomSleep) { + MonotonicClock* mono_clock = + MonotonicClock::CreateMonotonicClock(Clock::RealClock()); + TestRandomSleep(mono_clock); + delete mono_clock; +} + +// Test the Sleep interface using a SynchronizedMonotonicClock. +TEST_F(MonotonicClockTest, RandomSleepSynced) { + MonotonicClock* mono_clock = + MonotonicClock::CreateSynchronizedMonotonicClock(); + TestRandomSleep(mono_clock); + delete mono_clock; +} + +// Test that SleepUntil has no effect if monotonic time has passed the +// requested wakeup time. +TEST_F(MonotonicClockTest, SimulatedInsomnia) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + sim_clock.Sleep(absl::Now() - sim_clock.TimeNow()); + ASSERT_EQ(sim_clock.TimeNow(), mono_clock->TimeNow()); + + sim_clock.Sleep(absl::Seconds(-3.14159)); + // Even though sim_clock will never advance, this call will not sleep + // because monotonic_time has already advanced beyond the wakeup time. + mono_clock->SleepUntil(sim_clock.TimeNow() + absl::Seconds(1)); + // Note that the same test can't be performed with Sleep because the argument + // to sleep is an offset from monotonic time, not raw time. + sim_clock.ThreadFinish(); + delete mono_clock; +} + +// Two monotonic clocks, clock1 and clock2, each synced to the same +// raw clock. Advance simulated time, read one clock, regress simulated +// time, and read the other clock. The values should be the same. +TEST_F(MonotonicClockTest, SyncedPair) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + State* state = CreateMonotonicClockState(&sim_clock); + MonotonicClock* clock1 = CreateMonotonicClock(state); + MonotonicClock* clock2 = CreateMonotonicClock(state); + sim_clock.Sleep(absl::Seconds(1000)); + ASSERT_EQ(sim_clock.TimeNow(), clock1->TimeNow()); + ASSERT_EQ(sim_clock.TimeNow(), clock2->TimeNow()); + + absl::Time time1, time2; + sim_clock.Sleep(absl::Seconds(2)); + time1 = clock1->TimeNow(); + ASSERT_EQ(sim_clock.TimeNow(), time1); + sim_clock.Sleep(absl::Seconds(-5)); + time2 = clock2->TimeNow(); + ASSERT_EQ(time1, time2); + VerifyCorrectionMetrics(clock1, 0, 0.0); + VerifyCorrectionMetrics(clock2, 1, 5.0); + + clock1->ResetCorrectionMetrics(); + clock2->ResetCorrectionMetrics(); + VerifyCorrectionMetrics(clock1, 0, 0.0); + VerifyCorrectionMetrics(clock2, 0, 0.0); + + // In this example, time on clock1 goes forward by a greater amount than + // time goes backward on clock2. Although clock2 still reports the global + // monotonic time, it does not report a correction because it never + // observed a raw clock reading that went backward. + sim_clock.Sleep(absl::Seconds(10)); + time1 = clock1->TimeNow(); + ASSERT_EQ(sim_clock.TimeNow(), time1); + sim_clock.Sleep(absl::Seconds(-1)); + time2 = clock2->TimeNow(); + ASSERT_EQ(time1, time2); + VerifyCorrectionMetrics(clock1, 0, 0.0); + VerifyCorrectionMetrics(clock2, 0, 0.0); + + sim_clock.ThreadFinish(); + delete clock1; + delete clock2; + DeleteMonotonicClockState(state); +} + +// Test that a globally-synchronized MonotonicClock is unaffected by clock +// behavior of a vanilla MonotonicClock. +TEST_F(MonotonicClockTest, UnsyncedPair) { + SimulationClock sim_clock; + sim_clock.ThreadStart(); + MonotonicClock* sync_clock = + MonotonicClock::CreateSynchronizedMonotonicClock(); + MonotonicClock* mono_clock = MonotonicClock::CreateMonotonicClock(&sim_clock); + absl::Time before = sync_clock->TimeNow(); + sim_clock.Sleep(before - sim_clock.TimeNow()); + ASSERT_EQ(before, mono_clock->TimeNow()); + sim_clock.Sleep(absl::Seconds(61)); + ASSERT_LT(sync_clock->TimeNow(), mono_clock->TimeNow()); + sim_clock.ThreadFinish(); + delete sync_clock; + delete mono_clock; +} + +// The factory method CreateSynchronizedMonotonicClock should return a +// MonotonicClock based on real time. Since time waits for no unit test, +// we can't test equality of the time read from the factory-produced clock +// and the time read from a real clock. But we can verifying that, as long +// as the real clock moves forward, the time read from the factory-produced +// clock is bounded by consecutive readings of the real clock. +TEST_F(MonotonicClockTest, CreateSynchronizedMonotonicClock) { + Clock* real_clock = Clock::RealClock(); + MonotonicClock* mono_clock = + MonotonicClock::CreateSynchronizedMonotonicClock(); + const int kNumSamples = 100; + for (int i = 0; i < kNumSamples; ++i) { + absl::Time before = real_clock->TimeNow(); + absl::Time now = mono_clock->TimeNow(); + absl::Time after = real_clock->TimeNow(); + if (after < before) { + // Real clock moved backward -- test is invalid. + continue; + } + ASSERT_LE(before, now); + ASSERT_LE(now, after); + } + delete mono_clock; +} + +// Start up a number of threads to beat on the interface to verify that +// (a) nothing crashes and (b) nothing deadlocks. +class ClockFrenzy { + public: + ClockFrenzy() + : real_clock_(Clock::RealClock()), + random_( + new RandomEngine(testing::UnitTest::GetInstance()->random_seed())) { + } + + void AddSimulationClock(SimulationClock* clock) { + sim_clocks_.push_back(clock); + } + + void AddMonotonicClock(MonotonicClock* clock) { + mono_clocks_.push_back(clock); + } + + void Feed() { + while (Running()) { + // 40% of the time, advance a simulated clock. + // 50% of the time, read a monotonic clock. + const int32 u = UniformRandom(100); + if (u < 40) { + // Pick a simulated clock and advance it. + const int nclocks = sim_clocks_.size(); + if (nclocks == 0) continue; + SimulationClock* sim_clock = sim_clocks_[UniformRandom(nclocks)]; + // Bias the clock towards forward movement. + sim_clock->Sleep(absl::Seconds(RndFloatRandom() - 0.2)); + } else if (u < 90) { + // Pick a monotonic clock and read it. + const int nclocks = mono_clocks_.size(); + if (nclocks == 0) continue; + MonotonicClock* mono_clock = mono_clocks_[UniformRandom(nclocks)]; + mono_clock->TimeNow(); + } + } + } + + // Start Feed-ing threads. + void Start(int nthreads) { + absl::MutexLock l(&lock_); + running_ = true; + threads_ = absl::make_unique<::mediapipe::ThreadPool>("Frenzy", nthreads); + threads_->StartWorkers(); + for (int i = 0; i < nthreads; ++i) { + threads_->Schedule([&]() { Feed(); }); + } + } + + void Stop() { + absl::MutexLock l(&lock_); + running_ = false; + } + + bool Running() { + absl::MutexLock l(&lock_); + return running_; + } + + // Wait for all threads to finish. + void Wait() { threads_.reset(); } + + private: + Clock* real_clock_; + std::vector sim_clocks_; + std::vector mono_clocks_; + std::unique_ptr<::mediapipe::ThreadPool> threads_; + + // Provide a lock to avoid race conditions in non-threadsafe ACMRandom. + mutable absl::Mutex lock_; + std::unique_ptr random_ GUARDED_BY(lock_); + + // The stopping notification. + bool running_; + + // Thread-safe random number generation functions for use by other class + // member functions. + int32 UniformRandom(int32 n) { + absl::MutexLock l(&lock_); + return std::uniform_int_distribution(0, n - 1)(*random_); + } + + float RndFloatRandom() { + absl::MutexLock l(&lock_); + return std::uniform_real_distribution(0.0f, 1.0f)(*random_); + } +}; + +TEST_F(MonotonicClockTest, SimulatedFrenzy) { + ClockFrenzy f; + SimulationClock s1, s2; + s1.ThreadStart(); + s2.ThreadStart(); + f.AddSimulationClock(&s1); + f.AddSimulationClock(&s2); + MonotonicClock* m11 = MonotonicClock::CreateMonotonicClock(&s1); + State* state = CreateMonotonicClockState(&s1); + MonotonicClock* m12 = CreateMonotonicClock(state); + MonotonicClock* m13 = CreateMonotonicClock(state); + MonotonicClock* m21 = MonotonicClock::CreateMonotonicClock(&s2); + MonotonicClock* m22 = MonotonicClock::CreateMonotonicClock(&s2); + f.AddMonotonicClock(m11); + f.AddMonotonicClock(m12); + f.AddMonotonicClock(m13); + f.AddMonotonicClock(m21); + f.AddMonotonicClock(m22); + f.Start(10); + Clock::RealClock()->Sleep(absl::Seconds(1)); + f.Stop(); + f.Wait(); + s2.ThreadFinish(); + s1.ThreadFinish(); + delete m11; + delete m12; + delete m13; + delete m21; + delete m22; + DeleteMonotonicClockState(state); +} + +// Just for completeness, a frenzy with only real-time +// SynchronizedMonotonicClock instances. +TEST_F(MonotonicClockTest, RealFrenzy) { + ClockFrenzy f; + MonotonicClock* m1 = MonotonicClock::CreateSynchronizedMonotonicClock(); + MonotonicClock* m2 = MonotonicClock::CreateSynchronizedMonotonicClock(); + MonotonicClock* m3 = MonotonicClock::CreateSynchronizedMonotonicClock(); + f.AddMonotonicClock(m1); + f.AddMonotonicClock(m2); + f.AddMonotonicClock(m3); + f.Start(10); + Clock::RealClock()->Sleep(kDefaultRealTest); + f.Stop(); + f.Wait(); + // Just out of curiousity -- did real clock go backwards? + int clock_num_corrections; + m1->GetCorrectionMetrics(&clock_num_corrections, NULL); + LOG_IF(INFO, clock_num_corrections > 0) + << clock_num_corrections << " corrections"; + m2->GetCorrectionMetrics(&clock_num_corrections, NULL); + LOG_IF(INFO, clock_num_corrections > 0) + << clock_num_corrections << " corrections"; + m3->GetCorrectionMetrics(&clock_num_corrections, NULL); + LOG_IF(INFO, clock_num_corrections > 0) + << clock_num_corrections << " corrections"; + delete m1; + delete m2; + delete m3; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/no_destructor.h b/mediapipe/framework/deps/no_destructor.h new file mode 100644 index 000000000..e617e5dc8 --- /dev/null +++ b/mediapipe/framework/deps/no_destructor.h @@ -0,0 +1,115 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_NO_DESTRUCTOR_H_ +#define MEDIAPIPE_DEPS_NO_DESTRUCTOR_H_ + +#include +#include + +namespace mediapipe { + +// NoDestructor is a wrapper around an object of type T that +// * stores the object of type T inline inside NoDestructor +// * eagerly forwards constructor arguments to it (i.e. acts like T in terms +// of construction) +// * provides access to the object of type T like a pointer via ->, *, and get() +// (note that const NoDestructor works like a pointer to const T) +// * never calls T's destructor for the object +// (hence NoDestructor objects created on the stack or as member variables +// will lead to memory and/or resource leaks) +// +// One key use case of NoDestructor (which in itself is not lazy) is optimizing +// the following pattern of safe on-demand construction of an object with +// non-trivial constructor in static storage without destruction ever happening: +// const std::string& MyString() { +// static std::string* x = new std::string("foo"); // note the "static" +// return *x; +// } +// By using NoDestructor we do not need to involve heap allocation and +// corresponding pointer following (and hence extra CPU cache usage/needs) +// on each access: +// const std::string& MyString() { +// static NoDestructor x("foo"); +// return *x; +// } +// Since C++11 this static-in-a-function pattern results in exactly-once, +// thread-safe, on-demand construction of an object, and very fast access +// thereafter (the cost is a few extra cycles). +// NoDestructor makes accesses even faster by storing the object inline in +// static storage. +// +// Note that: +// * Since destructor is never called, the object lives on during program exit +// and can be safely accessed by any threads that have not been joined. +// * This static-in-a-function NoDestructor usage pattern should be preferred +// to uses of gtl::LazyStaticPtr in new code. +// +// Also note that +// static NoDestructor ptr(whatever); +// can safely replace +// static NonPOD* ptr = new NonPOD(whatever); +// or +// static NonPOD obj(whatever); +// at file-level scope when the safe static-in-a-function pattern is infeasible +// to use for some good reason. +// All three of the NonPOD patterns above suffer from the same issue that +// initialization of that object happens non-thread-safely at +// a globally-undefined point during initialization of static-storage objects, +// but NoDestructor<> usage provides both the safety of having the object alive +// during program exit sequence and the performance of not doing extra memory +// dereference on access. +// +template +class NoDestructor { + public: + typedef T element_type; + + // Forwards arguments to the T's constructor: calls T(args...). + template ::type...), + void(NoDestructor)>::value, + int>::type = 0> + explicit NoDestructor(Ts&&... args) { + new (&space_) T(std::forward(args)...); + } + + // Forwards copy and move construction for T. Enables usage like this: + // static NoDestructor> x{{{"1", "2", "3"}}}; + // static NoDestructor> x{{1, 2, 3}}; + explicit NoDestructor(const T& x) { new (&space_) T(x); } + explicit NoDestructor(T&& x) { new (&space_) T(std::move(x)); } + + // No copying. + NoDestructor(const NoDestructor&) = delete; + NoDestructor& operator=(const NoDestructor&) = delete; + + // Pretend to be a smart pointer to T with deep constness. + // Never returns a null pointer. + T& operator*() { return *get(); } + T* operator->() { return get(); } + T* get() { return reinterpret_cast(&space_); } + const T& operator*() const { return *get(); } + const T* operator->() const { return get(); } + const T* get() const { return reinterpret_cast(&space_); } + + private: + typename std::aligned_storage::type space_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_NO_DESTRUCTOR_H_ diff --git a/mediapipe/framework/deps/numbers.h b/mediapipe/framework/deps/numbers.h new file mode 100644 index 000000000..b19055582 --- /dev/null +++ b/mediapipe/framework/deps/numbers.h @@ -0,0 +1,32 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_NUMBERS_H_ +#define MEDIAPIPE_DEPS_NUMBERS_H_ + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe { +ABSL_MUST_USE_RESULT inline std::string SimpleDtoa(double d) { + if (static_cast(static_cast(d)) == d) { + return absl::StrCat(static_cast(d)); + } else { + return absl::StrCat(d); + } +} +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_NUMBERS_H_ diff --git a/mediapipe/framework/deps/point2.h b/mediapipe/framework/deps/point2.h new file mode 100644 index 000000000..1a38f0bd4 --- /dev/null +++ b/mediapipe/framework/deps/point2.h @@ -0,0 +1,137 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Class to handle two-dimensional points. +// +// The aim of this class is to be able to do sensible geometric operations +// with points and vectors, which are distinct mathematical concepts. +// Operators +, -, =, ==, <, etc. are overloaded with the proper semantics +// (e.g. Point = Point + constant * vector or Vector = Point - Point). +// For more about Point expressions, see Goldman, Ronald N., "Illicit +// Expressions in Vector Algebra," ACM Transactions on Graphics, 4(3), +// pp. 223-243, July 1985 (http://portal.acm.org/citation.cfm?id=282969). +// +// Please be careful about overflows when using points with integer types +// The calculations are carried with the same type as the vector's components +// type, e.g. if you are using uint8 as the base type, all values will be modulo +// 256. This feature is necessary to use the class in a more general framework +// where T != plain old data type. + +#ifndef MEDIAPIPE_DEPS_POINT2_H_ +#define MEDIAPIPE_DEPS_POINT2_H_ + +#include +#include +#include + +#include "mediapipe/framework/deps/mathutil.h" +#include "mediapipe/framework/deps/vector.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" + +// Template class for 2D points +template +class Point2 { + public: + typedef T ElementType; + typedef Vector2 Coords; + + Point2() {} + Point2(const T& x, const T& y) : c_(x, y) {} + explicit Point2(const Coords& v) : c_(v) {} + + Coords ToVector() const { return c_; } + + void Set(const T& x, const T& y) { *this = Point2(x, y); } + + T* Data() { return c_.Data(); } + const T* Data() const { return c_.Data(); } + + void Clear() { *this = Point2(); } + + Point2& operator+=(const Coords& v) { + c_ += v; + return *this; + } + Point2& operator-=(const Coords& v) { + c_ -= v; + return *this; + } + + const T& operator[](std::size_t b) const { return Data()[b]; } + T& operator[](std::size_t b) { return Data()[b]; } + + const T& x() const { return (*this)[0]; } + const T& y() const { return (*this)[1]; } + void set_x(const T& x) { (*this)[0] = x; } + void set_y(const T& y) { (*this)[1] = y; } + + // Compares two points, returns true if all their components are within + // a difference of a tolerance. + bool aequal(const Point2& p, double tolerance) const { + using std::abs; + return (abs(c_[0] - p.c_[0]) <= tolerance) && + (abs(c_[1] - p.c_[1]) <= tolerance); + } + + private: + // Friend arithmetic operators. + friend Point2 operator+(const Point2& p, const Coords& v) { + return Point2(p.c_ + v); + } + friend Point2 operator+(const Coords& v, const Point2& p) { + return Point2(v + p.c_); + } + friend Point2 operator-(const Point2& p, const Coords& v) { + return Point2(p.c_ - v); + } + friend Coords operator-(const Point2& p1, const Point2& p2) { + return p1.c_ - p2.c_; + } + + // Friend relational nonmember operators. + friend bool operator==(const Point2& a, const Point2& b) { + return a.c_ == b.c_; + } + friend bool operator!=(const Point2& a, const Point2& b) { + return a.c_ != b.c_; + } + friend bool operator<(const Point2& a, const Point2& b) { + return a.c_ < b.c_; + } + friend bool operator>(const Point2& a, const Point2& b) { + return a.c_ > b.c_; + } + friend bool operator<=(const Point2& a, const Point2& b) { + return a.c_ <= b.c_; + } + friend bool operator>=(const Point2& a, const Point2& b) { + return a.c_ >= b.c_; + } + + // Streaming operator. + friend std::ostream& operator<<(std::ostream& out, const Point2& p) { + return out << "Point with coordinates: (" << p.c_[0] << ", " << p.c_[1] + << ")"; + } + + Coords c_; // coordinates +}; + +typedef Point2 Point2_b; +typedef Point2 Point2_i; +typedef Point2 Point2_f; +typedef Point2 Point2_d; + +#endif // MEDIAPIPE_DEPS_POINT2_H_ diff --git a/mediapipe/framework/deps/proto_descriptor.proto b/mediapipe/framework/deps/proto_descriptor.proto new file mode 100644 index 000000000..77762dfd0 --- /dev/null +++ b/mediapipe/framework/deps/proto_descriptor.proto @@ -0,0 +1,40 @@ +syntax = "proto2"; + +package mediapipe; + +// Describes a field within a message. +message FieldDescriptorProto { + enum Type { + // 0 is reserved for errors. + TYPE_INVALID = 0; + // Order is weird for historical reasons. + TYPE_DOUBLE = 1; + TYPE_FLOAT = 2; + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if + // negative values are likely. + TYPE_INT64 = 3; + TYPE_UINT64 = 4; + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if + // negative values are likely. + TYPE_INT32 = 5; + TYPE_FIXED64 = 6; + TYPE_FIXED32 = 7; + TYPE_BOOL = 8; + TYPE_STRING = 9; + // Tag-delimited aggregate. + // Group type is deprecated and not supported in proto3. However, Proto3 + // implementations should still be able to parse the group wire format and + // treat group fields as unknown fields. + TYPE_GROUP = 10; + TYPE_MESSAGE = 11; // Length-delimited aggregate. + + // New in version 2. + TYPE_BYTES = 12; + TYPE_UINT32 = 13; + TYPE_ENUM = 14; + TYPE_SFIXED32 = 15; + TYPE_SFIXED64 = 16; + TYPE_SINT32 = 17; // Uses ZigZag encoding. + TYPE_SINT64 = 18; // Uses ZigZag encoding. + } +} diff --git a/mediapipe/framework/deps/random_base.h b/mediapipe/framework/deps/random_base.h new file mode 100644 index 000000000..69f6949ba --- /dev/null +++ b/mediapipe/framework/deps/random_base.h @@ -0,0 +1,27 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_RANDOM_BASE_H_ +#define MEDIAPIPE_DEPS_RANDOM_BASE_H_ + +class RandomBase { + public: + // constructors. Don't do too much. + RandomBase() {} + virtual ~RandomBase(); + + virtual float RandFloat() { return 0; } +}; + +#endif // MEDIAPIPE_DEPS_RANDOM_BASE_H_ diff --git a/mediapipe/framework/deps/rectangle.h b/mediapipe/framework/deps/rectangle.h new file mode 100644 index 000000000..9ca9d7ad1 --- /dev/null +++ b/mediapipe/framework/deps/rectangle.h @@ -0,0 +1,328 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Class for axis-aligned rectangles represented as two corner points +// (min_x, min_y) and (max_x, max_y). The methods such as Contain, Intersect +// and IsEmpty() assume that the points in region include the 4 boundary edges. +// The default box is initialized so that IsEmpty() is true. Note that the +// use of corner points supports both right-handed (Cartesian) and left- +// handed (image) coordinate systems. + +#ifndef MEDIAPIPE_DEPS_RECTANGLE_H_ +#define MEDIAPIPE_DEPS_RECTANGLE_H_ + +#include +#include +#include +#include +#include + +#include "mediapipe/framework/deps/point2.h" +#include "mediapipe/framework/port/integral_types.h" + +template +class Rectangle; + +template +std::ostream& operator<<(std::ostream&, const Rectangle&); + +template +class Rectangle { + public: + typedef Rectangle Self; + + // Default constructed rectangle which is empty. + Rectangle() { SetEmpty(); } + + // Creates a rectangle from the minimum point and the dimensions. + Rectangle(const T& x, const T& y, const T& width, const T& height); + + // Creates a rectangle given two points. The resulting rectangle will + // have non-negative width and height. + Rectangle(const Point2& p0, const Point2& p1); + + // Same as above but using vectors as input. + Rectangle(const Vector2& p0, const Vector2& p1); + + // Sets min to be very large numbers and max to be very large negative numbers + // so that points can be used to correctly extend the rectangle. + void SetEmpty(); + + // A rectangle is empty if there are no points inside of it. A degenerate + // rectangle where the corners are coincident has zero area but is not empty. + bool IsEmpty() const { return min_.x() > max_.x() || min_.y() > max_.y(); } + + bool operator==(const Rectangle&) const; + bool operator!=(const Rectangle&) const; + + // Width and height are both max - min, which may be negative if SetEmpty() + // was called or the user explicity set the min and max points. + T Width() const { return max_.x() - min_.x(); } + T Height() const { return max_.y() - min_.y(); } + + // Computes the area, which is negative if the width xor height is negative. + // The value is undefined if SetEmpty() is called. + // Watch out for large integer rectangles because the area may overflow. + T Area() const { return Width() * Height(); } + + // Accessors are provided for both points and sides. + const T& xmin() const { return min_.x(); } + const T& xmax() const { return max_.x(); } + const T& ymin() const { return min_.y(); } + const T& ymax() const { return max_.y(); } + + // Returns the min and max corner points. + const Point2& min_xy() const { return min_; } + const Point2& max_xy() const { return max_; } + + // Sets the geometry of the rectangle given two points. + // The resulting rectangle will have non-negative width and height. + void Set(const Point2& p0, const Point2& p1); + + // Same as above using vectors as input. + void Set(const Vector2& p0, const Vector2& p1); + + // Sets the geometry of the rectangle given a minimum point and dimensions. + void Set(const T& x, const T& y, const T& width, const T& height); + + // Sets the min and max values, and min greater than max is allowable, + // but the user has to be aware of the consequences such as negative width + // and height. Both point and side accessors are provided. + void set_xmin(const T& x) { min_.set_x(x); } + void set_xmax(const T& x) { max_.set_x(x); } + void set_ymin(const T& y) { min_.set_y(y); } + void set_ymax(const T& y) { max_.set_y(y); } + + void set_min_xy(const Point2& p) { min_.Set(p.x(), p.y()); } + void set_max_xy(const Point2& p) { max_.Set(p.x(), p.y()); } + + // Expands a rectangle to contain a point or vector. + void Expand(const T& x, const T& y); + void Expand(const Point2& p); + void Expand(const Vector2& p); + + // Expands a rectangle to contain another rectangle. + void Expand(const Rectangle& other); + + // Returns the union of this rectangle with another rectangle, which + // is the smallest rectangle that contains both rectangles. + Rectangle Union(const Rectangle& other) const; + + // Returns the intersection of this rectangle with another rectangle. + // If the intersection is empty, returns a rectangle initialized by + // SetEmpty(). + Rectangle Intersect(const Rectangle& other) const; + + // Tests if this rectangle has a non-empty intersection with another rectangle + // including the boundary. + bool Intersects(const Rectangle& other) const; + + // Tests if a point is inside or on any of the 4 edges of the rectangle. + bool Contains(const T& x, const T& y) const; + bool Contains(const Point2& pt) const; + bool Contains(const Vector2& pt) const; + + // Tests if a rectangle is inside or on any of the 4 edges of the rectangle. + bool Contains(const Rectangle& other) const; + + // Translates this rectangle by a vector. + void Translate(const Vector2& v); + + // Adds a border around the rectangle by subtracting the border size from the + // min point and adding it to the max point. The border size can be + // negative. + void AddBorder(const T& border_size); + + // Debug printing. + friend std::ostream& operator<<(std::ostream&, const Rectangle&); + + private: + Point2 min_; + Point2 max_; +}; + +// +// Inline method definitions. These are not placed in the definition of the +// class to keep the class interface more readable. +// + +template +Rectangle::Rectangle(const Point2& p0, const Point2& p1) { + Set(p0, p1); +} + +template +Rectangle::Rectangle(const Vector2& p0, const Vector2& p1) { + Set(p0, p1); +} + +template +Rectangle::Rectangle(const T& x, const T& y, const T& width, + const T& height) { + Set(x, y, width, height); +} + +// The general version works only when T models Integer (there are more +// integer classes than float classes). +template +void Rectangle::SetEmpty() { + T min_value = std::numeric_limits::min(); + T max_value = std::numeric_limits::max(); + min_.Set(max_value, max_value); + max_.Set(min_value, min_value); +} + +template <> +inline void Rectangle::SetEmpty() { + float max_value = std::numeric_limits::max(); + min_.Set(max_value, max_value); + max_.Set(-max_value, -max_value); +} + +template <> +inline void Rectangle::SetEmpty() { + double max_value = std::numeric_limits::max(); + min_.Set(max_value, max_value); + max_.Set(-max_value, -max_value); +} + +template +bool Rectangle::operator==(const Rectangle& other) const { + return min_ == other.min_ && max_ == other.max_; +} + +template +bool Rectangle::operator!=(const Rectangle& other) const { + return !(*this == other); +} + +template +void Rectangle::Set(const Vector2& p0, const Vector2& p1) { + if (p0[0] <= p1[0]) + min_.set_x(p0[0]), max_.set_x(p1[0]); + else + max_.set_x(p0[0]), min_.set_x(p1[0]); + + if (p0[1] <= p1[1]) + min_.set_y(p0[1]), max_.set_y(p1[1]); + else + max_.set_y(p0[1]), min_.set_y(p1[1]); +} + +template +void Rectangle::Set(const Point2& p0, const Point2& p1) { + Set(p0.ToVector(), p1.ToVector()); +} + +template +void Rectangle::Set(const T& x, const T& y, const T& width, + const T& height) { + min_.Set(x, y); + max_.Set(x + width, y + height); +} + +template +void Rectangle::Expand(const T& x, const T& y) { + min_.Set(std::min(x, xmin()), std::min(y, ymin())); + max_.Set(std::max(x, xmax()), std::max(y, ymax())); +} + +template +void Rectangle::Expand(const Point2& p) { + Expand(p.x(), p.y()); +} + +template +void Rectangle::Expand(const Vector2& v) { + Expand(v[0], v[1]); +} + +template +void Rectangle::Expand(const Rectangle& other) { + Expand(other.min_); + Expand(other.max_); +} + +template +void Rectangle::Translate(const Vector2& v) { + min_ += v; + max_ += v; +} + +template +bool Rectangle::Contains(const T& x, const T& y) const { + return x >= xmin() && x <= xmax() && y >= ymin() && y <= ymax(); +} + +template +bool Rectangle::Contains(const Point2& p) const { + return Contains(p.x(), p.y()); +} + +template +bool Rectangle::Contains(const Vector2& v) const { + return Contains(v[0], v[1]); +} + +template +bool Rectangle::Contains(const Rectangle& r) const { + return Contains(r.min_) && Contains(r.max_); +} + +template +Rectangle Rectangle::Union(const Rectangle& r) const { + return Rectangle( + Point2(std::min(xmin(), r.xmin()), std::min(ymin(), r.ymin())), + Point2(std::max(xmax(), r.xmax()), std::max(ymax(), r.ymax()))); +} + +template +Rectangle Rectangle::Intersect(const Rectangle& r) const { + Point2 pmin(std::max(xmin(), r.xmin()), std::max(ymin(), r.ymin())); + Point2 pmax(std::min(xmax(), r.xmax()), std::min(ymax(), r.ymax())); + + if (pmin.x() > pmax.x() || pmin.y() > pmax.y()) + return Rectangle(); + else + return Rectangle(pmin, pmax); +} + +template +bool Rectangle::Intersects(const Rectangle& r) const { + return !(IsEmpty() || r.IsEmpty() || r.xmax() < xmin() || xmax() < r.xmin() || + r.ymax() < ymin() || ymax() < r.ymin()); +} + +template +void Rectangle::AddBorder(const T& border_size) { + min_.Set(xmin() - border_size, ymin() - border_size); + max_.Set(xmax() + border_size, ymax() + border_size); +} + +template +std::ostream& operator<<(std::ostream& out, const Rectangle& r) { + out << "[(" << r.xmin() << ", " << r.ymin() << "), (" << r.xmax() << ", " + << r.ymax() << ")]"; + return out; +} + +template +class Rectangle; + +typedef Rectangle Rectangle_b; +typedef Rectangle Rectangle_i; +typedef Rectangle Rectangle_f; +typedef Rectangle Rectangle_d; + +#endif // MEDIAPIPE_DEPS_RECTANGLE_H_ diff --git a/mediapipe/framework/deps/registration.cc b/mediapipe/framework/deps/registration.cc new file mode 100644 index 000000000..f12a3834f --- /dev/null +++ b/mediapipe/framework/deps/registration.cc @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/registration.h" + +namespace mediapipe { + +namespace { + +constexpr char const* kTopNamespaces[] = { + "mediapipe", +}; + +template +inline size_t array_size(T (&arr)[SIZE]) { + return SIZE; +} + +} // namespace + +/*static*/ +const std::unordered_set& NamespaceWhitelist::TopNamespaces() { + static std::unordered_set* result = + new std::unordered_set( + kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces)); + return *result; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h new file mode 100644 index 000000000..58ada0130 --- /dev/null +++ b/mediapipe/framework/deps/registration.h @@ -0,0 +1,387 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_REGISTRATION_H_ +#define MEDIAPIPE_DEPS_REGISTRATION_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/thread_annotations.h" +#include "absl/meta/type_traits.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/deps/registration_token.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +// Usage: +// +// === Defining a registry ================================================ +// +// class Widget {}; +// +// using WidgetRegistry = +// GlobalFactoryRegistry, // return +// unique_ptr, const Thing*> // args +// +// === Registering an implementation ======================================= +// +// class MyWidget : public Widget { +// static unique_ptr Create(unique_ptr arg, +// const Thing* thing) { +// return MakeUnique(std::move(arg), thing); +// } +// ... +// }; +// +// REGISTER_FACTORY_FUNCTION_QUALIFIED( +// WidgetRegistry, widget_registration, +// ::my_ns::MyWidget, MyWidget::Create); +// +// === Using std::function ================================================= +// +// class Client {}; +// +// using ClientRegistry = +// GlobalFactoryRegistry<::mediapipe::StatusOr>; +// +// class MyClient : public Client { +// public: +// MyClient(unique_ptr backend) +// : backend_(std::move(backend)) {} +// private: +// const std::unique_ptr backend_; +// }; +// +// // Any std::function that returns a Client is valid to pass here. Below, +// // we use a lambda. +// REGISTER_FACTORY_FUNCTION_QUALIFIED( +// ClientRegistry, client_registration, +// ::my_ns::MyClient, +// []() { +// auto backend = absl::make_unique("/path/to/backend"); +// const ::mediapipe::Status status = backend->Init(); +// if (!status.ok()) { +// return status; +// } +// std::unique_ptr client +// = absl::make_unique(std::move(backend)); +// return client; +// }); +// +// === Using the registry to create instances ============================== +// +// // Registry will return ::mediapipe::StatusOr +// ::mediapipe::StatusOr> s_or_widget = +// WidgetRegistry::CreateByName( +// "my_ns.MyWidget", std::move(gadget), thing); +// // Registry will return NOT_FOUND if the name is unknown. +// if (!s_or_widget.ok()) ... // handle error +// DoStuffWithWidget(std::move(s_or_widget).ValueOrDie()); +// +// // It's also possible to find an instance by name within a source namespace. +// auto s_or_widget = WidgetRegistry::CreateByNameInNamespace( +// "my_ns.sub_namespace", "MyWidget"); +// +// // It's also possible to just check if a name is registered without creating +// // an instance. +// bool registered = WidgetRegistry::IsRegistered("my_ns::MyWidget"); +// +// // It's also possible to iterate through all registered function names. +// // This might be useful if clients outside of your codebase are registering +// // plugins. +// for (const auto& name : WidgetRegistry::GetRegisteredNames()) { +// ::mediapipe::StatusOr> s_or_widget = +// WidgetRegistry::CreateByName(name, std::move(gadget), thing); +// ... +// } +// +// === Injecting instances for testing ===================================== +// +// Unregister unregisterer(WidgetRegistry::Register( +// "MockWidget", +// [](unique_ptr arg, const Thing* thing) { +// ... +// })); + +namespace registration_internal { +constexpr char kCxxSep[] = "::"; +constexpr char kNameSep[] = "."; + +template +struct WrapStatusOr { + using type = ::mediapipe::StatusOr; +}; + +// Specialization to avoid double-wrapping types that are already StatusOrs. +template +struct WrapStatusOr<::mediapipe::StatusOr> { + using type = ::mediapipe::StatusOr; +}; +} // namespace registration_internal + +class NamespaceWhitelist { + public: + static const std::unordered_set& TopNamespaces(); +}; + +template +class FunctionRegistry { + public: + using Function = std::function; + using ReturnType = typename registration_internal::WrapStatusOr::type; + + FunctionRegistry() {} + FunctionRegistry(const FunctionRegistry&) = delete; + FunctionRegistry& operator=(const FunctionRegistry&) = delete; + + RegistrationToken Register(const std::string& name, Function func) + LOCKS_EXCLUDED(lock_) { + std::string normalized_name = GetNormalizedName(name); + absl::WriterMutexLock lock(&lock_); + std::string adjusted_name = GetAdjustedName(normalized_name); + if (adjusted_name != normalized_name) { + functions_.insert(std::make_pair(adjusted_name, func)); + } + if (functions_.insert(std::make_pair(normalized_name, std::move(func))) + .second) { + return RegistrationToken( + [this, normalized_name]() { Unregister(normalized_name); }); + } + LOG(FATAL) << "Function with name " << name << " already registered."; + return RegistrationToken([]() {}); + } + + // Force 'args' to be deduced by templating the function, instead of just + // accepting Args. This is necessary to make 'args' a forwarding reference as + // opposed to a plain rvalue reference. + // https://isocpp.org/blog/2012/11/universal-references-in-c11-scott-meyers + // + // The absl::enable_if_t is used to disable this method if Args2 are not + // convertible to Args. This will allow the compiler to identify the offending + // line (i.e. the line where the method is called) in the first error message, + // rather than nesting it multiple levels down the error stack. + template , + std::tuple>::value, + int> = 0> + ReturnType Invoke(const std::string& name, Args2&&... args) + LOCKS_EXCLUDED(lock_) { + Function function; + { + absl::ReaderMutexLock lock(&lock_); + auto it = functions_.find(name); + if (it == functions_.end()) { + return ::mediapipe::NotFoundError("No registered object with name: " + + name); + } + function = it->second; + } + return function(std::forward(args)...); + } + + // Invokes the specified factory function and returns the result. + // Namespaces in |name| and |ns| are separated by kNameSep. + template + ReturnType Invoke(const std::string& ns, const std::string& name, + Args2&&... args) LOCKS_EXCLUDED(lock_) { + return Invoke(GetQualifiedName(ns, name), args...); + } + + // Note that it's possible for registered implementations to be subsequently + // unregistered, though this will never happen with registrations made via + // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. + bool IsRegistered(const std::string& name) const LOCKS_EXCLUDED(lock_) { + absl::ReaderMutexLock lock(&lock_); + return functions_.count(name) != 0; + } + + // Returns true if the specified factory function is available. + // Namespaces in |name| and |ns| are separated by kNameSep. + bool IsRegistered(const std::string& ns, const std::string& name) const + LOCKS_EXCLUDED(lock_) { + return IsRegistered(GetQualifiedName(ns, name)); + } + + // Returns a vector of all registered function names. + // Note that it's possible for registered implementations to be subsequently + // unregistered, though this will never happen with registrations made via + // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. + std::unordered_set GetRegisteredNames() const + LOCKS_EXCLUDED(lock_) { + absl::ReaderMutexLock lock(&lock_); + std::unordered_set names; + std::for_each(functions_.cbegin(), functions_.cend(), + [&names](const std::pair& pair) { + names.insert(pair.first); + }); + return names; + } + + // Normalizes a C++ qualified name. Validates the name qualification. + // The name must be either unqualified or fully qualified with a leading "::". + // The leading "::" in a fully qualified name is stripped. + std::string GetNormalizedName(const std::string& name) { + constexpr auto kCxxSep = registration_internal::kCxxSep; + std::vector names = absl::StrSplit(name, kCxxSep); + if (names[0].empty()) { + names.erase(names.begin()); + } else { + CHECK_EQ(1, names.size()) + << "A registered class name must be either fully qualified " + << "with a leading :: or unqualified, got: " << name << "."; + } + return absl::StrJoin(names, kCxxSep); + } + + // Returns the registry key for a name specified within a namespace. + // Namespaces are separated by kNameSep. + std::string GetQualifiedName(const std::string& ns, + const std::string& name) const { + constexpr auto kCxxSep = registration_internal::kCxxSep; + constexpr auto kNameSep = registration_internal::kNameSep; + std::vector names = absl::StrSplit(name, kNameSep); + if (names[0].empty()) { + names.erase(names.begin()); + return absl::StrJoin(names, kCxxSep); + } + std::string cxx_name = absl::StrJoin(names, kCxxSep); + if (ns.empty()) { + return cxx_name; + } + std::vector spaces = absl::StrSplit(ns, kNameSep); + absl::ReaderMutexLock lock(&lock_); + while (!spaces.empty()) { + std::string cxx_ns = absl::StrJoin(spaces, kCxxSep); + std::string qualified_name = absl::StrCat(cxx_ns, kCxxSep, cxx_name); + if (functions_.count(qualified_name)) { + return qualified_name; + } + spaces.pop_back(); + } + return cxx_name; + } + + private: + mutable absl::Mutex lock_; + std::unordered_map functions_ GUARDED_BY(lock_); + + // For names included in NamespaceWhitelist, strips the namespace. + std::string GetAdjustedName(const std::string& name) { + constexpr auto kCxxSep = registration_internal::kCxxSep; + std::vector names = absl::StrSplit(name, kCxxSep); + std::string base_name = names.back(); + names.pop_back(); + std::string ns = absl::StrJoin(names, kCxxSep); + if (NamespaceWhitelist::TopNamespaces().count(ns)) { + return base_name; + } + return name; + } + + void Unregister(const std::string& name) { + absl::WriterMutexLock lock(&lock_); + std::string adjusted_name = GetAdjustedName(name); + if (adjusted_name != name) { + functions_.erase(adjusted_name); + } + functions_.erase(name); + } +}; + +template +class GlobalFactoryRegistry { + using Functions = FunctionRegistry; + + public: + static RegistrationToken Register(const std::string& name, + typename Functions::Function func) { + return functions()->Register(name, std::move(func)); + } + + // Same as CreateByNameInNamespace but without a namespace. + template + static typename Functions::ReturnType CreateByName(const std::string& name, + Args2&&... args) { + return CreateByNameInNamespace("", name, std::forward(args)...); + } + + // Same as IsRegistered(ns, name) but without a namespace. + static bool IsRegistered(const std::string& name) { + return functions()->IsRegistered("", name); + } + + static std::unordered_set GetRegisteredNames() { + return functions()->GetRegisteredNames(); + } + + // Invokes the specified factory function and returns the result. + // Namespaces in |name| and |ns| are separated by kNameSep. + // See comments re: use of Args2 and absl::enable_if_t on Invoke. + template , + std::tuple>::value, + int> = 0> + static typename Functions::ReturnType CreateByNameInNamespace( + const std::string& ns, const std::string& name, Args2&&... args) { + return functions()->Invoke(ns, name, std::forward(args)...); + } + + // Returns true if the specified factory function is available. + // Namespaces in |name| and |ns| are separated by kNameSep. + static bool IsRegistered(const std::string& ns, const std::string& name) { + return functions()->IsRegistered(ns, name); + } + + // Returns the factory function registry singleton. + static Functions* functions() { + static auto* functions = new Functions(); + return functions; + } + + private: + GlobalFactoryRegistry() = delete; +}; + +// Two levels of macros are required to convert __LINE__ into a std::string +// containing the line number. +#define REGISTRY_STATIC_VAR_INNER(var_name, line) var_name##_##line##__ +#define REGISTRY_STATIC_VAR(var_name, line) \ + REGISTRY_STATIC_VAR_INNER(var_name, line) + +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \ + static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \ + new ::mediapipe::RegistrationToken( \ + RegistryType::Register(#name, __VA_ARGS__)) + +#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \ + static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \ + new ::mediapipe::RegistrationToken( \ + RegistryType::Register(#name, __VA_ARGS__)) + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_REGISTRATION_H_ diff --git a/mediapipe/framework/deps/registration_token.cc b/mediapipe/framework/deps/registration_token.cc new file mode 100644 index 000000000..04855480f --- /dev/null +++ b/mediapipe/framework/deps/registration_token.cc @@ -0,0 +1,89 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/registration_token.h" + +#include + +namespace mediapipe { +RegistrationToken::RegistrationToken(std::function unregisterer) + : unregister_function_(std::move(unregisterer)) {} + +RegistrationToken::RegistrationToken(RegistrationToken&& rhs) + : unregister_function_(std::move(rhs.unregister_function_)) { + rhs.unregister_function_ = nullptr; +} + +RegistrationToken& RegistrationToken::operator=(RegistrationToken&& rhs) { + if (&rhs != this) { + unregister_function_ = std::move(rhs.unregister_function_); + rhs.unregister_function_ = nullptr; + } + return *this; +} + +void RegistrationToken::Unregister() { + if (unregister_function_ != nullptr) { + unregister_function_(); + unregister_function_ = nullptr; + } +} + +namespace { +struct CombinedToken { + void operator()() { + for (auto& f : functions) { + f(); + } + } + std::vector> functions; +}; +} // anonymous namespace + +// static +RegistrationToken RegistrationToken::Combine( + std::vector tokens) { + CombinedToken combined; + + // When vector grows, it only moves elements if the move constructor is marked + // noexcept (or if the element isn't copyable). In related news, function's + // move constructor is not marked noexcept. By reserving the correct amount of + // space up front, we remove the need for the vector to grow, and thus + // eliminate copies. + combined.functions.reserve(tokens.size()); + for (RegistrationToken& token : tokens) { + combined.functions.push_back(std::move(token.unregister_function_)); + } + return RegistrationToken(std::move(combined)); +} + +Unregister::Unregister(RegistrationToken token) : token_(std::move(token)) {} + +Unregister::~Unregister() { token_.Unregister(); } + +Unregister::Unregister(Unregister&& rhs) : token_(std::move(rhs.token_)) {} +Unregister& Unregister::operator=(Unregister&& rhs) { + if (&rhs != this) { + token_.Unregister(); + token_ = std::move(rhs.token_); + } + return *this; +} + +void Unregister::Reset(RegistrationToken token) { + token_.Unregister(); + token_ = std::move(token); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/registration_token.h b/mediapipe/framework/deps/registration_token.h new file mode 100644 index 000000000..03597d230 --- /dev/null +++ b/mediapipe/framework/deps/registration_token.h @@ -0,0 +1,117 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_REGISTRATION_TOKEN_H_ +#define MEDIAPIPE_DEPS_REGISTRATION_TOKEN_H_ + +#include +#include + +namespace mediapipe { +// RegistrationToken is a generic class that represents a registration that +// can be later undone, via a call to Unregister(). +// +// It is generally a good idea for registration methods, such as +// RegisterListener(X) to return ways to undo the registration (for instance if +// X goes out of scope). +// RegistrationToken is a good candidate as a return value for those methods. +// +// Example usage: +// +// RegistrationToken token = MyCancellableRegisterListener(foo); +// ... +// do something +// +// token.Unregister(); +// +// +// There is also a Unregister RAII helper below that automatically unregisters +// a token when it goes out of scope: +// +// { +// Unregister unregisterer(MyCancellableRegisterListener(foo)); +// ... +// do something +// +// } // unregisterer goes out of scope, we are unregistered. +// +// +// Implementation: tokens are generic, they just accept a std::function +// that does the actual unregistration. It is up to each registration system to +// pass the function that corresponds to their own implementation for +// unregistering things. +// +// In that regard, tokens are basically a glorified unique_ptr. +// The main advantage is that they guarantee the function can be called only +// once, and naming is also much clearer (Unregister versus operator()). +// +// Tokens are not copyable but they are movable, which reflects the fact that +// there should only ever be one token in charge of a particular registration +// at any time (else there could be confusion, who is in charge of +// unregistering). +// +// This class is thread compatible. +class RegistrationToken { + public: + explicit RegistrationToken(std::function unregisterer); + + // It is useful to have an empty constructor for when we want to declare a + // token, and assign it later. + RegistrationToken() {} + + RegistrationToken(const RegistrationToken&) = delete; + RegistrationToken& operator=(const RegistrationToken&) = delete; + + RegistrationToken(RegistrationToken&& rhs); + RegistrationToken& operator=(RegistrationToken&& rhs); + + // Unregisters the registration for which this token is in charge, and voids + // the token. It is safe to call this more than once, but further calls are + // guaranteed to be noop. + void Unregister(); + + // Returns a token whose Unregister() will Unregister() all . + static RegistrationToken Combine(std::vector tokens); + + private: + std::function unregister_function_ = nullptr; +}; + +// RAII class for registration tokens: it calls Unregister() when it goes out +// of scope. +class Unregister { + public: + // Useful to have an empty constructor for when we want to assign it later. + // The default is an empty token that does nothing. + Unregister() : token_() {} + explicit Unregister(RegistrationToken token); + ~Unregister(); + + Unregister(const Unregister&) = delete; + Unregister& operator=(const Unregister&) = delete; + + Unregister(Unregister&& rhs); + Unregister& operator=(Unregister&& rhs); + + // Similar to unique_ptr.reset() and the likes: this will unregister the + // current token if any, and then assume registration ownership of this new + // . + void Reset(RegistrationToken token = RegistrationToken()); + + private: + RegistrationToken token_; +}; +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_REGISTRATION_TOKEN_H_ diff --git a/mediapipe/framework/deps/registration_token_test.cc b/mediapipe/framework/deps/registration_token_test.cc new file mode 100644 index 000000000..51d5477bc --- /dev/null +++ b/mediapipe/framework/deps/registration_token_test.cc @@ -0,0 +1,126 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/registration_token.h" + +#include +#include + +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { +namespace { +class RegistrationTokenTest : public testing::Test { + public: + void CallFirst() { ++called_1_; } + + void CallSecond() { ++called_2_; } + + void CallThird() { ++called_3_; } + + protected: + int called_1_{0}; + int called_2_{0}; + int called_3_{0}; +}; + +// Trivial unregistration test. +TEST_F(RegistrationTokenTest, TestUnregister) { + std::function caller = [this]() { + RegistrationTokenTest::CallFirst(); + }; + RegistrationToken token(caller); + ASSERT_EQ(0, called_1_); + token.Unregister(); + ASSERT_EQ(1, called_1_); + + // Check that further calls have no effect. + token.Unregister(); + token.Unregister(); + ASSERT_EQ(1, called_1_); + + // Test the RAII class. + ASSERT_EQ(0, called_2_); + RegistrationToken token2([this]() { RegistrationTokenTest::CallSecond(); }); + { + Unregister t(std::move(token2)); + ASSERT_EQ(0, called_2_); + } + + // It was called since the Unregister() went out of scope. + ASSERT_EQ(1, called_2_); +} + +// Tests that the result of a Combine() token does unregisters all combined +// tokens. +TEST_F(RegistrationTokenTest, TestCombine) { + std::function caller_1 = [this]() { + RegistrationTokenTest::CallFirst(); + }; + std::function caller_2 = [this]() { + RegistrationTokenTest::CallSecond(); + }; + std::function caller_3 = [this]() { + RegistrationTokenTest::CallThird(); + }; + + RegistrationToken token_1(caller_1); + RegistrationToken token_2(caller_2); + RegistrationToken token_3(caller_3); + + ASSERT_EQ(0, called_1_); + ASSERT_EQ(0, called_2_); + ASSERT_EQ(0, called_3_); + + std::vector tokens; + tokens.emplace_back(std::move(token_1)); + tokens.emplace_back(std::move(token_2)); + tokens.emplace_back(std::move(token_3)); + + RegistrationToken combined = RegistrationToken::Combine(std::move(tokens)); + combined.Unregister(); + + ASSERT_EQ(1, called_1_); + ASSERT_EQ(1, called_2_); + ASSERT_EQ(1, called_3_); + + // Check that the original tokens were invalidated by their move and do + // nothing. + token_1.Unregister(); + token_2.Unregister(); + token_3.Unregister(); + + ASSERT_EQ(1, called_1_); + ASSERT_EQ(1, called_2_); + ASSERT_EQ(1, called_3_); +} + +TEST_F(RegistrationTokenTest, TestMove) { + RegistrationToken token([this] { CallFirst(); }); + token = RegistrationToken([this] { CallFirst(); }); + EXPECT_EQ(0, called_1_); + + Unregister unreg; + unreg = Unregister(std::move(token)); + EXPECT_EQ(0, called_1_); + unreg = Unregister(RegistrationToken([this] { CallFirst(); })); + EXPECT_EQ(1, called_1_); + unreg = Unregister(RegistrationToken([this] { CallFirst(); })); + EXPECT_EQ(2, called_1_); + unreg.Reset(); + EXPECT_EQ(3, called_1_); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/deps/ret_check.cc b/mediapipe/framework/deps/ret_check.cc new file mode 100644 index 000000000..65fdcc033 --- /dev/null +++ b/mediapipe/framework/deps/ret_check.cc @@ -0,0 +1,39 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/ret_check.h" + +namespace mediapipe { + +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location) { + // TODO Implement LogWithStackTrace(). + return ::mediapipe::InternalErrorBuilder(location) + << "RET_CHECK failure (" << location.file_name() << ":" + << location.line() << ") "; +} + +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location, const char* condition) { + return ::mediapipe::RetCheckFailSlowPath(location) << condition; +} + +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location, const char* condition, + const ::mediapipe::Status& status) { + return ::mediapipe::RetCheckFailSlowPath(location) + << condition << " returned " << status << " "; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/ret_check.h b/mediapipe/framework/deps/ret_check.h new file mode 100644 index 000000000..ceecd3818 --- /dev/null +++ b/mediapipe/framework/deps/ret_check.h @@ -0,0 +1,65 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_RET_CHECK_H_ +#define MEDIAPIPE_DEPS_RET_CHECK_H_ + +#include "absl/base/optimization.h" +#include "mediapipe/framework/deps/status_builder.h" +#include "mediapipe/framework/deps/status_macros.h" + +namespace mediapipe { +// Returns a StatusBuilder that corresponds to a `RET_CHECK` failure. +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location); + +// Returns a StatusBuilder that corresponds to a `RET_CHECK` failure. +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location, const char* condition); + +// Returns a StatusBuilder that corresponds to a `RET_CHECK` failure. +::mediapipe::StatusBuilder RetCheckFailSlowPath( + ::mediapipe::source_location location, const char* condition, + const ::mediapipe::Status& status); + +inline StatusBuilder RetCheckImpl(const ::mediapipe::Status& status, + const char* condition, + ::mediapipe::source_location location) { + if (ABSL_PREDICT_TRUE(status.ok())) + return ::mediapipe::StatusBuilder(OkStatus(), location); + return RetCheckFailSlowPath(location, condition, status); +} + +} // namespace mediapipe + +#define RET_CHECK(cond) \ + while (ABSL_PREDICT_FALSE(!(cond))) \ + return ::mediapipe::RetCheckFailSlowPath(MEDIAPIPE_LOC, #cond) + +#define RET_CHECK_OK(status) \ + RETURN_IF_ERROR(::mediapipe::RetCheckImpl((status), #status, MEDIAPIPE_LOC)) + +#define RET_CHECK_FAIL() return ::mediapipe::RetCheckFailSlowPath(MEDIAPIPE_LOC) + +#define MEDIAPIPE_INTERNAL_RET_CHECK_OP(name, op, lhs, rhs) \ + RET_CHECK((lhs)op(rhs)) + +#define RET_CHECK_EQ(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(EQ, ==, lhs, rhs) +#define RET_CHECK_NE(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(NE, !=, lhs, rhs) +#define RET_CHECK_LE(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(LE, <=, lhs, rhs) +#define RET_CHECK_LT(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(LT, <, lhs, rhs) +#define RET_CHECK_GE(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(GE, >=, lhs, rhs) +#define RET_CHECK_GT(lhs, rhs) MEDIAPIPE_INTERNAL_RET_CHECK_OP(GT, >, lhs, rhs) + +#endif // MEDIAPIPE_DEPS_RET_CHECK_H_ diff --git a/mediapipe/framework/deps/safe_int.h b/mediapipe/framework/deps/safe_int.h new file mode 100644 index 000000000..94aaeb8b1 --- /dev/null +++ b/mediapipe/framework/deps/safe_int.h @@ -0,0 +1,310 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A "safe int" is a StrongInt which does additional validation of the +// various arithmetic and logical operations, and reacts to overflows and +// underflow and invalid operations. You can define the "safe int" types +// to react to errors in pre-defined ways or you can define your own policy +// classes. +// +// Usage: +// MEDIAPIPE_DEFINE_SAFE_INT_TYPE(Name, NativeType, PolicyType); +// +// Defines a new StrongInt type named 'Name' in the current namespace with +// underflow/overflow checking on all operations, with configurable error +// policy. +// +// Name: The desired name for the new StrongInt typedef. Must be unique +// within the current namespace. +// NativeType: The primitive integral type this StrongInt will hold, as +// defined by std::is_integral (see ). +// PolicyType: The type of policy used by this StrongInt type. A few +// pre-built policy types are provided here, but the caller can +// define any custom policy they desire. +// +// PolicyTypes: +// LogFatalOnError: LOG(FATAL) when a error occurs. + +#ifndef MEDIAPIPE_DEPS_SAFE_INT_H_ +#define MEDIAPIPE_DEPS_SAFE_INT_H_ + +#include + +#include +#include + +#include "mediapipe/framework/deps/strong_int.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { +namespace intops { + +// A StrongInt validator class for "safe" type enforcement. For signed types, +// this checks for overflows and underflows as well as undefined- or +// implementation-defined behaviors. For unsigned type, this further disallows +// operations that would take advantage of unsigned wrap-around behavior and +// operations which would discard data unexpectedly. This assumes two's +// complement representations, and that division truncates towards zero. +// +// For some more on overflow safety, see: +// https://www.securecoding.cert.org/confluence/display/seccode/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow?showComments=false +template +class SafeIntStrongIntValidator { + private: + template + static void SanityCheck() { + // Check that the underlying integral type provides a range that is + // compatible with two's complement. + if (std::numeric_limits::is_signed) { + CHECK_EQ(-1, + std::numeric_limits::min() + std::numeric_limits::max()) + << "unexpected integral bounds"; + } + + // Check that division truncates towards 0 (implementation defined in + // C++'03, but standard in C++'11). + CHECK_EQ(12, 127 / 10) << "division does not truncate towards 0"; + CHECK_EQ(-12, -127 / 10) << "division does not truncate towards 0"; + CHECK_EQ(-12, 127 / -10) << "division does not truncate towards 0"; + CHECK_EQ(12, -127 / -10) << "division does not truncate towards 0"; + } + + public: + template + static void ValidateInit(U arg) { + // Do some sanity checks before proceeding. + SanityCheck(); + + // If the argument is floating point, we can do a simple check to make + // sure the value is in range. It is undefined behavior to convert to int + // from a float that is out of range. + if (std::is_floating_point::value) { + if (arg < std::numeric_limits::min() || + arg > std::numeric_limits::max()) { + ErrorType::Error("SafeInt: init from out of bounds float", arg, "="); + } + } else { + // If the initial value (type U) is changed by being converted to and from + // the native type (type T), then it must be out of bounds for type T. + // + // If T is unsigned and the argument is negative, then it is clearly out + // of bounds for type T. + // + // If the initial value is greater than the max value for type T, then it + // is clearly out of bounds for type T. Before we check that, though, we + // must ensure that the initial value is positive, or else we could get + // unwanted promotion to unsigned, making the test wrong. If the initial + // value is negative, it can't be larger than the max value for type T. + if ((static_cast(static_cast(arg)) != arg) || + (!std::numeric_limits::is_signed && arg < 0) || + (arg > 0 && arg > std::numeric_limits::max())) { + ErrorType::Error("SafeInt: init from out of bounds value", arg, "="); + } + } + } + template + static void ValidateNegate( // Signed types only. + typename std::enable_if::is_signed, T>::type + value) { + if (value == std::numeric_limits::min()) { + ErrorType::Error("SafeInt: overflow", value, -1, "*"); + } + } + template + static void ValidateBitNot( // Unsigned types only. + typename std::enable_if::is_signed, T>::type + value) { + // Do nothing. + } + template + static void ValidateAdd(T lhs, T rhs) { + // The same logic applies to signed and unsigned types. + if ((rhs > 0) && (lhs > (std::numeric_limits::max() - rhs))) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "+"); + } else if ((rhs < 0) && (lhs < (std::numeric_limits::min() - rhs))) { + ErrorType::Error("SafeInt: underflow", lhs, rhs, "+"); + } + } + template + static void ValidateSubtract(T lhs, T rhs) { + // The same logic applies to signed and unsigned types. + if ((rhs > 0) && (lhs < (std::numeric_limits::min() + rhs))) { + ErrorType::Error("SafeInt: underflow", lhs, rhs, "-"); + } else if ((rhs < 0) && (lhs > (std::numeric_limits::max() + rhs))) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "-"); + } + } + template + static void ValidateMultiply(T lhs, U rhs) { + if (!std::numeric_limits::is_signed) { + // Unsigned types only. + if (rhs < 0) { + ErrorType::Error("SafeInt: negation of unsigned type", lhs, rhs, "*"); + } + } + // Multiplication by 0 can never overflow/underflow, but handling 0 makes + // the below code more complex. + if (lhs == 0 || rhs == 0) { + return; + } + // The remaining logic applies to signed and unsigned types. Note that + // while multiplication is commutative, the underlying StrongInt class + // always calls this with T as StrongInt::ValueType. + if (lhs > 0) { + if (rhs > 0) { + if (lhs > (std::numeric_limits::max() / rhs)) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "*"); + } + } else { + if (rhs < (std::numeric_limits::min() / lhs)) { + ErrorType::Error("SafeInt: underflow", lhs, rhs, "*"); + } + } + } else { + if (rhs > 0) { + // Underflow could be tested by lhs < min / rhs, but that does not + // work if rhs is an unsigned type. Intead we test rhs > min / lhs. + // There is a special case for lhs = -1, which would overflow min / lhs. + if ((lhs == -1 && rhs - 1 > std::numeric_limits::max()) || + (lhs < -1 && rhs > std::numeric_limits::min() / lhs)) { + ErrorType::Error("SafeInt: underflow", lhs, rhs, "*"); + } + } else { + if ((lhs != 0) && (rhs < (std::numeric_limits::max() / lhs))) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "*"); + } + } + } + } + template + static void ValidateDivide(T lhs, U rhs) { + // This applies to signed and unsigned types. + if (rhs == 0) { + ErrorType::Error("SafeInt: divide by zero", lhs, rhs, "/"); + } + if (std::numeric_limits::is_signed) { + // Signed types only. + if ((lhs == std::numeric_limits::min()) && (rhs == -1)) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "/"); + } + } else { + // Unsigned types only. + if (rhs < 0) { + ErrorType::Error("SafeInt: negation of unsigned type", lhs, rhs, "/"); + } + } + } + template + static void ValidateModulo(T lhs, U rhs) { + // This applies to signed and unsigned types. + if (rhs == 0) { + ErrorType::Error("SafeInt: divide by zero", lhs, rhs, "%"); + } + if (std::numeric_limits::is_signed) { + // Signed types only. + if ((lhs == std::numeric_limits::min()) && (rhs == -1)) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "%"); + } + } else { + // Unsigned types only. + if (rhs < 0) { + ErrorType::Error("SafeInt: negation of unsigned type", lhs, rhs, "%"); + } + } + } + template + static void ValidateLeftShift(T lhs, int64 rhs) { + if (std::numeric_limits::is_signed) { + // Signed types only. + if (lhs < 0) { + ErrorType::Error("SafeInt: shift of negative value", lhs, rhs, "<<"); + } + } + // The remaining logic applies to signed and unsigned types. + if (rhs < 0) { + ErrorType::Error("SafeInt: shift by negative arg", lhs, rhs, "<<"); + } + if (rhs >= (sizeof(T) * CHAR_BIT)) { + ErrorType::Error("SafeInt: shift by large arg", lhs, rhs, "<<"); + } + if (lhs > (std::numeric_limits::max() >> rhs)) { + ErrorType::Error("SafeInt: overflow", lhs, rhs, "<<"); + } + } + template + static void ValidateRightShift(T lhs, int64 rhs) { + if (std::numeric_limits::is_signed) { + // Signed types only. + if (lhs < 0) { + ErrorType::Error("SafeInt: shift of negative value", lhs, rhs, ">>"); + } + } + // The remaining logic applies to signed and unsigned types. + if (rhs < 0) { + ErrorType::Error("SafeInt: shift by negative arg", lhs, rhs, ">>"); + } + if (rhs >= (sizeof(T) * CHAR_BIT)) { + ErrorType::Error("SafeInt: shift by large arg", lhs, rhs, ">>"); + } + } + template + static void ValidateBitAnd( // Unsigned types only. + typename std::enable_if::is_signed, T>::type lhs, + typename std::enable_if::is_signed, T>::type + rhs) { + // Do nothing. + } + template + static void ValidateBitOr( // Unsigned types only. + typename std::enable_if::is_signed, T>::type lhs, + typename std::enable_if::is_signed, T>::type + rhs) { + // Do nothing. + } + template + static void ValidateBitXor( // Unsigned types only. + typename std::enable_if::is_signed, T>::type lhs, + typename std::enable_if::is_signed, T>::type + rhs) { + // Do nothing. + } +}; + +// A SafeIntStrongIntValidator policy class to LOG(FATAL) on errors. +struct LogFatalOnError { + template + static void Error(const char *error, Tlhs lhs, Trhs rhs, const char *op) { + LOG(FATAL) << error << ": (" << lhs << " " << op << " " << rhs << ")"; + } + template + static void Error(const char *error, Tval val, const char *op) { + LOG(FATAL) << error << ": (" << op << val << ")"; + } +}; + +} // namespace intops +} // namespace mediapipe + +// Defines the StrongInt using value_type and typedefs it to type_name, with +// strong checking of under/overflow conditions. +// The struct int_type_name ## _tag_ trickery is needed to ensure that a new +// type is created per type_name. +#define MEDIAPIPE_DEFINE_SAFE_INT_TYPE(type_name, value_type, policy_type) \ + struct type_name##_safe_tag_ {}; \ + typedef ::mediapipe::intops::StrongInt< \ + type_name##_safe_tag_, value_type, \ + ::mediapipe::intops::SafeIntStrongIntValidator> \ + type_name; + +#endif // MEDIAPIPE_DEPS_SAFE_INT_H_ diff --git a/mediapipe/framework/deps/safe_int_test.cc b/mediapipe/framework/deps/safe_int_test.cc new file mode 100644 index 000000000..2619837f7 --- /dev/null +++ b/mediapipe/framework/deps/safe_int_test.cc @@ -0,0 +1,771 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Unit test cases for SafeInt. Some of this overlaps with the testing for +// StrongInt, but it's important to test not only that SafeInt fails when +// expected, but that it passes when expected. + +#include "mediapipe/framework/deps/safe_int.h" + +#include "mediapipe/framework/port/gtest.h" + +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt8, int8, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt8, uint8, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt16, int16, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt16, uint16, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt32, int32, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt64, int64, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt32, uint32, + ::mediapipe::intops::LogFatalOnError); +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt64, uint64, + ::mediapipe::intops::LogFatalOnError); + +namespace mediapipe { +namespace intops { + +// +// Test cases that apply to signed and unsigned types equally. +// + +template +class SignNeutralSafeIntTest : public ::testing::Test { + public: + typedef T SafeIntTypeUnderTest; +}; + +typedef ::testing::Types + AllSafeIntTypes; + +TYPED_TEST_SUITE(SignNeutralSafeIntTest, AllSafeIntTypes); + +TYPED_TEST(SignNeutralSafeIntTest, TestCtors) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test default construction. + T x; + EXPECT_EQ(V(), x.value()); + } + + { // Test construction from a value. + T x(93); + EXPECT_EQ(V(93), x.value()); + } + + { // Test copy construction. + T x(76); + T y(x); + EXPECT_EQ(V(76), y.value()); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestUnaryOperators) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test unary plus of positive values. + T x(123); + EXPECT_EQ(V(123), (+x).value()); + } + { // Test logical not of positive values. + T x(123); + EXPECT_EQ(false, !x); + EXPECT_EQ(true, !!x); + } + { // Test logical not of zero. + T x(0); + EXPECT_EQ(true, !x); + EXPECT_EQ(false, !!x); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestCtorFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test out-of-bounds construction. + if (std::numeric_limits::is_signed || sizeof(V) < sizeof(uint64)) { + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + } + } + { // Test out-of-bounds construction from float. + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + EXPECT_DEATH((T(-std::numeric_limits::max())), "bounds"); + } + { // Test out-of-bounds construction from double. + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + EXPECT_DEATH((T(-std::numeric_limits::max())), "bounds"); + } + { // Test out-of-bounds construction from long double. + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + EXPECT_DEATH((T(-std::numeric_limits::max())), "bounds"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestIncrementDecrement) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test simple increments and decrements. + T x(0); + EXPECT_EQ(V(0), x.value()); + EXPECT_EQ(V(0), (x++).value()); + EXPECT_EQ(V(1), x.value()); + EXPECT_EQ(V(2), (++x).value()); + EXPECT_EQ(V(2), x.value()); + EXPECT_EQ(V(2), (x--).value()); + EXPECT_EQ(V(1), x.value()); + EXPECT_EQ(V(0), (--x).value()); + EXPECT_EQ(V(0), x.value()); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestIncrementDecrementFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflowing increment. + T x(std::numeric_limits::max() - 1); + EXPECT_EQ(std::numeric_limits::max(), (++x).value()); + EXPECT_DEATH(x++, "overflow"); + EXPECT_DEATH(++x, "overflow"); + } + { // Test underflowing decrement. + T x(std::numeric_limits::min() + 1); + EXPECT_EQ(std::numeric_limits::min(), (--x).value()); + EXPECT_DEATH(x--, "underflow"); + EXPECT_DEATH(--x, "underflow"); + } +} + +#define TEST_T_OP_T(xval, op, yval) \ + { \ + T x(xval); \ + T y(yval); \ + V expected = x.value() op y.value(); \ + EXPECT_EQ(expected, (x op y).value()); \ + EXPECT_EQ(expected, (x op## = y).value()); \ + EXPECT_EQ(expected, x.value()); \ + } + +TYPED_TEST(SignNeutralSafeIntTest, TestAdd) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test positive vs. positive addition. + TEST_T_OP_T(9, +, 3) + // Test addition by zero. + TEST_T_OP_T(93, +, 0); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestAddFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflowing addition. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x + T(1), "overflow"); + EXPECT_DEATH(x += T(1), "overflow"); + } + { // Test overflowing addition. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x + T(std::numeric_limits::max()), "overflow"); + EXPECT_DEATH(x += T(std::numeric_limits::max()), "overflow"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestSubtract) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test positive vs. positive subtraction. + TEST_T_OP_T(9, -, 3) + // Test subtraction of zero. + TEST_T_OP_T(93, -, 0); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestSubtractFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test underflowing subtraction. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x - T(1), "underflow"); + EXPECT_DEATH(x -= T(1), "underflow"); + } + { // Test underflowing subtraction. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x - T(std::numeric_limits::max()), "underflow"); + EXPECT_DEATH(x -= T(std::numeric_limits::max()), "underflow"); + } +} + +#define TEST_T_OP_NUM(xval, op, numtype, yval) \ + { \ + T x(xval); \ + numtype y = yval; \ + V expected = x.value() op y; \ + EXPECT_EQ(expected, (x op y).value()); \ + EXPECT_EQ(expected, (x op## = y).value()); \ + EXPECT_EQ(expected, x.value()); \ + } + +TYPED_TEST(SignNeutralSafeIntTest, TestMultiply) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test positive vs. positive multiplication across types. + TEST_T_OP_NUM(9, *, int32, 3); + TEST_T_OP_NUM(9, *, uint32, 3); + TEST_T_OP_NUM(9, *, float, 3); + TEST_T_OP_NUM(9, *, double, 3); + + // Test positive vs. zero multiplication commutatively across types. This + // was a real bug. + TEST_T_OP_NUM(93, *, int32, 0); + TEST_T_OP_NUM(93, *, uint32, 0); + TEST_T_OP_NUM(93, *, float, 0); + TEST_T_OP_NUM(93, *, double, 0); + + TEST_T_OP_NUM(0, *, int32, 76); + TEST_T_OP_NUM(0, *, uint32, 76); + TEST_T_OP_NUM(0, *, float, 76); + TEST_T_OP_NUM(0, *, double, 76); + + // Test positive vs. epsilon multiplication. + TEST_T_OP_NUM(93, *, float, std::numeric_limits::epsilon()); + TEST_T_OP_NUM(93, *, double, std::numeric_limits::epsilon()); + + { // Test multiplication by float. + // Multiplication is the only operator that takes one numeric type and + // one StrongInt type *and* is commutative. This was a real bug. + T x(0); + EXPECT_EQ(0, (x * static_cast(1.1)).value()); + EXPECT_EQ(0, (static_cast(1.1) * x).value()); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestMultiplyFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflowing multiplication. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x * 2, "overflow"); + EXPECT_DEATH(x *= 2, "overflow"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestDivide) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test positive vs. positive division across types. + TEST_T_OP_NUM(9, /, int32, 3); + TEST_T_OP_NUM(9, /, uint32, 3); + TEST_T_OP_NUM(9, /, float, 3); + TEST_T_OP_NUM(9, /, double, 3); + + // Test zero vs. positive division across types. + TEST_T_OP_NUM(0, /, int32, 76); + TEST_T_OP_NUM(0, /, uint32, 76); + TEST_T_OP_NUM(0, /, float, 76); + TEST_T_OP_NUM(0, /, double, 76); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestDivideFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test divide by zero. + T x(93); + EXPECT_DEATH(x / 0, "divide by zero"); + EXPECT_DEATH(x /= 0, "divide by zero"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestModulo) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test positive vs. positive modulo across signedness. + TEST_T_OP_NUM(7, %, int32, 6); + TEST_T_OP_NUM(7, %, uint32, 6); + + // Test zero vs. positive modulo across signedness. + TEST_T_OP_NUM(0, %, int32, 6); + TEST_T_OP_NUM(0, %, uint32, 6); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestModuloFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test modulo by zero. + T x(93); + EXPECT_DEATH(x % 0, "divide by zero"); + EXPECT_DEATH(x %= 0, "divide by zero"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestLeftShift) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test basic shift. + TEST_T_OP_NUM(0x09, <<, int, 3); + // Test shift by zero. + TEST_T_OP_NUM(0x09, <<, int, 0); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestLeftShiftFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test shift by a negative. + T x(9); + EXPECT_DEATH(x << -1, "shift by negative"); + EXPECT_DEATH(x <<= -1, "shift by negative"); + } + { // Test shift by a too-large. + T x(9); + EXPECT_DEATH(x << sizeof(T) * CHAR_BIT, "shift by large"); + EXPECT_DEATH(x <<= sizeof(T) * CHAR_BIT, "shift by large"); + EXPECT_DEATH(x <<= 0x100000001ULL, "shift by large"); + } + { // Test overflowing shift. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x << 1, "overflow"); + EXPECT_DEATH(x <<= 1, "overflow"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestRightShift) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test basic shift. + TEST_T_OP_NUM(0x09, >>, int, 3); + // Test shift by zero. + TEST_T_OP_NUM(0x09, >>, int, 0); +} + +TYPED_TEST(SignNeutralSafeIntTest, TestRightShiftFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test shift by a negative. + T x(9); + EXPECT_DEATH(x >> -1, "shift by negative"); + EXPECT_DEATH(x >>= -1, "shift by negative"); + } + { // Test shift by a too-large. + T x(9); + EXPECT_DEATH(x >> sizeof(T) * CHAR_BIT, "shift by large"); + EXPECT_DEATH(x >>= sizeof(T) * CHAR_BIT, "shift by large"); + } +} + +TYPED_TEST(SignNeutralSafeIntTest, TestFloatToIntTruncation) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test construction from float. + { + float f = 93.123; + T x(f); + EXPECT_EQ(93, x.value()); + } + { + float f = 93.76; + T x(f); + EXPECT_EQ(93, x.value()); + } + // Test construction from double. + { + double f = 93.123; + T x(f); + EXPECT_EQ(93, x.value()); + } + { + double f = 93.76; + T x(f); + EXPECT_EQ(93, x.value()); + } + // Test construction from long double. + { + long double f = 93.123; + T x(f); + EXPECT_EQ(93, x.value()); + } + { + long double f = 93.76; + T x(f); + EXPECT_EQ(93, x.value()); + } +} + +// +// Test cases that apply only to signed types. +// + +template +class SignedSafeIntTest : public ::testing::Test { + public: + typedef T SafeIntTypeUnderTest; +}; + +typedef ::testing::Types + SignedSafeIntTypes; + +TYPED_TEST_SUITE(SignedSafeIntTest, SignedSafeIntTypes); + +TYPED_TEST(SignedSafeIntTest, TestCtors) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test construction from a negative value. + T x(-1); + EXPECT_EQ(V(-1), x.value()); + } +} + +TYPED_TEST(SignedSafeIntTest, TestUnaryOperators) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test unary plus and minus of positive values. + T x(123); + EXPECT_EQ(V(123), (+x).value()); + EXPECT_EQ(V(-123), (-x).value()); + } + { // Test unary plus and minus of negative values. + T x(-123); + EXPECT_EQ(V(-123), (+x).value()); + EXPECT_EQ(V(123), (-x).value()); + } + { // Test logical not of negative values. + T x(-123); + EXPECT_EQ(false, !x); + EXPECT_EQ(true, !!x); + } +} + +TYPED_TEST(SignedSafeIntTest, TestUnaryOperatorsFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test unary minus of negative values. + T y(std::numeric_limits::min()); + EXPECT_DEATH(-y, "overflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestAdd) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test negative vs. positive addition. + TEST_T_OP_T(-9, +, 3) + // Test positive vs. negative addition. + TEST_T_OP_T(9, +, -3) + // Test negative vs. negative addition. + TEST_T_OP_T(-9, +, -3) +} + +TYPED_TEST(SignedSafeIntTest, TestAddFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test underflow by addition of a negative. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x + T(-1), "underflow"); + EXPECT_DEATH(x += T(-1), "underflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestSubtract) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test negative vs. positive subtraction. + TEST_T_OP_T(-9, -, 3) + // Test positive vs. negative subtraction. + TEST_T_OP_T(9, -, -3) + // Test negative vs. negative subtraction. + TEST_T_OP_T(-9, -, -3) + // Test positive vs. positive subtraction resulting in negative. + TEST_T_OP_T(3, -, 9); + // Test subtraction from zero. + TEST_T_OP_T(0, -, 93); +} + +TYPED_TEST(SignedSafeIntTest, TestSubtractFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflow by subtraction of a negative. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x - T(-1), "overflow"); + EXPECT_DEATH(x -= T(-1), "overflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestMultiply) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test negative vs. positive multiplication across types. + TEST_T_OP_NUM(-9, *, int32, 3); + TEST_T_OP_NUM(-9, *, uint32, 3); + TEST_T_OP_NUM(-9, *, float, 3); + TEST_T_OP_NUM(-9, *, double, 3); + // Test positive vs. negative multiplication across types. + TEST_T_OP_NUM(9, *, int32, -3); + // Don't cover unsigneds that are initialized from negative values. + TEST_T_OP_NUM(9, *, float, -3); + TEST_T_OP_NUM(9, *, double, -3); + // Test negative vs. negative multiplication across types. + TEST_T_OP_NUM(-9, *, int32, -3); + // Don't cover unsigneds that are initialized from negative values. + TEST_T_OP_NUM(-9, *, float, -3); + TEST_T_OP_NUM(-9, *, double, -3); + + // Test negative vs. zero multiplication commutatively across types. + TEST_T_OP_NUM(-93, *, int32, 0); + TEST_T_OP_NUM(-93, *, uint32, 0); + TEST_T_OP_NUM(-93, *, float, 0); + TEST_T_OP_NUM(-93, *, double, 0); + TEST_T_OP_NUM(0, *, int32, -76); + TEST_T_OP_NUM(0, *, uint32, -76); + TEST_T_OP_NUM(0, *, float, -76); + TEST_T_OP_NUM(0, *, double, -76); + + // Test negative vs. epsilon multiplication. + TEST_T_OP_NUM(-93, *, float, std::numeric_limits::epsilon()); + TEST_T_OP_NUM(-93, *, double, std::numeric_limits::epsilon()); +} + +TYPED_TEST(SignedSafeIntTest, TestMultiplyFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test underflowing multiplication. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x * 2, "underflow"); + EXPECT_DEATH(x *= 2, "underflow"); + } + { // Test underflowing multiplication. + T x(std::numeric_limits::max()); + EXPECT_DEATH(x * -2, "underflow"); + EXPECT_DEATH(x *= -2, "underflow"); + } + { // Test overflowing multiplication. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x * -2, "overflow"); + EXPECT_DEATH(x *= -2, "overflow"); + } + { // Test overflowing multiplication. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x * -1, "overflow"); + EXPECT_DEATH(x *= -1, "overflow"); + } + { // Test underflowing multiplication where rhs type is uint64. + T x(-2); + EXPECT_DEATH(x * kuint64max, "underflow"); + EXPECT_DEATH(x *= kuint64max, "underflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestDivide) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test negative vs. positive division across types. + TEST_T_OP_NUM(-9, /, int32, 3); + TEST_T_OP_NUM(-9, /, uint32, 3); + TEST_T_OP_NUM(-9, /, float, 3); + TEST_T_OP_NUM(-9, /, double, 3); + // Test positive vs. negative division across types. + TEST_T_OP_NUM(9, /, int32, -3); + TEST_T_OP_NUM(9, /, uint32, -3); + TEST_T_OP_NUM(9, /, float, -3); + TEST_T_OP_NUM(9, /, double, -3); + // Test negative vs. negative division across types. + TEST_T_OP_NUM(-9, /, int32, -3); + TEST_T_OP_NUM(-9, /, uint32, -3); + TEST_T_OP_NUM(-9, /, float, -3); + TEST_T_OP_NUM(-9, /, double, -3); + + // Test zero vs. negative division across types. + TEST_T_OP_NUM(0, /, int32, -76); + TEST_T_OP_NUM(0, /, uint32, -76); + TEST_T_OP_NUM(0, /, float, -76); + TEST_T_OP_NUM(0, /, double, -76); +} + +TYPED_TEST(SignedSafeIntTest, TestDivideFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflowing division. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x / -1, "overflow"); + EXPECT_DEATH(x /= -1, "overflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestModulo) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + // Test negative vs. positive modulo across signedness. + TEST_T_OP_NUM(-7, %, int32, 6); + TEST_T_OP_NUM(-7, %, uint32, 6); + // Test positive vs. negative modulo across signedness. + TEST_T_OP_NUM(7, %, int32, -6); + TEST_T_OP_NUM(7, %, uint32, -6); + // Test negative vs. negative modulo across signedness. + TEST_T_OP_NUM(-7, %, int32, -6); + TEST_T_OP_NUM(-7, %, uint32, -6); + + // Test zero vs. negative modulo across signedness. + TEST_T_OP_NUM(0, %, int32, -6); + TEST_T_OP_NUM(0, %, uint32, -6); +} + +TYPED_TEST(SignedSafeIntTest, TestModuloFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test overflowing modulo. + T x(std::numeric_limits::min()); + EXPECT_DEATH(x % -1, "overflow"); + EXPECT_DEATH(x %= -1, "overflow"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestLeftShiftFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test shift of a negative. + T x(-9); + EXPECT_DEATH(x << 1, "shift of negative"); + EXPECT_DEATH(x <<= 1, "shift of negative"); + } +} + +TYPED_TEST(SignedSafeIntTest, TestRightShiftFailures) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test shift of a negative. + T x(-9); + EXPECT_DEATH(x >> 1, "shift of negative"); + EXPECT_DEATH(x >>= 1, "shift of negative"); + } +} + +// +// Test cases that apply only to unsigned types. +// + +template +class UnsignedSafeIntTest : public ::testing::Test { + public: + typedef T SafeIntTypeUnderTest; +}; + +typedef ::testing::Types + UnsignedSafeIntTypes; + +TYPED_TEST_SUITE(UnsignedSafeIntTest, UnsignedSafeIntTypes); + +TYPED_TEST(UnsignedSafeIntTest, TestCtors) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test out-of-bounds construction. + EXPECT_DEATH(T(-1), "bounds"); + } + { // Test out-of-bounds construction from float. + EXPECT_DEATH((T(static_cast(-1))), "bounds"); + } + { // Test out-of-bounds construction from double. + EXPECT_DEATH((T(static_cast(-1))), "bounds"); + } + { // Test out-of-bounds construction from long double. + EXPECT_DEATH((T(static_cast(-1))), "bounds"); + } +} + +TYPED_TEST(UnsignedSafeIntTest, TestUnaryOperators) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test bitwise not of positive values. + T x(123); + EXPECT_EQ(V(~(x.value())), (~x).value()); + EXPECT_EQ(x.value(), (~~x).value()); + } + { // Test bitwise not of zero. + T x(0x00); + EXPECT_EQ(V(~(x.value())), (~x).value()); + EXPECT_EQ(x.value(), (~~x).value()); + } +} + +TYPED_TEST(UnsignedSafeIntTest, TestMultiply) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test multiplication by a negative. + T x(93); + EXPECT_DEATH(x * -1, "negation"); + EXPECT_DEATH(x *= -1, "negation"); + } +} + +TYPED_TEST(UnsignedSafeIntTest, TestDivide) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test division by a negative. + T x(93); + EXPECT_DEATH(x / -1, "negation"); + EXPECT_DEATH(x /= -1, "negation"); + } +} + +TYPED_TEST(UnsignedSafeIntTest, TestModulo) { + typedef typename TestFixture::SafeIntTypeUnderTest T; + typedef typename T::ValueType V; + + { // Test modulo by a negative. + T x(93); + EXPECT_DEATH(x % -5, "negation"); + EXPECT_DEATH(x %= -5, "negation"); + } +} + +} // namespace intops +} // namespace mediapipe diff --git a/mediapipe/framework/deps/singleton.h b/mediapipe/framework/deps/singleton.h new file mode 100644 index 000000000..86dcd78df --- /dev/null +++ b/mediapipe/framework/deps/singleton.h @@ -0,0 +1,72 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_SINGLETON_H_ +#define MEDIAPIPE_DEPS_SINGLETON_H_ + +#include "absl/synchronization/mutex.h" + +// The Singleton template class creates a single instance of template parameter +// |T| when needed in a thread-safe fashion. A pointer to this single instance +// may be retrieved through a call to get(). +template +class Singleton { + public: + // Returns the pointer to the singleton of type |T|. + // This method is thread-safe. + static T *get() LOCKS_EXCLUDED(mu_) { + absl::MutexLock lock(&mu_); + if (instance_) { + return instance_; + } + + if (destroyed_) { + return nullptr; + } + if (instance_) { + return instance_; + } + instance_ = new T(); + return instance_; + } + + // Destroys the singleton . This method is only partially thread-safe. + // It ensures that instance_ gets destroyed only once, and once destroyed, it + // cannot be recreated. However, the callers of this method responsible for + // making sure that no other threads are accessing (or plan to access) the + // singleton any longer. + static void Destruct() LOCKS_EXCLUDED(mu_) { + absl::MutexLock lock(&mu_); + T *tmp_ptr = instance_; + instance_ = nullptr; + delete tmp_ptr; + destroyed_ = true; + } + + private: + static T *instance_ GUARDED_BY(mu_); + static bool destroyed_ GUARDED_BY(mu_); + static absl::Mutex mu_; +}; + +template +T *Singleton::instance_ = nullptr; + +template +bool Singleton::destroyed_ = false; + +template +absl::Mutex Singleton::mu_; + +#endif // MEDIAPIPE_DEPS_SINGLETON_H_ diff --git a/mediapipe/framework/deps/source_location.h b/mediapipe/framework/deps/source_location.h new file mode 100644 index 000000000..7f7af9f37 --- /dev/null +++ b/mediapipe/framework/deps/source_location.h @@ -0,0 +1,64 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_SOURCE_LOCATION_H_ +#define MEDIAPIPE_DEPS_SOURCE_LOCATION_H_ + +#include + +namespace mediapipe { + +// Class representing a specific location in the source code of a program. +// source_location is copyable. +class source_location { + public: + // Avoid this constructor; it populates the object with dummy values. + constexpr source_location() : line_(0), file_name_(nullptr) {} + + // Wrapper to invoke the private constructor below. This should only be + // used by the MEDIAPIPE_LOC macro, hence the name. + static constexpr source_location DoNotInvokeDirectly(std::uint_least32_t line, + const char* file_name) { + return source_location(line, file_name); + } + + // The line number of the captured source location. + constexpr std::uint_least32_t line() const { return line_; } + + // The file name of the captured source location. + constexpr const char* file_name() const { return file_name_; } + + // column() and function_name() are omitted because we don't have a + // way to support them. + + private: + // Do not invoke this constructor directly. Instead, use the + // MEDIAPIPE_LOC macro below. + // + // file_name must outlive all copies of the source_location + // object, so in practice it should be a std::string literal. + constexpr source_location(std::uint_least32_t line, const char* file_name) + : line_(line), file_name_(file_name) {} + + std::uint_least32_t line_; + const char* file_name_; +}; + +} // namespace mediapipe + +// If a function takes a source_location parameter, pass this as the argument. +#define MEDIAPIPE_LOC \ + ::mediapipe::source_location::DoNotInvokeDirectly(__LINE__, __FILE__) + +#endif // MEDIAPIPE_DEPS_SOURCE_LOCATION_H_ diff --git a/mediapipe/framework/deps/status.cc b/mediapipe/framework/deps/status.cc new file mode 100644 index 000000000..c0a02ee3d --- /dev/null +++ b/mediapipe/framework/deps/status.cc @@ -0,0 +1,133 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/status.h" + +#include + +namespace mediapipe { + +Status::Status(::mediapipe::StatusCode code, absl::string_view msg) { + state_ = std::unique_ptr(new State); + state_->code = code; + state_->msg = std::string(msg); +} + +void Status::Update(const Status& new_status) { + if (ok()) { + *this = new_status; + } +} + +void Status::SlowCopyFrom(const State* src) { + if (src == nullptr) { + state_ = nullptr; + } else { + state_ = std::unique_ptr(new State(*src)); + } +} + +const std::string& Status::empty_string() { + static std::string* empty = new std::string; + return *empty; +} + +std::string Status::ToString() const { + if (state_ == nullptr) { + return "OK"; + } else { + char tmp[30]; + const char* type; + switch (code()) { + case ::mediapipe::StatusCode::kCancelled: + type = "Cancelled"; + break; + case ::mediapipe::StatusCode::kUnknown: + type = "Unknown"; + break; + case ::mediapipe::StatusCode::kInvalidArgument: + type = "Invalid argument"; + break; + case ::mediapipe::StatusCode::kDeadlineExceeded: + type = "Deadline exceeded"; + break; + case ::mediapipe::StatusCode::kNotFound: + type = "Not found"; + break; + case ::mediapipe::StatusCode::kAlreadyExists: + type = "Already exists"; + break; + case ::mediapipe::StatusCode::kPermissionDenied: + type = "Permission denied"; + break; + case ::mediapipe::StatusCode::kUnauthenticated: + type = "Unauthenticated"; + break; + case ::mediapipe::StatusCode::kResourceExhausted: + type = "Resource exhausted"; + break; + case ::mediapipe::StatusCode::kFailedPrecondition: + type = "Failed precondition"; + break; + case ::mediapipe::StatusCode::kAborted: + type = "Aborted"; + break; + case ::mediapipe::StatusCode::kOutOfRange: + type = "Out of range"; + break; + case ::mediapipe::StatusCode::kUnimplemented: + type = "Unimplemented"; + break; + case ::mediapipe::StatusCode::kInternal: + type = "Internal"; + break; + case ::mediapipe::StatusCode::kUnavailable: + type = "Unavailable"; + break; + case ::mediapipe::StatusCode::kDataLoss: + type = "Data loss"; + break; + default: + snprintf(tmp, sizeof(tmp), "Unknown code(%d)", + static_cast(code())); + type = tmp; + break; + } + std::string result(type); + result += ": "; + result += state_->msg; + return result; + } +} + +void Status::IgnoreError() const { + // no-op +} + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +std::string* MediaPipeCheckOpHelperOutOfLine(const ::mediapipe::Status& v, + const char* msg) { + std::string r("Non-OK-status: "); + r += msg; + r += " status: "; + r += v.ToString(); + // Leaks std::string but this is only to be used in a fatal error message + return new std::string(r); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/status.h b/mediapipe/framework/deps/status.h new file mode 100644 index 000000000..80f4055ce --- /dev/null +++ b/mediapipe/framework/deps/status.h @@ -0,0 +1,172 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_STATUS_H_ +#define MEDIAPIPE_DEPS_STATUS_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +enum class StatusCode { + kOk = 0, + kCancelled = 1, + kUnknown = 2, + kInvalidArgument = 3, + kDeadlineExceeded = 4, + kNotFound = 5, + kAlreadyExists = 6, + kPermissionDenied = 7, + kResourceExhausted = 8, + kFailedPrecondition = 9, + kAborted = 10, + kOutOfRange = 11, + kUnimplemented = 12, + kInternal = 13, + kUnavailable = 14, + kDataLoss = 15, + kUnauthenticated = 16, + kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 +}; + +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +class ABSL_MUST_USE_RESULT Status; +#endif + +// Denotes success or failure of a call in MediaPipe. +class Status { + public: + // Creates a success status. + Status() {} + + // Creates a status with the specified error code and msg as a + // human-readable std::string containing more detailed information. + Status(::mediapipe::StatusCode code, absl::string_view msg); + + // Copies the specified status. + Status(const Status& s); + void operator=(const Status& s); + + // Returns true iff the status indicates success. + bool ok() const { + return (state_ == NULL) || (state_->code == ::mediapipe::StatusCode::kOk); + } + + ::mediapipe::StatusCode code() const { + return ok() ? ::mediapipe::StatusCode::kOk : state_->code; + } + + const std::string& error_message() const { + return ok() ? empty_string() : state_->msg; + } + + absl::string_view message() const { + return absl::string_view(error_message()); + } + + bool operator==(const Status& x) const; + bool operator!=(const Status& x) const; + + // If `ok()`, stores `new_status` into `*this`. If `!ok()`, + // preserves the current status, but may augment with additional + // information about `new_status`. + // + // Convenient way of keeping track of the first error encountered. + // Instead of: + // `if (overall_status.ok()) overall_status = new_status` + // Use: + // `overall_status.Update(new_status);` + void Update(const Status& new_status); + + // Returns a std::string representation of this status suitable for + // printing. Returns the std::string `"OK"` for success. + std::string ToString() const; + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; + + private: + static const std::string& empty_string(); + struct State { + ::mediapipe::StatusCode code; + std::string msg; + }; + // OK status has a `NULL` state_. Otherwise, `state_` points to + // a `State` structure containing the error code and message(s) + std::unique_ptr state_; + + void SlowCopyFrom(const State* src); +}; + +inline Status::Status(const Status& s) + : state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {} + +inline void Status::operator=(const Status& s) { + // The following condition catches both aliasing (when this == &s), + // and the common case where both s and *this are ok. + if (state_ != s.state_) { + SlowCopyFrom(s.state_.get()); + } +} + +inline bool Status::operator==(const Status& x) const { + return (this->state_ == x.state_) || (ToString() == x.ToString()); +} + +inline bool Status::operator!=(const Status& x) const { return !(*this == x); } + +inline Status OkStatus() { return Status(); } + +std::ostream& operator<<(std::ostream& os, const Status& x); + +typedef std::function StatusCallback; + +extern std::string* MediaPipeCheckOpHelperOutOfLine( + const ::mediapipe::Status& v, const char* msg); + +inline std::string* MediaPipeCheckOpHelper(::mediapipe::Status v, + const char* msg) { + if (v.ok()) return nullptr; + return MediaPipeCheckOpHelperOutOfLine(v, msg); +} + +#define MEDIAPIPE_DO_CHECK_OK(val, level) \ + while (auto _result = ::mediapipe::MediaPipeCheckOpHelper(val, #val)) \ + LOG(level) << *(_result) + +// To be consistent with MEDIAPIPE_EXPECT_OK, we add prefix MEDIAPIPE_ to +// CHECK_OK, QCHECK_OK, and DCHECK_OK. We prefer to use the marcos with +// MEDIAPIPE_ prefix in mediapipe's codebase. +#define MEDIAPIPE_CHECK_OK(val) MEDIAPIPE_DO_CHECK_OK(val, FATAL) +#define MEDIAPIPE_QCHECK_OK(val) MEDIAPIPE_DO_CHECK_OK(val, QFATAL) + +#ifndef NDEBUG +#define MEDIAPIPE_DCHECK_OK(val) MEDIAPIPE_CHECK_OK(val) +#else +#define MEDIAPIPE_DCHECK_OK(val) \ + while (false && (::mediapipe::OkStatus() == (val))) LOG(FATAL) +#endif + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_STATUS_H_ diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc new file mode 100644 index 000000000..2e66af296 --- /dev/null +++ b/mediapipe/framework/deps/status_builder.cc @@ -0,0 +1,85 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/status_builder.h" + +#include "absl/memory/memory.h" + +namespace mediapipe { + +StatusBuilder::StatusBuilder(const StatusBuilder& sb) { + status_ = sb.status_; + file_ = sb.file_; + line_ = sb.line_; + no_logging_ = sb.no_logging_; + stream_ = absl::make_unique(sb.stream_->str()); + join_style_ = sb.join_style_; +} + +StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { + status_ = sb.status_; + file_ = sb.file_; + line_ = sb.line_; + no_logging_ = sb.no_logging_; + stream_ = absl::make_unique(sb.stream_->str()); + join_style_ = sb.join_style_; + return *this; +} + +StatusBuilder& StatusBuilder::SetAppend() { + if (status_.ok()) return *this; + join_style_ = MessageJoinStyle::kAppend; + return *this; +} + +StatusBuilder& StatusBuilder::SetPrepend() { + if (status_.ok()) return *this; + join_style_ = MessageJoinStyle::kPrepend; + return *this; +} + +StatusBuilder& StatusBuilder::SetNoLogging() { + no_logging_ = true; + return *this; +} + +StatusBuilder::operator Status() const& { + if (stream_->str().empty() || no_logging_) { + return status_; + } + return StatusBuilder(*this).JoinMessageToStatus(); +} + +StatusBuilder::operator Status() && { + if (stream_->str().empty() || no_logging_) { + return status_; + } + return JoinMessageToStatus(); +} + +::mediapipe::Status StatusBuilder::JoinMessageToStatus() { + std::string message; + if (join_style_ == MessageJoinStyle::kAnnotate) { + if (!status_.ok()) { + message = absl::StrCat(status_.error_message(), "; ", stream_->str()); + } + } else { + message = join_style_ == MessageJoinStyle::kPrepend + ? absl::StrCat(stream_->str(), status_.error_message()) + : absl::StrCat(status_.error_message(), stream_->str()); + } + return Status(status_.code(), message); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h new file mode 100644 index 000000000..c89a4d4c7 --- /dev/null +++ b/mediapipe/framework/deps/status_builder.h @@ -0,0 +1,148 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_STATUS_BUILDER_H_ +#define MEDIAPIPE_DEPS_STATUS_BUILDER_H_ + +#include "absl/base/attributes.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/source_location.h" +#include "mediapipe/framework/deps/status.h" + +namespace mediapipe { + +class ABSL_MUST_USE_RESULT StatusBuilder { + public: + StatusBuilder(const StatusBuilder& sb); + StatusBuilder& operator=(const StatusBuilder& sb); + // Creates a `StatusBuilder` based on an original status. If logging is + // enabled, it will use `location` as the location from which the log message + // occurs. A typical user will call this with `MEDIAPIPE_LOC`. + StatusBuilder(const ::mediapipe::Status& original_status, + ::mediapipe::source_location location) + : status_(original_status), + line_(location.line()), + file_(location.file_name()), + stream_(new std::ostringstream) {} + + StatusBuilder(::mediapipe::Status&& original_status, + ::mediapipe::source_location location) + : status_(std::move(original_status)), + line_(location.line()), + file_(location.file_name()), + stream_(new std::ostringstream) {} + + // Creates a `StatusBuilder` from a mediapipe status code. If logging is + // enabled, it will use `location` as the location from which the log message + // occurs. A typical user will call this with `MEDIAPIPE_LOC`. + StatusBuilder(::mediapipe::StatusCode code, + ::mediapipe::source_location location) + : status_(code, ""), + line_(location.line()), + file_(location.file_name()), + stream_(new std::ostringstream) {} + + StatusBuilder(const ::mediapipe::Status& original_status, const char* file, + int line) + : status_(original_status), + line_(line), + file_(file), + stream_(new std::ostringstream) {} + + bool ok() const { return status_.ok(); } + + StatusBuilder& SetAppend(); + + StatusBuilder& SetPrepend(); + + StatusBuilder& SetNoLogging(); + + template + StatusBuilder& operator<<(const T& msg) { + if (status_.ok()) return *this; + *stream_ << msg; + return *this; + } + + operator Status() const&; + operator Status() &&; + + ::mediapipe::Status JoinMessageToStatus(); + + private: + // Specifies how to join the error message in the original status and any + // additional message that has been streamed into the builder. + enum class MessageJoinStyle { + kAnnotate, + kAppend, + kPrepend, + }; + + // The status that the result will be based on. + ::mediapipe::Status status_; + // The line to record if this file is logged. + int line_; + // Not-owned: The file to record if this status is logged. + const char* file_; + bool no_logging_ = false; + // The additional messages added with `<<`. + std::unique_ptr stream_; + // Specifies how to join the message in `status_` and `stream_`. + MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate; +}; + +inline StatusBuilder AlreadyExistsErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kAlreadyExists, location); +} + +inline StatusBuilder FailedPreconditionErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kFailedPrecondition, location); +} + +inline StatusBuilder InternalErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kInternal, location); +} + +inline StatusBuilder InvalidArgumentErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kInvalidArgument, location); +} + +inline StatusBuilder NotFoundErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kNotFound, location); +} + +inline StatusBuilder UnavailableErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kUnavailable, location); +} + +inline StatusBuilder UnimplementedErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kUnimplemented, location); +} + +inline StatusBuilder UnknownErrorBuilder( + ::mediapipe::source_location location) { + return StatusBuilder(::mediapipe::StatusCode::kUnknown, location); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_STATUS_BUILDER_H_ diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc new file mode 100644 index 000000000..76c92eed5 --- /dev/null +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/status_builder.h" + +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { + +TEST(StatusBuilder, AnnotateMode) { + ::mediapipe::Status status = + StatusBuilder(::mediapipe::Status(::mediapipe::StatusCode::kNotFound, + "original message"), + MEDIAPIPE_LOC) + << "annotated message1 " + << "annotated message2"; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kNotFound); + EXPECT_EQ(status.error_message(), + "original message; annotated message1 annotated message2"); +} + +TEST(StatusBuilder, PrependMode) { + ::mediapipe::Status status = + StatusBuilder( + ::mediapipe::Status(::mediapipe::StatusCode::kInvalidArgument, + "original message"), + MEDIAPIPE_LOC) + .SetPrepend() + << "prepended message1 " + << "prepended message2 "; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_EQ(status.error_message(), + "prepended message1 prepended message2 original message"); +} + +TEST(StatusBuilder, AppendMode) { + ::mediapipe::Status status = + StatusBuilder(::mediapipe::Status(::mediapipe::StatusCode::kInternal, + "original message"), + MEDIAPIPE_LOC) + .SetAppend() + << " extra message1" + << " extra message2"; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInternal); + EXPECT_EQ(status.error_message(), + "original message extra message1 extra message2"); +} + +TEST(StatusBuilder, NoLoggingMode) { + ::mediapipe::Status status = + StatusBuilder(::mediapipe::Status(::mediapipe::StatusCode::kUnavailable, + "original message"), + MEDIAPIPE_LOC) + .SetNoLogging() + << " extra message"; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(status.error_message(), "original message"); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h new file mode 100644 index 000000000..90b23cc22 --- /dev/null +++ b/mediapipe/framework/deps/status_macros.h @@ -0,0 +1,221 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Helper macros and methods to return and propagate errors with +// `::mediapipe::Status`. +// +// The owners of mediapipe do not endorse use of these macros as a good +// programming practice, and would prefer that you write the equivalent C++ +// directly. The macros are provided and supported for those that disagree, +// with the goal of having a single, consistent, and robust implementation. + +#ifndef MEDIAPIPE_DEPS_STATUS_MACROS_H_ +#define MEDIAPIPE_DEPS_STATUS_MACROS_H_ + +#include "mediapipe/framework/deps/status.h" +#include "mediapipe/framework/deps/status_builder.h" + +// Evaluates an expression that produces a `::mediapipe::Status`. If the status +// is not ok, returns it from the current function. +// +// For example: +// ::mediapipe::Status MultiStepFunction() { +// RETURN_IF_ERROR(Function(args...)); +// RETURN_IF_ERROR(foo.Method(args...)); +// return ::mediapipe::OkStatus(); +// } +// +// The macro ends with a `::mediapipe::StatusBuilder` which allows the returned +// status to be extended with more details. Any chained expressions after the +// macro will not be evaluated unless there is an error. +// +// For example: +// ::mediapipe::Status MultiStepFunction() { +// RETURN_IF_ERROR(Function(args...)) << "in MultiStepFunction"; +// RETURN_IF_ERROR(foo.Method(args...)).Log(base_logging::ERROR) +// << "while processing query: " << query.DebugString(); +// return ::mediapipe::OkStatus(); +// } +// +// `::mediapipe::StatusBuilder` supports adapting the builder chain using a +// `With` method and a functor. This allows for powerful extensions to the +// macro. +// +// For example, teams can define local policies to use across their code: +// +// StatusBuilder TeamPolicy(StatusBuilder builder) { +// return std::move(builder.Log(base_logging::WARNING).Attach(...)); +// } +// +// RETURN_IF_ERROR(foo()).With(TeamPolicy); +// RETURN_IF_ERROR(bar()).With(TeamPolicy); +// +// Changing the return type allows the macro to be used with Task and Rpc +// interfaces. See `::mediapipe::TaskReturn` and `rpc::RpcSetStatus` for +// details. +// +// void Read(StringPiece name, ::mediapipe::Task* task) { +// int64 id; +// RETURN_IF_ERROR(GetIdForName(name, &id)).With(TaskReturn(task)); +// RETURN_IF_ERROR(ReadForId(id)).With(TaskReturn(task)); +// task->Return(); +// } +// +// If using this macro inside a lambda, you need to annotate the return type +// to avoid confusion between a `::mediapipe::StatusBuilder` and a +// `::mediapipe::Status` type. E.g. +// +// []() -> ::mediapipe::Status { +// RETURN_IF_ERROR(Function(args...)); +// RETURN_IF_ERROR(foo.Method(args...)); +// return ::mediapipe::OkStatus(); +// } +#define RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (::mediapipe::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), __FILE__, __LINE__}) { \ + } else /* NOLINT */ \ + return status_macro_internal_adaptor.Consume() + +// Executes an expression `rexpr` that returns a `::mediapipe::StatusOr`. On +// OK, extracts its value into the variable defined by `lhs`, otherwise returns +// from the current function. By default the error status is returned +// unchanged, but it may be modified by an `error_expression`. If there is an +// error, `lhs` is not evaluated; thus any side effects that `lhs` may have +// only occur in the success case. +// +// Interface: +// +// ASSIGN_OR_RETURN(lhs, rexpr) +// ASSIGN_OR_RETURN(lhs, rexpr, error_expression); +// +// WARNING: expands into multiple statements; it cannot be used in a single +// statement (e.g. as the body of an if statement without {})! +// +// Example: Declaring and initializing a new variable (ValueType can be anything +// that can be initialized with assignment, including references): +// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(arg)); +// +// Example: Assigning to an existing variable: +// ValueType value; +// ASSIGN_OR_RETURN(value, MaybeGetValue(arg)); +// +// Example: Assigning to an expression with side effects: +// MyProto data; +// ASSIGN_OR_RETURN(*data.mutable_str(), MaybeGetValue(arg)); +// // No field "str" is added on error. +// +// Example: Assigning to a std::unique_ptr. +// ASSIGN_OR_RETURN(std::unique_ptr ptr, MaybeGetPtr(arg)); +// +// If passed, the `error_expression` is evaluated to produce the return +// value. The expression may reference any variable visible in scope, as +// well as a `::mediapipe::StatusBuilder` object populated with the error and +// named by a single underscore `_`. The expression typically uses the +// builder to modify the status and is returned directly in manner similar +// to RETURN_IF_ERROR. The expression may, however, evaluate to any type +// returnable by the function, including (void). For example: +// +// Example: Adjusting the error message. +// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(query), +// _ << "while processing query " << query.DebugString()); +// +// Example: Logging the error on failure. +// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(query), _.LogError()); +// +#define ASSIGN_OR_RETURN(...) \ + STATUS_MACROS_IMPL_GET_VARIADIC_(__VA_ARGS__, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_) \ + (__VA_ARGS__) + +// ================================================================= +// == Implementation details, do not rely on anything below here. == +// ================================================================= + +#define STATUS_MACROS_IMPL_GET_VARIADIC_(_1, _2, _3, NAME, ...) NAME + +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ + error_expression) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + ::mediapipe::StatusBuilder _(std::move(statusor).status(), __FILE__, \ + __LINE__); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression); \ + } \ + lhs = std::move(statusor).ValueOrDie() + +// Internal helper for concatenating macro values. +#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +// The GNU compiler emits a warning for code like: +// +// if (foo) +// if (bar) { } else baz; +// +// because it thinks you might want the else to bind to the first if. This +// leads to problems with code like: +// +// if (do_expr) RETURN_IF_ERROR(expr) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + switch (0) \ + case 0: \ + default: // NOLINT + +namespace mediapipe { +namespace status_macro_internal { + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + StatusAdaptorForMacros(const Status& status, const char* file, int line) + : builder_(status, file, line) {} + + StatusAdaptorForMacros(Status&& status, const char* file, int line) + : builder_(std::move(status), file, line) {} + + StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, + int /* line */) + : builder_(builder) {} + + StatusAdaptorForMacros(StatusBuilder&& builder, const char* /* file */, + int /* line */) + : builder_(std::move(builder)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); } + + StatusBuilder&& Consume() { return std::move(builder_); } + + private: + StatusBuilder builder_; +}; + +} // namespace status_macro_internal +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_STATUS_MACROS_H_ diff --git a/mediapipe/framework/deps/status_matchers.h b/mediapipe/framework/deps/status_matchers.h new file mode 100644 index 000000000..01255e231 --- /dev/null +++ b/mediapipe/framework/deps/status_matchers.h @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_STATUS_MATCHERS_H_ +#define MEDIAPIPE_DEPS_STATUS_MATCHERS_H_ + +#include "gtest/gtest.h" +#include "mediapipe/framework/deps/status.h" + +// EXPECT_OK marco is already defined in our external dependency library +// protobuf. To be consistent with MEDIAPIPE_EXPECT_OK, we also add prefix +// MEDIAPIPE_ to ASSERT_OK. We prefer to use the marcos with MEDIAPIPE_ prefix +// in mediapipe's codebase. +#define MEDIAPIPE_EXPECT_OK(statement) EXPECT_TRUE((statement).ok()) +#define MEDIAPIPE_ASSERT_OK(statement) ASSERT_TRUE((statement).ok()) + +#endif // MEDIAPIPE_DEPS_STATUS_MATCHERS_H_ diff --git a/mediapipe/framework/deps/status_test.cc b/mediapipe/framework/deps/status_test.cc new file mode 100644 index 000000000..59eeaa4e9 --- /dev/null +++ b/mediapipe/framework/deps/status_test.cc @@ -0,0 +1,98 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/status.h" + +#include "mediapipe/framework/deps/status_matchers.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { + +TEST(Status, OK) { + EXPECT_EQ(OkStatus().code(), ::mediapipe::StatusCode::kOk); + EXPECT_EQ(OkStatus().error_message(), ""); + MEDIAPIPE_EXPECT_OK(OkStatus()); + MEDIAPIPE_ASSERT_OK(OkStatus()); + EXPECT_EQ(OkStatus(), Status()); + Status s; + EXPECT_TRUE(s.ok()); +} + +TEST(DeathStatus, CheckOK) { + Status status(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + ASSERT_DEATH(MEDIAPIPE_CHECK_OK(status), "Invalid"); +} + +TEST(Status, Set) { + Status status; + status = Status(::mediapipe::StatusCode::kCancelled, "Error message"); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kCancelled); + EXPECT_EQ(status.error_message(), "Error message"); +} + +TEST(Status, Copy) { + Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status b(a); + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Assign) { + Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status b; + b = a; + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Update) { + Status s; + s.Update(OkStatus()); + ASSERT_TRUE(s.ok()); + Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + s.Update(a); + ASSERT_EQ(s.ToString(), a.ToString()); + Status b(::mediapipe::StatusCode::kInternal, "Invalid"); + s.Update(b); + ASSERT_EQ(s.ToString(), a.ToString()); + s.Update(OkStatus()); + ASSERT_EQ(s.ToString(), a.ToString()); + ASSERT_FALSE(s.ok()); +} + +TEST(Status, EqualsOK) { ASSERT_EQ(OkStatus(), Status()); } + +TEST(Status, EqualsSame) { + Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status b(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsCopy) { + const Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + const Status b = a; + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsDifferentCode) { + const Status a(::mediapipe::StatusCode::kInvalidArgument, "Invalid"); + const Status b(::mediapipe::StatusCode::kInternal, "Internal"); + ASSERT_NE(a, b); +} + +TEST(Status, EqualsDifferentMessage) { + const Status a(::mediapipe::StatusCode::kInvalidArgument, "message"); + const Status b(::mediapipe::StatusCode::kInvalidArgument, "another"); + ASSERT_NE(a, b); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/statusor.cc b/mediapipe/framework/deps/statusor.cc new file mode 100644 index 000000000..fe63c133f --- /dev/null +++ b/mediapipe/framework/deps/statusor.cc @@ -0,0 +1,38 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/statusor.h" + +#include "absl/base/attributes.h" +#include "mediapipe/framework/deps/canonical_errors.h" +#include "mediapipe/framework/deps/status.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { +namespace internal_statusor { + +void Helper::HandleInvalidStatusCtorArg(::mediapipe::Status* status) { + const char* kMessage = + "An OK status is not a valid constructor argument to StatusOr"; + LOG(ERROR) << kMessage; + *status = ::mediapipe::InternalError(kMessage); +} + +void Helper::Crash(const ::mediapipe::Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; +} + +} // namespace internal_statusor +} // namespace mediapipe diff --git a/mediapipe/framework/deps/statusor.h b/mediapipe/framework/deps/statusor.h new file mode 100644 index 000000000..a33c9382c --- /dev/null +++ b/mediapipe/framework/deps/statusor.h @@ -0,0 +1,331 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// StatusOr is the union of a Status object and a T +// object. StatusOr models the concept of an object that is either a +// usable value, or an error Status explaining why such a value is +// not present. To this end, StatusOr does not allow its Status +// value to be Status::OK. Furthermore, the value of a StatusOr +// must not be null. This is enforced by a debug check in most cases, +// but even when it is not, clients must not set the value to null. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// ::mediapipe::StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// ::mediapipe::StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// ::mediapipe::StatusOr> result = +// FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// ::mediapipe::StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return ::mediapipe::InvalidArgumentError("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. + +#ifndef MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ +#define MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ + +#include "absl/base/attributes.h" +#include "mediapipe/framework/deps/status.h" +#include "mediapipe/framework/deps/status_builder.h" +#include "mediapipe/framework/deps/statusor_internals.h" + +namespace mediapipe { + +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class ABSL_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructible/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructible/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + // TODO: These should not participate in overload resolution if U + // is not convertible to T. + template + StatusOr(const StatusOr& other); + template + StatusOr(StatusOr&& other); + + // Conversion copy/move assignment operator, T must be convertible from U. + template + StatusOr& operator=(const StatusOr& other); + template + StatusOr& operator=(StatusOr&& other); + + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing ::mediapipe::StatusCode::kInternal as a fallback. + StatusOr(const ::mediapipe::Status& status); + StatusOr& operator=(const ::mediapipe::Status& status); + StatusOr(const ::mediapipe::StatusBuilder& builder); + StatusOr& operator=(const ::mediapipe::StatusBuilder& builder); + + // TODO: Add operator=(T) overloads. + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(::mediapipe::Status&& status); + StatusOr& operator=(::mediapipe::Status&& status); + StatusOr(::mediapipe::StatusBuilder&& builder); + StatusOr& operator=(::mediapipe::StatusBuilder&& builder); + + // Returns this->status().ok() + bool ok() const { return this->status_.ok(); } + + // Returns a reference to mediapipe status. If this contains a T, then + // returns Status::OK(). + const ::mediapipe::Status& status() const&; + ::mediapipe::Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const&; + T& ValueOrDie() &; + const T&& ValueOrDie() const&&; + T&& ValueOrDie() &&; + + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() + : Base(::mediapipe::Status(::mediapipe::StatusCode::kUnknown, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const ::mediapipe::Status& status) : Base(status) {} + +template +StatusOr::StatusOr(const ::mediapipe::StatusBuilder& builder) + : Base(builder) {} + +template +StatusOr& StatusOr::operator=(const ::mediapipe::Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr& StatusOr::operator=(const ::mediapipe::StatusBuilder& builder) { + return *this = static_cast<::mediapipe::Status>(builder); +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(::mediapipe::Status&& status) : Base(std::move(status)) {} + +template +StatusOr::StatusOr(::mediapipe::StatusBuilder&& builder) + : Base(std::move(builder)) {} + +template +StatusOr& StatusOr::operator=(::mediapipe::Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +StatusOr& StatusOr::operator=(::mediapipe::StatusBuilder&& builder) { + return *this = static_cast<::mediapipe::Status>(std::move(builder)); +} + +template +template +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + +template +template +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; +} + +template +template +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + +template +template +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } + return *this; +} + +template +const ::mediapipe::Status& StatusOr::status() const& { + return this->status_; +} +template +::mediapipe::Status StatusOr::status() && { + return ok() ? ::mediapipe::OkStatus() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const& { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const&& { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +void StatusOr::IgnoreError() const { + // no-op +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ diff --git a/mediapipe/framework/deps/statusor_internals.h b/mediapipe/framework/deps/statusor_internals.h new file mode 100644 index 000000000..b42d206a2 --- /dev/null +++ b/mediapipe/framework/deps/statusor_internals.h @@ -0,0 +1,245 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ +#define MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ + +#include "absl/base/attributes.h" +#include "mediapipe/framework/deps/status.h" + +namespace mediapipe { +namespace internal_statusor { + +class Helper { + public: + // Move type-agnostic error handling to the .cc. + static void HandleInvalidStatusCtorArg(::mediapipe::Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const ::mediapipe::Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) __builtin_unreachable(); +#endif + new (p) T(std::forward(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template +class StatusOrData { + template + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template + StatusOrData(StatusOrData&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const ::mediapipe::Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(::mediapipe::Status&& status) + : status_(std::move(status)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + void Assign(const T& value) { + if (ok()) { + data_.~T(); + MakeValue(value); + } else { + MakeValue(value); + status_ = ::mediapipe::OkStatus(); + } + } + + void Assign(T&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::move(value)); + } else { + MakeValue(std::move(value)); + status_ = ::mediapipe::OkStatus(); + } + } + + void Assign(const ::mediapipe::Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(::mediapipe::Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // We make it a union to be able to initialize exactly how we need without + // waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + ::mediapipe::Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (!ok()) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed + // argument. + template + void MakeValue(Arg&& arg) { + internal_statusor::PlacementNew(&dummy_, std::forward(arg)); + } + + // Construct the status (ie. status_) through placement new with the passed + // argument. + template + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew<::mediapipe::Status>( + &status_, std::forward(args)...); + } +}; + +// Helper base class to allow implicitly deleted constructors and assignment +// operations in StatusOr. +// TraitsBase will explicitly delete what it can't support and StatusOr will +// inherit that behavior implicitly. +template +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = default; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = default; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = delete; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = delete; +}; + +} // namespace internal_statusor +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ diff --git a/mediapipe/framework/deps/statusor_test.cc b/mediapipe/framework/deps/statusor_test.cc new file mode 100644 index 000000000..dde69a424 --- /dev/null +++ b/mediapipe/framework/deps/statusor_test.cc @@ -0,0 +1,438 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Unit tests for StatusOr + +#include "mediapipe/framework/deps/statusor.h" + +#include +#include + +#include "mediapipe/framework/deps/canonical_errors.h" +#include "mediapipe/framework/deps/status.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { +namespace { + +class Base1 { + public: + virtual ~Base1() {} + int pad_; +}; + +class Base2 { + public: + virtual ~Base2() {} + int yetotherpad_; +}; + +class Derived : public Base1, public Base2 { + public: + ~Derived() override {} + int evenmorepad_; +}; + +class CopyNoAssign { + public: + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; + + private: + const CopyNoAssign& operator=(const CopyNoAssign&); +}; + +class NoDefaultConstructor { + public: + explicit NoDefaultConstructor(int foo); +}; + +static_assert(!std::is_default_constructible(), + "Should not be default-constructible."); + +StatusOr> ReturnUniquePtr() { + // Uses implicit constructor from T&& + return std::unique_ptr(new int(0)); +} + +TEST(StatusOr, ElementType) { + static_assert(std::is_same::element_type, int>(), ""); + static_assert(std::is_same::element_type, char>(), ""); +} + +TEST(StatusOr, TestNoDefaultConstructorInitialization) { + // Explicitly initialize it with an error code. + ::mediapipe::StatusOr statusor( + ::mediapipe::CancelledError("")); + EXPECT_FALSE(statusor.ok()); + EXPECT_EQ(statusor.status().code(), ::mediapipe::StatusCode::kCancelled); + + // Default construction of StatusOr initializes it with an UNKNOWN error code. + ::mediapipe::StatusOr statusor2; + EXPECT_FALSE(statusor2.ok()); + EXPECT_EQ(statusor2.status().code(), ::mediapipe::StatusCode::kUnknown); +} + +TEST(StatusOr, TestMoveOnlyInitialization) { + ::mediapipe::StatusOr> thing(ReturnUniquePtr()); + ASSERT_TRUE(thing.ok()); + EXPECT_EQ(0, *thing.ValueOrDie()); + int* previous = thing.ValueOrDie().get(); + + thing = ReturnUniquePtr(); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(0, *thing.ValueOrDie()); + EXPECT_NE(previous, thing.ValueOrDie().get()); +} + +TEST(StatusOr, TestMoveOnlyStatusCtr) { + ::mediapipe::StatusOr> thing( + ::mediapipe::CancelledError("")); + ASSERT_FALSE(thing.ok()); +} + +TEST(StatusOr, TestMoveOnlyValueExtraction) { + ::mediapipe::StatusOr> thing(ReturnUniquePtr()); + ASSERT_TRUE(thing.ok()); + std::unique_ptr ptr = thing.ConsumeValueOrDie(); + EXPECT_EQ(0, *ptr); + + thing = std::move(ptr); + ptr = std::move(thing.ValueOrDie()); + EXPECT_EQ(0, *ptr); +} + +TEST(StatusOr, TestMoveOnlyConversion) { + ::mediapipe::StatusOr> const_thing( + ReturnUniquePtr()); + EXPECT_TRUE(const_thing.ok()); + EXPECT_EQ(0, *const_thing.ValueOrDie()); + + // Test rvalue converting assignment + const int* const_previous = const_thing.ValueOrDie().get(); + const_thing = ReturnUniquePtr(); + EXPECT_TRUE(const_thing.ok()); + EXPECT_EQ(0, *const_thing.ValueOrDie()); + EXPECT_NE(const_previous, const_thing.ValueOrDie().get()); +} + +TEST(StatusOr, TestMoveOnlyVector) { + // Sanity check that ::mediapipe::StatusOr works in vector. + std::vector<::mediapipe::StatusOr>> vec; + vec.push_back(ReturnUniquePtr()); + vec.resize(2); + auto another_vec = std::move(vec); + EXPECT_EQ(0, *another_vec[0].ValueOrDie()); + EXPECT_EQ(::mediapipe::StatusCode::kUnknown, another_vec[1].status().code()); +} + +TEST(StatusOr, TestMoveWithValuesAndErrors) { + ::mediapipe::StatusOr status_or(std::string(1000, '0')); + ::mediapipe::StatusOr value1(std::string(1000, '1')); + ::mediapipe::StatusOr value2(std::string(1000, '2')); + ::mediapipe::StatusOr error1( + Status(::mediapipe::StatusCode::kUnknown, "error1")); + ::mediapipe::StatusOr error2( + Status(::mediapipe::StatusCode::kUnknown, "error2")); + + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with another value. + status_or = std::move(value1); + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with an error. + status_or = std::move(error1); + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error1", status_or.status().error_message()); + + // Overwrite the error in status_or with another error. + status_or = std::move(error2); + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error2", status_or.status().error_message()); + + // Overwrite the error with a value. + status_or = std::move(value2); + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); +} + +TEST(StatusOr, TestCopyWithValuesAndErrors) { + ::mediapipe::StatusOr status_or(std::string(1000, '0')); + ::mediapipe::StatusOr value1(std::string(1000, '1')); + ::mediapipe::StatusOr value2(std::string(1000, '2')); + ::mediapipe::StatusOr error1( + Status(::mediapipe::StatusCode::kUnknown, "error1")); + ::mediapipe::StatusOr error2( + Status(::mediapipe::StatusCode::kUnknown, "error2")); + + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with another value. + status_or = value1; + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with an error. + status_or = error1; + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error1", status_or.status().error_message()); + + // Overwrite the error in status_or with another error. + status_or = error2; + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error2", status_or.status().error_message()); + + // Overwrite the error with a value. + status_or = value2; + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); + + // Verify original values unchanged. + EXPECT_EQ(std::string(1000, '1'), value1.ValueOrDie()); + EXPECT_EQ("error1", error1.status().error_message()); + EXPECT_EQ("error2", error2.status().error_message()); + EXPECT_EQ(std::string(1000, '2'), value2.ValueOrDie()); +} + +TEST(StatusOr, TestDefaultCtor) { + ::mediapipe::StatusOr thing; + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), ::mediapipe::StatusCode::kUnknown); +} + +TEST(StatusOrDeathTest, TestDefaultCtorValue) { + ::mediapipe::StatusOr thing; + EXPECT_DEATH(thing.ValueOrDie(), ""); + + const ::mediapipe::StatusOr thing2; + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestStatusCtor) { + ::mediapipe::StatusOr thing( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "")); + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), ::mediapipe::StatusCode::kCancelled); +} + +TEST(StatusOr, TestValueCtor) { + const int kI = 4; + const ::mediapipe::StatusOr thing(kI); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusOk) { + const int kI = 4; + const ::mediapipe::StatusOr original(kI); + const ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusNotOk) { + ::mediapipe::StatusOr original( + Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestCopyCtorNonAssignable) { + const int kI = 4; + CopyNoAssign value(kI); + ::mediapipe::StatusOr original(value); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); +} + +TEST(StatusOr, TestCopyCtorStatusOKConverting) { + const int kI = 4; + ::mediapipe::StatusOr original(kI); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_DOUBLE_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusNotOkConverting) { + ::mediapipe::StatusOr original( + Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestAssignmentStatusOk) { + const int kI = 4; + ::mediapipe::StatusOr source(kI); + ::mediapipe::StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); + EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); +} + +TEST(StatusOr, TestAssignmentStatusNotOk) { + ::mediapipe::StatusOr source( + Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); +} + +TEST(StatusOr, TestStatus) { + ::mediapipe::StatusOr good(4); + EXPECT_TRUE(good.ok()); + ::mediapipe::StatusOr bad( + Status(::mediapipe::StatusCode::kCancelled, "")); + EXPECT_FALSE(bad.ok()); + EXPECT_EQ(bad.status(), Status(::mediapipe::StatusCode::kCancelled, "")); +} + +TEST(StatusOr, TestValue) { + const int kI = 4; + ::mediapipe::StatusOr thing(kI); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestValueConst) { + const int kI = 4; + const ::mediapipe::StatusOr thing(kI); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOrDeathTest, TestValueNotOk) { + ::mediapipe::StatusOr thing( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +TEST(StatusOrDeathTest, TestValueNotOkConst) { + const ::mediapipe::StatusOr thing( + ::mediapipe::Status(::mediapipe::StatusCode::kUnknown, "")); + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestPointerDefaultCtor) { + ::mediapipe::StatusOr thing; + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), ::mediapipe::StatusCode::kUnknown); +} + +TEST(StatusOrDeathTest, TestPointerDefaultCtorValue) { + ::mediapipe::StatusOr thing; + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestPointerStatusCtor) { + ::mediapipe::StatusOr thing( + Status(::mediapipe::StatusCode::kCancelled, "")); + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status(), Status(::mediapipe::StatusCode::kCancelled, "")); +} + +TEST(StatusOr, TestPointerValueCtor) { + const int kI = 4; + ::mediapipe::StatusOr thing(&kI); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusOk) { + const int kI = 0; + ::mediapipe::StatusOr original(&kI); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusNotOk) { + ::mediapipe::StatusOr original( + Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusOKConverting) { + Derived derived; + ::mediapipe::StatusOr original(&derived); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(static_cast(original.ValueOrDie()), + copy.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusNotOkConverting) { + ::mediapipe::StatusOr original( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestPointerAssignmentStatusOk) { + const int kI = 0; + ::mediapipe::StatusOr source(&kI); + ::mediapipe::StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); + EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); +} + +TEST(StatusOr, TestPointerAssignmentStatusNotOk) { + ::mediapipe::StatusOr source( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "")); + ::mediapipe::StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); +} + +TEST(StatusOr, TestPointerStatus) { + const int kI = 0; + ::mediapipe::StatusOr good(&kI); + EXPECT_TRUE(good.ok()); + ::mediapipe::StatusOr bad( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "")); + EXPECT_EQ(bad.status(), + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "")); +} + +TEST(StatusOr, TestPointerValue) { + const int kI = 0; + ::mediapipe::StatusOr thing(&kI); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestPointerValueConst) { + const int kI = 0; + const ::mediapipe::StatusOr thing(&kI); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +TEST(StatusOrDeathTest, TestPointerValueNotOk) { + ::mediapipe::StatusOr thing( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +TEST(StatusOrDeathTest, TestPointerValueNotOkConst) { + const ::mediapipe::StatusOr thing( + ::mediapipe::Status(::mediapipe::StatusCode::kCancelled, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/deps/strong_int.h b/mediapipe/framework/deps/strong_int.h new file mode 100644 index 000000000..6f102238f --- /dev/null +++ b/mediapipe/framework/deps/strong_int.h @@ -0,0 +1,461 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// StrongInt is a simple template class mechanism for defining "logical" +// integer-like class types that support almost all of the same functionality +// as native integer types, but which prevents assignment, construction, and +// other operations from other integer-like types. In other words, you cannot +// assign from raw integer types or other StrongInt<> types, nor can you do +// most arithmetic or logical operations. This provides a simple form of +// dimensionality in that you can add two instances of StrongInt, producing +// a StrongInt, but you can not add a StrongInt and a raw T nor can you +// add a StrongInt and a StrongInt. Details on supported operations are +// below. +// +// In addition to type strength, StrongInt provides a way to inject (optional) +// validation of the various operations. This allows you to define StrongInt +// types that check for overflow conditions and react in standard or custom +// ways. +// +// A StrongInt with a NullStrongIntValidator should compile away to a raw T +// in optimized mode. What this means is that the generated assembly for: +// +// int64 foo = 123; +// int64 bar = 456; +// int64 baz = foo + bar; +// constexpr int64 fubar = 789; +// +// ...should be identical to the generated assembly for: +// +// DEFINE_STRONG_INT_TYPE(MyStrongInt, int64); +// MyStrongInt foo(123); +// MyStrongInt bar(456); +// MyStrongInt baz = foo + bar; +// constexpr MyStrongInt fubar(789); +// +// Since the methods are all inline and non-virtual and the class has just +// one data member, the compiler can erase the StrongInt class entirely in its +// code-generation phase. This also means that you can pass StrongInt +// around by value just as you would a raw T. +// +// It is important to note that StrongInt does NOT generate compile time +// warnings or errors for overflows on implicit constant conversions. +// +// Usage: +// StrongInt +// +// Creates a new StrongInt instance directly. +// +// TagType: The unique type which discriminates this StrongInt from +// other StrongInt types. +// NativeType: The primitive integral type this StrongInt will hold, as +// defined by std::is_integral (see ). +// ValidatorType: The type of validation used by this StrongInt type. A +// few pre-built validator types are provided here, but the caller can +// define any custom validator they desire. +// +// Supported operations: +// StrongInt = StrongInt +// !StrongInt => bool +// ~StrongInt => StrongInt +// -StrongInt => StrongInt +// +StrongInt => StrongInt +// ++StrongInt => StrongInt +// StrongInt++ => StrongInt +// --StrongInt => StrongInt +// StrongInt-- => StrongInt +// StrongInt + StrongInt => StrongInt +// StrongInt - StrongInt => StrongInt +// StrongInt * (numeric type) => StrongInt +// StrongInt / (numeric type) => StrongInt +// StrongInt % (numeric type) => StrongInt +// StrongInt << (numeric type) => StrongInt +// StrongInt >> (numeric type) => StrongInt +// StrongInt & StrongInt => StrongInt +// StrongInt | StrongInt => StrongInt +// StrongInt ^ StrongInt => StrongInt +// +// For binary operations, the equivalent op-equal (eg += vs. +) operations are +// also supported. Other operator combinations should cause compile-time +// errors. +// +// Validators: +// NullStrongIntValidator: Do no validation. This should be entirely +// optimized away by the compiler. + +#ifndef MEDIAPIPE_DEPS_STRONG_INT_H_ +#define MEDIAPIPE_DEPS_STRONG_INT_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/port.h" + +namespace mediapipe { +namespace intops { + +// Define the validators which can be plugged-in to make StrongInt resilient to +// things like overflows. This is a do-nothing implementation of the +// compile-time interface. +// +// NOTE: For all validation functions that operate on an existing StrongInt, +// the type argument 'T' *must* be StrongInt::ValueType (the int type being +// strengthened). +struct NullStrongIntValidator { + // Verify initialization of StrongInt from arg, type U. + // + // Note that this templated default implementation has an arbitrary bool + // return value for the sole purpose of conforming to c++11 constexpr. + // + // Custom validator implementations can choose to return void or use a similar + // return value constexpr construct if constexpr initialization is desirable. + // + // The StrongInt class does not care about or use the returned value. Any + // returned value is solely there to allow the constexpr declaration; custom + // validators can only fail / abort when detecting an invalid value. + // + // For example, other than the constexpr behavior, the below 2 custom + // validator implementations are logically equivalent: + // + // template + // static void ValidateInit(U arg) { + // if (arg < 0) LOG(FATAL) << "arg < 0"; + // } + // + // template + // static constexpr bool ValidateInit(U arg) { + // return (arg < 0) ? (LOG(FATAL) << "arg < 0", false) : false; + // } + // + // A constexpr ValidateInit implementation has the added advantage that the + // validation can take place (fail) at compile time. + template + static constexpr bool ValidateInit(U arg) { + return true; + } + // Verify -value. + template + static void ValidateNegate(T value) { /* do nothing */ + } + // Verify ~value; + template + static void ValidateBitNot(T value) { /* do nothing */ + } + // Verify lhs + rhs. + template + static void ValidateAdd(T lhs, T rhs) { /* do nothing */ + } + // Verify lhs - rhs. + template + static void ValidateSubtract(T lhs, T rhs) { /* do nothing */ + } + // Verify lhs * rhs. + template + static void ValidateMultiply(T lhs, U rhs) { /* do nothing */ + } + // Verify lhs / rhs. + template + static void ValidateDivide(T lhs, U rhs) { /* do nothing */ + } + // Verify lhs % rhs. + template + static void ValidateModulo(T lhs, U rhs) { /* do nothing */ + } + // Verify lhs << rhs. + template + static void ValidateLeftShift(T lhs, int64 rhs) { /* do nothing */ + } + // Verify lhs >> rhs. + template + static void ValidateRightShift(T lhs, int64 rhs) { /* do nothing */ + } + // Verify lhs & rhs. + template + static void ValidateBitAnd(T lhs, T rhs) { /* do nothing */ + } + // Verify lhs | rhs. + template + static void ValidateBitOr(T lhs, T rhs) { /* do nothing */ + } + // Verify lhs ^ rhs. + template + static void ValidateBitXor(T lhs, T rhs) { /* do nothing */ + } +}; + +// Holds an integer value (of type NativeType) and behaves as a NativeType by +// exposing assignment, unary, comparison, and arithmetic operators. +// +// This class is NOT thread-safe. +template +class StrongInt { + public: + typedef NativeType ValueType; + + // Default value initialization. + constexpr StrongInt() + : value_((ValidatorType::template ValidateInit(NativeType()), + NativeType())) {} + + // Explicit initialization from another StrongInt type that has an + // implementation of: + // + // ToType StrongIntConvert(FromType source, ToType*); + // + // This uses Argument Dependent Lookup (ADL) to find which function to + // call. + // + // Example: Assume you have two StrongInt types. + // + // DEFINE_STRONG_INT_TYPE(Bytes, int64); + // DEFINE_STRONG_INT_TYPE(Megabytes, int64); + // + // If you want to be able to (explicitly) construct an instance of Bytes from + // an instance of Megabytes, simply define a converter function in the same + // namespace as either Bytes or Megabytes (or both): + // + // Megabytes StrongIntConvert(Bytes arg, Megabytes* /* unused */) { + // return Megabytes((arg >> 20).value()); + // }; + // + // The second argument is needed to differentiate conversions, and it always + // passed as NULL. + template + explicit StrongInt( + StrongInt arg) { + // We have to pass both the "from" type and the "to" type as args for the + // conversions to be differentiated. The converter can not be a template + // because explicit template call syntax defeats ADL. + StrongInt *dummy = NULL; + StrongInt converted = StrongIntConvert(arg, dummy); + value_ = converted.value(); + } + + // Explicit initialization from a numeric primitive. + template ::value>::type> + explicit constexpr StrongInt(T init_value) + : value_((ValidatorType::template ValidateInit(init_value), + static_cast(init_value))) {} + + // Use the default copy constructor, assignment, and destructor. + + // Accesses the raw value. + constexpr ValueType value() const { return value_; } + + // Accesses the raw value, with cast. + // Primarily for compatibility with int-type.h + template + constexpr ValType value() const { + return static_cast(value_); + } + + // Metadata functions. + static ValueType Max() { return std::numeric_limits::max(); } + static ValueType Min() { return std::numeric_limits::min(); } + + // Unary operators. + bool operator!() const { return value_ == 0; } + const StrongInt operator+() const { return StrongInt(value_); } + const StrongInt operator-() const { + ValidatorType::template ValidateNegate(value_); + return StrongInt(-value_); + } + const StrongInt operator~() const { + ValidatorType::template ValidateBitNot(value_); + return StrongInt(ValueType(~value_)); + } + + // Increment and decrement operators. + StrongInt &operator++() { // ++x + ValidatorType::template ValidateAdd(value_, ValueType(1)); + ++value_; + return *this; + } + const StrongInt operator++(int postfix_flag) { // x++ + ValidatorType::template ValidateAdd(value_, ValueType(1)); + StrongInt temp(*this); + ++value_; + return temp; + } + StrongInt &operator--() { // --x + ValidatorType::template ValidateSubtract(value_, ValueType(1)); + --value_; + return *this; + } + const StrongInt operator--(int postfix_flag) { // x-- + ValidatorType::template ValidateSubtract(value_, ValueType(1)); + StrongInt temp(*this); + --value_; + return temp; + } + + // Action-Assignment operators. + StrongInt &operator+=(StrongInt arg) { + ValidatorType::template ValidateAdd(value_, arg.value()); + value_ += arg.value(); + return *this; + } + StrongInt &operator-=(StrongInt arg) { + ValidatorType::template ValidateSubtract(value_, arg.value()); + value_ -= arg.value(); + return *this; + } + template + StrongInt &operator*=(ArgType arg) { + ValidatorType::template ValidateMultiply(value_, arg); + value_ *= arg; + return *this; + } + template + StrongInt &operator/=(ArgType arg) { + ValidatorType::template ValidateDivide(value_, arg); + value_ /= arg; + return *this; + } + template + StrongInt &operator%=(ArgType arg) { + ValidatorType::template ValidateModulo(value_, arg); + value_ %= arg; + return *this; + } + StrongInt &operator<<=(int64 arg) { // NOLINT(whitespace/operators) + ValidatorType::template ValidateLeftShift(value_, arg); + value_ <<= arg; + return *this; + } + StrongInt &operator>>=(int64 arg) { // NOLINT(whitespace/operators) + ValidatorType::template ValidateRightShift(value_, arg); + value_ >>= arg; + return *this; + } + StrongInt &operator&=(StrongInt arg) { + ValidatorType::template ValidateBitAnd(value_, arg.value()); + value_ &= arg.value(); + return *this; + } + StrongInt &operator|=(StrongInt arg) { + ValidatorType::template ValidateBitOr(value_, arg.value()); + value_ |= arg.value(); + return *this; + } + StrongInt &operator^=(StrongInt arg) { + ValidatorType::template ValidateBitXor(value_, arg.value()); + value_ ^= arg.value(); + return *this; + } + + private: + // The integer value of type ValueType. + ValueType value_; + + static_assert(std::is_integral::value, + "invalid integer type for strong int"); +}; + +// Provide the << operator, primarily for logging purposes. +template +std::ostream &operator<<(std::ostream &os, + StrongInt arg) { + return os << arg.value(); +} + +// Provide the << operator, primarily for logging purposes. Specialized for int8 +// so that an integer and not a character is printed. +template +std::ostream &operator<<(std::ostream &os, + StrongInt arg) { + return os << static_cast(arg.value()); +} + +// Provide the << operator, primarily for logging purposes. Specialized for +// uint8 so that an integer and not a character is printed. +template +std::ostream &operator<<(std::ostream &os, + StrongInt arg) { + return os << static_cast(arg.value()); +} + +// Define operators that take two StrongInt arguments. These operators are +// defined in terms of their op-equal member function cousins. +#define STRONG_INT_VS_STRONG_INT_BINARY_OP(op) \ + template \ + inline StrongInt operator op( \ + StrongInt lhs, \ + StrongInt rhs) { \ + lhs op## = rhs; \ + return lhs; \ + } +STRONG_INT_VS_STRONG_INT_BINARY_OP(+); +STRONG_INT_VS_STRONG_INT_BINARY_OP(-); +STRONG_INT_VS_STRONG_INT_BINARY_OP(&); +STRONG_INT_VS_STRONG_INT_BINARY_OP(|); +STRONG_INT_VS_STRONG_INT_BINARY_OP(^); +#undef STRONG_INT_VS_STRONG_INT_BINARY_OP + +// Define operators that take one StrongInt and one native integer argument. +// These operators are defined in terms of their op-equal member function +// cousins, mostly. +#define STRONG_INT_VS_NUMERIC_BINARY_OP(op) \ + template \ + inline StrongInt operator op( \ + StrongInt lhs, NumType rhs) { \ + lhs op## = rhs; \ + return lhs; \ + } +// This is used for commutative operators between one StrongInt and one native +// integer argument. That is a long way of saying "multiplication". +#define NUMERIC_VS_STRONG_INT_BINARY_OP(op) \ + template \ + inline StrongInt operator op( \ + NumType lhs, StrongInt rhs) { \ + rhs op## = lhs; \ + return rhs; \ + } +STRONG_INT_VS_NUMERIC_BINARY_OP(*); +NUMERIC_VS_STRONG_INT_BINARY_OP(*); +STRONG_INT_VS_NUMERIC_BINARY_OP(/); +STRONG_INT_VS_NUMERIC_BINARY_OP(%); +STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators) +STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) +#undef STRONG_INT_VS_NUMERIC_BINARY_OP +#undef NUMERIC_VS_STRONG_INT_BINARY_OP + +// Define comparison operators. We allow all comparison operators. +#define STRONG_INT_COMPARISON_OP(op) \ + template \ + inline bool operator op(StrongInt lhs, \ + StrongInt rhs) { \ + return lhs.value() op rhs.value(); \ + } +STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators) +#undef STRONG_INT_COMPARISON_OP + +} // namespace intops +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_STRONG_INT_H_ diff --git a/mediapipe/framework/deps/thread_options.h b/mediapipe/framework/deps/thread_options.h new file mode 100644 index 000000000..562b5f9b5 --- /dev/null +++ b/mediapipe/framework/deps/thread_options.h @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_THREAD_OPTIONS_H_ +#define MEDIAPIPE_DEPS_THREAD_OPTIONS_H_ + +#include + +#include +#include + +namespace mediapipe { + +// Options to configure a thread. Default values are listed in +// the field descriptions. +class ThreadOptions { + public: + ThreadOptions() : stack_size_(0), nice_priority_level_(0) {} + + // Set the thread stack size (in bytes). Passing stack_size==0 resets + // the stack size to the default value for the system. The system default + // is also the default for this class. + ThreadOptions& set_stack_size(size_t stack_size) { + stack_size_ = stack_size; + return *this; + } + + ThreadOptions& set_nice_priority_level(int nice_priority_level) { + nice_priority_level_ = nice_priority_level; + return *this; + } + + ThreadOptions& set_cpu_set(const std::set& cpu_set) { + cpu_set_ = cpu_set; + return *this; + } + + ThreadOptions& set_name_prefix(const std::string& name_prefix) { + name_prefix_ = name_prefix; + return *this; + } + + size_t stack_size() const { return stack_size_; } + + int nice_priority_level() const { return nice_priority_level_; } + + const std::set& cpu_set() const { return cpu_set_; } + + std::string name_prefix() const { return name_prefix_; } + + private: + size_t stack_size_; // Size of thread stack + int nice_priority_level_; // Nice priority level of the workers + std::set cpu_set_; // CPU set for affinity setting + std::string name_prefix_; // Name of the thread +}; + +} // namespace mediapipe +#endif // MEDIAPIPE_DEPS_THREAD_OPTIONS_H_ diff --git a/mediapipe/framework/deps/threadpool.cc b/mediapipe/framework/deps/threadpool.cc new file mode 100644 index 000000000..4cd222264 --- /dev/null +++ b/mediapipe/framework/deps/threadpool.cc @@ -0,0 +1,193 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/threadpool.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +class ThreadPool::WorkerThread { + public: + // Creates and starts a thread that runs pool->RunWorker(). + WorkerThread(ThreadPool* pool, const std::string& name_prefix); + + // REQUIRES: Join() must have been called. + ~WorkerThread(); + + // Joins with the running thread. + void Join(); + + private: + static void* ThreadBody(void* arg); + + ThreadPool* pool_; + std::string name_prefix_; + pthread_t thread_; +}; + +ThreadPool::WorkerThread::WorkerThread(ThreadPool* pool, + const std::string& name_prefix) + : pool_(pool), name_prefix_(name_prefix) { + pthread_create(&thread_, nullptr, ThreadBody, this); +} + +ThreadPool::WorkerThread::~WorkerThread() {} + +void ThreadPool::WorkerThread::Join() { pthread_join(thread_, nullptr); } + +void* ThreadPool::WorkerThread::ThreadBody(void* arg) { + auto thread = reinterpret_cast(arg); + int nice_priority_level = + thread->pool_->thread_options().nice_priority_level(); + const std::set selected_cpus = thread->pool_->thread_options().cpu_set(); + const std::string name = + internal::CreateThreadName(thread->name_prefix_, syscall(SYS_gettid)); +#if defined(__linux__) + if (nice_priority_level != 0) { + if (nice(nice_priority_level) != -1 || errno == 0) { + VLOG(1) << "Changed the nice priority level by " << nice_priority_level; + } else { + LOG(ERROR) << "Error : " << strerror(errno) << std::endl + << "Could not change the nice priority level by " + << nice_priority_level; + } + } + if (!selected_cpus.empty()) { + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); + for (const int cpu : selected_cpus) { + CPU_SET(cpu, &cpu_set); + } + if (sched_setaffinity(syscall(SYS_gettid), sizeof(cpu_set_t), &cpu_set) != + -1 || + errno == 0) { + VLOG(1) << "Pinned the thread pool executor to processor " + << absl::StrJoin(selected_cpus, ", processor ") << "."; + } else { + LOG(ERROR) << "Error : " << strerror(errno) << std::endl + << "Failed to set processor affinity. Ignore processor " + "affinity setting for now."; + } + } + int error = pthread_setname_np(pthread_self(), name.c_str()); + if (error != 0) { + LOG(ERROR) << "Error : " << strerror(error) << std::endl + << "Failed to set name for thread: " << name; + } +#else + if (nice_priority_level != 0 || !selected_cpus.empty()) { + LOG(ERROR) << "Thread priority and processor affinity feature aren't " + "supported on the current platform."; + } + int error = pthread_setname_np(name.c_str()); + if (error != 0) { + LOG(ERROR) << "Error : " << strerror(error) << std::endl + << "Failed to set name for thread: " << name; + } +#endif + thread->pool_->RunWorker(); + return nullptr; +} + +ThreadPool::ThreadPool(int num_threads) { + num_threads_ = (num_threads == 0) ? 1 : num_threads; +} + +ThreadPool::ThreadPool(const std::string& name_prefix, int num_threads) + : name_prefix_(name_prefix) { + num_threads_ = (num_threads == 0) ? 1 : num_threads; +} + +ThreadPool::ThreadPool(const ThreadOptions& thread_options, + const std::string& name_prefix, int num_threads) + : name_prefix_(name_prefix), thread_options_(thread_options) { + num_threads_ = (num_threads == 0) ? 1 : num_threads; +} + +ThreadPool::~ThreadPool() { + mutex_.Lock(); + stopped_ = true; + condition_.SignalAll(); + mutex_.Unlock(); + + for (int i = 0; i < threads_.size(); ++i) { + threads_[i]->Join(); + delete threads_[i]; + } + + threads_.clear(); +} + +void ThreadPool::StartWorkers() { + for (int i = 0; i < num_threads_; ++i) { + threads_.push_back(new WorkerThread(this, name_prefix_)); + } +} + +void ThreadPool::Schedule(std::function callback) { + mutex_.Lock(); + tasks_.push_back(std::move(callback)); + condition_.Signal(); + mutex_.Unlock(); +} + +int ThreadPool::num_threads() const { return num_threads_; } + +void ThreadPool::RunWorker() { + mutex_.Lock(); + while (true) { + if (!tasks_.empty()) { + std::function task = std::move(tasks_.front()); + tasks_.pop_front(); + mutex_.Unlock(); + task(); + mutex_.Lock(); + } else { + if (stopped_) { + break; + } else { + condition_.Wait(&mutex_); + } + } + } + mutex_.Unlock(); +} + +const ThreadOptions& ThreadPool::thread_options() const { + return thread_options_; +} + +namespace internal { + +std::string CreateThreadName(const std::string& prefix, int thread_id) { + std::string name = absl::StrCat(prefix, "/", thread_id); + // 16 is the limit allowed by `pthread_setname_np`, including + // the terminating null byte ('\0') + constexpr size_t kMaxThreadNameLength = 15; + name.resize(std::min(name.length(), kMaxThreadNameLength)); + return name; +} + +} // namespace internal + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/threadpool.h b/mediapipe/framework/deps/threadpool.h new file mode 100644 index 000000000..78c175e3f --- /dev/null +++ b/mediapipe/framework/deps/threadpool.h @@ -0,0 +1,117 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_THREADPOOL_H_ +#define MEDIAPIPE_DEPS_THREADPOOL_H_ + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/deps/thread_options.h" + +namespace mediapipe { + +// A thread pool consists of a set of threads that sit around waiting +// for callbacks to appear on a queue. When that happens, one of the +// threads pulls a callback off the queue and runs it. +// +// The thread pool is shut down when the pool is destroyed. +// +// Sample usage: +// +// { +// ThreadPool pool("testpool", num_workers); +// pool.StartWorkers(); +// for (int i = 0; i < N; ++i) { +// pool.Schedule([i]() { DoWork(i); }); +// } +// } +// +class ThreadPool { + public: + // Create a thread pool that provides a concurrency of "num_threads" + // threads. I.e., if "num_threads" items are added, they are all + // guaranteed to run concurrently without excessive delay. + // It has an effectively infinite maximum queue length. + // If num_threads is 1, the callbacks are run in FIFO order. + explicit ThreadPool(int num_threads); + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + + // Like the ThreadPool(int num_threads) constructor, except that + // it also associates "name_prefix" with each of the threads + // in the thread pool. + ThreadPool(const std::string& name_prefix, int num_threads); + + // Create a thread pool that creates and can use up to "num_threads" + // threads. Any standard thread options, such as stack size, should + // be passed via "thread_options". "name_prefix" specifies the + // thread name prefix. + ThreadPool(const ThreadOptions& thread_options, + const std::string& name_prefix, int num_threads); + + // Waits for closures (if any) to complete. May be called without + // having called StartWorkers(). + ~ThreadPool(); + + // REQUIRES: StartWorkers has not been called + // Actually start the worker threads. + void StartWorkers(); + + // REQUIRES: StartWorkers has been called + // Add specified callback to queue of pending callbacks. Eventually a + // thread will pull this callback off the queue and execute it. + void Schedule(std::function callback); + + // Provided for debugging and testing only. + int num_threads() const; + + // Standard thread options. Use this accessor to get them. + const ThreadOptions& thread_options() const; + + private: + class WorkerThread; + void RunWorker(); + + std::string name_prefix_; + std::vector threads_; + int num_threads_; + + absl::Mutex mutex_; + absl::CondVar condition_; + bool stopped_ GUARDED_BY(mutex_) = false; + std::deque> tasks_ GUARDED_BY(mutex_); + + ThreadOptions thread_options_; +}; + +namespace internal { + +// Creates name for thread in a thread pool based on provided prefix and +// thread id. Length of the resulting name is guaranteed to be less or equal +// to 15. Name or thread id can be truncated to achieve that, see truncation +// samples below: +// name_prefix, 1234 -> name_prefix/123 +// name_prefix, 1234567 -> name_prefix/123 +// name_prefix_long, 1234 -> name_prefix_lon +std::string CreateThreadName(const std::string& prefix, int thread_id); + +} // namespace internal + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_THREADPOOL_H_ diff --git a/mediapipe/framework/deps/threadpool_test.cc b/mediapipe/framework/deps/threadpool_test.cc new file mode 100644 index 000000000..c6ab54ec0 --- /dev/null +++ b/mediapipe/framework/deps/threadpool_test.cc @@ -0,0 +1,118 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/threadpool.h" + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { + +TEST(ThreadPoolTest, DestroyWithoutStart) { + ThreadPool thread_pool("testpool", 10); +} + +TEST(ThreadPoolTest, EmptyThread) { + ThreadPool thread_pool("testpool", 0); + ASSERT_EQ(1, thread_pool.num_threads()); + thread_pool.StartWorkers(); +} + +TEST(ThreadPoolTest, SingleThread) { + absl::Mutex mu; + int n = 100; + { + ThreadPool thread_pool("testpool", 1); + ASSERT_EQ(1, thread_pool.num_threads()); + thread_pool.StartWorkers(); + + for (int i = 0; i < 100; ++i) { + thread_pool.Schedule([&n, &mu]() mutable { + absl::MutexLock l(&mu); + --n; + }); + } + } + + EXPECT_EQ(0, n); +} + +TEST(ThreadPoolTest, MultiThreads) { + absl::Mutex mu; + int n = 100; + { + ThreadPool thread_pool("testpool", 10); + ASSERT_EQ(10, thread_pool.num_threads()); + thread_pool.StartWorkers(); + + for (int i = 0; i < 100; ++i) { + thread_pool.Schedule([&n, &mu]() mutable { + absl::MutexLock l(&mu); + --n; + }); + } + } + + EXPECT_EQ(0, n); +} + +TEST(ThreadPoolTest, CreateWithThreadOptions) { + ThreadPool thread_pool(ThreadOptions(), "testpool", 10); + ASSERT_EQ(10, thread_pool.num_threads()); + thread_pool.StartWorkers(); +} + +TEST(ThreadPoolTest, CreateWithThreadPriority) { + ThreadOptions thread_options = ThreadOptions().set_nice_priority_level(-10); + ThreadPool thread_pool(thread_options, "testpool", 10); + ASSERT_EQ(10, thread_pool.num_threads()); + ASSERT_EQ(-10, thread_pool.thread_options().nice_priority_level()); + thread_pool.StartWorkers(); +} + +TEST(ThreadPoolTest, CreateWithCPUAffinity) { + ThreadOptions thread_options = ThreadOptions().set_cpu_set({0}); + ThreadPool thread_pool(thread_options, "testpool", 10); + ASSERT_EQ(10, thread_pool.num_threads()); + ASSERT_EQ(1, thread_pool.thread_options().cpu_set().size()); + thread_pool.StartWorkers(); +} + +TEST(ThreadPoolTest, CreateThreadName) { + ASSERT_EQ("name_prefix/123", internal::CreateThreadName("name_prefix", 1234)); + ASSERT_EQ("name_prefix/123", + internal::CreateThreadName("name_prefix", 12345)); + ASSERT_EQ("name_prefix/123", + internal::CreateThreadName("name_prefix", 123456)); + ASSERT_EQ("name_prefix/123", + internal::CreateThreadName("name_prefix", 1234567)); + ASSERT_EQ("name_prefix/123", + internal::CreateThreadName("name_prefix", 1234567891)); + ASSERT_EQ("name_prefix_/12", + internal::CreateThreadName("name_prefix_", 1234)); + ASSERT_EQ("name_pre/123456", + internal::CreateThreadName("name_pre", 1234567891)); + ASSERT_EQ("n/1", internal::CreateThreadName("n", 1)); + ASSERT_EQ("name_p/12345678", + internal::CreateThreadName("name_p", 1234567891)); + ASSERT_EQ("/1", internal::CreateThreadName("", 1)); + ASSERT_EQ("name_prefix_lon", + internal::CreateThreadName("name_prefix_long", 1234)); + ASSERT_EQ("name_prefix_lon", + internal::CreateThreadName("name_prefix_lon", 1234)); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/topologicalsorter.cc b/mediapipe/framework/deps/topologicalsorter.cc new file mode 100644 index 000000000..67fc6adc4 --- /dev/null +++ b/mediapipe/framework/deps/topologicalsorter.cc @@ -0,0 +1,153 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/topologicalsorter.h" + +#include + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +TopologicalSorter::TopologicalSorter(int num_nodes) : num_nodes_(num_nodes) { + CHECK_GE(num_nodes_, 0); + adjacency_lists_.resize(num_nodes_); +} + +void TopologicalSorter::AddEdge(int from, int to) { + CHECK(!traversal_started_ && from < num_nodes_ && to < num_nodes_ && + from >= 0 && to >= 0); + adjacency_lists_[from].push_back(to); +} + +bool TopologicalSorter::GetNext(int* node_index, bool* cyclic, + std::vector* output_cycle_nodes) { + if (!traversal_started_) { + // Iterates over all adjacency lists, and fills the indegree_ vector. + indegree_.assign(num_nodes_, 0); + for (int from = 0; from < num_nodes_; ++from) { + std::vector& adjacency_list = adjacency_lists_[from]; + // Eliminates duplicate edges. + std::sort(adjacency_list.begin(), adjacency_list.end()); + adjacency_list.erase( + std::unique(adjacency_list.begin(), adjacency_list.end()), + adjacency_list.end()); + for (int to : adjacency_list) { + ++indegree_[to]; + } + } + + // Fills the nodes_with_zero_indegree_ vector. + for (int i = 0; i < num_nodes_; ++i) { + if (indegree_[i] == 0) { + nodes_with_zero_indegree_.push(i); + } + } + num_nodes_left_ = num_nodes_; + traversal_started_ = true; + } + + *cyclic = false; + if (num_nodes_left_ == 0) { + // Done the traversal. + return false; + } + if (nodes_with_zero_indegree_.empty()) { + *cyclic = true; + FindCycle(output_cycle_nodes); + return false; + } + + // Gets the least node. + --num_nodes_left_; + *node_index = nodes_with_zero_indegree_.top(); + nodes_with_zero_indegree_.pop(); + // Swap out the adjacency list, since we won't need it afterwards, + // to decrease memory usage. + std::vector adjacency_list; + adjacency_list.swap(adjacency_lists_[*node_index]); + + // Updates the indegree_ vector and nodes_with_zero_indegree_ queue. + for (int i = 0; i < adjacency_list.size(); ++i) { + if (--indegree_[adjacency_list[i]] == 0) { + nodes_with_zero_indegree_.push(adjacency_list[i]); + } + } + return true; +} + +void TopologicalSorter::FindCycle(std::vector* cycle_nodes) { + cycle_nodes->clear(); + // To find a cycle, we start a DFS from each yet-unvisited node and + // try to find a cycle, if we don't find it then we know for sure that + // no cycle is reachable from any of the explored nodes (so, we don't + // explore them in later DFSs). + std::vector no_cycle_reachable_from(num_nodes_, false); + // The DFS stack will contain a chain of nodes, from the root of the + // DFS to the current leaf. + struct DfsState { + int node; + // Points at the first child node that we did *not* yet look at. + int adjacency_list_index; + explicit DfsState(int _node) : node(_node), adjacency_list_index(0) {} + }; + std::vector dfs_stack; + std::vector in_cur_stack(num_nodes_, false); + + for (int start_node = 0; start_node < num_nodes_; ++start_node) { + if (no_cycle_reachable_from[start_node]) { + continue; + } + // Starts the DFS. + dfs_stack.push_back(DfsState(start_node)); + in_cur_stack[start_node] = true; + while (!dfs_stack.empty()) { + DfsState* cur_state = &dfs_stack.back(); + if (cur_state->adjacency_list_index >= + adjacency_lists_[cur_state->node].size()) { + no_cycle_reachable_from[cur_state->node] = true; + in_cur_stack[cur_state->node] = false; + dfs_stack.pop_back(); + continue; + } + // Looks at the current child, and increases the current state's + // adjacency_list_index. + const int child = + adjacency_lists_[cur_state->node][cur_state->adjacency_list_index]; + ++(cur_state->adjacency_list_index); + if (no_cycle_reachable_from[child]) { + continue; + } + if (in_cur_stack[child]) { + // We detected a cycle! Fills it and return. + for (;;) { + cycle_nodes->push_back(dfs_stack.back().node); + if (dfs_stack.back().node == child) { + std::reverse(cycle_nodes->begin(), cycle_nodes->end()); + return; + } + dfs_stack.pop_back(); + } + } + // Pushs the child onto the stack. + dfs_stack.push_back(DfsState(child)); + in_cur_stack[child] = true; + } + } + // If we're here, then all the DFS stopped, and they never encountered + // a cycle (otherwise, we would have returned). Just exit; the output + // vector has been cleared already. +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/topologicalsorter.h b/mediapipe/framework/deps/topologicalsorter.h new file mode 100644 index 000000000..d5027477c --- /dev/null +++ b/mediapipe/framework/deps/topologicalsorter.h @@ -0,0 +1,83 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_DEPS_TOPOLOGICALSORTER_H_ +#define MEDIAPIPE_DEPS_TOPOLOGICALSORTER_H_ + +#include +#include +#include + +namespace mediapipe { + +// TopologicalSorter provides topologically sorted traversal of the nodes of a +// directed acyclic graph (DAG) with up to INT_MAX nodes. The sorter requires +// that all nodes and edges be added before traversing the nodes, otherwise it +// will die with a fatal error. If a cycle is detected during the traversal, +// the sorter will stop the traversal, and set the cycle_nodes vector. +// +// Sample usage: +// TopologicalSorter sorter(num_nodes); +// sorter.AddEdge(ObjToIndex(obj_a), ObjToIndex(obj_b)); +// sorter.AddEdge(ObjToIndex(obj_a), ObjToIndex(obj_c)); +// ... +// sorter.AddEdge(ObjToIndex(obj_b), ObjToIndex(obj_c)); +// int idx; +// bool cyclic = false; +// std::vector cycle_nodes; +// while (sorter.GetNext(&idx, &cyclic, &cycle_nodes)) { +// if (cyclic) { +// PrintCycleNodes(cycle_nodes); +// } else { +// LOG(INFO) << idx; +// } +// } +class TopologicalSorter { + public: + explicit TopologicalSorter(int num_nodes); + TopologicalSorter(const TopologicalSorter&) = delete; + TopologicalSorter& operator=(const TopologicalSorter&) = delete; + + // Adds a directed edge with the given endpoints to the graph. + void AddEdge(int from, int to); + + // Visits the least node in topological order over the current set of + // nodes and edges, and marks that node as visited. + // The repeated calls to GetNext() will visit all nodes in order. Writes the + // newly visited node into *node_index and returns true with *cyclic set to + // false (assuming the graph has not yet been discovered to be cyclic). + // Returns false if all nodes have been visited, or if the graph is + // discovered to be cyclic, in which case *cyclic is also set to true. + bool GetNext(int* node_index, bool* cyclic, + std::vector* output_cycle_nodes); + + private: + // Finds the cycle. + void FindCycle(std::vector* cycle_nodes); + + const int num_nodes_; + // Outoging adjacency lists. + std::vector> adjacency_lists_; + + // If true, no more AddEdge() can be called. + bool traversal_started_ = false; + int num_nodes_left_; + std::priority_queue, std::greater> + nodes_with_zero_indegree_; + std::vector indegree_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_DEPS_TOPOLOGICALSORTER_H_ diff --git a/mediapipe/framework/deps/topologicalsorter_test.cc b/mediapipe/framework/deps/topologicalsorter_test.cc new file mode 100644 index 000000000..8af729161 --- /dev/null +++ b/mediapipe/framework/deps/topologicalsorter_test.cc @@ -0,0 +1,110 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/deps/topologicalsorter.h" + +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { + +TEST(TopologicalSorterTest, NoConnection) { + TopologicalSorter sorter(3); + std::vector expected_result({0, 1, 2}); + + int visited = 0; + int node_index; + bool cyclic; + std::vector cycle_nodes; + while (sorter.GetNext(&node_index, &cyclic, &cycle_nodes)) { + EXPECT_EQ(expected_result[visited], node_index); + ++visited; + } + ASSERT_FALSE(cyclic); + EXPECT_EQ(3, visited); +} + +TEST(TopologicalSorterTest, SimpleDAG) { + TopologicalSorter sorter(5); + sorter.AddEdge(4, 0); + sorter.AddEdge(4, 1); + sorter.AddEdge(4, 2); + sorter.AddEdge(0, 3); + sorter.AddEdge(1, 3); + sorter.AddEdge(3, 2); + std::vector expected_result({4, 0, 1, 3, 2}); + + int visited = 0; + int node_index; + bool cyclic; + std::vector cycle_nodes; + while (sorter.GetNext(&node_index, &cyclic, &cycle_nodes)) { + EXPECT_EQ(expected_result[visited], node_index); + ++visited; + } + ASSERT_FALSE(cyclic); + EXPECT_EQ(5, visited); +} + +TEST(TopologicalSorterTest, DuplicatedEdges) { + TopologicalSorter sorter(5); + sorter.AddEdge(3, 2); + sorter.AddEdge(4, 0); + sorter.AddEdge(4, 2); + sorter.AddEdge(4, 1); + sorter.AddEdge(3, 2); + sorter.AddEdge(4, 2); + sorter.AddEdge(1, 3); + sorter.AddEdge(0, 3); + sorter.AddEdge(1, 3); + sorter.AddEdge(3, 2); + std::vector expected_result({4, 0, 1, 3, 2}); + + int visited = 0; + int node_index; + bool cyclic; + std::vector cycle_nodes; + while (sorter.GetNext(&node_index, &cyclic, &cycle_nodes)) { + EXPECT_EQ(expected_result[visited], node_index); + ++visited; + } + ASSERT_FALSE(cyclic); + EXPECT_EQ(5, visited); +} + +TEST(TopologicalSorterTest, Cycle) { + // Cycle: 1->3->2->1 + TopologicalSorter sorter(5); + sorter.AddEdge(4, 0); + sorter.AddEdge(4, 1); + sorter.AddEdge(4, 2); + sorter.AddEdge(0, 3); + sorter.AddEdge(1, 3); + sorter.AddEdge(3, 2); + sorter.AddEdge(2, 1); + + int node_index; + bool cyclic; + std::vector cycle_nodes; + while (sorter.GetNext(&node_index, &cyclic, &cycle_nodes)) { + } + + EXPECT_TRUE(cyclic); + std::vector expected_cycle({1, 3, 2}); + ASSERT_EQ(3, cycle_nodes.size()); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(expected_cycle[i], cycle_nodes[i]); + } +} + +} // namespace mediapipe diff --git a/mediapipe/framework/deps/vector.h b/mediapipe/framework/deps/vector.h new file mode 100644 index 000000000..24f2480cd --- /dev/null +++ b/mediapipe/framework/deps/vector.h @@ -0,0 +1,560 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Simple classes to handle vectors in 2D, 3D, and 4D. +#ifndef MEDIAPIPE_DEPS_VECTOR_H_ +#define MEDIAPIPE_DEPS_VECTOR_H_ + +#include +#include +#include +#include +#include // NOLINT(readability/streams) +#include +#include + +#include "absl/utility/utility.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" + +template +class Vector2; +template +class Vector3; +template +class Vector4; + +namespace mediapipe { +namespace deps { +namespace internal_vector { + +// CRTP base class for all Vector templates. +template