Project import generated by Copybara.

PiperOrigin-RevId: 253489161
This commit is contained in:
MediaPipe Team 2019-06-16 16:03:25 -07:00 committed by jqtang
commit d68f5e4169
844 changed files with 134997 additions and 0 deletions

34
.bazelrc Normal file
View File

@ -0,0 +1,34 @@
# The bazelrc file for MediaPipe OSS.
# Basic build settings
build --jobs 128
build --define='absl=1'
build --cxxopt='-std=c++11'
build --copt='-Wno-sign-compare'
build --copt='-Wno-unused-function'
build --copt='-Wno-uninitialized'
build --copt='-Wno-unused-result'
build --copt='-Wno-comment'
build --copt='-Wno-return-type'
build --copt='-Wno-unused-local-typedefs'
build --copt='-Wno-ignored-attributes'
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
# Android configs.
build:android --crosstool_top=//external:android/crosstool
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:android --linkopt=-landroid
build:android --linkopt=-ldl
build:android --linkopt=-llog
build:android --linkopt=-lm
build:android --linkopt=-Wl,--gc-sections
build:android_arm --config=android
build:android_arm --cpu=armeabi-v7a
build:android_arm --fat_apk_cpu=armeabi-v7a
build:android_arm64 --config=android
build:android_arm64 --cpu=arm64-v8a
build:android_arm64 --fat_apk_cpu=arm64-v8a

2
.dockerignore Normal file
View File

@ -0,0 +1,2 @@
.git
Dockerfile

17
BUILD Normal file
View File

@ -0,0 +1,17 @@
# Copyright 2019 The MediaPipeOSS Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])

127
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,127 @@
# Contributing guidelines
## Pull Request Checklist
Before sending your pull requests, make sure you followed this list.
- Read [contributing guidelines](CONTRIBUTING.md).
- Read [Code of Conduct](CODE_OF_CONDUCT.md).
- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
- Check if my changes are consistent with the [guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
- Changes are consistent with the [Coding Style](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#c-coding-style).
- Run [Unit Tests](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md#running-unit-tests).
## How to become a contributor and submit your own code
### Contributor License Agreements
We'd love to accept your patches! Before we can take them, we have to jump a couple of legal hurdles.
Please fill out either the individual or corporate Contributor License Agreement (CLA).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository.
### Contributing code
If you have improvements to MediaPipe, send us your pull requests! For those
just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/).
MediaPipe team members will be assigned to review your pull requests. Once the
pull requests are approved and pass continuous integration checks, a MediaPipe
team member will apply `ready to pull` label to your change. This means we are
working on getting your pull request submitted to our internal repository. After
the change has been submitted internally, your pull request will be merged
automatically on GitHub.
If you want to contribute but you're not sure where to start, take a look at the
[issues with the "contributions welcome" label](https://github.com/google/mediapipe/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you
decide to start on an issue, leave a comment so that other people know that
you're working on it. If you want to help out, but not alone, use the issue
comment thread to coordinate.
### Contribution guidelines and standards
Before sending your pull request for
[review](https://github.com/google/mediapipe/pulls),
make sure your changes are consistent with the guidelines and follow the
MediaPipe coding style.
#### General guidelines and philosophy for contribution
* Include unit tests when you contribute new features, as they help to a)
prove that your code works correctly, and b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require unit tests, because the presence of bugs
usually indicates insufficient test coverage.
* Keep API compatibility in mind when you change code in MediaPipe framework
e.g., code in
[mediapipe/framework](https://github.com/google/mediapipe/tree/master/mediapipe/framework)
and
[mediapipe/calculators](https://github.com/google/mediapipe/tree/master/mediapipe/calculators).
Once MediaPipe has reached version 1 and we will not make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to MediaPipe, the maintenance burden is
(by default) transferred to the MediaPipe team. This means that benefit of
the contribution must be compared against the cost of maintaining the
feature.
* Full new features (e.g., a new op implementing a cutting-edge algorithm)
typically will live in
[mediapipe/addons](https://github.com/google/mediapipe/addons) to get some
airtime before decision is made regarding whether they are to be migrated to
the core.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/calculator_base.cc#L1)
* [Java license example](https://github.com/google/mediapipe/blob/master/mediapipe/java/com/google/mediapipe/components/CameraHelper.java)
Bazel BUILD files also need to include a license section, e.g.,
[BUILD example](https://github.com/google/mediapipe/blob/master/mediapipe/framework/BUILD#L61).
#### C++ coding style
Changes to MediaPipe C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
```
You can check a C/C++ file by doing:
```bash
clang-format <my_cc_file> --style=google > /tmp/my_cc_file.cc
diff <my_cc_file> /tmp/my_cc_file.cc
```
#### Coding style for other languages
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html)
#### Running sanity check
If you have Docker installed on your system, you can perform a sanity check on
your changes by running the command:
```bash
mediapipe/tools/ci_build/ci_build.sh CPU mediapipe/tools/ci_build/ci_sanity.sh
```
This will catch most license, Python coding style and BUILD file issues that
may exist in your changes.

52
Dockerfile Normal file
View File

@ -0,0 +1,52 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM ubuntu:latest
MAINTAINER <mediapipe@google.com>
WORKDIR /io
WORKDIR /mediapipe
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
git \
wget \
unzip \
python \
libopencv-core-dev \
libopencv-highgui-dev \
libopencv-imgproc-dev \
libopencv-video-dev \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install bazel
ARG BAZEL_VERSION=0.26.1
RUN mkdir /bazel && \
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
chmod +x /bazel/installer.sh && \
/bazel/installer.sh && \
rm -f /bazel/installer.sh
COPY . /mediapipe/
# If we want the docker image to contain the pre-built object_detection_offline_demo binary, do the following
# RUN bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/demo:object_detection_tensorflow_demo

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

29
README.md Normal file
View File

@ -0,0 +1,29 @@
![MediaPipe](mediapipe/docs/images/mediapipe_small.png?raw=true "MediaPipe logo")
=======================================================================
#### We will be [presenting at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA. Come join us!
[MediaPipe](http://g.co/mediapipe) is a framework for building multimodal (eg. video, audio, any time series data) applied ML pipelines. With MediaPipe, a perception pipeline can be built as a graph of modular components, including, for instance, inference models (e.g., TensorFlow, TFLite) and media processing functions.
![Real-time Face Detection](mediapipe/docs/images/mobile/face_detection_android_gpu_small.gif)
## Installation
Follow these [instructions](mediapipe/docs/install.md).
## Getting started
See mobile and desktop [examples](mediapipe/docs/examples.md).
## Documentation
On [MediaPipe Read-the-Docs](https://mediapipe.readthedocs.io/).
## Visualizing MediaPipe graphs
A web-based visualizer is hosted on [MediaPipe Visualizer](https://mediapipe-viz.appspot.com/). Please also see instructions [here](mediapipe/docs/visualizer.md).
## Publications
* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/) on [arXiv](https://arxiv.org/).
* [MediaPipe: A Framework for Perceiving and Augmenting Reality](http://mixedreality.cs.cornell.edu/s/22_crv2_MediaPipe_CVPR_CV4ARVR_Workshop_2019_v2.pdf), extended abstract for [Third Workshop on Computer Vision for AR/VR](http://mixedreality.cs.cornell.edu/workshop/program).
## Contributing
We welcome contributions. Please follow these [guidelines](./CONTRIBUTING.md).
We use GitHub issues for tracking requests and bugs. Please post questions to the MediaPipe Stack Overflow with a 'mediapipe' tag.

181
WORKSPACE Normal file
View File

@ -0,0 +1,181 @@
workspace(name = "mediapipe")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "bazel_skylib",
sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
)
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "0.23.0")
# ABSL cpp library.
http_archive(
name = "com_google_absl",
# Head commit on 2019-04-12.
# TODO: Switch to the latest absl version when the problem gets
# fixed.
urls = [
"https://github.com/abseil/abseil-cpp/archive/a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a.tar.gz",
],
sha256 = "d437920d1434c766d22e85773b899c77c672b8b4865d5dc2cd61a29fdff3cf03",
strip_prefix = "abseil-cpp-a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a",
)
# GoogleTest/GoogleMock framework. Used by most unit-tests.
http_archive(
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/master.zip"],
strip_prefix = "googletest-master",
)
# Google Benchmark library.
http_archive(
name = "com_google_benchmark",
urls = ["https://github.com/google/benchmark/archive/master.zip"],
strip_prefix = "benchmark-master",
build_file = "@//third_party:benchmark.BUILD",
)
# gflags needed by glog
http_archive(
name = "com_github_gflags_gflags",
sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe",
strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
"https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
],
)
# glog
http_archive(
name = "com_google_glog",
url = "https://github.com/google/glog/archive/v0.3.5.zip",
sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8",
strip_prefix = "glog-0.3.5",
build_file = "@//third_party:glog.BUILD",
)
# libyuv
http_archive(
name = "libyuv",
urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/refs/heads/master.tar.gz"],
build_file = "@//third_party:libyuv.BUILD",
)
http_archive(
name = "com_google_protobuf_javalite",
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
strip_prefix = "protobuf-384989534b2246d413dbcd750744faab2607b516",
urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
)
# Needed by TensorFlow
http_archive(
name = "io_bazel_rules_closure",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
],
)
# TensorFlow r1.14-rc0
http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-1.14.0-rc0",
sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(tf_repo_name = "org_tensorflow")
# Please run $ sudo apt-get install libopencv-dev
new_local_repository(
name = "linux_opencv",
build_file = "@//third_party:opencv_linux.BUILD",
path = "/usr",
)
# Please run $ brew install opencv
new_local_repository(
name = "macos_opencv",
build_file = "@//third_party:opencv_macos.BUILD",
path = "/usr",
)
http_archive(
name = "android_opencv",
sha256="cd7e5d5ec76eeddadf36a1cfe5197129328e80287d4d198c169e090421f838ba",
build_file = "@//third_party:opencv_android.BUILD",
strip_prefix = "OpenCV-android-sdk",
type = "zip",
url = "https://sourceforge.net/projects/opencvlibrary/files/4.0.1/opencv-4.0.1-android-sdk.zip/download"
)
# Google Maven Repository
GMAVEN_TAG = "20181212-2"
http_archive(
name = "gmaven_rules",
strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
)
load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
gmaven_rules()
maven_server(
name = "google_server",
url = "http://maven.google.com",
)
maven_jar(
name = "androidx_lifecycle",
artifact = "androidx.lifecycle:lifecycle-common:2.0.0",
server = "google_server",
)
maven_jar(
name = "androidx_concurrent_futures",
artifact = "androidx.concurrent:concurrent-futures:1.0.0-alpha03",
server = "google_server",
)
maven_jar(
name = "com_google_guava_android",
artifact = "com.google.guava:guava:27.0.1-android",
sha1 = "b7e1c37f66ef193796ccd7ea6e80c2b05426182d",
)
maven_jar(
name = "com_google_common_flogger",
artifact = "com.google.flogger:flogger:0.3.1",
sha1 = "585030fe1ec709760cbef997a459729fb965df0e",
)
maven_jar(
name = "com_google_common_flogger_system_backend",
artifact = "com.google.flogger:flogger-system-backend:0.3.1",
sha1 = "287b569d76abcd82f9de87fe41829fbc7ebd8ac9",
)
maven_jar(
name = "com_google_code_findbugs",
artifact = "com.google.code.findbugs:jsr305:3.0.2",
)
# You may run setup_android.sh to install Android SDK and NDK.
android_ndk_repository(
name = "androidndk",
)
android_sdk_repository(
name = "androidsdk",
)

75
mediapipe/BUILD Normal file
View File

@ -0,0 +1,75 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
config_setting(
name = "android",
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86_64",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_armeabi",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_arm",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi-v7a",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_arm64",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "arm64-v8a",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)

14
mediapipe/__init__.py Normal file
View File

@ -0,0 +1,14 @@
"""Copyright 2019 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -0,0 +1,270 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "packet_resampler_calculator_proto",
srcs = ["packet_resampler_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "packet_resampler_calculator_cc_proto",
srcs = ["packet_resampler_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":packet_resampler_calculator_proto"],
)
cc_library(
name = "counting_source_calculator",
srcs = ["counting_source_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "make_pair_calculator",
srcs = ["make_pair_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "mux_calculator",
srcs = ["mux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
],
alwayslink = 1,
)
cc_library(
name = "packet_cloner_calculator",
srcs = ["packet_cloner_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "pass_through_calculator",
srcs = ["pass_through_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "round_robin_demux_calculator",
srcs = ["round_robin_demux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
],
alwayslink = 1,
)
cc_library(
name = "immediate_mux_calculator",
srcs = ["immediate_mux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "previous_loopback_calculator",
srcs = ["previous_loopback_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
],
alwayslink = 1,
)
cc_library(
name = "real_time_flow_limiter_calculator",
srcs = ["real_time_flow_limiter_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/util:header_util",
],
alwayslink = 1,
)
cc_test(
name = "immediate_mux_calculator_test",
srcs = ["immediate_mux_calculator_test.cc"],
deps = [
":immediate_mux_calculator",
":round_robin_demux_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:test_calculators",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "packet_resampler_calculator",
srcs = ["packet_resampler_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/tool:options_util",
"@com_google_absl//absl/strings",
"//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:integral_types",
] + select({
"//conditions:default": [
"//mediapipe/framework/deps:random",
],
}),
alwayslink = 1,
)
cc_test(
name = "packet_resampler_calculator_test",
timeout = "short",
srcs = ["packet_resampler_calculator_test.cc"],
deps = [
":packet_resampler_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "previous_loopback_calculator_test",
srcs = ["previous_loopback_calculator_test.cc"],
deps = [
":previous_loopback_calculator",
"//mediapipe/calculators/core:make_pair_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/time",
],
)
cc_test(
name = "real_time_flow_limiter_calculator_test",
srcs = ["real_time_flow_limiter_calculator_test.cc"],
deps = [
":real_time_flow_limiter_calculator",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:test_calculators",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,114 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
// sequential numbers from INITIAL_VALUE (default 0) with a common
// difference of INCREMENT (default 1) between successive numbers (with
// timestamps corresponding to the sequence numbers). The packets are
// produced in BATCH_SIZE sized batches with each call to Process(). An
// error will be returned after ERROR_COUNT batches. An error will be
// produced in Open() if ERROR_ON_OPEN is true. Either MAX_COUNT or
// ERROR_COUNT must be provided and non-negative. If BATCH_SIZE is not
// provided, then batches are of size 1.
class CountingSourceCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
}
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
return ::mediapipe::NotFoundError("expected error");
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
RET_CHECK_LE(0, error_count_);
}
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
RET_CHECK_LE(0, max_count_);
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
RET_CHECK_LT(0, batch_size_);
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
RET_CHECK_LT(0, increment_);
}
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (error_count_ >= 0 && batch_counter_ >= error_count_) {
return ::mediapipe::InternalError("expected error");
}
if (max_count_ >= 0 && batch_counter_ >= max_count_) {
return tool::StatusStop();
}
for (int i = 0; i < batch_size_; ++i) {
cc->Outputs().Index(0).Add(new int(counter_), Timestamp(counter_));
counter_ += increment_;
}
++batch_counter_;
return ::mediapipe::OkStatus();
}
private:
int max_count_ = -1;
int error_count_ = -1;
int batch_size_ = 1;
int batch_counter_ = 0;
int counter_ = 0;
int increment_ = 1;
};
REGISTER_CALCULATOR(CountingSourceCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,90 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// This Calculator multiplexes several input streams into a single
// output stream, dropping input packets with timestamps older than the
// last output packet. In case two packets arrive with the same timestamp,
// the packet with the lower stream index will be output and the rest will
// be dropped.
//
// This Calculator optionally produces a finish inidicator as its second
// output stream. One indicator packet is produced for each input packet
// received.
//
// This Calculator can be used with an ImmediateInputStreamHandler or with the
// default ISH. Note that currently ImmediateInputStreamHandler seems to
// interfere with timestamp bound propagation, so it is better to use the
// default unless the immediate one is needed. (b/118387598)
//
// This Calculator is designed to work with a Demux calculator such as
// the RoundRobinDemuxCalculator. Therefore, packets from different
// input streams are normally not expected to have the same timestamp.
//
class ImmediateMuxCalculator : public CalculatorBase {
public:
// This calculator combines any set of input streams into a single
// output stream. All input stream types must match the output stream type.
static ::mediapipe::Status GetContract(CalculatorContract* cc);
// Passes any input packet to the output stream immediately, unless the
// packet timestamp is lower than a previously passed packet.
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Open(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(ImmediateMuxCalculator);
::mediapipe::Status ImmediateMuxCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2)
<< "This calculator produces only one or two output streams.";
cc->Outputs().Index(0).SetAny();
if (cc->Outputs().NumEntries() >= 2) {
cc->Outputs().Index(1).Set<bool>();
}
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) {
// Pass along the first packet, unless it has been superseded.
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
const Packet& packet = cc->Inputs().Index(i).Value();
if (!packet.IsEmpty()) {
if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) {
cc->Outputs().Index(0).AddPacket(packet);
}
if (cc->Outputs().NumEntries() >= 2) {
Timestamp output_timestamp = std::max(
cc->InputTimestamp(), cc->Outputs().Index(1).NextTimestampBound());
cc->Outputs().Index(1).Add(new bool(true), output_timestamp);
}
}
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,373 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <atomic>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/port/threadpool.h"
#include "mediapipe/framework/tool/sink.h"
// Tests for ImmediateMuxCalculator. These tests show how parallel output
// packets are handled when they arrive in various orders.
using testing::ElementsAre;
namespace mediapipe {
namespace {
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_.fetch_add(amount); }
private:
std::atomic<int64_t> supply_;
};
// A mediapipe::Executor that signals the start and finish of each task.
// Provides 4 worker threads.
class CountingExecutor : public Executor {
public:
CountingExecutor(std::function<void()> start_callback,
std::function<void()> finish_callback)
: thread_pool_(4),
start_callback_(std::move(start_callback)),
finish_callback_(std::move(finish_callback)) {
thread_pool_.StartWorkers();
}
void Schedule(std::function<void()> task) override {
start_callback_();
thread_pool_.Schedule([this, task] {
task();
finish_callback_();
});
}
private:
ThreadPool thread_pool_;
std::function<void()> start_callback_;
std::function<void()> finish_callback_;
};
// Returns a new mediapipe::Executor with 4 worker threads.
std::shared_ptr<Executor> MakeExecutor(std::function<void()> start_callback,
std::function<void()> finish_callback) {
return std::make_shared<CountingExecutor>(start_callback, finish_callback);
}
// Tests showing ImmediateMuxCalculator dropping packets in various sequences.
class ImmediateMuxCalculatorTest : public ::testing::Test {
protected:
void SetUpMuxGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
input_stream: "input_packets_1"
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "input_packets_0"
input_stream: "input_packets_1"
output_stream: "output_packets_0"
}
)",
&graph_config_));
}
void SetUpDemuxGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input_packets_0"
output_stream: "OUTPUT:0:input_0"
output_stream: "OUTPUT:1:input_1"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_0'
input_stream: "input_0"
output_stream: "output_0"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_1'
input_stream: "input_1"
output_stream: "output_1"
}
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "output_0"
input_stream: "output_1"
output_stream: "output_packets_0"
}
)",
&graph_config_));
}
void SetUpDemuxInFlightGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: 'RealTimeFlowLimiterCalculator'
input_stream_handler {
input_stream_handler: 'ImmediateInputStreamHandler'
}
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'input_packets_0'
input_stream: 'FINISHED:finish_indicator'
input_stream_info: {
tag_index: 'FINISHED'
back_edge: true
}
output_stream: 'input_0_sampled'
}
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input_0_sampled"
output_stream: "OUTPUT:0:input_0"
output_stream: "OUTPUT:1:input_1"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_0'
input_stream: "input_0"
output_stream: "output_0"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_1'
input_stream: "input_1"
output_stream: "output_1"
}
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "output_0"
input_stream: "output_1"
output_stream: 'output_packets_0'
output_stream: 'finish_indicator'
}
)",
&graph_config_));
}
static Packet PacketAt(int64 ts) {
return Adopt(new int64(999)).At(Timestamp(ts));
}
static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); }
static bool IsNone(const Packet& packet) {
return packet.Timestamp() == Timestamp::OneOverPostStream();
}
// Return the values of the timestamps of a vector of Packets.
static std::vector<int64> TimestampValues(
const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& p : packets) {
result.push_back(p.Timestamp().Value());
}
return result;
}
// Runs a CalculatorGraph with a series of packet sets.
// Returns a vector of packets from each graph output stream.
void RunGraph(const std::vector<std::vector<Packet>>& input_sets,
std::vector<Packet>* output_packets) {
// Register output packet observers.
tool::AddVectorSink("output_packets_0", &graph_config_, output_packets);
// Start running the graph.
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config_));
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
// Send each packet to the graph in the specified order.
for (int t = 0; t < input_sets.size(); t++) {
const std::vector<Packet>& input_set = input_sets[t];
MEDIAPIPE_EXPECT_OK(graph.WaitUntilIdle());
for (int i = 0; i < input_set.size(); i++) {
const Packet& packet = input_set[i];
if (!IsNone(packet)) {
MEDIAPIPE_EXPECT_OK(graph.AddPacketToInputStream(
absl::StrCat("input_packets_", i), packet));
}
}
}
MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams());
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
}
CalculatorGraphConfig graph_config_;
};
TEST_F(ImmediateMuxCalculatorTest, IncreasingTimestamps) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(20000), None()}, //
{None(), PacketAt(30000)}, //
{None(), PacketAt(40000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Validate the output packets.
EXPECT_THAT(TimestampValues(output_packets),
ElementsAre(10000, 20000, 30000, 40000));
}
TEST_F(ImmediateMuxCalculatorTest, SupersededTimestamp) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(30000), None()}, //
{None(), PacketAt(20000)}, //
{None(), PacketAt(40000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Output packet 20000 is superseded and dropped.
EXPECT_THAT(TimestampValues(output_packets),
ElementsAre(10000, 30000, 40000));
}
TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(40000), PacketAt(20000)}, //
{None(), PacketAt(30000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Output packet 20000 is superseded and dropped.
EXPECT_THAT(TimestampValues(output_packets), ElementsAre(10000, 40000));
}
// A Calculator::Process callback function.
typedef std::function<::mediapipe::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A testing callback function that passes through all packets.
::mediapipe::Status PassThrough(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return ::mediapipe::OkStatus();
}
TEST_F(ImmediateMuxCalculatorTest, Demux) {
// Semaphores to sequence the parallel Process outputs.
AtomicSemaphore semaphore_0(0);
AtomicSemaphore semaphore_1(0);
ProcessFunction wait_0 = [&semaphore_0](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore_0.Acquire(1);
return PassThrough(inputs, outputs);
};
ProcessFunction wait_1 = [&semaphore_1](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore_1.Acquire(1);
return PassThrough(inputs, outputs);
};
// A callback to await and capture output packets.
std::vector<Packet> out_packets;
absl::Mutex out_mutex;
auto out_cb = [&](const Packet& p) {
absl::MutexLock lock(&out_mutex);
out_packets.push_back(p);
return ::mediapipe::OkStatus();
};
auto wait_for = [&](std::function<bool()> cond) {
absl::MutexLock lock(&out_mutex);
out_mutex.Await(absl::Condition(&cond));
};
SetUpDemuxGraph();
// Start the graph and add five input packets.
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(
graph_config_, {
{"callback_0", Adopt(new auto(wait_0))},
{"callback_1", Adopt(new auto(wait_1))},
}));
MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream("output_packets_0", out_cb));
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
MEDIAPIPE_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(10000)));
MEDIAPIPE_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(20000)));
MEDIAPIPE_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(30000)));
MEDIAPIPE_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(40000)));
MEDIAPIPE_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(50000)));
// Release the outputs in order 20000, 10000, 30000, 50000, 40000.
semaphore_1.Release(1); // 20000
wait_for([&] { return !out_packets.empty(); });
semaphore_0.Release(1); // 10000
semaphore_0.Release(1); // 30000
wait_for([&] { return out_packets.size() >= 2; });
semaphore_0.Release(1); // 50000
wait_for([&] { return out_packets.size() >= 3; });
semaphore_1.Release(1); // 40000
MEDIAPIPE_ASSERT_OK(graph.CloseAllInputStreams());
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
// Output packets 10000 and 40000 are superseded and dropped.
EXPECT_THAT(TimestampValues(out_packets), ElementsAre(20000, 30000, 50000));
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,61 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Given two input streams (A, B), output a single stream containing a pair<A,
// B>.
//
// Example config:
// node {
// calculator: "MakePairCalculator"
// input_stream: "packet_a"
// input_stream: "packet_b"
// output_stream: "output_pair_a_b"
// }
class MakePairCalculator : public CalculatorBase {
public:
MakePairCalculator() {}
~MakePairCalculator() override {}
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Index(1).SetAny();
cc->Outputs().Index(0).Set<std::pair<Packet, Packet>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(
new std::pair<Packet, Packet>(cc->Inputs().Index(0).Value(),
cc->Inputs().Index(1).Value()),
cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MakePairCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,77 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
// stream, and passes the packet on the selected input stream to the "OUTPUT"
// output stream.
//
// Note that this calculator defaults to use MuxInputStreamHandler, which is
// required for this calculator.
class MuxCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("SELECT").Set<int>();
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny();
++data_input_id;
for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
// Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("MuxInputStreamHandler");
MediaPipeOptions options;
cc->SetInputStreamHandlerOptions(options);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
select_input_ = cc->Inputs().GetId("SELECT", 0);
data_input_base_ = cc->Inputs().GetId("INPUT", 0);
num_data_inputs_ = cc->Inputs().NumEntries("INPUT");
output_ = cc->Outputs().GetId("OUTPUT", 0);
cc->SetOffset(mediapipe::TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
int select = cc->Inputs().Get(select_input_).Get<int>();
RET_CHECK(0 <= select && select < num_data_inputs_);
if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) {
cc->Outputs().Get(output_).AddPacket(
cc->Inputs().Get(data_input_base_ + select).Value());
}
return ::mediapipe::OkStatus();
}
private:
CollectionItemId select_input_;
CollectionItemId data_input_base_;
int num_data_inputs_ = 0;
CollectionItemId output_;
};
REGISTER_CALCULATOR(MuxCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,98 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B.
// For every packet that appears in B, outputs the most recent packet from each
// of the A_i on a separate stream.
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
// For every packet received on the last stream, output the latest packet
// obtained on all other streams. Therefore, if the last stream outputs at a
// higher rate than the others, this effectively clones the packets from the
// other streams to match the last.
//
// Example config:
// node {
// calculator: "PacketClonerCalculator"
// input_stream: "first_base_signal"
// input_stream: "second_base_signal"
// input_stream: "tick_signal"
// output_stream: "cloned_first_base_signal"
// output_stream: "cloned_second_base_signal"
// }
//
// Related:
// merge_input_streams_calculator.cc: One output stream.
// packet_inner_join_calculator.cc: Don't output unless all inputs are new.
class PacketClonerCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
const int tick_signal_index = cc->Inputs().NumEntries() - 1;
for (int i = 0; i < tick_signal_index; ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
}
cc->Inputs().Index(tick_signal_index).SetAny();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
current_.resize(tick_signal_index_);
// Pass along the header for each stream if present.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!cc->Inputs().Index(i).Header().IsEmpty()) {
cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header());
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
// Store input signals.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!cc->Inputs().Index(i).Value().IsEmpty()) {
current_[i] = cc->Inputs().Index(i).Value();
}
}
// Output if the tick signal is non-empty.
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
for (int i = 0; i < tick_signal_index_; ++i) {
if (!current_[i].IsEmpty()) {
cc->Outputs().Index(i).AddPacket(
current_[i].At(cc->InputTimestamp()));
} else {
cc->Outputs().Index(i).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
}
}
}
return ::mediapipe::OkStatus();
}
private:
std::vector<Packet> current_;
int tick_signal_index_;
};
REGISTER_CALCULATOR(PacketClonerCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,434 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include <memory>
#include <string>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/deps/mathutil.h"
#include "mediapipe/framework/deps/random_base.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/options_util.h"
namespace {
// Creates a secure random number generator for use in ProcessWithJitter.
// If no secure random number generator can be constructed, the jitter
// option is disabled in order to mainatain a consistent security and
// consistent random seeding.
std::unique_ptr<RandomBase> CreateSecureRandom(const std::string& seed) {
RandomBase* result = nullptr;
return std::unique_ptr<RandomBase>(result);
}
} // namespace
namespace mediapipe {
// This calculator is used to normalize the frequency of the packets
// out of a stream. Given a desired frame rate, packets are going to be
// removed or added to achieve it.
//
// The jitter feature is disabled by default. To enable it, you need to
// implement CreateSecureRandom(const std::string&).
//
// The data stream may be either specified as the only stream (by index)
// or as the stream with tag "DATA".
//
// The input and output streams may be accompanied by a VIDEO_HEADER
// stream. This stream includes a VideoHeader at Timestamp::PreStream().
// The input VideoHeader on the VIDEO_HEADER stream will always be updated
// with the resampler frame rate no matter what the options value for
// output_header is before being output on the output VIDEO_HEADER stream.
// If the input VideoHeader is not available, then only the frame rate
// value will be set in the output.
//
// Related:
// packet_downsampler_calculator.cc: skips packets regardless of timestamps.
class PacketResamplerCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// Logic for Process() when jitter_ != 0.0.
::mediapipe::Status ProcessWithJitter(CalculatorContext* cc);
// Logic for Process() when jitter_ == 0.0.
::mediapipe::Status ProcessWithoutJitter(CalculatorContext* cc);
// Given the current count of periods that have passed, this returns
// the next valid timestamp of the middle point of the next period:
// if count is 0, it returns the first_timestamp_.
// if count is 1, it returns the first_timestamp_ + period (corresponding
// to the first tick using exact fps)
// e.g. for frame_rate=30 and first_timestamp_=0:
// 0: 0
// 1: 33333
// 2: 66667
// 3: 100000
//
// Can only be used if jitter_ equals zero.
Timestamp PeriodIndexToTimestamp(int64 index) const;
// Given a Timestamp, finds the closest sync Timestamp based on
// first_timestamp_ and the desired fps.
//
// Can only be used if jitter_ equals zero.
int64 TimestampToPeriodIndex(Timestamp timestamp) const;
// Outputs a packet if it is in range (start_time_, end_time_).
void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const;
// The timestamp of the first packet received.
Timestamp first_timestamp_;
// Number of frames per second (desired output frequency).
double frame_rate_;
// Inverse of frame_rate_.
int64 frame_time_usec_;
// Number of periods that have passed (= #packets sent to the output).
//
// Can only be used if jitter_ equals zero.
int64 period_count_;
// The last packet that was received.
Packet last_packet_;
VideoHeader video_header_;
// The "DATA" input stream.
CollectionItemId input_data_id_;
// The "DATA" output stream.
CollectionItemId output_data_id_;
// Indicator whether to flush last packet even if its timestamp is greater
// than the final stream timestamp. Set to false when jitter_ is non-zero.
bool flush_last_packet_;
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
double jitter_ = 0.0;
Timestamp next_output_timestamp_;
// If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp.
Timestamp base_timestamp_;
// If specified, only outputs at/after start_time are included.
Timestamp start_time_;
// If specified, only outputs before end_time are included.
Timestamp end_time_;
// If set, the output timestamps nearest to start_time and end_time
// are included in the output, even if the nearest timestamp is not
// between start_time and end_time.
bool round_limits_;
};
REGISTER_CALCULATOR(PacketResamplerCalculator);
namespace {
// Returns a TimestampDiff (assuming microseconds) corresponding to the
// given time in seconds.
TimestampDiff TimestampDiffFromSeconds(double seconds) {
return TimestampDiff(MathUtil::SafeRound<int64, double>(
seconds * Timestamp::kTimestampUnitsPerSecond));
}
} // namespace
::mediapipe::Status PacketResamplerCalculator::GetContract(
CalculatorContract* cc) {
const auto& resampler_options =
cc->Options<PacketResamplerCalculatorOptions>();
if (cc->InputSidePackets().HasTag("OPTIONS")) {
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
}
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
if (!input_data_id.IsValid()) {
input_data_id = cc->Inputs().GetId("", 0);
}
cc->Inputs().Get(input_data_id).SetAny();
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
if (!output_data_id.IsValid()) {
output_data_id = cc->Outputs().GetId("", 0);
}
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
if (resampler_options.jitter() != 0.0) {
RET_CHECK_GT(resampler_options.jitter(), 0.0);
RET_CHECK_LE(resampler_options.jitter(), 1.0);
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
cc->InputSidePackets().Tag("SEED").Set<std::string>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
flush_last_packet_ = resampler_options.flush_last_packet();
jitter_ = resampler_options.jitter();
input_data_id_ = cc->Inputs().GetId("DATA", 0);
if (!input_data_id_.IsValid()) {
input_data_id_ = cc->Inputs().GetId("", 0);
}
output_data_id_ = cc->Outputs().GetId("DATA", 0);
if (!output_data_id_.IsValid()) {
output_data_id_ = cc->Outputs().GetId("", 0);
}
period_count_ = 0;
frame_rate_ = resampler_options.frame_rate();
base_timestamp_ = resampler_options.has_base_timestamp()
? Timestamp(resampler_options.base_timestamp())
: Timestamp::Unset();
start_time_ = resampler_options.has_start_time()
? Timestamp(resampler_options.start_time())
: Timestamp::Min();
end_time_ = resampler_options.has_end_time()
? Timestamp(resampler_options.end_time())
: Timestamp::Max();
round_limits_ = resampler_options.round_limits();
// The frame_rate has a default value of -1.0, so the user must set it!
RET_CHECK_LT(0, frame_rate_)
<< "The output frame rate must be greater than zero";
RET_CHECK_LE(frame_rate_, Timestamp::kTimestampUnitsPerSecond)
<< "The output frame rate must be smaller than "
<< Timestamp::kTimestampUnitsPerSecond;
frame_time_usec_ = static_cast<int64>(1000000.0 / frame_rate_);
video_header_.frame_rate = frame_rate_;
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE &&
!cc->Inputs().Get(input_data_id_).Header().IsEmpty()) {
if (resampler_options.output_header() ==
PacketResamplerCalculatorOptions::UPDATE_VIDEO_HEADER) {
video_header_ =
cc->Inputs().Get(input_data_id_).Header().Get<VideoHeader>();
video_header_.frame_rate = frame_rate_;
cc->Outputs()
.Get(output_data_id_)
.SetHeader(Adopt(new VideoHeader(video_header_)));
} else {
cc->Outputs()
.Get(output_data_id_)
.SetHeader(cc->Inputs().Get(input_data_id_).Header());
}
}
if (jitter_ != 0.0) {
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (flush_last_packet_) {
flush_last_packet_ = false;
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return ::mediapipe::Status(
::mediapipe::StatusCode::kInvalidArgument,
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream() &&
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
video_header_.frame_rate = frame_rate_;
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
return ::mediapipe::OkStatus();
}
}
if (jitter_ != 0.0 && random_ != nullptr) {
RETURN_IF_ERROR(ProcessWithJitter(cc));
} else {
RETURN_IF_ERROR(ProcessWithoutJitter(cc));
}
last_packet_ = cc->Inputs().Get(input_data_id_).Value();
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketResamplerCalculator::ProcessWithJitter(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
RET_CHECK_NE(jitter_, 0.0);
if (first_timestamp_ == Timestamp::Unset()) {
first_timestamp_ = cc->InputTimestamp();
next_output_timestamp_ =
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
return ::mediapipe::OkStatus();
}
LOG_IF(WARNING, frame_time_usec_ <
(cc->InputTimestamp() - last_packet_.Timestamp()).Value())
<< "Adding jitter is meaningless when upsampling.";
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
const int64 last_diff =
(next_output_timestamp_ - last_packet_.Timestamp()).Value();
if (curr_diff * last_diff > 0) {
return ::mediapipe::OkStatus();
}
OutputWithinLimits(cc, (std::abs(curr_diff) > std::abs(last_diff)
? last_packet_
: cc->Inputs().Get(input_data_id_).Value())
.At(next_output_timestamp_));
next_output_timestamp_ +=
frame_time_usec_ *
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketResamplerCalculator::ProcessWithoutJitter(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
RET_CHECK_EQ(jitter_, 0.0);
if (first_timestamp_ == Timestamp::Unset()) {
// This is the first packet, initialize the first_timestamp_.
if (base_timestamp_ == Timestamp::Unset()) {
// Initialize first_timestamp_ with exactly the first packet timestamp.
first_timestamp_ = cc->InputTimestamp();
} else {
// Initialize first_timestamp_ with the first packet timestamp
// aligned to the base_timestamp_.
int64 first_index = MathUtil::SafeRound<int64, double>(
(cc->InputTimestamp() - base_timestamp_).Seconds() * frame_rate_);
first_timestamp_ =
base_timestamp_ + TimestampDiffFromSeconds(first_index / frame_rate_);
}
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs()
.Tag("VIDEO_HEADER")
.Add(new VideoHeader(video_header_), Timestamp::PreStream());
}
}
const Timestamp received_timestamp = cc->InputTimestamp();
const int64 received_timestamp_idx =
TimestampToPeriodIndex(received_timestamp);
// Only consider the received packet if it belongs to the current period
// (== period_count_) or to a newer one (> period_count_).
if (received_timestamp_idx >= period_count_) {
// Fill the empty periods until we are in the same index as the received
// packet.
while (received_timestamp_idx > period_count_) {
OutputWithinLimits(
cc, last_packet_.At(PeriodIndexToTimestamp(period_count_)));
++period_count_;
}
// Now, if the received packet has a timestamp larger than the middle of
// the current period, we can send a packet without waiting. We send the
// one closer to the middle.
Timestamp target_timestamp = PeriodIndexToTimestamp(period_count_);
if (received_timestamp >= target_timestamp) {
bool have_last_packet = (last_packet_.Timestamp() != Timestamp::Unset());
bool send_current =
!have_last_packet || (received_timestamp - target_timestamp <=
target_timestamp - last_packet_.Timestamp());
if (send_current) {
OutputWithinLimits(
cc, cc->Inputs().Get(input_data_id_).Value().At(target_timestamp));
} else {
OutputWithinLimits(cc, last_packet_.At(target_timestamp));
}
++period_count_;
}
// TODO: Add a mechanism to the framework to allow these packets
// to be output earlier (without waiting for a much later packet to
// arrive)
// Update the bound for the next packet.
cc->Outputs()
.Get(output_data_id_)
.SetNextTimestampBound(PeriodIndexToTimestamp(period_count_));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketResamplerCalculator::Close(CalculatorContext* cc) {
if (!cc->GraphStatus().ok()) {
return ::mediapipe::OkStatus();
}
// Emit the last packet received if we have at least one packet, but
// haven't sent anything for its period.
if (first_timestamp_ != Timestamp::Unset() && flush_last_packet_ &&
TimestampToPeriodIndex(last_packet_.Timestamp()) == period_count_) {
OutputWithinLimits(cc,
last_packet_.At(PeriodIndexToTimestamp(period_count_)));
}
return ::mediapipe::OkStatus();
}
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const {
CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset());
return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_);
}
int64 PacketResamplerCalculator::TimestampToPeriodIndex(
Timestamp timestamp) const {
CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset());
return MathUtil::SafeRound<int64, double>(
(timestamp - first_timestamp_).Seconds() * frame_rate_);
}
void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc,
const Packet& packet) const {
TimestampDiff margin((round_limits_) ? frame_time_usec_ / 2 : 0);
if (packet.Timestamp() >= start_time_ - margin &&
packet.Timestamp() < end_time_ + margin) {
cc->Outputs().Get(output_data_id_).AddPacket(packet);
}
}
} // namespace mediapipe

View File

@ -0,0 +1,95 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message PacketResamplerCalculatorOptions {
extend CalculatorOptions {
optional PacketResamplerCalculatorOptions ext = 95743844;
}
// The output frame rate measured in frames per second.
//
// The closest packet in time in each period will be chosen. If there
// is no packet in the period then the most recent packet will be chosen
// (not the closest in time).
optional double frame_rate = 1 [default = -1.0];
enum OutputHeader {
// Do not output a header, even if the input contained one.
NONE = 0;
// Pass the header, if the input contained one.
PASS_HEADER = 1;
// Update the frame rate in the header, which must be of type VideoHeader.
UPDATE_VIDEO_HEADER = 2;
}
// Whether and what kind of header to place on the output stream.
// Note, this is about the actual header, not the VIDEO_HEADER stream.
// If this option is set to UPDATE_VIDEO_HEADER then the header will
// also be parsed (updated) and passed along to the VIDEO_HEADER stream.
optional OutputHeader output_header = 2 [default = NONE];
// Flush last packet even if its timestamp is greater than the final stream
// timestamp.
optional bool flush_last_packet = 3 [default = true];
// Adds jitter to resampling if set, so that Google's sampling is not
// externally deterministic.
//
// When set, the randomizer will be initialized with a seed. Then, the first
// sample is chosen randomly (uniform distribution) among frames that
// correspond to timestamps [0, 1/frame_rate). Let the chosen frame
// correspond to timestamp t. The next frame is chosen randomly (uniform
// distribution) among frames that correspond to [t+(1-jitter)/frame_rate,
// t+(1+jitter)/frame_rate]. t is updated and the process is repeated.
//
// Valid values are in the range of [0.0, 1.0] with the default being 0.0 (no
// jitter). A typical value would be a value in the range of 0.1-0.25.
//
// Note that this does NOT guarantee the desired frame rate, but if the
// pseudo-random number generator does its job and the number of frames is
// sufficiently large, the average frame rate will be close to this value.
optional double jitter = 4;
// If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp.
//
// In order to ensure that the outptut timestamps are reproducible,
// with round_limits = false, the bounds for input timestamps must include:
// [start_time - period / 2, end_time + period / 2],
// with round_limits = true, the bounds for input timestamps must include:
// [start_time - period, end_time + period],
// where period = 1 / frame_rate.
//
// For example, in PacketResamplerCalculatorOptions specify
// "start_time: 3000000", and in MediaDecoderOptions specify
// "start_time: 2999950".
optional int64 base_timestamp = 5;
// If specified, only outputs at/after start_time are included.
optional int64 start_time = 6;
// If specified, only outputs before end_time are included.
optional int64 end_time = 7;
// If set, the output timestamps nearest to start_time and end_time
// are included in the output, even if the nearest timestamp is not
// between start_time and end_time.
optional bool round_limits = 8 [default = false];
}

View File

@ -0,0 +1,679 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A simple version of CalculatorRunner with built-in convenience
// methods for setting inputs from a vector and checking outputs
// against expected outputs (both timestamps and contents).
class SimpleRunner : public CalculatorRunner {
public:
explicit SimpleRunner(const std::string& options_string)
: CalculatorRunner("PacketResamplerCalculator", options_string, 1, 1, 0) {
}
explicit SimpleRunner(const CalculatorGraphConfig::Node& node_config)
: CalculatorRunner(node_config) {}
virtual ~SimpleRunner() {}
void SetInput(const std::vector<int64>& timestamp_list) {
MutableInputs()->Index(0).packets.clear();
for (const int64 ts : timestamp_list) {
MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts)))
.At(Timestamp(ts)));
}
}
void SetVideoHeader(const double frame_rate) {
video_header_.width = static_count_;
video_header_.height = static_count_ * 10;
video_header_.frame_rate = frame_rate;
video_header_.duration = static_count_ * 100.0;
video_header_.format = static_cast<ImageFormat::Format>(
static_count_ % ImageFormat::Format_ARRAYSIZE);
MutableInputs()->Index(0).header = Adopt(new VideoHeader(video_header_));
++static_count_;
}
void CheckOutputTimestamps(
const std::vector<int64>& expected_frames,
const std::vector<int64>& expected_timestamps) const {
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
int count = 0;
for (const Packet& packet : Outputs().Index(0).packets) {
EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp());
const std::string& packet_contents = packet.Get<std::string>();
EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])),
packet_contents);
++count;
}
}
void CheckVideoHeader(const double expected_frame_rate) const {
ASSERT_FALSE(Outputs().Index(0).header.IsEmpty());
const VideoHeader& header = Outputs().Index(0).header.Get<VideoHeader>();
const double frame_rate = header.frame_rate;
EXPECT_EQ(video_header_.width, header.width);
EXPECT_EQ(video_header_.height, header.height);
EXPECT_DOUBLE_EQ(expected_frame_rate, frame_rate);
EXPECT_FLOAT_EQ(video_header_.duration, header.duration);
EXPECT_EQ(video_header_.format, header.format);
}
private:
VideoHeader video_header_;
static int static_count_;
};
int SimpleRunner::static_count_ = 0;
TEST(PacketResamplerCalculatorTest, NoPacketsInStream) {
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({});
MEDIAPIPE_ASSERT_OK(runner.Run());
}
}
TEST(PacketResamplerCalculatorTest, SinglePacketInStream) {
// Stream with 1 packet / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0}, {0});
}
// Stream with 1 packet / 1 period (0 < packet timestamp < first limit).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({1000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({1000}, {1000});
}
// Stream with 1 packet / 1 period (packet timestamp > first limit).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({16668});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({16668}, {16668});
}
}
TEST(PacketResamplerCalculatorTest, TwoPacketsInStream) {
// Stream with 2 packets / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0}, {0});
}
// Stream with 2 packets / 2 periods (left extreme for second period).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16667});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 16667}, {0, 33333});
}
// Stream with 2 packets / 2 periods (right extreme for second period).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 49999});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 49999}, {0, 33333});
}
// Stream with 2 packets / 3 periods (filling 1 in the middle).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 50000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 0, 50000}, {0, 33333, 66667});
}
// Stream with 2 packets / 4 periods (filling 2 in the middle).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({2000, 118666});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({2000, 2000, 2000, 118666},
{2000, 35333, 68667, 102000});
}
}
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepoints) {
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 33333, 66667, 100000, 133333, 166667, 200000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{0, 33333, 66667, 100000, 133333, 166667, 200000},
{0, 33333, 66667, 100000, 133333, 166667, 200000});
}
// When there are several candidates for a period, the one closer to the center
// should be sent to the output.
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriods) {
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 33300, 66600}, {0, 33333, 66667});
}
// When a period must be filled, we use the latest packet received (not
// necessarily the same as the one stored for the best in the previous period).
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacket) {
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 5000, 16666, 83334});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 16666, 16666, 83334},
{0, 33333, 66667, 100000});
}
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666, 16667, 25000, 33000, 35000, 135000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 33000, 35000, 35000, 135000},
{0, 33333, 66667, 100000, 133333});
}
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, SuperHighFrameRate) {
// frame rate == 500000 (a packet will have to be sent every 2 ticks).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:500000}");
runner.SetInput({0, 10, 13});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 0, 0, 0, 0, 10, 10, 13},
{0, 2, 4, 6, 8, 10, 12, 14});
}
// frame rate == 1000000 (a packet will have to be sent in each tick).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:1000000}");
runner.SetInput({0, 10, 13});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 13},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13});
}
}
TEST(PacketResamplerCalculatorTest, NegativeTimestampTest) {
// Stream with negative timestamps / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-200, -20, 16466});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-200}, {-200});
}
// Stream with negative timestamps / 2 periods.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-200, -20, 16467});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-200, 16467}, {-200, 33133});
}
// Stream with negative timestamps and filling an empty period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-500, 66667});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-500, -500, 66667}, {-500, 32833, 66167});
}
// Stream with negative timestamps and initial packet < -period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-50000, -33334, 33334});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-50000, -33334, -33334, 33334},
{-50000, -16667, 16667, 50000});
}
}
TEST(PacketResamplerCalculatorTest, ExactFramesPerSecond) {
// Using frame_rate=50, that makes a period of 20000 microsends (exact).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50}");
runner.SetInput({0, 9999, 29999});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 29999}, {0, 20000});
}
// Test filling empty periods.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50}");
runner.SetInput({0, 10000, 50000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 10000, 10000, 50000},
{0, 20000, 40000, 60000});
}
}
TEST(PacketResamplerCalculatorTest, FrameRateTest) {
// Test changing Frame Rate to the same initial value.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
runner.SetInput({0, 10000, 30000, 50000, 60000});
runner.SetVideoHeader(50.0);
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 10000, 30000, 60000},
{0, 20000, 40000, 60000});
runner.CheckVideoHeader(50.0);
}
// Test changing Frame Rate to new value.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
runner.SetInput({0, 5000, 10010, 15001, 19990});
runner.SetVideoHeader(200.0);
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
runner.CheckVideoHeader(50.0);
}
// Test that the frame rate is not changing if update_video_header = false.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:PASS_HEADER}");
runner.SetInput({0, 5000, 10010, 15001, 19990});
runner.SetVideoHeader(200.0);
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
runner.CheckVideoHeader(200.0);
}
}
TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "PacketResamplerCalculator"
input_stream: "DATA:in_data"
input_stream: "VIDEO_HEADER:in_video_header"
output_stream: "DATA:out_data"
output_stream: "VIDEO_HEADER:out_video_header"
options {
[mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 50.0 }
}
)"));
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
runner.MutableInputs()->Tag("DATA").packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
}
VideoHeader video_header_in;
video_header_in.width = 10;
video_header_in.height = 100;
video_header_in.frame_rate = 1.0;
video_header_in.duration = 1.0;
video_header_in.format = ImageFormat::SRGB;
runner.MutableInputs()
->Tag("VIDEO_HEADER")
.packets.push_back(
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
MEDIAPIPE_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
EXPECT_EQ(Timestamp::PreStream(),
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
const VideoHeader& video_header_out =
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
EXPECT_EQ(video_header_in.width, video_header_out.width);
EXPECT_EQ(video_header_in.height, video_header_out.height);
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
EXPECT_FLOAT_EQ(video_header_in.duration, video_header_out.duration);
EXPECT_EQ(video_header_in.format, video_header_out.format);
}
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithoutRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
MEDIAPIPE_ASSERT_OK(runner.Run());
// 1333333 is not emitted as 2000000, because it does not round to 2000000.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
MEDIAPIPE_ASSERT_OK(runner.Run());
// 1666667 is emitted as 2000000, because it rounds to 2000000.
runner.CheckOutputTimestamps({0, 1000000, 1666667}, {0, 1000000, 2000000});
}
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithoutRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
flush_last_packet: false
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
MEDIAPIPE_ASSERT_OK(runner.Run());
// 1333333 is not emitted no matter what; see FlushLastPacketWithoutRound.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
flush_last_packet: false
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
MEDIAPIPE_ASSERT_OK(runner.Run());
// 1666667 is not emitted due to flush_last_packet: false.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepointsAligned) {
{
// Without base_timestamp, outputs are aligned with the first input
// timestamp, (33333 - 222).
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({33111, 66667, 100000, 133333, 166667, 200000},
{33111, 66444, 99778, 133111, 166444, 199778});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{33111, 66667, 100000, 133333, 166667, 200000},
{33333, 66666, 100000, 133333, 166666, 200000});
}
}
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriodsAligned) {
{
// Without base_timestamp, outputs are aligned with the first input, -222.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 33300, 66600}, {-222, 33111, 66445});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 900011.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:900011}");
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 33300, 66600}, {11, 33344, 66678});
}
{
// With base_timestamp, outputs still approximate input timestamps,
// while aligned to base_timestamp, 11.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:11}");
runner.SetInput(
{899888, 916666, 916667, 920000, 933300, 949999, 950000, 966600});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({899888, 933300, 966600},
{900011, 933344, 966678});
}
}
// When a period must be filled, we use the latest packet received.
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacketAligned) {
{
// Without base_timestamp, outputs are aligned with the first input, -222.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{-222, 33111, 66445, 99778, 133111, 166445});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
}
// When base_timestamp is specified, output timestamps are aligned with it.
// The first packet is included, because we assume that the input includes the
// whole first sampling interval.
TEST(PacketResamplerCalculatorTest, FirstInputAfterMiddlepointAligned) {
{
// Packet 100020 is omitted from the output sequence because
// packet 99990 is closer to the period midpoint.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({66667, 100020, 133333, 166667});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({66667, 100020, 133333, 166667},
{66667, 100000, 133334, 166667});
}
{
// If we seek to packet 100020, packet 100020 is included in
// the output sequence, because we assume that the input includes the
// whole first sampling interval.
//
// We assume that the input includes whole sampling intervals
// in order to produce "reproducible timestamps", which are timestamps
// from the series of timestamps starting at 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({100020, 133333, 166667});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({100020, 133333, 166667},
{100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, OutputTimestampRangeAligned) {
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
{
// With start_time, end_time, outputs are filtered.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0 "
"start_time:40000 "
"end_time:160000}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({49999, 49999, 49999},
{66667, 100000, 133333});
}
{
// With start_time, end_time, round_limits, outputs are filtered,
// rounding to the nearest limit.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0 "
"start_time:40000 "
"end_time:160000 "
"round_limits:true}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({32000, 49999, 49999, 49999, 150000},
{33333, 66667, 100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "PacketResamplerCalculator"
input_side_packet: "OPTIONS:options"
input_stream: "input"
output_stream: "output"
options {
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 60
base_timestamp: 0
}
})");
{
SimpleRunner runner(node_config);
auto options =
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(
R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30
})"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
}
{
SimpleRunner runner(node_config);
auto options =
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(R"(
merge_fields: false
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30
base_timestamp: 0
})"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,98 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
namespace mediapipe {
// A Calculator that simply passes its input Packets and header through,
// unchanged. The inputs may be specified by tag or index. The outputs
// must match the inputs exactly. Any number of input side packets may
// also be specified. If output side packets are specified, they must
// match the input side packets exactly and the Calculator passes its
// input side packets through, unchanged. Otherwise, the input side
// packets will be ignored (allowing PassThroughCalculator to be used to
// test internal behavior). Any options may be specified and will be
// ignored.
class PassThroughCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) {
return ::mediapipe::InvalidArgumentError(
"Input and output streams to PassThroughCalculator must use "
"matching tags and indexes.");
}
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
cc->Inputs().Get(id).SetAny();
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id));
}
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->InputSidePackets().Get(id).SetAny();
}
if (cc->OutputSidePackets().NumEntries() != 0) {
if (!cc->InputSidePackets().TagMap()->SameAs(
*cc->OutputSidePackets().TagMap())) {
return ::mediapipe::InvalidArgumentError(
"Input and output side packets to PassThroughCalculator must use "
"matching tags and indexes.");
}
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->OutputSidePackets().Get(id).SetSameAs(
&cc->InputSidePackets().Get(id));
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
if (!cc->Inputs().Get(id).Header().IsEmpty()) {
cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header());
}
}
if (cc->OutputSidePackets().NumEntries() != 0) {
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id));
}
}
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->GetCounter("PassThrough")->Increment();
if (cc->Inputs().NumEntries() == 0) {
return tool::StatusStop();
}
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
if (!cc->Inputs().Get(id).IsEmpty()) {
VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to "
<< cc->Outputs().Get(id).Name() << " at "
<< cc->InputTimestamp().DebugString();
cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value());
}
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(PassThroughCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,118 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <deque>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// PreviousLoopbackCalculator is useful when a graph needs to process an input
// together with some previous output.
//
// For the first packet that arrives on the MAIN input, the timestamp bound is
// advanced on the output. Downstream calculators will see this as an empty
// packet. This way they are not kept waiting for the previous output, which
// for the first iteration does not exist.
//
// Thereafter, each packet received on MAIN is matched with a packet received
// on LOOP; the LOOP packet's timestamp is changed to that of the MAIN packet,
// and it is output on PREV_LOOP.
//
// Example config:
// node {
// calculator: "PreviousLoopbackCalculator"
// input_stream: "MAIN:input"
// input_stream: "LOOP:output"
// input_stream_info: { tag_index: 'LOOP' back_edge: true }
// output_stream: "PREV_LOOP:prev_output"
// }
// node {
// calculator: "FaceTracker"
// input_stream: "VIDEO:input"
// input_stream: "PREV_TRACK:prev_output"
// output_stream: "TRACK:output"
// }
class PreviousLoopbackCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Get("MAIN", 0).SetAny();
cc->Inputs().Get("LOOP", 0).SetAny();
cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0)));
// TODO: an optional PREV_TIMESTAMP output could be added to
// carry the original timestamp of the packet on PREV_LOOP.
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
main_id_ = cc->Inputs().GetId("MAIN", 0);
loop_id_ = cc->Inputs().GetId("LOOP", 0);
loop_out_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
cc->Outputs()
.Get(loop_out_id_)
.SetHeader(cc->Inputs().Get(loop_id_).Header());
// Use an empty packet for the first round, since there is no previous
// output.
loopback_packets_.push_back({});
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
Packet& main_packet = cc->Inputs().Get(main_id_).Value();
if (!main_packet.IsEmpty()) {
main_ts_.push_back(main_packet.Timestamp());
}
Packet& loopback_packet = cc->Inputs().Get(loop_id_).Value();
if (!loopback_packet.IsEmpty()) {
loopback_packets_.push_back(loopback_packet);
while (!main_ts_.empty() &&
main_ts_.front() <= loopback_packets_.front().Timestamp()) {
main_ts_.pop_front();
}
}
while (!main_ts_.empty() && !loopback_packets_.empty()) {
Timestamp main_timestamp = main_ts_.front();
main_ts_.pop_front();
Packet previous_loopback = loopback_packets_.front().At(main_timestamp);
loopback_packets_.pop_front();
if (previous_loopback.IsEmpty()) {
// TODO: SetCompleteTimestampBound would be more useful.
cc->Outputs()
.Get(loop_out_id_)
.SetNextTimestampBound(main_timestamp + 1);
} else {
cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback));
}
}
return ::mediapipe::OkStatus();
}
private:
CollectionItemId main_id_;
CollectionItemId loop_id_;
CollectionItemId loop_out_id_;
std::deque<Timestamp> main_ts_;
std::deque<Packet> loopback_packets_;
};
REGISTER_CALCULATOR(PreviousLoopbackCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,111 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
namespace {
// Returns the timestamp values for a vector of Packets.
// TODO: puth this kind of test util in a common place.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& packet : packets) {
result.push_back(packet.Timestamp().Value());
}
return result;
}
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
std::vector<Packet> in_prev;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous2'
output_stream: 'pair'
}
)");
tool::AddVectorSink("pair", &graph_config_, &in_prev);
CalculatorGraph graph_;
MEDIAPIPE_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
auto pair_values = [](const Packet& packet) {
auto pair = packet.Get<std::pair<Packet, Packet>>();
int first = pair.first.IsEmpty() ? -1 : pair.first.Get<int>();
int second = pair.second.IsEmpty() ? -1 : pair.second.Get<int>();
return std::make_pair(first, second);
};
send_packet("in", 1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(1, -1));
send_packet("in", 5);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 5}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(5, 1));
send_packet("in", 15);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 5, 15}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(15, 5));
MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams());
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,199 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <utility>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h"
namespace mediapipe {
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
// processing operations in a section of the graph.
//
// Typical topology:
//
// in ->-[RTFLC]-[foo]-...-[bar]-+->- out
// ^____________________|
// FINISHED
//
// By connecting the output of the graph section to this calculator's FINISHED
// input with a backwards edge, this allows RTFLC to keep track of how many
// timestamps are currently being processed.
//
// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side
// packet.
//
// As long as the number of timestamps being processed ("in flight") is below
// the limit, RTFLC allows input to pass through. When the limit is reached,
// RTFLC starts dropping input packets, keeping only the most recent. When the
// processing count decreases again, as signaled by the receipt of a packet on
// FINISHED, RTFLC allows packets to flow again, releasing the most recently
// queued packet, if any.
//
// If there are multiple input streams, packet dropping is synchronized.
//
// IMPORTANT: for each timestamp where RTFLC forwards a packet (or a set of
// packets, if using multiple data streams), a packet must eventually arrive on
// the FINISHED stream. Dropping packets in the section between RTFLC and
// FINISHED will make the in-flight count incorrect.
//
// TODO: Remove this comment when graph-level ISH has been removed.
// NOTE: this calculator should always use the ImmediateInputStreamHandler and
// uses it by default. However, if the graph specifies a graph-level
// InputStreamHandler, to override that setting, the InputStreamHandler must
// be explicitly specified as shown below.
//
// Example config:
// node {
// calculator: "RealTimeFlowLimiterCalculator"
// input_stream: "raw_frames"
// input_stream: "FINISHED:finished"
// input_stream_info: {
// tag_index: 'FINISHED'
// back_edge: true
// }
// input_stream_handler {
// input_stream_handler: 'ImmediateInputStreamHandler'
// }
// output_stream: "gated_frames"
// }
class RealTimeFlowLimiterCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
int num_data_streams = cc->Inputs().NumEntries("");
RET_CHECK_GE(num_data_streams, 1);
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
<< "Output streams must correspond input streams except for the "
"finish indicator input stream.";
for (int i = 0; i < num_data_streams; ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
}
cc->Inputs().Get("FINISHED", 0).SetAny();
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>();
}
if (cc->Outputs().HasTag("ALLOW")) {
cc->Outputs().Tag("ALLOW").Set<bool>();
}
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
finished_id_ = cc->Inputs().GetId("FINISHED", 0);
max_in_flight_ = 1;
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>();
}
RET_CHECK_GE(max_in_flight_, 1);
num_in_flight_ = 0;
allowed_id_ = cc->Outputs().GetId("ALLOW", 0);
allow_ctr_ts_ = Timestamp(0);
num_data_streams_ = cc->Inputs().NumEntries("");
data_stream_bound_ts_.resize(num_data_streams_);
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
return ::mediapipe::OkStatus();
}
bool Allow() { return num_in_flight_ < max_in_flight_; }
::mediapipe::Status Process(CalculatorContext* cc) final {
bool old_allow = Allow();
Timestamp lowest_incomplete_ts = Timestamp::Done();
// Process FINISHED stream.
if (!cc->Inputs().Get(finished_id_).Value().IsEmpty()) {
RET_CHECK_GT(num_in_flight_, 0)
<< "Received a FINISHED packet, but we had none in flight.";
--num_in_flight_;
}
// Process data streams.
for (int i = 0; i < num_data_streams_; ++i) {
auto& stream = cc->Inputs().Get("", i);
auto& out = cc->Outputs().Get("", i);
Packet& packet = stream.Value();
auto ts = packet.Timestamp();
if (ts.IsRangeValue() && data_stream_bound_ts_[i] <= ts) {
data_stream_bound_ts_[i] = ts + 1;
// Note: it's ok to update the output bound here, before sending the
// packet, because updates are batched during the Process function.
out.SetNextTimestampBound(data_stream_bound_ts_[i]);
}
lowest_incomplete_ts =
std::min(lowest_incomplete_ts, data_stream_bound_ts_[i]);
if (packet.IsEmpty()) {
// If the input stream is closed, close the corresponding output.
if (stream.IsDone() && !out.IsClosed()) {
out.Close();
}
// TODO: if the packet is empty, the ts is unset, and we
// cannot read the timestamp bound, even though we'd like to propagate
// it.
} else if (mediapipe::ContainsKey(pending_ts_, ts)) {
// If we have already sent this timestamp (on another stream), send it
// on this stream too.
out.AddPacket(std::move(packet));
} else if (Allow() && (ts > last_dropped_ts_)) {
// If the in-flight is under the limit, and if we have not already
// dropped this or a later timestamp on another stream, then send
// the packet and add an in-flight timestamp.
out.AddPacket(std::move(packet));
pending_ts_.insert(ts);
++num_in_flight_;
} else {
// Otherwise, we'll drop the packet.
last_dropped_ts_ = std::max(last_dropped_ts_, ts);
}
}
// Remove old pending_ts_ entries.
auto it = std::lower_bound(pending_ts_.begin(), pending_ts_.end(),
lowest_incomplete_ts);
pending_ts_.erase(pending_ts_.begin(), it);
// Update ALLOW signal.
if ((old_allow != Allow()) && allowed_id_.IsValid()) {
cc->Outputs()
.Get(allowed_id_)
.AddPacket(MakePacket<bool>(Allow()).At(++allow_ctr_ts_));
}
return ::mediapipe::OkStatus();
}
private:
std::set<Timestamp> pending_ts_;
Timestamp last_dropped_ts_;
int num_data_streams_;
int num_in_flight_;
int max_in_flight_;
CollectionItemId finished_id_;
CollectionItemId allowed_id_;
Timestamp allow_ctr_ts_;
std::vector<Timestamp> data_stream_bound_ts_;
};
REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,496 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
namespace {
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_.fetch_add(amount); }
private:
std::atomic<int64_t> supply_;
};
// Returns the timestamp values for a vector of Packets.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& packet : packets) {
result.push_back(packet.Timestamp().Value());
}
return result;
}
// Returns the packet values for a vector of Packets.
template <typename T>
std::vector<T> PacketValues(const std::vector<Packet>& packets) {
std::vector<T> result;
for (const Packet& packet : packets) {
result.push_back(packet.Get<T>());
}
return result;
}
constexpr int kNumImageFrames = 5;
constexpr int kNumFinished = 3;
CalculatorGraphConfig::Node GetDefaultNode() {
return ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "RealTimeFlowLimiterCalculator"
input_stream: "raw_frames"
input_stream: "FINISHED:finished"
input_stream_info: { tag_index: "FINISHED" back_edge: true }
output_stream: "gated_frames"
)");
}
// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs
// just one packet when MAX_IN_FLIGHT is 1.
TEST(RealTimeFlowLimiterCalculator, OneOutputTest) {
// Setup the calculator runner and add only ImageFrame packets.
CalculatorRunner runner(GetDefaultNode());
for (int i = 0; i < kNumImageFrames; ++i) {
Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond);
runner.MutableInputs()->Index(0).packets.push_back(
MakePacket<ImageFrame>().At(timestamp));
}
// Run the calculator.
MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& frame_output_packets =
runner.Outputs().Index(0).packets;
EXPECT_EQ(frame_output_packets.size(), 1);
}
// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all
// input streams to have at least one packet available before publishing.
TEST(RealTimeFlowLimiterCalculator, BasicTest) {
// Setup the calculator runner and add both ImageFrame and finish packets.
CalculatorRunner runner(GetDefaultNode());
for (int i = 0; i < kNumImageFrames; ++i) {
Timestamp timestamp = Timestamp(i * Timestamp::kTimestampUnitsPerSecond);
runner.MutableInputs()->Index(0).packets.push_back(
MakePacket<ImageFrame>().At(timestamp));
}
for (int i = 0; i < kNumFinished; ++i) {
Timestamp timestamp =
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
runner.MutableInputs()
->Tag("FINISHED")
.packets.push_back(MakePacket<bool>(true).At(timestamp));
}
// Run the calculator.
MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& frame_output_packets =
runner.Outputs().Index(0).packets;
// Only outputs packets if both input streams are available.
int expected_num_packets = std::min(kNumImageFrames, kNumFinished + 1);
EXPECT_EQ(frame_output_packets.size(), expected_num_packets);
}
// A Calculator::Process callback function.
typedef std::function<::mediapipe::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A testing callback function that passes through all packets.
::mediapipe::Status PassthroughFunction(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return ::mediapipe::OkStatus();
}
// A Calculator that runs a testing callback function in Close.
class CloseCallbackCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
cc->Inputs().Get(id).SetAny();
}
for (CollectionItemId id = cc->Outputs().BeginId();
id < cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).SetAny();
}
cc->InputSidePackets().Index(0).Set<std::function<::mediapipe::Status()>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return PassthroughFunction(cc->Inputs(), &(cc->Outputs()));
}
::mediapipe::Status Close(CalculatorContext* cc) override {
const auto& callback = cc->InputSidePackets()
.Index(0)
.Get<std::function<::mediapipe::Status()>>();
return callback();
}
};
REGISTER_CALCULATOR(CloseCallbackCalculator);
// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic
// graph.
// TODO: clean up these tests.
class RealTimeFlowLimiterCalculatorTest : public testing::Test {
public:
RealTimeFlowLimiterCalculatorTest()
: enter_semaphore_(0), exit_semaphore_(0) {}
void SetUp() override {
graph_config_ = InflightGraphConfig();
tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_);
tool::AddVectorSink("out_2", &graph_config_, &out_2_packets_);
}
void InitializeGraph(int max_in_flight) {
ProcessFunction semaphore_0_func = [&](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
enter_semaphore_.Release(1);
return PassthroughFunction(inputs, outputs);
};
ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
exit_semaphore_.Acquire(1);
return PassthroughFunction(inputs, outputs);
};
std::function<::mediapipe::Status()> close_func = [this]() {
close_count_++;
return ::mediapipe::OkStatus();
};
MEDIAPIPE_ASSERT_OK(graph_.Initialize(
graph_config_, {
{"max_in_flight", MakePacket<int>(max_in_flight)},
{"callback_0", Adopt(new auto(semaphore_0_func))},
{"callback_1", Adopt(new auto(semaphore_1_func))},
{"callback_2", Adopt(new auto(close_func))},
}));
}
// Adds a packet to a graph input stream.
void AddPacket(const std::string& input_name, int value) {
MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(value).At(Timestamp(value))));
}
// A calculator graph starting with an RealTimeFlowLimiterCalculator and
// ending with a InFlightFinishCalculator.
// Back-edge "finished" limits processing to one frame in-flight.
// The two LambdaCalculators are used to keep certain packet sets in flight.
CalculatorGraphConfig InflightGraphConfig() {
return ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'RealTimeFlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in_1'
input_stream: 'in_2'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'in_2_sampled'
}
node {
calculator: 'LambdaCalculator'
input_side_packet: 'callback_0'
input_stream: 'in_1_sampled'
input_stream: 'in_2_sampled'
output_stream: 'queue_1'
output_stream: 'queue_2'
}
node {
calculator: 'LambdaCalculator'
input_side_packet: 'callback_1'
input_stream: 'queue_1'
input_stream: 'queue_2'
output_stream: 'close_1'
output_stream: 'close_2'
}
node {
calculator: 'CloseCallbackCalculator'
input_side_packet: 'callback_2'
input_stream: 'close_1'
input_stream: 'close_2'
output_stream: 'out_1'
output_stream: 'out_2'
}
)");
}
protected:
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
AtomicSemaphore enter_semaphore_;
AtomicSemaphore exit_semaphore_;
std::vector<Packet> out_1_packets_;
std::vector<Packet> out_2_packets_;
int close_count_ = 0;
};
// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic
// graph. This test shows that:
//
// (1) Timestamps are passed through unaltered.
// (2) All output streams including the back_edge stream are closed when
// the first input stream is closed.
//
TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) {
InitializeGraph(1);
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [this](const std::string& input_name, int64 n) {
MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int64>(n).At(Timestamp(n))));
};
for (int i = 0; i < 10; i++) {
send_packet("in_1", i * 10);
// This next input should be dropped.
send_packet("in_1", i * 10 + 5);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
send_packet("in_2", i * 10);
exit_semaphore_.Release(1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
}
MEDIAPIPE_EXPECT_OK(graph_.CloseInputStream("in_1"));
MEDIAPIPE_EXPECT_OK(graph_.CloseInputStream("in_2"));
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
// All output streams are closed and all output packets are delivered,
// with stream "in_1" and stream "in_2" closed.
EXPECT_EQ(10, out_1_packets_.size());
EXPECT_EQ(10, out_2_packets_.size());
// Timestamps have not been messed with.
EXPECT_EQ(PacketValues<int64>(out_1_packets_),
TimestampValues(out_1_packets_));
EXPECT_EQ(PacketValues<int64>(out_2_packets_),
TimestampValues(out_2_packets_));
// Extra inputs on in_1 have been dropped
EXPECT_EQ(TimestampValues(out_1_packets_),
(std::vector<int64>{0, 10, 20, 30, 40, 50, 60, 70, 80, 90}));
EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_));
// The closing of the stream has been propagated.
EXPECT_EQ(1, close_count_);
}
// A test demonstrating that all output streams are closed when all
// input streams are closed after the last input packet has been processed.
TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) {
InitializeGraph(1);
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
exit_semaphore_.Release(10);
for (int i = 0; i < 10; i++) {
AddPacket("in_1", i);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
AddPacket("in_2", i);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
}
MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams());
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(out_1_packets_), TimestampValues(out_2_packets_));
EXPECT_EQ(TimestampValues(out_1_packets_),
(std::vector<int64>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
EXPECT_EQ(1, close_count_);
}
TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
std::vector<Packet> a_passed;
std::vector<Packet> b_passed;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in_a'
input_stream: 'in_b'
input_stream: 'finished'
node {
name: 'input_dropper'
calculator: 'RealTimeFlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in_a'
input_stream: 'in_b'
input_stream: 'FINISHED:finished'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_a_sampled'
output_stream: 'in_b_sampled'
output_stream: 'ALLOW:allow'
}
)");
std::string allow_cb_name;
tool::AddVectorSink("in_a_sampled", &graph_config_, &a_passed);
tool::AddVectorSink("in_b_sampled", &graph_config_, &b_passed);
tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true);
bool allow = true;
auto allow_cb = [&allow](const Packet& packet) {
allow = packet.Get<bool>();
};
CalculatorGraph graph_;
MEDIAPIPE_EXPECT_OK(graph_.Initialize(
graph_config_,
{
{"max_in_flight", MakePacket<int>(1)},
{allow_cb_name,
MakePacket<std::function<void(const Packet&)>>(allow_cb)},
}));
MEDIAPIPE_EXPECT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in_a", 1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(allow, false);
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{}));
send_packet("in_a", 2);
send_packet("in_b", 1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1}));
EXPECT_EQ(allow, false);
send_packet("finished", 1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1}));
EXPECT_EQ(allow, true);
send_packet("in_b", 2);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1}));
EXPECT_EQ(allow, true);
send_packet("in_b", 3);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(allow, false);
send_packet("in_b", 4);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(allow, false);
send_packet("in_a", 3);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(allow, false);
send_packet("finished", 3);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(a_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(TimestampValues(b_passed), (std::vector<int64>{1, 3}));
EXPECT_EQ(allow, true);
MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams());
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone());
}
TEST(RealTimeFlowLimiterCalculator, CanConsume) {
std::vector<Packet> in_sampled_packets_;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
input_stream: 'finished'
node {
name: 'input_dropper'
calculator: 'RealTimeFlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in'
input_stream: 'FINISHED:finished'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_sampled'
output_stream: 'ALLOW:allow'
}
)");
std::string allow_cb_name;
tool::AddVectorSink("in_sampled", &graph_config_, &in_sampled_packets_);
tool::AddCallbackCalculator("allow", &graph_config_, &allow_cb_name, true);
bool allow = true;
auto allow_cb = [&allow](const Packet& packet) {
allow = packet.Get<bool>();
};
CalculatorGraph graph_;
MEDIAPIPE_EXPECT_OK(graph_.Initialize(
graph_config_,
{
{"max_in_flight", MakePacket<int>(1)},
{allow_cb_name,
MakePacket<std::function<void(const Packet&)>>(allow_cb)},
}));
MEDIAPIPE_EXPECT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MEDIAPIPE_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 1);
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(allow, false);
EXPECT_EQ(TimestampValues(in_sampled_packets_), (std::vector<int64>{1}));
MEDIAPIPE_EXPECT_OK(in_sampled_packets_[0].Consume<int>());
MEDIAPIPE_EXPECT_OK(graph_.CloseAllInputStreams());
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,120 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// Forwards the input packet to one of the n output streams "OUTPUT:0",
// "OUTPUT:1", ..., in round robin fashion. The index of the selected output
// stream is emitted to the output stream "SELECT". If not needed, the
// output stream "SELECT" may be omitted.
//
// Designed to run graph bottlenecks in parallel and thus reduce graph
// processing latency by parallelizing.
//
// A simple example config is:
//
// node {
// calculator: "RoundRobinDemuxCalculator"
// input_stream: "signal"
// output_stream: "OUTPUT:0:signal0"
// output_stream: "OUTPUT:1:signal1"
// output_stream: "SELECT:select"
// }
//
// node {
// calculator: "SlowCalculator"
// input_stream: "signal0"
// output_stream: "output0"
// }
//
// node {
// calculator: "SlowCalculator"
// input_stream: "signal1"
// output_stream: "output1"
// }
//
// node {
// calculator: "MuxCalculator"
// input_stream: "INPUT:0:output0"
// input_stream: "INPUT:1:output1"
// input_stream: "SELECT:select"
// output_stream: "OUTPUT:output"
// input_stream_handler {
// input_stream_handler: "MuxInputStreamHandler"
// }
// }
//
// which is essentially running the following configuration in parallel with a
// concurrency level of two:
//
// node {
// calculator: "SlowCalculator"
// input_stream: "signal"
// output_stream: "output"
// }
//
// If SlowCalculator has more than one output stream, the user can group the
// output with MakePairCalculator, MakeVectorCalculator, or a similar variant to
// use it with MuxCalculator and later unpack, or can create new variants of
// MuxCalculator/MuxInputStreamHandler.
class RoundRobinDemuxCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1);
cc->Inputs().Index(0).SetAny();
if (cc->Outputs().HasTag("SELECT")) {
cc->Outputs().Tag("SELECT").Set<int>();
}
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) {
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Index(0));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
select_output_ = cc->Outputs().GetId("SELECT", 0);
output_data_stream_index_ = 0;
output_data_stream_base_ = cc->Outputs().GetId("OUTPUT", 0);
num_output_data_streams_ = cc->Outputs().NumEntries("OUTPUT");
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
cc->Outputs()
.Get(output_data_stream_base_ + output_data_stream_index_)
.AddPacket(cc->Inputs().Index(0).Value());
if (select_output_.IsValid()) {
cc->Outputs()
.Get(select_output_)
.Add(new int(output_data_stream_index_), cc->InputTimestamp());
}
output_data_stream_index_ =
(output_data_stream_index_ + 1) % num_output_data_streams_;
return ::mediapipe::OkStatus();
}
private:
CollectionItemId select_output_;
CollectionItemId output_data_stream_base_;
int num_output_data_streams_;
int output_data_stream_index_;
};
REGISTER_CALCULATOR(RoundRobinDemuxCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,415 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "opencv_image_encoder_calculator_proto",
srcs = ["opencv_image_encoder_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "scale_image_calculator_proto",
srcs = ["scale_image_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:image_format_proto",
],
)
proto_library(
name = "set_alpha_calculator_proto",
srcs = ["set_alpha_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "recolor_calculator_proto",
srcs = ["recolor_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
],
)
mediapipe_cc_proto_library(
name = "opencv_image_encoder_calculator_cc_proto",
srcs = ["opencv_image_encoder_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":opencv_image_encoder_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "mask_overlay_calculator_cc_proto",
srcs = ["mask_overlay_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":mask_overlay_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "scale_image_calculator_cc_proto",
srcs = ["scale_image_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/formats:image_format_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":scale_image_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "set_alpha_calculator_cc_proto",
srcs = ["set_alpha_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":set_alpha_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "recolor_calculator_cc_proto",
srcs = ["recolor_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/util:color_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":recolor_calculator_proto"],
)
cc_library(
name = "color_convert_calculator",
srcs = ["color_convert_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "opencv_encoded_image_to_image_frame_calculator",
srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "opencv_image_encoder_calculator",
srcs = ["opencv_image_encoder_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":opencv_image_encoder_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "opencv_put_text_calculator",
srcs = ["opencv_put_text_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "set_alpha_calculator",
srcs = ["set_alpha_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":set_alpha_calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
proto_library(
name = "image_transformation_calculator_proto",
srcs = ["image_transformation_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
],
)
mediapipe_cc_proto_library(
name = "image_transformation_calculator_cc_proto",
srcs = ["image_transformation_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/gpu:scale_mode_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":image_transformation_calculator_proto"],
)
cc_library(
name = "image_transformation_calculator",
srcs = ["image_transformation_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":image_transformation_calculator_cc_proto",
"//mediapipe/gpu:scale_mode_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
cc_library(
name = "luminance_calculator",
srcs = ["luminance_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_simple_calculator",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:shader_util",
],
alwayslink = 1,
)
cc_library(
name = "sobel_edges_calculator",
srcs = ["sobel_edges_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_simple_calculator",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:shader_util",
],
alwayslink = 1,
)
cc_library(
name = "recolor_calculator",
srcs = ["recolor_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":recolor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
cc_library(
name = "scale_image_utils",
srcs = ["scale_image_utils.cc"],
hdrs = ["scale_image_utils.h"],
visibility = [
"//mediapipe:__subpackages__",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "scale_image_calculator",
srcs = ["scale_image_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":scale_image_utils",
"//mediapipe/calculators/image:scale_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/formats:yuv_image",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:image_resizer",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:image_frame_util",
"@com_google_absl//absl/strings",
"@libyuv",
],
alwayslink = 1,
)
cc_test(
name = "opencv_encoded_image_to_image_frame_calculator_test",
srcs = ["opencv_encoded_image_to_image_frame_calculator_test.cc"],
data = ["//mediapipe/calculators/image/testdata:test_images"],
deps = [
":opencv_encoded_image_to_image_frame_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_test(
name = "opencv_image_encoder_calculator_test",
srcs = ["opencv_image_encoder_calculator_test.cc"],
data = ["//mediapipe/calculators/image/testdata:test_images"],
deps = [
":opencv_image_encoder_calculator",
":opencv_image_encoder_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_test(
name = "scale_image_utils_test",
srcs = ["scale_image_utils_test.cc"],
deps = [
":scale_image_utils",
"//mediapipe/framework/port:gtest_main",
],
)
proto_library(
name = "mask_overlay_calculator_proto",
srcs = ["mask_overlay_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
cc_library(
name = "mask_overlay_calculator",
srcs = ["mask_overlay_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":mask_overlay_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:shader_util",
],
alwayslink = 1,
)

View File

@ -0,0 +1,160 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h"
namespace mediapipe {
namespace {
void SetColorChannel(int channel, uint8 value, cv::Mat* mat) {
CHECK(mat->depth() == CV_8U);
CHECK(channel < mat->channels());
const int step = mat->channels();
for (int r = 0; r < mat->rows; ++r) {
uint8* row_ptr = mat->ptr<uint8>(r);
for (int offset = channel; offset < mat->cols * step; offset += step) {
row_ptr[offset] = value;
}
}
}
constexpr char kRgbaInTag[] = "RGBA_IN";
constexpr char kRgbInTag[] = "RGB_IN";
constexpr char kGrayInTag[] = "GRAY_IN";
constexpr char kRgbaOutTag[] = "RGBA_OUT";
constexpr char kRgbOutTag[] = "RGB_OUT";
constexpr char kGrayOutTag[] = "GRAY_OUT";
} // namespace
// A portable color conversion calculator calculator.
//
// The following conversions are currently supported, but it's fairly easy to
// add new ones if this doesn't meet your needs--Don't forget to add a test to
// color_convert_calculator_test.cc if you do!
// RGBA -> RGB
// GRAY -> RGB
// RGB -> GRAY
// RGB -> RGBA
//
// This calculator only supports a single input stream and output stream at a
// time. If more than one input stream or output stream is present, the
// calculator will fail at FillExpectations.
// TODO: Remove this requirement by replacing the typed input streams
// with a single generic input and allow multiple simultaneous outputs.
//
// Input streams:
// RGBA_IN: The input video stream (ImageFrame, SRGBA).
// RGB_IN: The input video stream (ImageFrame, SRGB).
// GRAY_IN: The input video stream (ImageFrame, GRAY8).
//
// Output streams:
// RGBA_OUT: The output video stream (ImageFrame, SRGBA).
// RGB_OUT: The output video stream (ImageFrame, SRGB).
// GRAY_OUT: The output video stream (ImageFrame, GRAY8).
class ColorConvertCalculator : public CalculatorBase {
public:
~ColorConvertCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// Wrangles the appropriate inputs and outputs to perform the color
// conversion. The ImageFrame on input_tag is converted using the
// open_cv_convert_code provided and then output on the output_tag stream.
// Note that the output_format must match the destination conversion code.
::mediapipe::Status ConvertAndOutput(const std::string& input_tag,
const std::string& output_tag,
ImageFormat::Format output_format,
int open_cv_convert_code,
CalculatorContext* cc);
};
REGISTER_CALCULATOR(ColorConvertCalculator);
::mediapipe::Status ColorConvertCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is allowed.";
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is allowed.";
if (cc->Inputs().HasTag(kRgbaInTag)) {
cc->Inputs().Tag(kRgbaInTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag(kGrayInTag)) {
cc->Inputs().Tag(kGrayInTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag(kRgbInTag)) {
cc->Inputs().Tag(kRgbInTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag(kRgbOutTag)) {
cc->Outputs().Tag(kRgbOutTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag(kGrayOutTag)) {
cc->Outputs().Tag(kGrayOutTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag(kRgbaOutTag)) {
cc->Outputs().Tag(kRgbaOutTag).Set<ImageFrame>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ColorConvertCalculator::ConvertAndOutput(
const std::string& input_tag, const std::string& output_tag,
ImageFormat::Format output_format, int open_cv_convert_code,
CalculatorContext* cc) {
const cv::Mat& input_mat =
formats::MatView(&cc->Inputs().Tag(input_tag).Get<ImageFrame>());
std::unique_ptr<ImageFrame> output_frame(
new ImageFrame(output_format, input_mat.cols, input_mat.rows));
cv::Mat output_mat = formats::MatView(output_frame.get());
cv::cvtColor(input_mat, output_mat, open_cv_convert_code);
// cv::cvtColor will leave the alpha channel set to 0, which is a bizarre
// design choice. Instead, let's set alpha to 255.
if (open_cv_convert_code == cv::COLOR_RGB2RGBA) {
SetColorChannel(3, 255, &output_mat);
}
cc->Outputs()
.Tag(output_tag)
.Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status ColorConvertCalculator::Process(CalculatorContext* cc) {
// RGBA -> RGB
if (cc->Inputs().HasTag(kRgbaInTag) && cc->Outputs().HasTag(kRgbOutTag)) {
return ConvertAndOutput(kRgbaInTag, kRgbOutTag, ImageFormat::SRGB,
cv::COLOR_RGBA2RGB, cc);
}
// GRAY -> RGB
if (cc->Inputs().HasTag(kGrayInTag) && cc->Outputs().HasTag(kRgbOutTag)) {
return ConvertAndOutput(kGrayInTag, kRgbOutTag, ImageFormat::SRGB,
cv::COLOR_GRAY2RGB, cc);
}
// RGB -> GRAY
if (cc->Inputs().HasTag(kRgbInTag) && cc->Outputs().HasTag(kGrayOutTag)) {
return ConvertAndOutput(kRgbInTag, kGrayOutTag, ImageFormat::GRAY8,
cv::COLOR_RGB2GRAY, cc);
}
// RGB -> RGBA
if (cc->Inputs().HasTag(kRgbInTag) && cc->Outputs().HasTag(kRgbaOutTag)) {
return ConvertAndOutput(kRgbInTag, kRgbaOutTag, ImageFormat::SRGBA,
cv::COLOR_RGB2RGBA, cc);
}
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Unsupported image format conversion.";
}
} // namespace mediapipe

View File

@ -0,0 +1,465 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/scale_mode.pb.h"
#if defined(__ANDROID__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_quad_renderer.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
#if defined(__ANDROID__)
// The size of Java arrays is dynamic, which makes it difficult to
// generate the right packet type with a fixed size. Therefore, we
// are using unsized arrays on Android.
typedef int DimensionsPacketType[];
#else
typedef int DimensionsPacketType[2];
#endif // __ANDROID__
#define DEFAULT_SCALE_MODE mediapipe::ScaleMode_Mode_STRETCH
namespace mediapipe {
#if defined(__ANDROID__)
#endif // __ANDROID__
namespace {
int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) {
switch (rotation) {
case mediapipe::RotationMode_Mode_UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0:
return 0;
case mediapipe::RotationMode_Mode_ROTATION_90:
return 90;
case mediapipe::RotationMode_Mode_ROTATION_180:
return 180;
case mediapipe::RotationMode_Mode_ROTATION_270:
return 270;
}
}
mediapipe::RotationMode_Mode DegreesToRotationMode(int degrees) {
switch (degrees) {
case 0:
return mediapipe::RotationMode_Mode_ROTATION_0;
case 90:
return mediapipe::RotationMode_Mode_ROTATION_90;
case 180:
return mediapipe::RotationMode_Mode_ROTATION_180;
case 270:
return mediapipe::RotationMode_Mode_ROTATION_270;
default:
return mediapipe::RotationMode_Mode_UNKNOWN;
}
}
mediapipe::ScaleMode_Mode ParseScaleMode(
mediapipe::ScaleMode_Mode scale_mode,
mediapipe::ScaleMode_Mode default_mode) {
switch (scale_mode) {
case mediapipe::ScaleMode_Mode_DEFAULT:
return default_mode;
case mediapipe::ScaleMode_Mode_STRETCH:
return scale_mode;
case mediapipe::ScaleMode_Mode_FIT:
return scale_mode;
case mediapipe::ScaleMode_Mode_FILL_AND_CROP:
return scale_mode;
default:
return default_mode;
}
}
} // namespace
// Scales, rotates, and flips images horizontally or vertically.
//
// Input:
// One of the following two tags:
// IMAGE: ImageFrame representing the input image.
// IMAGE_GPU: GpuBuffer representing the input image.
//
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
// degrees. This allows different rotation angles for different frames. It has
// to be a multiple of 90 degrees. If provided, it overrides the
// ROTATION_DEGREES input side packet.
//
// Output:
// One of the following two tags:
// IMAGE - ImageFrame representing the output image.
// IMAGE_GPU - GpuBuffer representing the output image.
//
// LETTERBOX_PADDING (optional): An std::array<float, 4> representing the
// letterbox padding from the 4 sides ([left, top, right, bottom]) of the
// output image, normalized to [0.f, 1.f] by the output dimensions. The
// padding values are non-zero only when the scale mode specified in the
// calculator options is FIT. For instance, when the input image is 10x10
// (width x height) and the output dimensions specified in the calculator
// option are 20x40 and scale mode is FIT, the calculator scales the input
// image to 20x20 and places it in the middle of the output image with an
// equal padding of 10 pixels at the top and the bottom. The resulting array
// is therefore [0.f, 0.25f, 0.f, 0.25f] (10/40 = 0.25f).
//
// Input side packet:
// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as the
// first two elements in an integer array. It overrides the corresponding
// field in the calculator options.
//
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
// degrees. It has to be a multiple of 90 degrees. It overrides the
// corresponding field in the calculator options.
//
// Calculator options (see image_transformation_calculator.proto):
// output_width, output_height - (optional) Desired scaled image size.
// rotation_mode - (optional) Rotation in multiples of 90 degrees.
// flip_vertically, flip_horizontally - (optional) flip about x or y axis.
// scale_mode - (optional) Stretch, Fit, or Fill and Crop
//
// Note: To enable horizontal or vertical flipping, specify them in the
// calculator options. Flipping is applied after rotation.
//
// Note: Only scale mode STRETCH is currently supported on CPU,
// and flipping is not yet supported either.
//
class ImageTransformationCalculator : public CalculatorBase {
public:
ImageTransformationCalculator() = default;
~ImageTransformationCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status RenderCpu(CalculatorContext* cc);
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status GlSetup();
void ComputeOutputDimensions(int input_width, int input_height,
int* output_width, int* output_height);
void ComputeOutputLetterboxPadding(int input_width, int input_height,
int output_width, int output_height,
std::array<float, 4>* padding);
ImageTransformationCalculatorOptions options_;
int output_width_ = 0;
int output_height_ = 0;
mediapipe::RotationMode_Mode rotation_;
mediapipe::ScaleMode_Mode scale_mode_;
bool use_gpu_ = false;
#if defined(__ANDROID__)
GlCalculatorHelper helper_;
std::unique_ptr<QuadRenderer> rgb_renderer_;
std::unique_ptr<QuadRenderer> ext_rgb_renderer_;
#endif // __ANDROID__
};
REGISTER_CALCULATOR(ImageTransformationCalculator);
// static
::mediapipe::Status ImageTransformationCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
if (cc->Inputs().HasTag("IMAGE")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
}
if (cc->InputSidePackets().HasTag("OUTPUT_DIMENSIONS")) {
cc->InputSidePackets().Tag("OUTPUT_DIMENSIONS").Set<DimensionsPacketType>();
}
if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) {
cc->InputSidePackets().Tag("ROTATION_DEGREES").Set<int>();
}
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
cc->Outputs().Tag("LETTERBOX_PADDING").Set<std::array<float, 4>>();
}
#if defined(__ANDROID__)
RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) {
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(mediapipe::TimestampDiff(0));
options_ = cc->Options<ImageTransformationCalculatorOptions>();
if (cc->Inputs().HasTag("IMAGE_GPU")) {
use_gpu_ = true;
}
if (cc->InputSidePackets().HasTag("OUTPUT_DIMENSIONS")) {
const auto& dimensions = cc->InputSidePackets()
.Tag("OUTPUT_DIMENSIONS")
.Get<DimensionsPacketType>();
output_width_ = dimensions[0];
output_height_ = dimensions[1];
} else {
output_width_ = options_.output_width();
output_height_ = options_.output_height();
}
if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) {
rotation_ = DegreesToRotationMode(
cc->InputSidePackets().Tag("ROTATION_DEGREES").Get<int>());
} else {
rotation_ = DegreesToRotationMode(options_.rotation_mode());
}
scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE);
if (use_gpu_) {
#if defined(__ANDROID__)
// Let the helper access the GL context information.
RETURN_IF_ERROR(helper_.Open(cc));
#else
RET_CHECK_FAIL() << "GPU processing for non-Android not supported yet.";
#endif // __ANDROID__
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageTransformationCalculator::Process(
CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
return helper_.RunInGlContext(
[this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); });
#endif // __ANDROID__
} else {
return RenderCpu(cc);
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageTransformationCalculator::Close(
CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
QuadRenderer* rgb_renderer = rgb_renderer_.release();
QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release();
helper_.RunInGlContext([rgb_renderer, ext_rgb_renderer] {
if (rgb_renderer) {
rgb_renderer->GlTeardown();
delete rgb_renderer;
}
if (ext_rgb_renderer) {
ext_rgb_renderer->GlTeardown();
delete ext_rgb_renderer;
}
});
#endif // __ANDROID__
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageTransformationCalculator::RenderCpu(
CalculatorContext* cc) {
int input_width = cc->Inputs().Tag("IMAGE").Get<ImageFrame>().Width();
int input_height = cc->Inputs().Tag("IMAGE").Get<ImageFrame>().Height();
int output_width;
int output_height;
ComputeOutputDimensions(input_width, input_height, &output_width,
&output_height);
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
auto padding = absl::make_unique<std::array<float, 4>>();
ComputeOutputLetterboxPadding(input_width, input_height, output_width,
output_height, padding.get());
cc->Outputs()
.Tag("LETTERBOX_PADDING")
.Add(padding.release(), cc->InputTimestamp());
}
if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) {
rotation_ = DegreesToRotationMode(
cc->InputSidePackets().Tag("ROTATION_DEGREES").Get<int>());
}
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
std::unique_ptr<ImageFrame> output_frame(
new ImageFrame(input_img.Format(), output_width, output_height));
cv::Mat input_mat = formats::MatView(&input_img);
cv::Mat output_mat = formats::MatView(output_frame.get());
cv::Mat scaled_mat;
if (scale_mode_ != mediapipe::ScaleMode_Mode_STRETCH) {
// TODO finish CPU version features.
return ::mediapipe::UnimplementedError(
"Only STRETCH scale mode currently supported.");
}
cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_));
cv::Mat rotated_mat;
const int angle = RotationModeToDegrees(rotation_);
cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0);
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size());
rotated_mat.copyTo(output_mat);
cc->Outputs().Tag("IMAGE").Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageTransformationCalculator::RenderGpu(
CalculatorContext* cc) {
#if defined(__ANDROID__)
int input_width = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().width();
int input_height = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().height();
int output_width;
int output_height;
ComputeOutputDimensions(input_width, input_height, &output_width,
&output_height);
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
auto padding = absl::make_unique<std::array<float, 4>>();
ComputeOutputLetterboxPadding(input_width, input_height, output_width,
output_height, padding.get());
cc->Outputs()
.Tag("LETTERBOX_PADDING")
.Add(padding.release(), cc->InputTimestamp());
}
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>();
QuadRenderer* renderer = nullptr;
GlTexture src1;
{
src1 = helper_.CreateSourceTexture(input);
if (src1.target() == GL_TEXTURE_EXTERNAL_OES) {
if (!ext_rgb_renderer_) {
ext_rgb_renderer_ = absl::make_unique<QuadRenderer>();
RETURN_IF_ERROR(ext_rgb_renderer_->GlSetup(
::mediapipe::kBasicTexturedFragmentShaderOES, {"video_frame"}));
}
renderer = ext_rgb_renderer_.get();
} else {
if (!rgb_renderer_) {
rgb_renderer_ = absl::make_unique<QuadRenderer>();
RETURN_IF_ERROR(rgb_renderer_->GlSetup());
}
renderer = rgb_renderer_.get();
}
}
RET_CHECK(renderer) << "Unsupported input texture type";
if (cc->InputSidePackets().HasTag("ROTATION_DEGREES")) {
rotation_ = DegreesToRotationMode(
cc->InputSidePackets().Tag("ROTATION_DEGREES").Get<int>());
}
static mediapipe::FrameScaleMode scale_mode =
mediapipe::FrameScaleModeFromProto(scale_mode_,
mediapipe::FrameScaleMode::kStretch);
mediapipe::FrameRotation rotation =
mediapipe::FrameRotationFromDegrees(RotationModeToDegrees(rotation_));
auto dst = helper_.CreateDestinationTexture(output_width, output_height,
input.format());
helper_.BindFramebuffer(dst); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), src1.name());
RETURN_IF_ERROR(renderer->GlRender(
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode,
rotation, options_.flip_horizontally(), options_.flip_vertically(),
/*flip_texture=*/false));
glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), 0);
// Execute GL commands, before getting result.
glFlush();
auto output = dst.GetFrame<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
void ImageTransformationCalculator::ComputeOutputDimensions(
int input_width, int input_height, int* output_width, int* output_height) {
if (output_width_ > 0 && output_height_ > 0) {
*output_width = output_width_;
*output_height = output_height_;
} else if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 ||
rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) {
*output_width = input_height;
*output_height = input_width;
} else {
*output_width = input_width;
*output_height = input_height;
}
}
void ImageTransformationCalculator::ComputeOutputLetterboxPadding(
int input_width, int input_height, int output_width, int output_height,
std::array<float, 4>* padding) {
if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) {
if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 ||
rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) {
std::swap(input_width, input_height);
}
const float input_aspect_ratio =
static_cast<float>(input_width) / input_height;
const float output_aspect_ratio =
static_cast<float>(output_width) / output_height;
if (input_aspect_ratio < output_aspect_ratio) {
// Compute left and right padding.
(*padding)[0] = (1.f - input_aspect_ratio / output_aspect_ratio) / 2.f;
(*padding)[2] = (*padding)[0];
} else if (output_aspect_ratio < input_aspect_ratio) {
// Compute top and bottom padding.
(*padding)[1] = (1.f - output_aspect_ratio / input_aspect_ratio) / 2.f;
(*padding)[3] = (*padding)[1];
}
}
}
} // namespace mediapipe

View File

@ -0,0 +1,49 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}
message ImageTransformationCalculatorOptions {
extend CalculatorOptions {
optional ImageTransformationCalculatorOptions ext = 251952830;
}
// Output dimensions. Set to 0 if they should be the same as the input.
optional int32 output_width = 1 [default = 0];
optional int32 output_height = 2 [default = 0];
// Counterclockwise rotation mode.
optional RotationMode.Mode rotation_mode = 3;
// Vertical flipping, applied after rotation.
optional bool flip_vertically = 4 [default = false];
// Horizontal flipping, applied after rotation.
optional bool flip_horizontally = 5 [default = false];
// Scale mode.
optional ScaleMode.Mode scale_mode = 6;
}

View File

@ -0,0 +1,151 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_simple_calculator.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
namespace mediapipe {
// Converts RGB images into luminance images, still stored in RGB format.
// See GlSimpleCalculatorBase for inputs, outputs and input side packets.
class LuminanceCalculator : public GlSimpleCalculator {
public:
::mediapipe::Status GlSetup() override;
::mediapipe::Status GlRender(const GlTexture& src,
const GlTexture& dst) override;
::mediapipe::Status GlTeardown() override;
private:
GLuint program_ = 0;
GLint frame_;
};
REGISTER_CALCULATOR(LuminanceCalculator);
::mediapipe::Status LuminanceCalculator::GlSetup() {
// Load vertex and fragment shaders
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
const GLchar* frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 sample_coordinate;
uniform sampler2D video_frame;
const highp vec3 W = vec3(0.2125, 0.7154, 0.0721);
void main() {
vec4 color = texture2D(video_frame, sample_coordinate);
float luminance = dot(color.rgb, W);
fragColor.rgb = vec3(luminance);
fragColor.a = color.a;
}
)";
// shader program
GlhCreateProgram(kBasicVertexShader, frag_src, NUM_ATTRIBUTES,
(const GLchar**)&attr_name[0], attr_location, &program_);
RET_CHECK(program_) << "Problem initializing the program.";
frame_ = glGetUniformLocation(program_, "video_frame");
return ::mediapipe::OkStatus();
}
::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src,
const GlTexture& dst) {
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
static const GLfloat texture_vertices[] = {
0.0f, 0.0f, // bottom left
1.0f, 0.0f, // bottom right
0.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
// program
glUseProgram(program_);
glUniform1i(frame_, 1);
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
return ::mediapipe::OkStatus();
}
::mediapipe::Status LuminanceCalculator::GlTeardown() {
if (program_) {
glDeleteProgram(program_);
program_ = 0;
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,282 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/mask_overlay_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
namespace mediapipe {
using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_ALPHA;
using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_RED;
using ::mediapipe::MaskOverlayCalculatorOptions_MaskChannel_UNKNOWN;
// Mixes two frames using a third mask frame or constant value.
//
// Inputs:
// VIDEO:[0,1] (GpuBuffer):
// Two inputs should be provided.
// MASK (GpuBuffer):
// Optional.
// Where the mask is 0, VIDEO:0 will be used. Where it is 1, VIDEO:1.
// Intermediate values will blend.
// If not specified, CONST_MASK float must be present.
// CONST_MASK (float):
// Optional.
// If not specified, MASK GpuBuffer must be present.
// Similar to MASK GpuBuffer, but applied globally to every pixel.
//
// Outputs:
// OUTPUT (GpuBuffer):
// The mix.
class MaskOverlayCalculator : public CalculatorBase {
public:
MaskOverlayCalculator() {}
~MaskOverlayCalculator();
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status GlSetup(
const MaskOverlayCalculatorOptions::MaskChannel mask_channel);
::mediapipe::Status GlRender(const float mask_const);
private:
GlCalculatorHelper helper_;
bool initialized_ = false;
bool use_mask_tex_ = false; // Otherwise, use constant float value.
GLuint program_ = 0;
GLint unif_frame1_;
GLint unif_frame2_;
GLint unif_mask_;
};
REGISTER_CALCULATOR(MaskOverlayCalculator);
// static
::mediapipe::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) {
RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
cc->Inputs().Get("VIDEO", 0).Set<GpuBuffer>();
cc->Inputs().Get("VIDEO", 1).Set<GpuBuffer>();
if (cc->Inputs().HasTag("MASK"))
cc->Inputs().Tag("MASK").Set<GpuBuffer>();
else if (cc->Inputs().HasTag("CONST_MASK"))
cc->Inputs().Tag("CONST_MASK").Set<float>();
else
return ::mediapipe::Status(
::mediapipe::StatusCode::kNotFound,
"At least one mask input stream must be present.");
cc->Outputs().Tag("OUTPUT").Set<GpuBuffer>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status MaskOverlayCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("MASK")) {
use_mask_tex_ = true;
}
return helper_.Open(cc);
}
::mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) {
return helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
if (!initialized_) {
const auto& options = cc->Options<MaskOverlayCalculatorOptions>();
const auto mask_channel = options.mask_channel();
RETURN_IF_ERROR(GlSetup(mask_channel));
initialized_ = true;
}
glDisable(GL_BLEND);
const Packet& input1_packet = cc->Inputs().Get("VIDEO", 1).Value();
const Packet& mask_packet = use_mask_tex_
? cc->Inputs().Tag("MASK").Value()
: cc->Inputs().Tag("CONST_MASK").Value();
if (mask_packet.IsEmpty()) {
cc->Outputs().Tag("OUTPUT").AddPacket(input1_packet);
return ::mediapipe::OkStatus();
}
const auto& input0_buffer = cc->Inputs().Get("VIDEO", 0).Get<GpuBuffer>();
const auto& input1_buffer = input1_packet.Get<GpuBuffer>();
auto src1 = helper_.CreateSourceTexture(input0_buffer);
auto src2 = helper_.CreateSourceTexture(input1_buffer);
GlTexture mask_tex;
if (use_mask_tex_) {
const auto& mask_buffer = mask_packet.Get<GpuBuffer>();
mask_tex = helper_.CreateSourceTexture(mask_buffer);
}
auto dst = helper_.CreateDestinationTexture(src1.width(), src1.height());
helper_.BindFramebuffer(dst);
glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), src1.name());
glActiveTexture(GL_TEXTURE2);
glBindTexture(src2.target(), src2.name());
if (use_mask_tex_) {
const float mask_const = -1;
glActiveTexture(GL_TEXTURE3);
glBindTexture(mask_tex.target(), mask_tex.name());
RETURN_IF_ERROR(GlRender(mask_const));
glActiveTexture(GL_TEXTURE3);
glBindTexture(mask_tex.target(), 0);
} else {
const float mask_const = mask_packet.Get<float>();
RETURN_IF_ERROR(GlRender(mask_const));
}
glActiveTexture(GL_TEXTURE2);
glBindTexture(src2.target(), 0);
glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), 0);
glFlush();
auto output = dst.GetFrame<GpuBuffer>();
src1.Release();
src2.Release();
if (use_mask_tex_) mask_tex.Release();
dst.Release();
cc->Outputs().Tag("OUTPUT").Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
});
}
::mediapipe::Status MaskOverlayCalculator::GlSetup(
const MaskOverlayCalculatorOptions::MaskChannel mask_channel) {
// Load vertex and fragment shaders
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
std::string mask_component;
switch (mask_channel) {
case MaskOverlayCalculatorOptions_MaskChannel_UNKNOWN:
case MaskOverlayCalculatorOptions_MaskChannel_RED:
mask_component = "r";
break;
case MaskOverlayCalculatorOptions_MaskChannel_ALPHA:
mask_component = "a";
break;
}
const std::string frag_src_tex =
std::string(kMediaPipeFragmentShaderPreamble) +
R"(
DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate;
uniform sampler2D frame1;
uniform sampler2D frame2;
uniform sampler2D mask;
void main() {
vec4 color1 = texture2D(frame1, sample_coordinate);
vec4 color2 = texture2D(frame2, sample_coordinate);
vec4 weight = texture2D(mask, sample_coordinate);
#define MASK_COMPONENT )" +
mask_component +
R"(
gl_FragColor = mix(color1, color2, weight.MASK_COMPONENT);
}
)";
const GLchar* frag_src_const = R"(
precision highp float;
varying vec2 sample_coordinate;
uniform sampler2D frame1;
uniform sampler2D frame2;
uniform float mask;
void main() {
vec4 color1 = texture2D(frame1, sample_coordinate);
vec4 color2 = texture2D(frame2, sample_coordinate);
float weight = mask;
gl_FragColor = mix(color1, color2, weight);
}
)";
// shader program
GlhCreateProgram(kBasicVertexShader,
use_mask_tex_ ? frag_src_tex.c_str() : frag_src_const,
NUM_ATTRIBUTES, &attr_name[0], attr_location, &program_);
RET_CHECK(program_) << "Problem initializing the program.";
unif_frame1_ = glGetUniformLocation(program_, "frame1");
unif_frame2_ = glGetUniformLocation(program_, "frame2");
unif_mask_ = glGetUniformLocation(program_, "mask");
return ::mediapipe::OkStatus();
}
::mediapipe::Status MaskOverlayCalculator::GlRender(const float mask_const) {
glUseProgram(program_);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, kBasicSquareVertices);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0,
kBasicTextureVertices);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glUniform1i(unif_frame1_, 1);
glUniform1i(unif_frame2_, 2);
if (use_mask_tex_)
glUniform1i(unif_mask_, 3);
else
glUniform1f(unif_mask_, mask_const);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
return ::mediapipe::OkStatus();
}
MaskOverlayCalculator::~MaskOverlayCalculator() {
helper_.RunInGlContext([this] {
if (program_) {
glDeleteProgram(program_);
program_ = 0;
}
});
}
} // namespace mediapipe

View File

@ -0,0 +1,34 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message MaskOverlayCalculatorOptions {
extend CalculatorOptions {
optional MaskOverlayCalculatorOptions ext = 252129282;
}
enum MaskChannel {
UNKNOWN = 0;
RED = 1;
ALPHA = 2;
}
// Selects which channel of the MASK input to use for masking.
optional MaskChannel mask_channel = 1 [default = RED];
}

View File

@ -0,0 +1,81 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
namespace mediapipe {
// Takes in an encoded image std::string, decodes it by OpenCV, and converts to
// an ImageFrame. Note that this calculator only supports grayscale and RGB
// images for now.
//
// Example config:
// node {
// calculator: "OpenCvEncodedImageToImageFrameCalculator"
// input_stream: "encoded_image"
// output_stream: "image_frame"
// }
class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
};
::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<std::string>();
cc->Outputs().Index(0).Set<ImageFrame>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process(
CalculatorContext* cc) {
const std::string& contents = cc->Inputs().Index(0).Get<std::string>();
const std::vector<char> contents_vector(contents.begin(), contents.end());
cv::Mat decoded_mat =
cv::imdecode(contents_vector, -1 /* return the loaded image as-is */);
ImageFormat::Format image_format = ImageFormat::UNKNOWN;
cv::Mat output_mat;
switch (decoded_mat.channels()) {
case 1:
image_format = ImageFormat::GRAY8;
output_mat = decoded_mat;
break;
case 3:
image_format = ImageFormat::SRGB;
cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGB);
break;
case 4:
return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC)
<< "4-channel image isn't supported yet";
default:
return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC)
<< "Unsupported number of channels: " << decoded_mat.channels();
}
std::unique_ptr<ImageFrame> output_frame = absl::make_unique<ImageFrame>(
image_format, decoded_mat.size().width, decoded_mat.size().height);
output_mat.copyTo(formats::MatView(output_frame.get()));
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
REGISTER_CALCULATOR(OpenCvEncodedImageToImageFrameCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,106 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
TEST(OpenCvEncodedImageToImageFrameCalculatorTest, TestRgbJpeg) {
std::string contents;
MEDIAPIPE_ASSERT_OK(file::GetContents(
file::JoinPath("./", "/mediapipe/calculators/image/testdata/dino.jpg"),
&contents));
Packet input_packet = MakePacket<std::string>(contents);
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "OpenCvEncodedImageToImageFrameCalculator"
input_stream: "encoded_image"
output_stream: "image_frame"
)");
CalculatorRunner runner(node_config);
runner.MutableInputs()->Index(0).packets.push_back(
input_packet.At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
const std::vector<Packet>& packets = outputs.Index(0).packets;
ASSERT_EQ(1, packets.size());
const ImageFrame& output_frame = packets[0].Get<ImageFrame>();
cv::Mat input_mat = cv::imread(
file::JoinPath("./", "/mediapipe/calculators/image/testdata/dino.jpg"));
cv::Mat output_mat;
cv::cvtColor(formats::MatView(&output_frame), output_mat, cv::COLOR_RGB2BGR);
cv::Mat diff;
cv::absdiff(input_mat, output_mat, diff);
double max_val;
cv::minMaxLoc(diff, nullptr, &max_val);
// Expects that the maximum absolute pixel-by-pixel difference is less
// than 10.
EXPECT_LE(max_val, 10);
}
TEST(OpenCvEncodedImageToImageFrameCalculatorTest, TestGrayscaleJpeg) {
cv::Mat input_mat;
cv::cvtColor(cv::imread(file::JoinPath("./",
"/mediapipe/calculators/"
"image/testdata/dino.jpg")),
input_mat, cv::COLOR_RGB2GRAY);
std::vector<uchar> encode_buffer;
std::vector<int> parameters;
parameters.push_back(cv::IMWRITE_JPEG_QUALITY);
parameters.push_back(100);
cv::imencode(".jpg", input_mat, encode_buffer, parameters);
Packet input_packet = MakePacket<std::string>(std::string(absl::string_view(
reinterpret_cast<const char*>(&encode_buffer[0]), encode_buffer.size())));
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "OpenCvEncodedImageToImageFrameCalculator"
input_stream: "encoded_image"
output_stream: "image_frame"
)");
CalculatorRunner runner(node_config);
runner.MutableInputs()->Index(0).packets.push_back(
input_packet.At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
const std::vector<Packet>& packets = outputs.Index(0).packets;
ASSERT_EQ(1, packets.size());
const ImageFrame& output_frame = packets[0].Get<ImageFrame>();
cv::Mat diff;
cv::absdiff(input_mat, formats::MatView(&output_frame), diff);
double max_val;
cv::minMaxLoc(diff, nullptr, &max_val);
// Expects that the maximum absolute pixel-by-pixel difference is less
// than 10.
EXPECT_LE(max_val, 10);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,121 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
namespace mediapipe {
// Calculator to encode raw image frames. This will result in considerable space
// savings if the frames need to be stored on disk.
//
// Example config:
// node {
// calculator: "OpenCvImageEncoderCalculator"
// input_stream: "image"
// output_stream: "encoded_image"
// node_options {
// [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: {
// quality: 80
// }
// }
// }
class OpenCvImageEncoderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
int encoding_quality_;
};
::mediapipe::Status OpenCvImageEncoderCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<ImageFrame>();
cc->Outputs().Index(0).Set<OpenCvImageEncoderCalculatorResults>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) {
auto options = cc->Options<OpenCvImageEncoderCalculatorOptions>();
encoding_quality_ = options.quality();
return ::mediapipe::OkStatus();
}
::mediapipe::Status OpenCvImageEncoderCalculator::Process(
CalculatorContext* cc) {
const ImageFrame& image_frame = cc->Inputs().Index(0).Get<ImageFrame>();
CHECK_EQ(1, image_frame.ByteDepth());
std::unique_ptr<OpenCvImageEncoderCalculatorResults> encoded_result =
absl::make_unique<OpenCvImageEncoderCalculatorResults>();
encoded_result->set_width(image_frame.Width());
encoded_result->set_height(image_frame.Height());
cv::Mat original_mat = formats::MatView(&image_frame);
cv::Mat input_mat;
switch (original_mat.channels()) {
case 1:
input_mat = original_mat;
encoded_result->set_colorspace(
OpenCvImageEncoderCalculatorResults::GRAYSCALE);
break;
case 3:
// OpenCV assumes the image to be BGR order. To use imencode(), do color
// conversion first.
cv::cvtColor(original_mat, input_mat, cv::COLOR_RGB2BGR);
encoded_result->set_colorspace(OpenCvImageEncoderCalculatorResults::RGB);
break;
case 4:
return ::mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC)
<< "4-channel image isn't supported yet";
default:
return ::mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC)
<< "Unsupported number of channels: " << original_mat.channels();
}
std::vector<int> parameters;
parameters.push_back(cv::IMWRITE_JPEG_QUALITY);
parameters.push_back(encoding_quality_);
std::vector<uchar> encode_buffer;
// Note that imencode() will store the data in RGB order.
// Check its JpegEncoder::write() in "imgcodecs/src/grfmt_jpeg.cpp" for more
// info.
if (!cv::imencode(".jpg", input_mat, encode_buffer, parameters)) {
return ::mediapipe::InternalErrorBuilder(MEDIAPIPE_LOC)
<< "Fail to encode the image to be jpeg format.";
}
encoded_result->set_encoded_image(std::string(absl::string_view(
reinterpret_cast<const char*>(&encode_buffer[0]), encode_buffer.size())));
cc->Outputs().Index(0).Add(encoded_result.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
REGISTER_CALCULATOR(OpenCvImageEncoderCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,47 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message OpenCvImageEncoderCalculatorOptions {
extend CalculatorOptions {
optional OpenCvImageEncoderCalculatorOptions ext = 227563646;
}
// Quality of the encoding. An integer between (0, 100].
optional int32 quality = 1;
}
// TODO: Consider renaming it to EncodedImage.
message OpenCvImageEncoderCalculatorResults {
// Encoded image
optional string encoded_image = 1;
// Dimensions of the encoded image
optional int32 height = 2;
optional int32 width = 3;
enum ColorSpace {
UNKNOWN = 0;
GRAYSCALE = 1;
RGB = 2;
}
// Color space used.
optional ColorSpace colorspace = 4;
}

View File

@ -0,0 +1,87 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
TEST(OpenCvImageEncoderCalculatorTest, TestJpegWithQualities) {
cv::Mat input_mat;
cv::cvtColor(cv::imread(file::JoinPath("./",
"/mediapipe/calculators/"
"image/testdata/dino.jpg")),
input_mat, cv::COLOR_BGR2RGB);
Packet input_packet = MakePacket<ImageFrame>(
ImageFormat::SRGB, input_mat.size().width, input_mat.size().height);
input_mat.copyTo(formats::MatView(&(input_packet.Get<ImageFrame>())));
std::vector<int> qualities = {50, 80};
for (int quality : qualities) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
absl::Substitute(R"(
calculator: "OpenCvImageEncoderCalculator"
input_stream: "image_frames"
output_stream: "encoded_images"
node_options {
[type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: {
quality: $0
}
})",
quality));
CalculatorRunner runner(node_config);
runner.MutableInputs()->Index(0).packets.push_back(
input_packet.At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
const std::vector<Packet>& packets = outputs.Index(0).packets;
ASSERT_EQ(1, packets.size());
const auto& result = packets[0].Get<OpenCvImageEncoderCalculatorResults>();
ASSERT_EQ(input_mat.size().height, result.height());
ASSERT_EQ(input_mat.size().width, result.width());
ASSERT_EQ(OpenCvImageEncoderCalculatorResults::RGB, result.colorspace());
cv::Mat expected_output = cv::imread(
file::JoinPath("./", absl::Substitute("/mediapipe/calculators/image/"
"testdata/dino_quality_$0.jpg",
quality)));
const std::vector<char> contents_vector(result.encoded_image().begin(),
result.encoded_image().end());
cv::Mat decoded_output =
cv::imdecode(contents_vector, -1 /* return the loaded image as-is */);
cv::Mat diff;
cv::absdiff(expected_output, decoded_output, diff);
double max_val;
cv::minMaxLoc(diff, nullptr, &max_val);
// Expects that the maximum absolute pixel-by-pixel difference is less
// than 10.
EXPECT_LE(max_val, 10);
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,60 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
namespace mediapipe {
// Takes in a std::string, draws the text std::string by cv::putText(), and
// outputs an ImageFrame.
//
// Example config:
// node {
// calculator: "OpenCvPutTextCalculator"
// input_stream: "text_to_put"
// output_stream: "out_image_frames"
// }
// TODO: Generalize the calculator for other text use cases.
class OpenCvPutTextCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
};
::mediapipe::Status OpenCvPutTextCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<std::string>();
cc->Outputs().Index(0).Set<ImageFrame>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) {
const std::string& text_content = cc->Inputs().Index(0).Get<std::string>();
cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC3);
cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3,
cv::Scalar(255, 255, 0), 4);
std::unique_ptr<ImageFrame> output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGB, mat.size().width, mat.size().height);
mat.copyTo(formats::MatView(output_frame.get()));
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
REGISTER_CALCULATOR(OpenCvPutTextCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,377 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/image/recolor_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/color.pb.h"
#if defined(__ANDROID__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
namespace {
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
} // namespace
namespace mediapipe {
// A calculator to recolor a masked area of an image to a specified color.
//
// A mask image is used to specify where to overlay a user defined color.
// The luminance of the input image is used to adjust the blending weight,
// to help preserve image textures.
//
// TODO implement cpu support.
//
// Inputs:
// One of the following IMAGE tags:
// IMAGE: An ImageFrame input image, RGB or RGBA.
// IMAGE_GPU: A GpuBuffer input image, RGBA.
// One of the following MASK tags:
// MASK: An ImageFrame input mask, Gray, RGB or RGBA.
// MASK_GPU: A GpuBuffer input mask, RGBA.
// Output:
// One of the following IMAGE tags:
// IMAGE: An ImageFrame output image.
// IMAGE_GPU: A GpuBuffer output image.
//
// Options:
// color_rgb (required): A map of RGB values [0-255].
// mask_channel (optional): Which channel of mask image is used [RED or ALPHA]
//
// Usage example:
// node {
// calculator: "RecolorCalculator"
// input_stream: "IMAGE_GPU:input_image"
// input_stream: "MASK_GPU:input_mask"
// output_stream: "IMAGE_GPU:output_image"
// node_options: {
// [mediapipe.RecolorCalculatorOptions] {
// color { r: 0 g: 0 b: 255 }
// mask_channel: RED
// }
// }
// }
//
class RecolorCalculator : public CalculatorBase {
public:
RecolorCalculator() = default;
~RecolorCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status LoadOptions(CalculatorContext* cc);
::mediapipe::Status InitGpu(CalculatorContext* cc);
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status RenderCpu(CalculatorContext* cc);
void GlRender();
bool initialized_ = false;
std::vector<float> color_;
mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_;
bool use_gpu_ = false;
#if defined(__ANDROID__)
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
#endif // __ANDROID__
};
REGISTER_CALCULATOR(RecolorCalculator);
// static
::mediapipe::Status RecolorCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
#if defined(__ANDROID__)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Inputs().HasTag("IMAGE")) {
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
if (cc->Inputs().HasTag("MASK_GPU")) {
cc->Inputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Inputs().HasTag("MASK")) {
cc->Inputs().Tag("MASK").Set<ImageFrame>();
}
#if defined(__ANDROID__)
if (cc->Outputs().HasTag("IMAGE_GPU")) {
cc->Outputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Outputs().HasTag("IMAGE")) {
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) {
if (cc->Inputs().HasTag("IMAGE_GPU")) {
use_gpu_ = true;
#if defined(__ANDROID__)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // __ANDROID__
}
RETURN_IF_ERROR(LoadOptions(cc));
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
if (!initialized_) {
RETURN_IF_ERROR(InitGpu(cc));
initialized_ = true;
}
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
});
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
return ::mediapipe::UnimplementedError("CPU support is not implemented yet.");
}
::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__)
// Get inputs and setup output.
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value();
const auto& input_buffer = input_packet.Get<mediapipe::GpuBuffer>();
const auto& mask_buffer = mask_packet.Get<mediapipe::GpuBuffer>();
auto img_tex = gpu_helper_.CreateSourceTexture(input_buffer);
auto mask_tex = gpu_helper_.CreateSourceTexture(mask_buffer);
auto dst_tex =
gpu_helper_.CreateDestinationTexture(img_tex.width(), img_tex.height());
// Run recolor shader on GPU.
{
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(img_tex.target(), img_tex.name());
glActiveTexture(GL_TEXTURE2);
glBindTexture(mask_tex.target(), mask_tex.name());
GlRender();
glBindTexture(GL_TEXTURE_2D, 0);
glFlush();
}
// Send result image in GPU packet.
auto output = dst_tex.GetFrame<mediapipe::GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
// Cleanup
img_tex.Release();
mask_tex.Release();
dst_tex.Release();
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
void RecolorCalculator::GlRender() {
#if defined(__ANDROID__)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
static const GLfloat texture_vertices[] = {
0.0f, 0.0f, // bottom left
1.0f, 0.0f, // bottom right
0.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
// program
glUseProgram(program_);
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
#endif // __ANDROID__
}
::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) {
const auto& options = cc->Options<mediapipe::RecolorCalculatorOptions>();
mask_channel_ = options.mask_channel();
if (!options.has_color()) RET_CHECK_FAIL() << "Missing color option.";
color_.push_back(options.color().r() / 255.0);
color_.push_back(options.color().g() / 255.0);
color_.push_back(options.color().b() / 255.0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
std::string mask_component;
switch (mask_channel_) {
case mediapipe::RecolorCalculatorOptions_MaskChannel_UNKNOWN:
case mediapipe::RecolorCalculatorOptions_MaskChannel_RED:
mask_component = "r";
break;
case mediapipe::RecolorCalculatorOptions_MaskChannel_ALPHA:
mask_component = "a";
break;
}
// A shader to blend a color onto an image where the mask > 0.
// The blending is based on the input image luminosity.
const std::string frag_src = R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
#define MASK_COMPONENT )" + mask_component +
R"(
in vec2 sample_coordinate;
uniform sampler2D frame;
uniform sampler2D mask;
uniform vec3 recolor;
void main() {
vec4 weight = texture2D(mask, sample_coordinate);
vec4 color1 = texture2D(frame, sample_coordinate);
vec4 color2 = vec4(recolor, 1.0);
float luminance = dot(color1.rgb, vec3(0.299, 0.587, 0.114));
float mix_value = weight.MASK_COMPONENT * luminance;
fragColor = mix(color1, color2, mix_value);
}
)";
// shader program and params
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(),
NUM_ATTRIBUTES, &attr_name[0], attr_location,
&program_);
RET_CHECK(program_) << "Problem initializing the program.";
glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "frame"), 1);
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
color_[2]);
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,39 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message RecolorCalculatorOptions {
extend CalculatorOptions {
optional RecolorCalculatorOptions ext = 252527117;
}
enum MaskChannel {
UNKNOWN = 0;
RED = 1;
ALPHA = 2;
}
// Selects which channel of the MASK input to use for masking.
optional MaskChannel mask_channel = 1 [default = RED];
// Color to blend into input image where mask is > 0.
// The blending is based on the input image luminosity.
optional Color color = 2;
}

View File

@ -0,0 +1,693 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This Calculator takes an ImageFrame and scales it appropriately.
#include <algorithm>
#include <memory>
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "libyuv/scale.h"
#include "mediapipe/calculators/image/scale_image_calculator.pb.h"
#include "mediapipe/calculators/image/scale_image_utils.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/formats/yuv_image.h"
#include "mediapipe/framework/port/image_resizer.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/image_frame_util.h"
namespace mediapipe {
namespace {
// Given an upscaling algorithm, determine which OpenCV interpolation algorithm
// to use.
::mediapipe::Status FindInterpolationAlgorithm(
ScaleImageCalculatorOptions::ScaleAlgorithm upscaling_algorithm,
int* interpolation_algorithm) {
switch (upscaling_algorithm) {
case ScaleImageCalculatorOptions::DEFAULT:
*interpolation_algorithm = cv::INTER_CUBIC;
break;
case ScaleImageCalculatorOptions::LINEAR:
*interpolation_algorithm = cv::INTER_LINEAR;
break;
case ScaleImageCalculatorOptions::CUBIC:
*interpolation_algorithm = cv::INTER_CUBIC;
break;
case ScaleImageCalculatorOptions::AREA:
*interpolation_algorithm = cv::INTER_AREA;
break;
case ScaleImageCalculatorOptions::LANCZOS:
*interpolation_algorithm = cv::INTER_LANCZOS4;
break;
case ScaleImageCalculatorOptions::DEFAULT_WITHOUT_UPSCALE:
*interpolation_algorithm = -1;
break;
default:
RET_CHECK_FAIL() << absl::Substitute("Unknown upscaling algorithm: $0",
upscaling_algorithm);
}
return ::mediapipe::OkStatus();
}
void CropImageFrame(const ImageFrame& original, int col_start, int row_start,
int crop_width, int crop_height, ImageFrame* cropped) {
const uint8* src = original.PixelData();
uint8* dst = cropped->MutablePixelData();
int des_y = 0;
for (int y = row_start; y < row_start + crop_height; ++y) {
const uint8* src_line = src + y * original.WidthStep();
const uint8* src_pixel = src_line + col_start *
original.NumberOfChannels() *
original.ByteDepth();
uint8* dst_line = dst + des_y * cropped->WidthStep();
std::memcpy(
dst_line, src_pixel,
crop_width * cropped->NumberOfChannels() * cropped->ByteDepth());
++des_y;
}
}
} // namespace
// Crops and scales an ImageFrame or YUVImage according to the options;
// The output can be cropped and scaled ImageFrame with the SRGB format. If the
// input is a YUVImage, the output can be a scaled YUVImage (the scaling is done
// using libyuv). Cropping is not yet supported for a YUVImage to a scaled
// YUVImage conversion.
//
// Example config:
// node {
// calculator: "ScaleImageCalculator"
// input_stream: "raw_frames"
// output_stream: "scaled_frames"
// node_options {
// [type.googleapis.com/mediapipe.ScaleImageCalculatorOptions] {
// target_width: 320
// target_height: 320
// preserve_aspect_ratio: true
// output_format: SRGB
// algorithm: DEFAULT
// }
// }
// }
//
// ScaleImageCalculator can also create or update a VideoHeader that is
// provided at Timestamp::PreStream on stream VIDEO_HEADER.
//
// Example config:
// node {
// calculator: "ScaleImageCalculator"
// input_stream: "FRAMES:ycbcr_frames"
// input_stream: "VIDEO_HEADER:ycbcr_frames_header" # Optional.
// output_stream: "FRAMES:srgb_frames"
// output_stream: "VIDEO_HEADER:srgb_frames_header" # Independently Optional.
// node_options {
// [type.googleapis.com/mediapipe.ScaleImageCalculatorOptions] {
// target_width: 320
// target_height: 320
// preserve_aspect_ratio: true
// output_format: SRGB
// algorithm: DEFAULT
// }
// }
// }
//
// The calculator options can be overrided with an input stream
// "OVERRIDE_OPTIONS". If this is provided, and non-empty at PreStream, the
// calculator options proto is merged with the proto provided in this packet
// (fields are overwritten in the original options) and the
// initialization happens in Process at PreStream, and not at Open.
class ScaleImageCalculator : public CalculatorBase {
public:
ScaleImageCalculator();
~ScaleImageCalculator() override;
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
ScaleImageCalculatorOptions options =
cc->Options<ScaleImageCalculatorOptions>();
CollectionItemId input_data_id = cc->Inputs().GetId("FRAMES", 0);
if (!input_data_id.IsValid()) {
input_data_id = cc->Inputs().GetId("", 0);
}
CollectionItemId output_data_id = cc->Outputs().GetId("FRAMES", 0);
if (!output_data_id.IsValid()) {
output_data_id = cc->Outputs().GetId("", 0);
}
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
if (options.has_input_format() &&
options.input_format() == ImageFormat::YCBCR420P) {
cc->Inputs().Get(input_data_id).Set<YUVImage>();
} else {
cc->Inputs().Get(input_data_id).Set<ImageFrame>();
}
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
if (options.has_output_format() &&
options.output_format() == ImageFormat::YCBCR420P) {
RET_CHECK_EQ(ImageFormat::YCBCR420P, options.input_format());
cc->Outputs().Get(output_data_id).Set<YUVImage>();
} else {
cc->Outputs().Get(output_data_id).Set<ImageFrame>();
}
if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) {
cc->Inputs().Tag("OVERRIDE_OPTIONS").Set<ScaleImageCalculatorOptions>();
}
return ::mediapipe::OkStatus();
}
// From Calculator.
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// Initialize some data members from options_. This can be called either from
// Open or Process depending on whether OVERRIDE_OPTIONS is used.
::mediapipe::Status InitializeFromOptions();
// Initialize crop and output parameters based on set member variable
// values. This function will also send the header information on
// the VIDEO_HEADER stream if it hasn't been done yet.
::mediapipe::Status InitializeFrameInfo(CalculatorContext* cc);
// Validate that input_format_ and output_format_ are supported image
// formats.
::mediapipe::Status ValidateImageFormats() const;
// Validate that the image frame has the proper format and dimensions.
// If the dimensions and format weren't initialized by the header,
// then the first frame on which this function is called is used
// to initialize.
::mediapipe::Status ValidateImageFrame(CalculatorContext* cc,
const ImageFrame& image_frame);
// Validate that the YUV image has the proper dimensions. If the
// dimensions weren't initialized by the header, then the first image
// on which this function is called is used to initialize.
::mediapipe::Status ValidateYUVImage(CalculatorContext* cc,
const YUVImage& yuv_image);
bool has_header_; // True if the input stream has a header.
int input_width_;
int input_height_;
int crop_width_;
int crop_height_;
int col_start_;
int row_start_;
int output_width_;
int output_height_;
ImageFormat::Format input_format_;
ImageFormat::Format output_format_;
int interpolation_algorithm_;
// The "DATA" input stream.
CollectionItemId input_data_id_;
// The "DATA" output stream.
CollectionItemId output_data_id_;
VideoHeader input_video_header_;
// Whether the header information was sent on the VIDEO_HEADER stream.
bool header_sent_ = false;
// The alignment boundary that newly created images should have.
int alignment_boundary_;
ScaleImageCalculatorOptions options_;
// Efficient image resizer with gamma correction and optional sharpening.
std::unique_ptr<ImageResizer> downscaler_;
};
REGISTER_CALCULATOR(ScaleImageCalculator);
ScaleImageCalculator::ScaleImageCalculator() {}
ScaleImageCalculator::~ScaleImageCalculator() {}
::mediapipe::Status ScaleImageCalculator::InitializeFrameInfo(
CalculatorContext* cc) {
RETURN_IF_ERROR(
scale_image::FindCropDimensions(input_width_, input_height_, //
options_.min_aspect_ratio(), //
options_.max_aspect_ratio(), //
&crop_width_, &crop_height_, //
&col_start_, &row_start_));
RETURN_IF_ERROR(
scale_image::FindOutputDimensions(crop_width_, crop_height_, //
options_.target_width(), //
options_.target_height(), //
options_.preserve_aspect_ratio(), //
options_.scale_to_multiple_of_two(), //
&output_width_, &output_height_));
RETURN_IF_ERROR(FindInterpolationAlgorithm(options_.algorithm(),
&interpolation_algorithm_));
if (interpolation_algorithm_ == -1 &&
(output_width_ > crop_width_ || output_height_ > crop_height_)) {
output_width_ = crop_width_;
output_height_ = crop_height_;
}
VLOG(1) << "Image scaling parameters:"
<< "\ninput_width_ " << input_width_ //
<< "\ninput_height_ " << input_height_ //
<< "\ninput_format_ " << input_format_ //
<< "\ncrop_width_ " << crop_width_ //
<< "\ncrop_height_ " << crop_height_ //
<< "\ncol_start_ " << col_start_ //
<< "\nrow_start_ " << row_start_ //
<< "\noutput_width_ " << output_width_ //
<< "\noutput_height_ " << output_height_ //
<< "\noutput_format_ " << output_format_ //
<< "\nOpenCV interpolation algorithm " << interpolation_algorithm_;
if (!header_sent_ && cc->Outputs().UsesTags() &&
cc->Outputs().HasTag("VIDEO_HEADER")) {
header_sent_ = true;
auto header = absl::make_unique<VideoHeader>();
*header = input_video_header_;
header->width = output_width_;
header->height = output_height_;
header->format = output_format_;
LOG(INFO) << "OUTPUTTING HEADER on stream";
cc->Outputs()
.Tag("VIDEO_HEADER")
.Add(header.release(), Timestamp::PreStream());
cc->Outputs().Tag("VIDEO_HEADER").Close();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<ScaleImageCalculatorOptions>();
input_data_id_ = cc->Inputs().GetId("FRAMES", 0);
if (!input_data_id_.IsValid()) {
input_data_id_ = cc->Inputs().GetId("", 0);
}
output_data_id_ = cc->Outputs().GetId("FRAMES", 0);
if (!output_data_id_.IsValid()) {
output_data_id_ = cc->Outputs().GetId("", 0);
}
// The output packets are at the same timestamp as the input.
cc->Outputs().Get(output_data_id_).SetOffset(mediapipe::TimestampDiff(0));
has_header_ = false;
input_width_ = 0;
input_height_ = 0;
crop_width_ = 0;
crop_height_ = 0;
output_width_ = 0;
output_height_ = 0;
bool has_override_options = cc->Inputs().HasTag("OVERRIDE_OPTIONS");
if (!has_override_options) {
RETURN_IF_ERROR(InitializeFromOptions());
}
if (!cc->Inputs().Get(input_data_id_).Header().IsEmpty()) {
// If the input stream has a header then our output stream also has a
// header.
if (has_override_options) {
// It's not possible to use OVERRIDE_OPTIONS when the main input stream
// has a header. At this point in the code, the ScaleImageCalculator
// config may be changed by the new options at PreStream, so the output
// header can't be determined.
return ::mediapipe::InvalidArgumentError(
"OVERRIDE_OPTIONS stream can't be used when the main input stream "
"has a header.");
}
input_video_header_ =
cc->Inputs().Get(input_data_id_).Header().Get<VideoHeader>();
input_format_ = input_video_header_.format;
if (options_.has_input_format()) {
RET_CHECK_EQ(input_format_, options_.input_format())
<< "The input header format does not match the input_format option.";
}
input_width_ = input_video_header_.width;
input_height_ = input_video_header_.height;
if (options_.has_output_format()) {
output_format_ = options_.output_format();
} else {
output_format_ = input_format_;
}
if (output_format_ == ImageFormat::YCBCR420P) {
RET_CHECK(options_.scale_to_multiple_of_two())
<< "ScaleImageCalculator always outputs width and height that are "
"divisible by 2 when output format is YCbCr420P. To scale to "
"width and height of odd numbers, the output format must be SRGB.";
} else if (options_.preserve_aspect_ratio()) {
RET_CHECK(options_.scale_to_multiple_of_two())
<< "ScaleImageCalculator always outputs width and height that are "
"divisible by 2 when perserving aspect ratio. To scale to width "
"and height of odd numbers, please set "
"preserve_aspect_ratio to false.";
}
if (input_width_ > 0 && input_height_ > 0 &&
input_format_ != ImageFormat::UNKNOWN &&
output_format_ != ImageFormat::UNKNOWN) {
RETURN_IF_ERROR(ValidateImageFormats());
RETURN_IF_ERROR(InitializeFrameInfo(cc));
std::unique_ptr<VideoHeader> output_header(new VideoHeader());
*output_header = input_video_header_;
output_header->format = output_format_;
output_header->width = output_width_;
output_header->height = output_height_;
cc->Outputs()
.Get(output_data_id_)
.SetHeader(Adopt(output_header.release()));
has_header_ = true;
} else {
LOG(WARNING) << "Stream had a VideoHeader which didn't have sufficient "
"information. "
"Dropping VideoHeader and trying to deduce needed "
"information.";
input_width_ = 0;
input_height_ = 0;
if (!options_.has_input_format()) {
input_format_ = ImageFormat::UNKNOWN;
}
output_format_ = ImageFormat::UNKNOWN;
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::InitializeFromOptions() {
if (options_.has_input_format()) {
input_format_ = options_.input_format();
} else {
input_format_ = ImageFormat::UNKNOWN;
}
alignment_boundary_ = 16;
if (options_.alignment_boundary() > 0) {
alignment_boundary_ = options_.alignment_boundary();
}
downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient()));
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::ValidateImageFormats() const {
RET_CHECK_NE(input_format_, ImageFormat::UNKNOWN)
<< "The input image format was UNKNOWN.";
RET_CHECK_NE(output_format_, ImageFormat::UNKNOWN)
<< "The output image format was set to UNKNOWN.";
// TODO Remove these conditions.
RET_CHECK(output_format_ == ImageFormat::SRGB ||
(input_format_ == output_format_ &&
output_format_ == ImageFormat::YCBCR420P))
<< "Outputting YCbCr420P images from SRGB input is not yet supported";
RET_CHECK(input_format_ == output_format_ ||
input_format_ == ImageFormat::YCBCR420P)
<< "Conversion of the color space (except from "
"YCbCr420P to SRGB) is not yet supported.";
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::ValidateImageFrame(
CalculatorContext* cc, const ImageFrame& image_frame) {
if (!has_header_) {
if (input_width_ != image_frame.Width() ||
input_height_ != image_frame.Height() ||
input_format_ != image_frame.Format()) {
// Set the dimensions based on the image frame. There was no header.
input_width_ = image_frame.Width();
input_height_ = image_frame.Height();
RET_CHECK(input_width_ > 0 && input_height_ > 0) << absl::StrCat(
"The input image did not have positive dimensions. dimensions: ",
input_width_, "x", input_height_);
input_format_ = image_frame.Format();
if (options_.has_input_format()) {
RET_CHECK_EQ(input_format_, options_.input_format())
<< "The input image format does not match the input_format option.";
}
if (options_.has_output_format()) {
output_format_ = options_.output_format();
} else {
output_format_ = input_format_;
}
RETURN_IF_ERROR(InitializeFrameInfo(cc));
}
RETURN_IF_ERROR(ValidateImageFormats());
} else {
if (input_width_ != image_frame.Width() ||
input_height_ != image_frame.Height()) {
return tool::StatusFail(absl::StrCat(
"If a header specifies a width and a height, then image frames on "
"the stream must have that size. Received frame of size ",
image_frame.Width(), "x", image_frame.Height(), " but expected ",
input_width_, "x", input_height_));
}
if (input_format_ != image_frame.Format()) {
const proto_ns::EnumDescriptor* desc = ImageFormat::Format_descriptor();
return tool::StatusFail(absl::StrCat(
"If a header specifies a format, then image frames on "
"the stream must have that format. Actual format ",
desc->FindValueByNumber(image_frame.Format())->DebugString(),
" but expected ",
desc->FindValueByNumber(input_format_)->DebugString()));
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::ValidateYUVImage(
CalculatorContext* cc, const YUVImage& yuv_image) {
CHECK_EQ(input_format_, ImageFormat::YCBCR420P);
if (!has_header_) {
if (input_width_ != yuv_image.width() ||
input_height_ != yuv_image.height()) {
// Set the dimensions based on the YUV image. There was no header.
input_width_ = yuv_image.width();
input_height_ = yuv_image.height();
RET_CHECK(input_width_ > 0 && input_height_ > 0) << absl::StrCat(
"The input image did not have positive dimensions. dimensions: ",
input_width_, "x", input_height_);
if (options_.has_output_format()) {
output_format_ = options_.output_format();
} else {
output_format_ = input_format_;
}
RETURN_IF_ERROR(InitializeFrameInfo(cc));
}
RETURN_IF_ERROR(ValidateImageFormats());
} else {
if (input_width_ != yuv_image.width() ||
input_height_ != yuv_image.height()) {
return tool::StatusFail(absl::StrCat(
"If a header specifies a width and a height, then YUV images on "
"the stream must have that size. Additionally, all YUV images in "
"a stream must have the same size. Received frame of size ",
yuv_image.width(), "x", yuv_image.height(), " but expected ",
input_width_, "x", input_height_));
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream()) {
if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) {
if (cc->Inputs().Tag("OVERRIDE_OPTIONS").IsEmpty()) {
return ::mediapipe::InvalidArgumentError(
"The OVERRIDE_OPTIONS input stream must be non-empty at PreStream "
"time if used.");
}
options_.MergeFrom(cc->Inputs()
.Tag("OVERRIDE_OPTIONS")
.Get<ScaleImageCalculatorOptions>());
RETURN_IF_ERROR(InitializeFromOptions());
}
if (cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
input_video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
}
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
return ::mediapipe::OkStatus();
}
}
cc->GetCounter("Inputs")->Increment();
const ImageFrame* image_frame;
ImageFrame converted_image_frame;
if (input_format_ == ImageFormat::YCBCR420P) {
const YUVImage* yuv_image =
&cc->Inputs().Get(input_data_id_).Get<YUVImage>();
RETURN_IF_ERROR(ValidateYUVImage(cc, *yuv_image));
if (output_format_ == ImageFormat::SRGB) {
// TODO: For ease of implementation, YUVImage is converted to
// ImageFrame immediately, before cropping and scaling. Investigate how to
// make color space conversion more efficient when cropping or scaling is
// also needed.
image_frame_util::YUVImageToImageFrame(*yuv_image, &converted_image_frame,
options_.use_bt709());
image_frame = &converted_image_frame;
} else if (output_format_ == ImageFormat::YCBCR420P) {
RET_CHECK(row_start_ == 0 && col_start_ == 0 &&
crop_width_ == input_width_ && crop_height_ == input_height_)
<< "ScaleImageCalculator only supports scaling on YUVImages. To crop "
"images, the output format must be SRGB.";
// Scale the YUVImage and output without converting the color space.
const int y_size = output_width_ * output_height_;
const int uv_size = output_width_ * output_height_ / 4;
std::unique_ptr<uint8_t[]> yuv_data(new uint8_t[y_size + uv_size * 2]);
uint8* y = yuv_data.get();
uint8* u = y + y_size;
uint8* v = u + uv_size;
RET_CHECK_EQ(0, I420Scale(yuv_image->data(0), yuv_image->stride(0),
yuv_image->data(1), yuv_image->stride(1),
yuv_image->data(2), yuv_image->stride(2),
yuv_image->width(), yuv_image->height(), y,
output_width_, u, output_width_ / 2, v,
output_width_ / 2, output_width_,
output_height_, libyuv::kFilterBox));
auto output_image = absl::make_unique<YUVImage>(
libyuv::FOURCC_I420, std::move(yuv_data), y, output_width_, u,
output_width_ / 2, v, output_width_ / 2, output_width_,
output_height_);
cc->GetCounter("Outputs Scaled")->Increment();
if (yuv_image->width() >= output_width_ &&
yuv_image->height() >= output_height_) {
cc->GetCounter("Downscales")->Increment();
} else if (interpolation_algorithm_ != -1) {
cc->GetCounter("Upscales")->Increment();
}
cc->Outputs()
.Get(output_data_id_)
.Add(output_image.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} else {
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame));
}
std::unique_ptr<ImageFrame> cropped_image;
if (crop_width_ < input_width_ || crop_height_ < input_height_) {
cc->GetCounter("Crops")->Increment();
// TODO Do the crop as a range restrict inside OpenCV code below.
cropped_image.reset(new ImageFrame(image_frame->Format(), crop_width_,
crop_height_, alignment_boundary_));
if (image_frame->ByteDepth() == 1 || image_frame->ByteDepth() == 2) {
CropImageFrame(*image_frame, col_start_, row_start_, crop_width_,
crop_height_, cropped_image.get());
} else {
return tool::StatusInvalid(
"Input format does not have ByteDepth of 1 or 2.");
}
// Update the image_frame to point to the cropped image. The
// unique_ptr will take care of deleting the cropped image when the
// function returns.
image_frame = cropped_image.get();
}
// Skip later operations if no scaling is necessary.
if (crop_width_ == output_width_ && crop_height_ == output_height_) {
// Efficiently use either the cropped image or the original image.
if (image_frame == cropped_image.get()) {
if (options_.set_alignment_padding()) {
cropped_image->SetAlignmentPaddingAreas();
}
cc->GetCounter("Outputs Cropped")->Increment();
cc->Outputs()
.Get(output_data_id_)
.Add(cropped_image.release(), cc->InputTimestamp());
} else {
if (options_.alignment_boundary() <= 0 &&
(!options_.set_alignment_padding() || image_frame->IsContiguous())) {
// Any alignment is acceptable and we don't need to clear the
// alignment padding (either because the user didn't request it
// or because the data is contiguous).
cc->GetCounter("Outputs Inputs")->Increment();
cc->Outputs()
.Get(output_data_id_)
.AddPacket(cc->Inputs().Get(input_data_id_).Value());
} else {
// Make a copy with the correct alignment.
std::unique_ptr<ImageFrame> output_frame(new ImageFrame());
output_frame->CopyFrom(*image_frame, alignment_boundary_);
if (options_.set_alignment_padding()) {
output_frame->SetAlignmentPaddingAreas();
}
cc->GetCounter("Outputs Aligned")->Increment();
cc->Outputs()
.Get(output_data_id_)
.Add(output_frame.release(), cc->InputTimestamp());
}
}
return ::mediapipe::OkStatus();
}
// Rescale the image frame.
std::unique_ptr<ImageFrame> output_frame(new ImageFrame());
if (image_frame->Width() >= output_width_ &&
image_frame->Height() >= output_height_) {
// Downscale.
cc->GetCounter("Downscales")->Increment();
cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame);
output_frame->Reset(image_frame->Format(), output_width_, output_height_,
alignment_boundary_);
cv::Mat output_mat = ::mediapipe::formats::MatView(output_frame.get());
downscaler_->Resize(input_mat, &output_mat);
} else {
// Upscale. If upscaling is disallowed, output_width_ and output_height_ are
// the same as the input/crop width and height.
image_frame_util::RescaleImageFrame(
*image_frame, output_width_, output_height_, alignment_boundary_,
interpolation_algorithm_, output_frame.get());
if (interpolation_algorithm_ != -1) {
cc->GetCounter("Upscales")->Increment();
}
}
if (options_.set_alignment_padding()) {
cc->GetCounter("Pads")->Increment();
output_frame->SetAlignmentPaddingAreas();
}
cc->GetCounter("Outputs Scaled")->Increment();
cc->Outputs()
.Get(output_data_id_)
.Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,109 @@
// Options for ScaleImageCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/image_format.proto";
// Order of operations.
// 1) Crop the image to fit within min_aspect_ratio and max_aspect_ratio.
// 2) Scale and convert the image to fit inside target_width x target_height
// using the specified scaling algorithm. (maintaining the aspect
// ratio if preserve_aspect_ratio is true).
// The output width and height will be divisible by 2. It is possible to output
// width and height that are odd number when the output format is SRGB and not
// perserving the aspect ratio. See scale_to_multiple_of_two option for details.
message ScaleImageCalculatorOptions {
extend CalculatorOptions {
optional ScaleImageCalculatorOptions ext = 66237115;
}
// Target output width and height. The final output's size may vary
// depending on the other options below. If unset, use the same width
// or height as the input. If only one is set then determine the other
// from the aspect ratio (after cropping). The output width and height
// will be divisible by 2.
optional int32 target_width = 1;
optional int32 target_height = 2;
// If true, the image is scaled up or down proportionally so that it
// fits inside the box represented by target_width and target_height.
// Otherwise it is scaled to fit target_width and target_height
// completely. In any case, the aspect ratio that is preserved is
// that after cropping to the minimum/maximum aspect ratio.
optional bool preserve_aspect_ratio = 3 [default = true];
// If ratio is positive, crop the image to this minimum and maximum
// aspect ratio (preserving the center of the frame). This is done
// before scaling.
// For example, for a min_aspect_ratio of "9/16" and max of "16/9" the
// following cropping will occur:
// 1920x1080 (which is 16:9) is not cropped
// 640x1024 (which is 10:16) is not cropped
// 640x320 (which is 2:1) cropped to 568x320 (just under 16/9)
// 96x480 (which is 1:5), cropped to 96x170 (just over 9/16)
// The resultant frame will always be between (or at) the
// min_aspect_ratio and max_aspect_ratio.
optional string min_aspect_ratio = 4 [default = "9/16"];
optional string max_aspect_ratio = 5 [default = "16/9"];
// If unset, use the same format as the input.
// NOTE: in the current implementation, the output format (either specified
// in the output_format option or inherited from the input format) must be
// SRGB. It can be YCBCR420P if the input_format is also the same.
optional ImageFormat.Format output_format = 6;
enum ScaleAlgorithm {
DEFAULT = 0;
LINEAR = 1;
CUBIC = 2;
AREA = 3;
LANCZOS = 4;
DEFAULT_WITHOUT_UPSCALE = 5; // Option to disallow upscaling.
}
// The upscaling algorithm to use. The default is to use CUBIC. Note that
// downscaling unconditionally uses DDA; see image_processing::
// AffineGammaResizer for documentation.
optional ScaleAlgorithm algorithm = 7 [default = DEFAULT];
// The output image will have this alignment. If set to zero, then
// any alignment could be used. If set to one, the output image will
// be stored contiguously.
optional int32 alignment_boundary = 8 [default = 16];
// Set the alignment padding area to deterministic values (as opposed
// to possibly leaving it as uninitialized memory). The padding is
// the space between the pixel values in a row and the end of the row
// (which may be different due to alignment requirements on the length
// of a row).
optional bool set_alignment_padding = 9 [default = true];
optional bool OBSOLETE_skip_linear_rgb_conversion = 10 [default = false];
// Applies sharpening for downscaled images as post-processing. See
// image_processing::AffineGammaResizer for documentation.
optional float post_sharpening_coefficient = 11 [default = 0.0];
// If input_format is YCBCR420P, input packets contain a YUVImage. If
// input_format is a format other than YCBCR420P or is unset, input packets
// contain an ImageFrame.
// NOTE: in the current implementation, the input format (either specified
// in the input_format option or inferred from the input packets) must be
// SRGB or YCBCR420P.
optional ImageFormat.Format input_format = 12;
// If true, the output width and height will be divisible by 2. Otherwise it
// will use the exact specified output width and height, which is only
// supported when the output format is SRGB and preserve_aspect_ratio option
// is set to false.
optional bool scale_to_multiple_of_two = 13 [default = true];
// If true, assume the input YUV is BT.709 (this is the HDTV standard, so most
// content is likely using it). If false use the previous assumption of BT.601
// (mid-80s standard). Ideally this information should be contained in the
// input YUV Frame, but as of 02/06/2019, it's not. Once this info is baked
// in, this flag becomes useless.
optional bool use_bt709 = 14 [default = false];
}

View File

@ -0,0 +1,159 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/scale_image_utils.h"
#include <math.h>
#include <string>
#include "absl/strings/str_split.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace scale_image {
namespace {
double ParseRational(const std::string& rational) {
const std::vector<std::string> v = absl::StrSplit(rational, '/');
const double numerator = std::strtod(v[0].c_str(), nullptr);
const double denominator = std::strtod(v[1].c_str(), nullptr);
return numerator / denominator;
}
} // namespace
::mediapipe::Status FindCropDimensions(int input_width, int input_height, //
const std::string& min_aspect_ratio, //
const std::string& max_aspect_ratio, //
int* crop_width, int* crop_height, //
int* col_start, int* row_start) {
CHECK(crop_width);
CHECK(crop_height);
CHECK(col_start);
CHECK(row_start);
double min_aspect_ratio_q = 0.0;
double max_aspect_ratio_q = 0.0;
if (!min_aspect_ratio.empty()) {
min_aspect_ratio_q = ParseRational(min_aspect_ratio);
}
if (!max_aspect_ratio.empty()) {
max_aspect_ratio_q = ParseRational(max_aspect_ratio);
}
*crop_width = input_width;
*crop_height = input_height;
*col_start = 0;
*row_start = 0;
// Determine the current aspect ratio.
const double aspect_ratio =
static_cast<double>(input_width) / static_cast<double>(input_height);
if (!std::isinf(max_aspect_ratio_q) && !std::isinf(min_aspect_ratio_q)) {
if (max_aspect_ratio_q > 0 && aspect_ratio > max_aspect_ratio_q) {
// Determine the width based on the height multiplied by the max
// aspect ratio.
*crop_width = static_cast<int>(static_cast<double>(input_height) *
max_aspect_ratio_q);
*crop_width = (*crop_width / 2) * 2;
// The col_start should be half the difference between the input width
// and the output width.
*col_start = (input_width - *crop_width) / 2;
} else if (min_aspect_ratio_q > 0 && aspect_ratio < min_aspect_ratio_q) {
// Determine the height based on the width divided by the min
// aspect ratio.
*crop_height = static_cast<int>(static_cast<double>(input_width) /
min_aspect_ratio_q);
*crop_height = (*crop_height / 2) * 2;
*row_start = (input_height - *crop_height) / 2;
}
}
CHECK_LE(*crop_width, input_width);
CHECK_LE(*crop_height, input_height);
return ::mediapipe::OkStatus();
}
::mediapipe::Status FindOutputDimensions(int input_width, //
int input_height, //
int target_width, //
int target_height, //
bool preserve_aspect_ratio, //
bool scale_to_multiple_of_two, //
int* output_width,
int* output_height) {
CHECK(output_width);
CHECK(output_height);
if (!preserve_aspect_ratio || (target_width <= 0 && target_height <= 0)) {
if (target_width <= 0) {
target_width = input_width;
}
if (target_height <= 0) {
target_height = input_height;
}
if (scale_to_multiple_of_two) {
*output_width = (target_width / 2) * 2;
*output_height = (target_height / 2) * 2;
} else {
*output_width = target_width;
*output_height = target_height;
}
return ::mediapipe::OkStatus();
}
if (target_width > 0) {
// Try setting the height based on the width and the aspect ratio.
int try_width = target_width;
int try_height = static_cast<int>(static_cast<double>(target_width) /
static_cast<double>(input_width) *
static_cast<double>(input_height));
try_width = (try_width / 2) * 2;
try_height = (try_height / 2) * 2;
if (target_height <= 0 || try_height <= target_height) {
// The resulting height based on the target width and aspect ratio
// was within the image, so use these dimensions.
*output_width = try_width;
*output_height = try_height;
return ::mediapipe::OkStatus();
}
}
if (target_height > 0) {
// Try setting the width based on the height and the aspect ratio.
int try_height = target_height;
int try_width = static_cast<int>(static_cast<double>(target_height) /
static_cast<double>(input_height) *
static_cast<double>(input_width));
try_width = (try_width / 2) * 2;
try_height = (try_height / 2) * 2;
if (target_width <= 0 || try_width <= target_width) {
// The resulting width based on the target width and aspect ratio
// was within the image, so use these dimensions.
*output_width = try_width;
*output_height = try_height;
return ::mediapipe::OkStatus();
}
}
RET_CHECK_FAIL()
<< "Unable to set output dimensions based on target dimensions.";
}
} // namespace scale_image
} // namespace mediapipe

View File

@ -0,0 +1,54 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Utilities for scaling operations defined by ScaleImageCalculatorOptions.
#ifndef MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_
#define MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_
#include <string>
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace scale_image {
// Given a width and height and min and max aspect ratios, determine the
// target width and height and column and row starts such that the target
// is a centered, cropped portion of the image that falls within the min
// and max aspect ratio. If either the min or max aspect ratio argument
// is empty or has a 0 in the numerator or denominator then it is ignored.
::mediapipe::Status FindCropDimensions(int input_width, int input_height, //
const std::string& min_aspect_ratio, //
const std::string& max_aspect_ratio, //
int* crop_width, int* crop_height, //
int* col_start, int* row_start);
// Given an input width and height, a target width and height, whether to
// preserve the aspect ratio, and whether to round down to a multiple of 2,
// determine the output width and height. If target_width or target_height is
// non-positive, then they will be set to the input_width and input_height
// respectively. The output_width and output_height will be reduced as necessary
// to preserve_aspect_ratio and to scale_to_multipe_of_two if these options are
// specified.
::mediapipe::Status FindOutputDimensions(int input_width, int input_height, //
int target_width,
int target_height, //
bool preserve_aspect_ratio, //
bool scale_to_multiple_of_two, //
int* output_width, int* output_height);
} // namespace scale_image
} // namespace mediapipe
#endif // MEDIAPIPE_IMAGE_SCALE_IMAGE_UTILS_H_

View File

@ -0,0 +1,153 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/image/scale_image_utils.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace scale_image {
namespace {
TEST(ScaleImageUtilsTest, FindCropDimensions) {
int crop_width;
int crop_height;
int col_start;
int row_start;
// No cropping because aspect ratios should be ignored.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(50, 100, "0/1", "1/0", &crop_width,
&crop_height, &col_start, &row_start));
EXPECT_EQ(50, crop_width);
EXPECT_EQ(100, crop_height);
EXPECT_EQ(0, row_start);
EXPECT_EQ(0, col_start);
// Tests proto examples.
// 16:9 aspect ratio, should be unchanged.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(1920, 1080, "9/16", "16/9",
&crop_width, &crop_height, &col_start,
&row_start));
EXPECT_EQ(0, col_start);
EXPECT_EQ(1920, crop_width);
EXPECT_EQ(0, row_start);
EXPECT_EQ(1080, crop_height);
// 10:16 aspect ratio, should be unchanged.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(640, 1024, "9/16", "16/9", &crop_width,
&crop_height, &col_start, &row_start));
EXPECT_EQ(0, col_start);
EXPECT_EQ(640, crop_width);
EXPECT_EQ(0, row_start);
EXPECT_EQ(1024, crop_height);
// 2:1 aspect ratio, width is cropped.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(640, 320, "9/16", "16/9", &crop_width,
&crop_height, &col_start, &row_start));
EXPECT_EQ(36, col_start);
EXPECT_EQ(568, crop_width);
EXPECT_EQ(0, row_start);
EXPECT_EQ(320, crop_height);
// 1:5 aspect ratio, height is cropped.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(96, 480, "9/16", "16/9", &crop_width,
&crop_height, &col_start, &row_start));
EXPECT_EQ(0, col_start);
EXPECT_EQ(96, crop_width);
EXPECT_EQ(155, row_start);
EXPECT_EQ(170, crop_height);
// Tests min = max, crops width.
MEDIAPIPE_ASSERT_OK(FindCropDimensions(200, 100, "1/1", "1/1", &crop_width,
&crop_height, &col_start, &row_start));
EXPECT_EQ(50, col_start);
EXPECT_EQ(100, crop_width);
EXPECT_EQ(0, row_start);
EXPECT_EQ(100, crop_height);
}
TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) {
int output_width;
int output_height;
// Not scale.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, -1, true, true,
&output_width, &output_height));
EXPECT_EQ(200, output_width);
EXPECT_EQ(100, output_height);
// Not scale with odd input size.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(201, 101, -1, -1, false, false,
&output_width, &output_height));
EXPECT_EQ(201, output_width);
EXPECT_EQ(101, output_height);
// Scale down by 1/2.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, true, true,
&output_width, &output_height));
EXPECT_EQ(100, output_width);
EXPECT_EQ(50, output_height);
// Scale up, doubling dimensions.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, true, true,
&output_width, &output_height));
EXPECT_EQ(400, output_width);
EXPECT_EQ(200, output_height);
// Fits a 2:1 image into a 150 x 150 box. Output dimensions are always
// visible by 2.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 150, 150, true, true,
&output_width, &output_height));
EXPECT_EQ(150, output_width);
EXPECT_EQ(74, output_height);
// Fits a 2:1 image into a 400 x 50 box.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 400, 50, true, true,
&output_width, &output_height));
EXPECT_EQ(100, output_width);
EXPECT_EQ(50, output_height);
// Scale to multiple number with odd targe size.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, true,
&output_width, &output_height));
EXPECT_EQ(100, output_width);
EXPECT_EQ(50, output_height);
// Scale to multiple number with odd targe size.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 101, -1, true, false,
&output_width, &output_height));
EXPECT_EQ(100, output_width);
EXPECT_EQ(50, output_height);
// Scale to odd size.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 151, 101, false, false,
&output_width, &output_height));
EXPECT_EQ(151, output_width);
EXPECT_EQ(101, output_height);
}
// Tests scaling without keeping the aspect ratio fixed.
TEST(ScaleImageUtilsTest, FindOutputDimensionsNoAspectRatio) {
int output_width;
int output_height;
// Scale width only.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 100, -1, false, true,
&output_width, &output_height));
EXPECT_EQ(100, output_width);
EXPECT_EQ(100, output_height);
// Scale height only.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, -1, 200, false, true,
&output_width, &output_height));
EXPECT_EQ(200, output_width);
EXPECT_EQ(200, output_height);
// Scale both dimensions.
MEDIAPIPE_ASSERT_OK(FindOutputDimensions(200, 100, 150, 200, false, true,
&output_width, &output_height));
EXPECT_EQ(150, output_width);
EXPECT_EQ(200, output_height);
}
} // namespace
} // namespace scale_image
} // namespace mediapipe

View File

@ -0,0 +1,462 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h"
#if defined(__ANDROID__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
namespace mediapipe {
namespace {
constexpr char kInputFrameTag[] = "IMAGE";
constexpr char kInputAlphaTag[] = "ALPHA";
constexpr char kOutputFrameTag[] = "IMAGE";
constexpr char kInputFrameTagGpu[] = "IMAGE_GPU";
constexpr char kInputAlphaTagGpu[] = "ALPHA_GPU";
constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU";
constexpr int kNumChannelsRGBA = 4;
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
} // namespace
// A calculator for setting the alpha channel of an RGBA image.
//
// The alpha channel can be set to a single value, or come from an image mask.
// If the input image has an alpha channel, it will be updated.
// If the input image doesn't have an alpha channel, one will be added.
// Adding alpha channel to a Grayscale (single channel) input is not suported.
//
// Inputs:
// One of the following two IMAGE tags:
// IMAGE: ImageFrame containing input image - RGB or RGBA.
// IMAGE_GPU: GpuBuffer containing input image - RGB or RGBA.
//
// ALPHA (optional): ImageFrame alpha mask to apply,
// can be any # of channels, only first channel used,
// must be same format as input
// ALPHA_GPU (optional): GpuBuffer alpha mask to apply,
// can be any # of channels, only first channel used,
// must be same format as input
// If ALPHA* input tag is not set, the 'alpha_value' option must be used.
//
// Output:
// One of the following two tags:
// IMAGE: An ImageFrame with alpha channel set - RGBA only.
// IMAGE_GPU: A GpuBuffer with alpha channel set - RGBA only.
//
// Options:
// alpha_value (optional): The alpha value to set to input image, [0-255],
// takes precedence over input mask.
// If alpha_value is not set, the ALPHA* input tag must be used.
//
// Notes:
// Either alpha_value option or ALPHA (or ALPHA_GPU) must be set.
// All CPU inputs must have the same image dimensions and data type.
//
class SetAlphaCalculator : public CalculatorBase {
public:
SetAlphaCalculator() = default;
~SetAlphaCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
// From Calculator.
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status RenderCpu(CalculatorContext* cc);
::mediapipe::Status GlRender(CalculatorContext* cc);
::mediapipe::Status GlSetup(CalculatorContext* cc);
mediapipe::SetAlphaCalculatorOptions options_;
float alpha_value_ = -1.f;
bool use_gpu_ = false;
bool gpu_initialized_ = false;
#if defined(__ANDROID__)
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
#endif // __ANDROID__
};
REGISTER_CALCULATOR(SetAlphaCalculator);
::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) {
return ::mediapipe::InternalError("Cannot have multiple input images.");
}
if (cc->Inputs().HasTag(kInputFrameTagGpu) !=
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
return ::mediapipe::InternalError("GPU output must have GPU input.");
}
// Input image to add/edit alpha channel.
#if defined(__ANDROID__)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
}
// Input alpha image mask (optional)
#if defined(__ANDROID__)
if (cc->Inputs().HasTag(kInputAlphaTagGpu)) {
cc->Inputs().Tag(kInputAlphaTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Inputs().HasTag(kInputAlphaTag)) {
cc->Inputs().Tag(kInputAlphaTag).Set<ImageFrame>();
}
// RGBA output image.
#if defined(__ANDROID__)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
}
#if defined(__ANDROID__)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::SetAlphaCalculatorOptions>();
if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
#endif // __ANDROID__
}
// Get global value from options (-1 if not set).
alpha_value_ = options_.alpha_value();
if (use_gpu_) alpha_value_ /= 255.0;
const bool use_image_mask = cc->Inputs().HasTag(kInputAlphaTag) ||
cc->Inputs().HasTag(kInputAlphaTagGpu);
if (!((alpha_value_ >= 0) ^ use_image_mask))
RET_CHECK_FAIL() << "Must use either image mask or options alpha value.";
if (use_gpu_) {
#if defined(__ANDROID__)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) {
RETURN_IF_ERROR(GlSetup(cc));
gpu_initialized_ = true;
}
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
});
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
// Setup source image
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
}
// Setup destination image
auto output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
const bool use_alpa_mask = alpha_value_ < 0 && has_alpha_mask;
// Setup alpha image and Update image in CPU.
if (use_alpa_mask) {
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask);
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* alpha_ptr = alpha_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
const int alpha_idx = j * alpha_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
}
}
} else {
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_value; // use value from options
}
}
}
cc->Outputs()
.Tag(kOutputFrameTag)
.Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__)
// Setup source texture.
const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
if (!(input_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 ||
input_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) {
LOG(ERROR) << "Only RGB or RGBA input image supported";
}
auto input_texture = gpu_helper_.CreateSourceTexture(input_frame);
// Setup destination texture.
const int width = input_frame.width(), height = input_frame.height();
auto output_texture = gpu_helper_.CreateDestinationTexture(
width, height, mediapipe::GpuBufferFormat::kBGRA32);
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTagGpu) &&
!cc->Inputs().Tag(kInputAlphaTagGpu).IsEmpty();
// Setup alpha texture and Update image in GPU shader.
if (has_alpha_mask) {
const auto& alpha_mask =
cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>();
auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask);
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name());
glActiveTexture(GL_TEXTURE2);
glBindTexture(GL_TEXTURE_2D, alpha_texture.name());
GlRender(cc); // use channel 0 of mask
alpha_texture.Release();
} else {
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name());
GlRender(cc); // use value from options
}
// Send out image as GPU packet.
auto output_frame = output_texture.GetFrame<mediapipe::GpuBuffer>();
cc->Outputs()
.Tag(kOutputFrameTagGpu)
.Add(output_frame.release(), cc->InputTimestamp());
// Cleanup
input_texture.Release();
output_texture.Release();
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
static const GLfloat texture_vertices[] = {
0.0f, 0.0f, // bottom left
1.0f, 0.0f, // bottom right
0.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
// program
glUseProgram(program_);
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
// execute command queue
glBindTexture(GL_TEXTURE_2D, 0);
glFlush();
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) {
#if defined(__ANDROID__)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
// Shader to overlay a texture onto another when overlay is non-zero.
const GLchar* frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 sample_coordinate;
uniform sampler2D input_frame;
uniform sampler2D alpha_mask;
uniform float alpha_value;
void main() {
vec3 image_pix = texture2D(input_frame, sample_coordinate).rgb;
float alpha = alpha_value;
if (alpha_value < 0.0) alpha = texture2D(alpha_mask, sample_coordinate).r;
vec4 out_pix = vec4(image_pix, alpha);
fragColor = out_pix;
}
)";
// Create shader program and set parameters.
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src,
NUM_ATTRIBUTES, (const GLchar**)&attr_name[0],
attr_location, &program_);
RET_CHECK(program_) << "Problem initializing the program.";
glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2);
glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_);
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,16 @@
// Options for SetAlphaCalculator
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message SetAlphaCalculatorOptions {
extend CalculatorOptions {
optional SetAlphaCalculatorOptions ext = 250949799;
}
// The value to set the alpha channel to (0-255).
// This option is ignored when set to -1 (use image mask instead).
optional sint32 alpha_value = 1 [default = -1];
}

View File

@ -0,0 +1,239 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_simple_calculator.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
namespace mediapipe {
// Applies the Sobel filter to an image. Expects a grayscale image stored as
// RGB, like LuminanceCalculator outputs.
// See GlSimpleCalculatorBase for inputs, outputs and input side packets.
class SobelEdgesCalculator : public GlSimpleCalculator {
public:
::mediapipe::Status GlSetup() override;
::mediapipe::Status GlRender(const GlTexture& src,
const GlTexture& dst) override;
::mediapipe::Status GlTeardown() override;
private:
GLuint program_ = 0;
GLint frame_;
GLint pixel_w_;
GLint pixel_h_;
};
REGISTER_CALCULATOR(SobelEdgesCalculator);
::mediapipe::Status SobelEdgesCalculator::GlSetup() {
// Load vertex and fragment shaders
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"vertexPosition",
"vertexTextureCoordinate",
};
const GLchar* vert_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in attribute
#define out varying
#endif // __VERSION__ < 130
in vec4 vertexPosition;
in vec4 vertexTextureCoordinate;
// width of a pixel in normalized texture coordinates (0..1)
uniform highp float pixelW;
// height of a pixel in normalized texture coordinates (0..1)
uniform highp float pixelH;
// Dependent texture reads (i.e. texture reads where texture coordinates
// are computed in the fragment shader) are slow on pre-ES 3.0 hardware.
// Avoid them by computing all texture coordinates in the vertex shader.
// iOS OGLES performance guide: https://developer.apple.com/library/ios/documentation/3DDrawing/Conceptual/OpenGLES_ProgrammingGuide/BestPracticesforShaders/BestPracticesforShaders.html
// Code for coordinates: u = up, d = down, l = left, r = right, c = center.
// Horizontal coordinate first, then vertical.
out vec2 luTexCoord;
out vec2 lcTexCoord;
out vec2 ldTexCoord;
out vec2 cuTexCoord;
// out vec2 ccTexCoord;
out vec2 cdTexCoord;
out vec2 ruTexCoord;
out vec2 rcTexCoord;
out vec2 rdTexCoord;
void main() {
gl_Position = vertexPosition;
vec2 right = vec2(pixelW, 0.0);
vec2 up = vec2(0.0, pixelH);
lcTexCoord = vertexTextureCoordinate.xy - right;
luTexCoord = lcTexCoord + up;
ldTexCoord = lcTexCoord - up;
vec2 ccTexCoord = vertexTextureCoordinate.xy;
cuTexCoord = ccTexCoord + up;
cdTexCoord = ccTexCoord - up;
rcTexCoord = vertexTextureCoordinate.xy + right;
ruTexCoord = rcTexCoord + up;
rdTexCoord = rcTexCoord - up;
}
)";
const GLchar* frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 luTexCoord;
in vec2 lcTexCoord;
in vec2 ldTexCoord;
in vec2 cuTexCoord;
// in vec2 ccTexCoord;
in vec2 cdTexCoord;
in vec2 ruTexCoord;
in vec2 rcTexCoord;
in vec2 rdTexCoord;
uniform sampler2D inputImage;
void main() {
float luPx = texture2D(inputImage, luTexCoord).r;
float lcPx = texture2D(inputImage, lcTexCoord).r;
float ldPx = texture2D(inputImage, ldTexCoord).r;
float cuPx = texture2D(inputImage, cuTexCoord).r;
// float ccPx = texture2D(inputImage, ccTexCoord).r;
float cdPx = texture2D(inputImage, cdTexCoord).r;
float ruPx = texture2D(inputImage, ruTexCoord).r;
float rcPx = texture2D(inputImage, rcTexCoord).r;
float rdPx = texture2D(inputImage, rdTexCoord).r;
float h = -luPx - 2.0 * lcPx - ldPx + ruPx + 2.0 * rcPx + rdPx;
float v = -luPx - 2.0 * cuPx - ruPx + ldPx + 2.0 * cdPx + rdPx;
float mag = length(vec2(h, v));
fragColor = vec4(vec3(mag), 1.0);
}
)";
// shader program
GlhCreateProgram(vert_src, frag_src, NUM_ATTRIBUTES,
(const GLchar**)&attr_name[0], attr_location, &program_);
RET_CHECK(program_) << "Problem initializing the program.";
frame_ = glGetUniformLocation(program_, "inputImage");
pixel_w_ = glGetUniformLocation(program_, "pixelW");
pixel_h_ = glGetUniformLocation(program_, "pixelH");
return ::mediapipe::OkStatus();
}
::mediapipe::Status SobelEdgesCalculator::GlRender(const GlTexture& src,
const GlTexture& dst) {
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
static const float texture_vertices[] = {
0.0f, 0.0f, // bottom left
1.0f, 0.0f, // bottom right
0.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
// program
glUseProgram(program_);
glUniform1i(frame_, 1);
// parameters
glUniform1i(frame_, 1);
glUniform1f(pixel_w_, 1.0 / src.width());
glUniform1f(pixel_h_, 1.0 / src.height());
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
return ::mediapipe::OkStatus();
}
::mediapipe::Status SobelEdgesCalculator::GlTeardown() {
if (program_) {
glDeleteProgram(program_);
program_ = 0;
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,26 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
licenses(["notice"]) # Apache 2.0
filegroup(
name = "test_images",
srcs = [
"dino.jpg",
"dino_quality_50.jpg",
"dino_quality_80.jpg",
],
visibility = ["//visibility:public"],
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 756 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 424 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 693 KiB

View File

@ -0,0 +1,47 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
package(default_visibility = ["//visibility:private"])
proto_library(
name = "callback_packet_calculator_proto",
srcs = ["callback_packet_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "callback_packet_calculator_cc_proto",
srcs = ["callback_packet_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [":callback_packet_calculator_proto"],
)
cc_library(
name = "callback_packet_calculator",
srcs = ["callback_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":callback_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_registry",
"//mediapipe/framework:output_side_packet",
],
alwayslink = 1,
)

View File

@ -0,0 +1,103 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <string>
#include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" // NOLINT
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_registry.h"
#include "mediapipe/framework/output_side_packet.h"
namespace mediapipe {
namespace {
// Callback function for writing a packet to a vector. The output is before the
// input since std::bind fills arguments from left to right (and only
// dumped_data is filled by std::bind).
void DumpToVector(std::vector<Packet>* dumped_data, const Packet& packet) {
dumped_data->push_back(packet);
}
// Callback function for saving the Timestamp::PostStream() packet.
// The output is before the input since std::bind fills arguments from left to
// right (and only post_stream_packet is filled by std::bind).
void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) {
if (packet.Timestamp() == Timestamp::PostStream()) {
*post_stream_packet = packet;
}
}
} // namespace
// Creates a callback which takes a packet and stores it either in a
// vector of packets or stores only the packet at PostStream timestamp.
// The kind of callback is controlled by an option. The callback is
// a std::function and is directly usable by CallbackCalculator.
// Since the options for the packet generator include a serialized pointer
// value, the resulting callback is only valid on the original machine
// while that pointer is still alive.
class CallbackPacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
const auto& options = cc->Options<CallbackPacketCalculatorOptions>();
switch (options.type()) {
case CallbackPacketCalculatorOptions::VECTOR_PACKET:
case CallbackPacketCalculatorOptions::POST_STREAM_PACKET:
cc->OutputSidePackets()
.Index(0)
.Set<std::function<void(const Packet&)>>();
break;
default:
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Invalid type of callback to produce.";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const auto& options = cc->Options<CallbackPacketCalculatorOptions>();
void* ptr;
if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Stored pointer value in options is invalid.";
}
switch (options.type()) {
case CallbackPacketCalculatorOptions::VECTOR_PACKET:
cc->OutputSidePackets().Index(0).Set(
MakePacket<std::function<void(const Packet&)>>(std::bind(
&DumpToVector, reinterpret_cast<std::vector<Packet>*>(ptr),
std::placeholders::_1)));
break;
case CallbackPacketCalculatorOptions::POST_STREAM_PACKET:
cc->OutputSidePackets().Index(0).Set(
MakePacket<std::function<void(const Packet&)>>(
std::bind(&DumpPostStreamPacket, reinterpret_cast<Packet*>(ptr),
std::placeholders::_1)));
break;
default:
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Invalid type to dump into.";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(CallbackPacketCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message CallbackPacketCalculatorOptions {
extend CalculatorOptions {
optional CallbackPacketCalculatorOptions ext = 245965803;
}
enum PointerType {
UNKNOWN = 0;
VECTOR_PACKET = 1;
POST_STREAM_PACKET = 2;
}
// The type of the data pointer that the callback will put data into.
optional PointerType type = 1;
// The location of the data stored as a string printed with
// snprintf(address, sizeof(address), "%p", pointer).
// This calculator only produces a reasonable callback if it is
// constructed on the same machine as the original pointer was created on and
// that pointer is still alive.
optional bytes pointer = 2;
}

View File

@ -0,0 +1,975 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "graph_tensors_packet_generator_proto",
srcs = ["graph_tensors_packet_generator.proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework:packet_generator_proto",
],
)
proto_library(
name = "matrix_to_tensor_calculator_options_proto",
srcs = ["matrix_to_tensor_calculator_options.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "lapped_tensor_buffer_calculator_proto",
srcs = ["lapped_tensor_buffer_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "object_detection_tensors_to_detections_calculator_proto",
srcs = ["object_detection_tensors_to_detections_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tensorflow_inference_calculator_proto",
srcs = ["tensorflow_inference_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tensorflow_session_from_saved_model_generator_proto",
srcs = ["tensorflow_session_from_saved_model_generator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:packet_generator_proto"],
)
proto_library(
name = "tensorflow_session_from_saved_model_calculator_proto",
srcs = ["tensorflow_session_from_saved_model_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "tensor_squeeze_dimensions_calculator_proto",
srcs = ["tensor_squeeze_dimensions_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tensor_to_image_frame_calculator_proto",
srcs = ["tensor_to_image_frame_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tensor_to_matrix_calculator_proto",
srcs = ["tensor_to_matrix_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:time_series_header_proto",
],
)
proto_library(
name = "tensor_to_vector_float_calculator_options_proto",
srcs = ["tensor_to_vector_float_calculator_options.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "vector_float_to_tensor_calculator_options_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "graph_tensors_packet_generator_cc_proto",
srcs = ["graph_tensors_packet_generator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":graph_tensors_packet_generator_proto"],
)
mediapipe_cc_proto_library(
name = "image_frame_to_tensor_calculator_cc_proto",
srcs = ["image_frame_to_tensor_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":image_frame_to_tensor_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "matrix_to_tensor_calculator_options_cc_proto",
srcs = ["matrix_to_tensor_calculator_options.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":matrix_to_tensor_calculator_options_proto"],
)
mediapipe_cc_proto_library(
name = "lapped_tensor_buffer_calculator_cc_proto",
srcs = ["lapped_tensor_buffer_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":lapped_tensor_buffer_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "object_detection_tensors_to_detections_calculator_cc_proto",
srcs = ["object_detection_tensors_to_detections_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":object_detection_tensors_to_detections_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pack_media_sequence_calculator_cc_proto",
srcs = ["pack_media_sequence_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":pack_media_sequence_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensorflow_inference_calculator_cc_proto",
srcs = ["tensorflow_inference_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensorflow_inference_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensorflow_session_from_frozen_graph_generator_cc_proto",
srcs = ["tensorflow_session_from_frozen_graph_generator.proto"],
cc_deps = [
"//mediapipe/framework:packet_generator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensorflow_session_from_frozen_graph_generator_proto"],
)
mediapipe_cc_proto_library(
name = "tensorflow_session_from_saved_model_generator_cc_proto",
srcs = ["tensorflow_session_from_saved_model_generator.proto"],
cc_deps = ["//mediapipe/framework:packet_generator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensorflow_session_from_saved_model_generator_proto"],
)
mediapipe_cc_proto_library(
name = "tensorflow_session_from_saved_model_calculator_cc_proto",
srcs = ["tensorflow_session_from_saved_model_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensorflow_session_from_saved_model_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensor_squeeze_dimensions_calculator_cc_proto",
srcs = ["tensor_squeeze_dimensions_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensor_squeeze_dimensions_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensor_to_image_frame_calculator_cc_proto",
srcs = ["tensor_to_image_frame_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensor_to_image_frame_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensor_to_matrix_calculator_cc_proto",
srcs = ["tensor_to_matrix_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensor_to_matrix_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tensor_to_vector_float_calculator_options_cc_proto",
srcs = ["tensor_to_vector_float_calculator_options.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tensor_to_vector_float_calculator_options_proto"],
)
mediapipe_cc_proto_library(
name = "unpack_media_sequence_calculator_cc_proto",
srcs = ["unpack_media_sequence_calculator.proto"],
cc_deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":unpack_media_sequence_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "vector_float_to_tensor_calculator_options_cc_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":vector_float_to_tensor_calculator_options_proto"],
)
cc_library(
name = "graph_tensors_packet_generator",
srcs = ["graph_tensors_packet_generator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library(
name = "image_frame_to_tensor_calculator",
srcs = ["image_frame_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:image_frame_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_lib_lite",
],
}),
alwayslink = 1,
)
cc_library(
name = "matrix_to_tensor_calculator",
srcs = ["matrix_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_lib_lite",
],
}),
alwayslink = 1,
)
cc_library(
name = "lapped_tensor_buffer_calculator",
srcs = ["lapped_tensor_buffer_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:circular_buffer",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_library(
name = "object_detection_tensors_to_detections_calculator",
srcs = ["object_detection_tensors_to_detections_calculator.cc"],
features = [
# Layering check doesn't play nicely with portable proto wrappers.
"no_layering_check",
],
visibility = [
"//visibility:public",
],
deps = [
":object_detection_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:tensor_to_detection",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library(
name = "pack_media_sequence_calculator",
srcs = ["pack_media_sequence_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_library(
name = "string_to_sequence_example_calculator",
srcs = ["string_to_sequence_example_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
# On android, this calculator is configured to run with lite protos. Therefore,
# compile your binary with the flag TENSORFLOW_PROTOS=lite.
cc_library(
name = "tensorflow_inference_calculator",
srcs = ["tensorflow_inference_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensorflow_session",
hdrs = [
"tensorflow_session.h",
],
features = ["no_layering_check"],
visibility = ["//visibility:public"],
deps = select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:core",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
],
}),
)
cc_library(
name = "tensorflow_session_from_frozen_graph_generator",
srcs = ["tensorflow_session_from_frozen_graph_generator.cc"],
features = ["no_layering_check"],
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
"@org_tensorflow//tensorflow/core:core",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
"//mediapipe/android/file/base",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensorflow_session_from_saved_model_calculator",
srcs = ["tensorflow_session_from_saved_model_calculator.cc"],
defines = select({
"//mediapipe:android": ["__ANDROID__"],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
":tensorflow_session_from_saved_model_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/cc/saved_model:constants",
"@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensorflow_session_from_saved_model_generator",
srcs = ["tensorflow_session_from_saved_model_generator.cc"],
defines = select({
"//mediapipe:android": ["__ANDROID__"],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
"//mediapipe/framework:packet_generator",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/cc/saved_model:constants",
"@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensor_squeeze_dimensions_calculator",
srcs = ["tensor_squeeze_dimensions_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library(
name = "tensor_to_image_frame_calculator",
srcs = ["tensor_to_image_frame_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library(
name = "tensor_to_matrix_calculator",
srcs = ["tensor_to_matrix_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_lib_lite",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensor_to_vector_float_calculator",
srcs = ["tensor_to_vector_float_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_lib_lite",
],
}),
alwayslink = 1,
)
cc_library(
name = "unpack_media_sequence_calculator",
srcs = ["unpack_media_sequence_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_library(
name = "vector_float_to_tensor_calculator",
srcs = ["vector_float_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_test(
name = "graph_tensors_packet_generator_test",
srcs = ["graph_tensors_packet_generator_test.cc"],
deps = [
":graph_tensors_packet_generator",
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:packet_set",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:validate_type",
"@org_tensorflow//tensorflow/core:framework",
],
)
cc_test(
name = "image_frame_to_tensor_calculator_test",
size = "small",
srcs = ["image_frame_to_tensor_calculator_test.cc"],
deps = [
":image_frame_to_tensor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
)
cc_test(
name = "matrix_to_tensor_calculator_test",
size = "small",
srcs = ["matrix_to_tensor_calculator_test.cc"],
deps = [
":matrix_to_tensor_calculator",
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"@org_tensorflow//tensorflow/core:framework",
],
)
cc_test(
name = "lapped_tensor_buffer_calculator_test",
size = "small",
srcs = ["lapped_tensor_buffer_calculator_test.cc"],
deps = [
":lapped_tensor_buffer_calculator",
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "object_detection_tensors_to_detections_calculator_test",
srcs = ["object_detection_tensors_to_detections_calculator_test.cc"],
linkstatic = 1,
deps = [
":object_detection_tensors_to_detections_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:testlib",
],
)
cc_test(
name = "pack_media_sequence_calculator_test",
srcs = ["pack_media_sequence_calculator_test.cc"],
deps = [
":pack_media_sequence_calculator",
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "tensorflow_session_from_frozen_graph_generator_test",
srcs = ["tensorflow_session_from_frozen_graph_generator_test.cc"],
data = [":test_frozen_graph"],
linkstatic = 1,
deps = [
":tensorflow_inference_calculator",
":tensorflow_session",
":tensorflow_session_from_frozen_graph_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:testlib",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:math",
],
)
cc_test(
name = "tensorflow_session_from_saved_model_generator_test",
srcs = ["tensorflow_session_from_saved_model_generator_test.cc"],
data = [":test_saved_model"],
linkstatic = 1,
deps = [
":tensorflow_inference_calculator",
":tensorflow_session",
":tensorflow_session_from_saved_model_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core/kernels:array",
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
],
)
cc_test(
name = "tensorflow_session_from_saved_model_calculator_test",
srcs = ["tensorflow_session_from_saved_model_calculator_test.cc"],
data = [":test_saved_model"],
linkstatic = 1,
deps = [
":tensorflow_inference_calculator",
":tensorflow_session",
":tensorflow_session_from_saved_model_calculator",
":tensorflow_session_from_saved_model_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core/kernels:array",
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
],
)
cc_test(
name = "tensor_squeeze_dimensions_calculator_test",
srcs = ["tensor_squeeze_dimensions_calculator_test.cc"],
deps = [
":tensor_squeeze_dimensions_calculator",
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "tensor_to_image_frame_calculator_test",
size = "small",
srcs = ["tensor_to_image_frame_calculator_test.cc"],
deps = [
":tensor_to_image_frame_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "tensor_to_matrix_calculator_test",
size = "small",
srcs = ["tensor_to_matrix_calculator_test.cc"],
deps = [
":tensor_to_matrix_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "tensor_to_vector_float_calculator_test",
srcs = ["tensor_to_vector_float_calculator_test.cc"],
deps = [
":tensor_to_vector_float_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "unpack_media_sequence_calculator_test",
srcs = ["unpack_media_sequence_calculator_test.cc"],
deps = [
":unpack_media_sequence_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:rectangle",
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "vector_float_to_tensor_calculator_test",
srcs = ["vector_float_to_tensor_calculator_test.cc"],
deps = [
":vector_float_to_tensor_calculator",
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
test_suite(
name = "ios",
tags = ["ios"],
)
test_suite(
name = "android",
tags = ["android"],
)
cc_test(
name = "tensorflow_inference_calculator_test",
size = "medium",
srcs = ["tensorflow_inference_calculator_test.cc"],
data = [":test_frozen_graph"],
linkstatic = 1,
deps = [
":tensorflow_session",
":tensorflow_inference_calculator",
":tensorflow_session_from_frozen_graph_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/tool:sink",
"//mediapipe/framework/tool:validate_type",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:testlib",
"@org_tensorflow//tensorflow/core/kernels:math",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core:direct_session",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
],
}),
)
filegroup(
name = "test_frozen_graph",
srcs = [
"testdata/bundle/00000000/checkpoint",
"testdata/bundle/00000000/export.meta",
"testdata/bundle/00000000/export-00000-of-00001",
"testdata/frozen_graph_def.pb",
"testdata/model.chkpt-0",
"testdata/model.chkpt-0.meta",
"testdata/tf_graph_def.pb",
],
)
filegroup(
name = "test_saved_model",
srcs = [
"testdata/tensorflow_saved_model/00000000/saved_model.pb",
"testdata/tensorflow_saved_model/00000000/variables/variables.data-00000-of-00001",
"testdata/tensorflow_saved_model/00000000/variables/variables.index",
],
)

View File

@ -0,0 +1,73 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Generates row tensors of prescribed length that are initialized to zeros.
// The tensors are placed in an ordered map, which maps the tensors to the
// tensor tags, and emitted as a packet. This generator has been developed
// primarily to generate initialization states for LSTMs.
#include <map>
#include <string>
#include "mediapipe/calculators/tensorflow/graph_tensors_packet_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/tensor.h"
namespace mediapipe {
namespace tf = ::tensorflow;
class GraphTensorsPacketGenerator : public PacketGenerator {
public:
static ::mediapipe::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
RET_CHECK(extendable_options.HasExtension(
GraphTensorsPacketGeneratorOptions::ext));
const auto& options = extendable_options.GetExtension( // NOLINT
GraphTensorsPacketGeneratorOptions::ext);
output_side_packets->Index(0)
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>(
/* "A map of tensor tags and tensors" */);
RET_CHECK_EQ(options.tensor_tag_size(), options.tensor_num_nodes_size());
RET_CHECK_GT(options.tensor_tag_size(), 0);
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status Generate(
const PacketGeneratorOptions& packet_generator_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) {
const GraphTensorsPacketGeneratorOptions& options =
packet_generator_options.GetExtension(
GraphTensorsPacketGeneratorOptions::ext);
// Output bundle packet.
auto tensor_map = absl::make_unique<std::map<std::string, tf::Tensor>>();
for (int i = 0; i < options.tensor_tag_size(); ++i) {
const std::string& tensor_tag = options.tensor_tag(i);
const int32 tensor_num_nodes = options.tensor_num_nodes(i);
(*tensor_map)[tensor_tag] =
tf::Tensor(tf::DT_FLOAT, tf::TensorShape{1, tensor_num_nodes});
(*tensor_map)[tensor_tag].flat<float>().setZero();
}
output_side_packets->Index(0) = AdoptAsUniquePtr(tensor_map.release());
return ::mediapipe::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(GraphTensorsPacketGenerator);
} // namespace mediapipe

View File

@ -0,0 +1,50 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
message GraphTensorsPacketGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional GraphTensorsPacketGeneratorOptions ext = 142721046;
}
// Names of tensor tags for each of the generated tensors.
// Examples are: "STATE_C_0" or "STATE_M_0".
repeated string tensor_tag = 1;
// Must be same length as tensor_tag. Each tensor tag must be paired with the
// number of nodes.
// Tags must be capitalized, matching regex [A-Z0-9_]+. Examples: "STATE_C_0"
// and "STATE_M_0". Then, those tags can be used as the MediaPipe tags of
// tensors to initialized in TensorflowInferenceCalculator consuming
// the packet produced by this generator. For example, a mediapipe graph
// can include the node:
// packet_generator {
// packet_generator: "GraphTensorsPacketGenerator"
// output_side_packet: "init_tensors"
// options {
// [mediapipe.StateTensorsPacketGeneratorOptions.ext]: {
// tensor_tag: "STATE_C_0"
// tensor_num_nodes:128
// tensor_tag: "STATE_M_0"
// tensor_num_nodes:128
// }
// }
// }
repeated int32 tensor_num_nodes = 2;
}

View File

@ -0,0 +1,82 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/graph_tensors_packet_generator.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/core/framework/tensor.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
// Helper function that creates a row tensor that is initialized to zeros.
tf::Tensor ZeroRowTensor(const int col_length) {
tf::Tensor tensor(tf::DT_FLOAT, tf::TensorShape{1, col_length});
tensor.flat<float>().setZero();
return tensor;
}
class GraphTensorsPacketGeneratorTest : public ::testing::Test {
protected:
void SetUp() override {
extendable_options_.Clear();
generator_options_ = extendable_options_.MutableExtension(
GraphTensorsPacketGeneratorOptions::ext);
generator_options_->add_tensor_tag("A");
generator_options_->add_tensor_num_nodes(3);
generator_options_->add_tensor_tag("B");
generator_options_->add_tensor_num_nodes(4);
}
void VerifyTensorMap(PacketSet* output_side_packets) {
const std::map<std::string, tf::Tensor>* tensor_map =
GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
output_side_packets->Index(0));
EXPECT_FALSE(tensor_map->find("A") == tensor_map->end());
EXPECT_FALSE(tensor_map->find("B") == tensor_map->end());
tf::Tensor expected_tensor = ZeroRowTensor(3);
EXPECT_EQ(expected_tensor.DebugString(),
tensor_map->find("A")->second.DebugString());
expected_tensor = ZeroRowTensor(4);
EXPECT_EQ(expected_tensor.DebugString(),
tensor_map->find("B")->second.DebugString());
}
PacketGeneratorOptions extendable_options_;
GraphTensorsPacketGeneratorOptions* generator_options_;
};
// Test that the tensors are of the right size and shape
TEST_F(GraphTensorsPacketGeneratorTest, VerifyTensorSizeShapeAndValue) {
PacketSet inputs({});
PacketSet outputs(1);
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"GraphTensorsPacketGenerator", extendable_options_, inputs, &outputs);
MEDIAPIPE_EXPECT_OK(run_status) << run_status.message();
VerifyTensorMap(&outputs);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,180 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = tensorflow;
namespace {
// Convert the ImageFrame into Tensor with floating point value type.
// The value will be normalized based on mean and stddev.
std::unique_ptr<tf::Tensor> ImageFrameToNormalizedTensor(
const ImageFrame& image_frame, float mean, float stddev) {
const int cols = image_frame.Width();
const int rows = image_frame.Height();
const int channels = image_frame.NumberOfChannels();
const uint8* pixel = image_frame.PixelData();
const int width_padding = image_frame.WidthStep() - cols * channels;
auto tensor = ::absl::make_unique<tf::Tensor>(
tf::DT_FLOAT, tf::TensorShape({rows, cols, channels}));
auto tensor_data = tensor->tensor<float, 3>();
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < cols; ++col) {
for (int channel = 0; channel < channels; ++channel) {
tensor_data(row, col, channel) = (pixel[channel] - mean) / stddev;
}
pixel += channels;
}
pixel += width_padding;
}
return tensor;
}
} // namespace
// Converts ImageFrames to TensorFlow Tensors.
//
// The calculator expects one input (a packet containing an ImageFrame) and
// generates one output (a packet containing a tf::Tensor holding the same
// pixel data). The output tensor will be 3D with dimensions corresponding to
// height, width, and the number of channels (e.g. 3 for RGB or 1 for GRAY8).
//
// This calculator supports ImageFrame objects with any valid format (SRGB
// SRGBA, GRAY8, GRAY16, and VEC32F1). It will generate a Tensor using DT_UINT8
// for the first three types, DT_UINT16 for GRAY16, and DT_FLOAT for VEC32F1.
//
// The ImageFrame data can be packed or padded. The pixel data will be copied
// to the Tensor in row-major order.
//
// Example config:
// node {
// calculator: "ImageFrameToTensorCalculator"
// input_stream: "scaled_frames"
// output_stream: "video_tensors"
// }
class ImageFrameToTensorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
ImageFrameToTensorCalculatorOptions options_;
};
REGISTER_CALCULATOR(ImageFrameToTensorCalculator);
::mediapipe::Status ImageFrameToTensorCalculator::GetContract(
CalculatorContract* cc) {
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<ImageFrame>(
// ImageFrame frame.
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output TensorFlow Tensor.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<ImageFrameToTensorCalculatorOptions>();
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageFrameToTensorCalculator::Process(
CalculatorContext* cc) {
const Packet& input_item = cc->Inputs().Index(0).Value();
RET_CHECK(!input_item.IsEmpty()) << "Input cannot be empty.";
// Extract the ImageFrame and metadata from the input packet.
const ImageFrame& video_frame = input_item.Get<ImageFrame>();
const int bytes_per_pixel = video_frame.ByteDepth();
std::unique_ptr<tf::Tensor> tensor;
if (options_.has_data_type()) {
RET_CHECK_EQ(bytes_per_pixel, 1) << "Unsupported image format ("
<< bytes_per_pixel << " bytes per pixel)";
const tf::DataType data_type = options_.data_type();
RET_CHECK_EQ(data_type, tf::DT_FLOAT)
<< "Unsupported data type " << data_type;
RET_CHECK_GT(options_.stddev(), 0.0f);
tensor = ImageFrameToNormalizedTensor(video_frame, options_.mean(),
options_.stddev());
} else {
const int height = video_frame.Height();
const int width = video_frame.Width();
const int num_channels = video_frame.NumberOfChannels();
const int num_components = width * height * num_channels;
tf::TensorShape tensor_shape({height, width, num_channels});
// Use uint8 uint16, or float as the TF type depending on bpp of ImageFrame.
tf::DataType data_type;
if (bytes_per_pixel == 1) {
data_type = tf::DT_UINT8;
} else if (bytes_per_pixel == 2) {
data_type = tf::DT_UINT16;
} else if (bytes_per_pixel == 4) {
data_type = tf::DT_FLOAT;
} else {
return ::mediapipe::InvalidArgumentError(absl::StrCat(
"Unsupported image format (", bytes_per_pixel, " bytes per pixel)"));
}
// This failure should never trigger, but it protects the code against
// internal TF changes.
RET_CHECK(tf::DataTypeCanUseMemcpy(data_type))
<< "Tensor data type does not support memcpy (type=" << data_type
<< ")";
// Create the output tensor.
tensor = ::absl::make_unique<tf::Tensor>(data_type, tensor_shape);
// Copy pixel data from the ImageFrame to the tensor.
if (data_type == tf::DT_UINT8) {
uint8* dst = tensor->flat<uint8>().data();
video_frame.CopyToBuffer(dst, num_components);
} else if (data_type == tf::DT_UINT16) {
uint16* dst = tensor->flat<uint16>().data();
video_frame.CopyToBuffer(dst, num_components);
} else {
float* dst = tensor->flat<float>().data();
video_frame.CopyToBuffer(dst, num_components);
}
}
cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,37 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "tensorflow/core/framework/types.proto";
message ImageFrameToTensorCalculatorOptions {
extend CalculatorOptions {
optional ImageFrameToTensorCalculatorOptions ext = 120603667;
}
// If set, the output tensor will be of data type specified by this field.
// Otherwise, the output tensor data type is equal to that of the input image
// frame.
optional tensorflow.DataType data_type = 1;
// If set, the output tensor T is equal to (F - mean * J) / stddev, where F
// and J are the input image frame and the all-ones matrix of the same size,
// respectively. Otherwise, T is equal to F.
optional float mean = 2;
optional float stddev = 3;
}

View File

@ -0,0 +1,457 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = tensorflow;
using RandomEngine = std::mt19937_64;
const uint8 kGray8 = 42;
const uint16 kGray16 = 4242;
const float kFloat = 42.0;
const uint kRed = 255;
const uint kGreen = 36;
const uint kBlue = 156;
const uint kAlpha = 42;
const int kFixedNoiseWidth = 3;
const int kFixedNoiseHeight = 2;
const uint8 kFixedNoiseData[kFixedNoiseWidth * kFixedNoiseHeight * 3] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 123, 213, 156, 9, 10, 11, 255, 0, 128};
class ImageFrameToTensorCalculatorTest : public ::testing::Test {
protected:
// Set image_frame to a constant per-channel pix_value.
template <class T>
void SetToColor(const T* pix_value, ImageFrame* image_frame) {
const int cols = image_frame->Width();
const int rows = image_frame->Height();
const int channels = image_frame->NumberOfChannels();
const int width_padding =
image_frame->WidthStep() / (sizeof(T)) - cols * channels;
T* pixel = reinterpret_cast<T*>(image_frame->MutablePixelData());
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < cols; ++col) {
for (int channel = 0; channel < channels; ++channel) {
pixel[channel] = pix_value[channel];
}
pixel += channels;
}
pixel += width_padding;
}
}
// Adds a packet with a solid red 8-bit RGB ImageFrame.
void AddRGBFrame(int width, int height) {
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::SRGB, width, height);
const uint8 color[] = {kRed, kGreen, kBlue};
SetToColor<uint8>(color, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with a solid red 8-bit RGBA ImageFrame.
void AddRGBAFrame(int width, int height) {
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::SRGBA, width, height);
const uint8 color[] = {kRed, kGreen, kBlue, kAlpha};
SetToColor<uint8>(color, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with a solid GRAY8 ImageFrame.
void AddGray8Frame(int width, int height) {
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::GRAY8, width, height);
const uint8 gray[] = {kGray8};
SetToColor<uint8>(gray, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with a solid GRAY16 ImageFrame.
void AddGray16Frame(int width, int height) {
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::GRAY16, width, height, 1);
const uint16 gray[] = {kGray16};
SetToColor<uint16>(gray, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with a solid VEC32F1 ImageFrame.
void AddFloatFrame(int width, int height) {
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::VEC32F1, width, height, 1);
const float gray[] = {kFloat};
SetToColor<float>(gray, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with an 8-bit RGB ImageFrame containing pre-determined noise.
void AddFixedNoiseRGBFrame() {
auto image_frame = ::absl::make_unique<ImageFrame>(
ImageFormat::SRGB, kFixedNoiseWidth, kFixedNoiseHeight);
// Copy fixed noise data into the ImageFrame.
const uint8* src = kFixedNoiseData;
uint8* pixels = image_frame->MutablePixelData();
for (int y = 0; y < kFixedNoiseHeight; ++y) {
uint8* row = pixels + y * image_frame->WidthStep();
std::memcpy(row, src, kFixedNoiseWidth * 3);
src += kFixedNoiseWidth * 3;
}
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
// Adds a packet with an 8-bit RGB ImageFrame containing random noise.
void AddRandomRGBFrame(int width, int height, uint32 seed) {
RandomEngine random(seed);
std::uniform_int_distribution<int> uniform_dist{
0, std::numeric_limits<uint8_t>::max()};
auto image_frame =
::absl::make_unique<ImageFrame>(ImageFormat::SRGB, width, height);
// Copy "noisy data" into the ImageFrame.
const int num_components_per_row = width * image_frame->NumberOfChannels();
uint8* pixels = image_frame->MutablePixelData();
for (int y = 0; y < kFixedNoiseHeight; ++y) {
uint8* p = pixels + y * image_frame->WidthStep();
for (int i = 0; i < num_components_per_row; ++i) {
p[i] = uniform_dist(random);
}
}
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(ImageFrameToTensorCalculatorTest, SolidRedRGBFrame) {
// Check two widths to cover packed and padded ImageFrame.
const int num_widths = 2;
const int widths[num_widths] = {10, 24};
const int height = 5;
for (int width_index = 0; width_index < num_widths; ++width_index) {
const int width = widths[width_index];
const int num_pixels = width * height;
// Run the calculator and verify that one output is generated.
runner_ = ::absl::make_unique<CalculatorRunner>(
"ImageFrameToTensorCalculator", "", 1, 1, 0);
AddRGBFrame(width, height);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT8, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(3, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const uint8* pixels =
reinterpret_cast<const uint8*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kRed, pixels[0]);
ASSERT_EQ(kGreen, pixels[1]);
ASSERT_EQ(kBlue, pixels[2]);
pixels += 3;
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, SolidRedRGBAFrame) {
// Check two widths to cover packed and padded ImageFrame.
const int num_widths = 2;
const int widths[num_widths] = {10, 24};
const int height = 5;
for (int width_index = 0; width_index < num_widths; ++width_index) {
const int width = widths[width_index];
const int num_pixels = width * height;
// Run the calculator and verify that one output is generated.
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddRGBAFrame(width, height);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT8, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(4, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const uint8* pixels =
reinterpret_cast<const uint8*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kRed, pixels[0]);
ASSERT_EQ(kGreen, pixels[1]);
ASSERT_EQ(kBlue, pixels[2]);
ASSERT_EQ(kAlpha, pixels[3]);
pixels += 4;
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, SolidGray8Frame) {
// Check two widths to cover packed and padded ImageFrame.
const int num_widths = 2;
const int widths[num_widths] = {10, 24};
const int height = 5;
for (int width_index = 0; width_index < num_widths; ++width_index) {
const int width = widths[width_index];
const int num_pixels = width * height;
// Run the calculator and verify that one output is generated.
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddGray8Frame(width, height);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT8, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(1, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const uint8* pixels =
reinterpret_cast<const uint8*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kGray8, pixels[0]);
++pixels;
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, SolidGray16Frame) {
// Check two widths to cover packed and padded ImageFrame.
const int num_widths = 2;
const int widths[num_widths] = {10, 24};
const int height = 5;
for (int width_index = 0; width_index < num_widths; ++width_index) {
const int width = widths[width_index];
const int num_pixels = width * height;
// Run the calculator and verify that one output is generated.
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddGray16Frame(width, height);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT16, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(1, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const uint16* pixels =
reinterpret_cast<const uint16*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kGray16, pixels[0]);
++pixels;
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, SolidFloatFrame) {
// Check two widths to cover packed and padded ImageFrame.
const int num_widths = 2;
const int widths[num_widths] = {10, 24};
const int height = 5;
for (int width_index = 0; width_index < num_widths; ++width_index) {
const int width = widths[width_index];
const int num_pixels = width * height;
// Run the calculator and verify that one output is generated.
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddFloatFrame(width, height);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_FLOAT, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(1, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const float* pixels =
reinterpret_cast<const float*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kFloat, pixels[0]);
++pixels;
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, FixedNoiseRGBFrame) {
// Run the calculator and verify that one output is generated.
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddFixedNoiseRGBFrame();
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT8, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(kFixedNoiseHeight, shape.dim_size(0));
ASSERT_EQ(kFixedNoiseWidth, shape.dim_size(1));
ASSERT_EQ(3, shape.dim_size(2));
// Verify that the data in the tensor is correct.
const int num_pixels = kFixedNoiseWidth * kFixedNoiseHeight;
const uint8* pixels =
reinterpret_cast<const uint8*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
ASSERT_EQ(kFixedNoiseData[i], pixels[i]);
}
}
TEST_F(ImageFrameToTensorCalculatorTest, RandomRGBFrame) {
// Run the calculator and verify that one output is generated.
const uint32 seed = 1234;
const int height = 2;
for (int width = 1; width <= 33; ++width) {
runner_.reset(
new CalculatorRunner("ImageFrameToTensorCalculator", "", 1, 1, 0));
AddRandomRGBFrame(width, height, seed);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that tensor is 3-dimensional
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_UINT8, tensor.dtype());
// Verify that each dimension has the correct size / number of channels.
const tf::TensorShape& shape = tensor.shape();
ASSERT_EQ(height, shape.dim_size(0));
ASSERT_EQ(width, shape.dim_size(1));
ASSERT_EQ(3, shape.dim_size(2));
// Verify that the data in the tensor is correct.
RandomEngine random(seed);
std::uniform_int_distribution<int> uniform_dist{
0, std::numeric_limits<uint8_t>::max()};
const int num_pixels = width * height;
const uint8* pixels =
reinterpret_cast<const uint8*>(tensor.tensor_data().data());
for (int i = 0; i < num_pixels; ++i) {
const uint8 expected = uniform_dist(random);
ASSERT_EQ(expected, pixels[i]);
}
}
}
TEST_F(ImageFrameToTensorCalculatorTest, FixedRGBFrameWithMeanAndStddev) {
runner_ = ::absl::make_unique<CalculatorRunner>(
"ImageFrameToTensorCalculator",
"[mediapipe.ImageFrameToTensorCalculatorOptions.ext]"
"{data_type:DT_FLOAT mean:128.0 stddev:128.0}",
1, 1, 0);
// Create a single pixel image of fixed color #0080ff.
auto image_frame = ::absl::make_unique<ImageFrame>(ImageFormat::SRGB, 1, 1);
const uint8 color[] = {0, 128, 255};
SetToColor<uint8>(color, image_frame.get());
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(image_frame.release()).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner_->Run());
const auto& tensor = runner_->Outputs().Index(0).packets[0].Get<tf::Tensor>();
EXPECT_EQ(tensor.dtype(), tf::DT_FLOAT);
ASSERT_EQ(tensor.dims(), 3);
EXPECT_EQ(tensor.shape().dim_size(0), 1);
EXPECT_EQ(tensor.shape().dim_size(1), 1);
EXPECT_EQ(tensor.shape().dim_size(2), 3);
const float* actual = tensor.flat<float>().data();
EXPECT_EQ(actual[0], -1.0f); // ( 0 - 128) / 128
EXPECT_EQ(actual[1], 0.0f); // (128 - 128) / 128
EXPECT_EQ(actual[2], 127.0f / 128.0f); // (255 - 128) / 128
}
} // namespace mediapipe

View File

@ -0,0 +1,154 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/profiler/circular_buffer.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace mediapipe {
namespace tf = tensorflow;
// Given an input stream of tensors, concatenates the tensors over timesteps.
// The concatenated output tensors can be specified to have overlap between
// output timesteps. The tensors are concatenated along the first dimension, and
// a flag controls whether a new first dimension is inserted before
// concatenation.
//
// Currently, the number of tensors output will be buffer_size less than the
// number of input tensors because no padding is implemented and only full
// buffers are output.
//
// The timestamp of the output batch will match the timestamp of the first
// tensor in that batch by default. (e.g. when buffer_size frames are added, the
// output tensor will have the timestamp of the first input.). This behavior can
// be adjusted by the timestamp_offset option.
//
// Example config:
// node {
// calculator: "LappedTensorBufferCalculator"
// input_stream: "input_tensor"
// output_stream: "output_tensor"
// options {
// [mediapipe.LappedTensorBufferCalculatorOptions.ext] {
// buffer_size: 2
// overlap: 1
// add_batch_dim_to_tensors: false
// }
// }
// }
class LappedTensorBufferCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// Adds a batch dimension to the input tensor if specified in the calculator
// options.
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor);
int steps_until_output_;
std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_;
std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_;
LappedTensorBufferCalculatorOptions options_;
};
REGISTER_CALCULATOR(LappedTensorBufferCalculator);
::mediapipe::Status LappedTensorBufferCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<tf::Tensor>(
// tensorflow::Tensor stream.
);
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output tensorflow::Tensor stream with possibly overlapping steps.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<LappedTensorBufferCalculatorOptions>();
RET_CHECK_LT(options_.overlap(), options_.buffer_size());
RET_CHECK_GE(options_.timestamp_offset(), 0)
<< "Negative timestamp_offset is not allowed.";
RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size())
<< "output_frame_num_offset has to be less than buffer_size.";
timestamp_buffer_ =
absl::make_unique<CircularBuffer<Timestamp>>(options_.buffer_size());
buffer_ =
absl::make_unique<CircularBuffer<tf::Tensor>>(options_.buffer_size());
steps_until_output_ = options_.buffer_size();
return ::mediapipe::OkStatus();
}
::mediapipe::Status LappedTensorBufferCalculator::Process(
CalculatorContext* cc) {
// These are cheap, shallow copies.
tensorflow::Tensor input_tensor(
cc->Inputs().Index(0).Get<tensorflow::Tensor>());
if (options_.add_batch_dim_to_tensors()) {
RET_CHECK_OK(AddBatchDimension(&input_tensor));
}
buffer_->push_back(input_tensor);
timestamp_buffer_->push_back(cc->InputTimestamp());
--steps_until_output_;
if (steps_until_output_ <= 0) {
auto concatenated = ::absl::make_unique<tf::Tensor>();
const tf::Status concat_status = tf::tensor::Concat(
std::vector<tf::Tensor>(buffer_->begin(), buffer_->end()),
concatenated.get());
RET_CHECK(concat_status.ok()) << concat_status.ToString();
cc->Outputs().Index(0).Add(
concatenated.release(),
timestamp_buffer_->Get(options_.timestamp_offset()));
steps_until_output_ = options_.buffer_size() - options_.overlap();
}
return ::mediapipe::OkStatus();
}
// Adds a batch dimension to the input tensor if specified in the calculator
// options.
::mediapipe::Status LappedTensorBufferCalculator::AddBatchDimension(
tf::Tensor* input_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(input_tensor->shape());
new_shape.InsertDim(0, 1);
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
<< "Could not add 0th dimension to tensor without changing its shape."
<< " Current shape: " << input_tensor->shape().DebugString();
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,48 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message LappedTensorBufferCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional LappedTensorBufferCalculatorOptions ext = 175222228;
}
// The output tensor will have this many tensors concatenated together along
// their first dimension.
optional int32 buffer_size = 1;
// The overlap determines how many input tensors are shared between frames.
// Because the input tensors may have a non-singleton first dimension, this
// is not necessarily the number of overlapping entries in the first
// dimension.
optional int32 overlap = 2;
// If true, inserts a singleton first dimension before concatenating the
// tensors together.
optional bool add_batch_dim_to_tensors = 3 [default = false];
// Timestamp offset for output batch. The valid range is [0, buffer_size).
// The timestamp of the output tensor will match the timestamp of the input
// correspeonding to the offset. For example, setting to 0 will output at the
// timestamp matching the first input tensor. Setting the timestamp_offset to
// int((N-1) / 2) output at the timestamp matching the middle input tensor.
// This is useful for aligning the timestamp to be centered on the input
// range.
optional int32 timestamp_offset = 4 [default = 0];
}

View File

@ -0,0 +1,240 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class LappedTensorBufferCalculatorTest : public ::testing::Test {
protected:
void SetUpCalculator(int buffer_size, int overlap, bool add_dim,
int timestamp_offset) {
CalculatorGraphConfig::Node config;
config.set_calculator("LappedTensorBufferCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
auto options = config.mutable_options()->MutableExtension(
LappedTensorBufferCalculatorOptions::ext);
options->set_buffer_size(buffer_size);
options->set_overlap(overlap);
if (add_dim) {
options->set_add_batch_dim_to_tensors(true);
}
options->set_timestamp_offset(timestamp_offset);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(LappedTensorBufferCalculatorTest, OneToOne) {
SetUpCalculator(1, 0, false, 0);
int num_timesteps = 3;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps, output_packets.size());
for (int i = 0; i < num_timesteps; ++i) {
float value = output_packets[i].Get<tf::Tensor>().tensor<float, 1>()(0);
ASSERT_NEAR(i, value, 0.0001);
}
}
TEST_F(LappedTensorBufferCalculatorTest, OneToTwo) {
int buffer_size = 2;
int overlap = 1;
bool add_dim = false;
SetUpCalculator(buffer_size, overlap, add_dim, 0);
int num_timesteps = 3;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size());
for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) {
for (int j = 0; j < buffer_size; ++j) {
float value = output_packets[i].Get<tf::Tensor>().tensor<float, 1>()(j);
ASSERT_NEAR(i + j, value, 0.0001);
}
}
}
TEST_F(LappedTensorBufferCalculatorTest, OneToThree) {
int buffer_size = 3;
int overlap = 2;
bool add_dim = false;
SetUpCalculator(buffer_size, overlap, add_dim, 0);
int num_timesteps = 3;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size());
for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) {
for (int j = 0; j < buffer_size; ++j) {
float value = output_packets[i].Get<tf::Tensor>().tensor<float, 1>()(j);
ASSERT_NEAR(i + j, value, 0.0001);
}
}
}
TEST_F(LappedTensorBufferCalculatorTest, OneToThreeSkip) {
int buffer_size = 3;
int overlap = 1;
bool add_dim = false;
SetUpCalculator(buffer_size, overlap, add_dim, 0);
int num_timesteps = 3;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size());
for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) {
for (int j = 0; j < buffer_size; ++j) {
float value = output_packets[i].Get<tf::Tensor>().tensor<float, 1>()(j);
ASSERT_NEAR((i * 2) + j, value, 0.0001);
}
}
}
TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatch) {
int buffer_size = 3;
int overlap = 2;
bool add_dim = true;
SetUpCalculator(buffer_size, overlap, add_dim, 0);
int num_timesteps = 3;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size());
for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) {
for (int j = 0; j < buffer_size; ++j) {
float value =
output_packets[i].Get<tf::Tensor>().tensor<float, 2>()(j, 0);
ASSERT_NEAR(i + j, value, 0.0001);
}
}
}
TEST_F(LappedTensorBufferCalculatorTest, NegativeTimestampOffsetFails) {
int buffer_size = 16;
int overlap = 15;
bool add_dim = true;
int timestamp_offset = -7;
SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset);
int num_timesteps = 20;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_FALSE(runner_->Run().ok());
}
TEST_F(LappedTensorBufferCalculatorTest, OutOfRangeTimestampOffsetFails) {
int buffer_size = 16;
int overlap = 15;
bool add_dim = true;
int timestamp_offset = buffer_size;
SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset);
int num_timesteps = 20;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_FALSE(runner_->Run().ok());
}
TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatchTimestampOffset) {
int buffer_size = 16;
int overlap = 15;
bool add_dim = true;
int timestamp_offset = 7;
SetUpCalculator(buffer_size, overlap, add_dim, timestamp_offset);
int num_timesteps = 20;
for (int i = 0; i < num_timesteps; ++i) {
auto input = ::absl::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
input->tensor<float, 1>()(0) = i;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(i)));
}
ASSERT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(num_timesteps - buffer_size + 1, output_packets.size());
for (int i = 0; i < num_timesteps - buffer_size + 1; ++i) {
for (int j = 0; j < buffer_size; ++j) {
int64 value = output_packets[i].Timestamp().Value();
ASSERT_EQ(i + timestamp_offset, value);
}
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,157 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace {
::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
TimeSeriesHeader* header) {
CHECK(header);
if (header_packet.IsEmpty()) {
return ::mediapipe::UnknownError("No header found.");
}
if (!header_packet.ValidateAsType<TimeSeriesHeader>().ok()) {
return ::mediapipe::UnknownError(
"Packet does not contain TimeSeriesHeader.");
}
*header = header_packet.Get<TimeSeriesHeader>();
if (header->has_sample_rate() && header->sample_rate() >= 0 &&
header->has_num_channels() && header->num_channels() >= 0) {
return ::mediapipe::OkStatus();
} else {
std::string error_message =
"TimeSeriesHeader is missing necessary fields: "
"sample_rate or num_channels, or one of their values is negative. ";
#ifndef MEDIAPIPE_MOBILE
absl::StrAppend(&error_message, "Got header:\n",
header->ShortDebugString());
#endif
return ::mediapipe::InvalidArgumentError(error_message);
}
}
} // namespace
namespace tf = tensorflow;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
RowMajorMatrixXf;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
ColMajorMatrixXf;
// Converts an input Matrix into a 2D or 3D tf::Tensor.
//
// The calculator expects one input (a packet containing a Matrix) and
// generates one output (a packet containing a tf::Tensor containing the same
// data). The output tensor will be 2D with dimensions corresponding to the
// input matrix, while it will be 3D if add_trailing_dimension is set to true.
// The option for making the tensor be 3D is useful for using audio and image
// features for training multimodal models, so that the number of tensor
// dimensions match up. It will hold DT_FLOAT values.
//
// Example config:
// node {
// calculator: "MatrixToTensorCalculator"
// input_stream: "matrix_features"
// output_stream: "tensor_features"
// }
class MatrixToTensorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
MatrixToTensorCalculatorOptions options_;
};
REGISTER_CALCULATOR(MatrixToTensorCalculator);
::mediapipe::Status MatrixToTensorCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<Matrix>(
// Input Matrix stream with optional TimeSeriesHeader.
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output stream with data as tf::Tensor and the same
// TimeSeriesHeader as the input (or no header if the input has no
// header).
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) {
// If the input is part of a time series, then preserve the header so that
// downstream consumers can access the sample rate if needed.
options_ = cc->Options<MatrixToTensorCalculatorOptions>();
auto input_header = ::absl::make_unique<TimeSeriesHeader>();
const ::mediapipe::Status header_status = FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), input_header.get());
if (header_status.ok()) {
cc->Outputs().Index(0).SetHeader(Adopt(input_header.release()));
}
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(mediapipe::TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) {
const Matrix& matrix = cc->Inputs().Index(0).Get<Matrix>();
tf::TensorShape tensor_shape;
if (options_.transpose()) {
tensor_shape = tf::TensorShape({matrix.cols(), matrix.rows()});
} else {
tensor_shape = tf::TensorShape({matrix.rows(), matrix.cols()});
}
auto tensor = ::absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
float* tensor_data = tensor->flat<float>().data();
if (options_.transpose()) {
auto matrix_map =
Eigen::Map<ColMajorMatrixXf>(tensor_data, matrix.rows(), matrix.cols());
matrix_map = matrix;
} else {
auto matrix_map =
Eigen::Map<RowMajorMatrixXf>(tensor_data, matrix.rows(), matrix.cols());
matrix_map = matrix;
}
if (options_.add_trailing_dimension()) {
tf::TensorShape new_shape(tensor_shape);
new_shape.AddDim(1 /* size of dimension */);
RET_CHECK(tensor->CopyFrom(*tensor, new_shape))
<< "Could not add dimension to tensor without changing its shape."
<< " Current shape: " << tensor->shape().DebugString();
}
cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message MatrixToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional MatrixToTensorCalculatorOptions ext = 130781699;
}
optional bool transpose = 1 [default = false];
// Adds a 3rd dimension of size 1 when this is set to true.
optional bool add_trailing_dimension = 2 [default = false];
}

View File

@ -0,0 +1,165 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace {
constexpr char kTransposeOptionsString[] =
"[mediapipe.MatrixToTensorCalculatorOptions.ext]: {"
"transpose: True}";
constexpr char kAddDimensionOptionsString[] =
"[mediapipe.MatrixToTensorCalculatorOptions.ext]: {"
"add_trailing_dimension: True}";
} // namespace
namespace tf = tensorflow;
using RandomEngine = std::mt19937_64;
const uint32 kSeed = 1234;
const int kNumSizes = 8;
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
{5, 3}, {7, 13}, {16, 32}, {101, 2}};
class MatrixToTensorCalculatorTest : public ::testing::Test {
protected:
// Adds a packet with a matrix filled with random values in [0,1].
void AddRandomMatrix(int num_rows, int num_columns, uint32 seed) {
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
auto matrix = ::absl::make_unique<Matrix>();
matrix->resize(num_rows, num_columns);
for (int y = 0; y < num_rows; ++y) {
for (int x = 0; x < num_columns; ++x) {
(*matrix)(y, x) = uniform_dist(random);
}
}
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(matrix.release()).At(Timestamp(0)));
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(MatrixToTensorCalculatorTest, RandomMatrix) {
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
const int num_rows = sizes[size_index][0];
const int num_columns = sizes[size_index][1];
// Run the calculator and verify that one output is generated.
runner_ = ::absl::make_unique<CalculatorRunner>("MatrixToTensorCalculator",
"", 1, 1, 0);
AddRandomMatrix(num_rows, num_columns, kSeed);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that the packet contains a 2D float tensor.
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(2, tensor.dims());
ASSERT_EQ(tf::DT_FLOAT, tensor.dtype());
// Verify that the data is correct.
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
const auto matrix = tensor.matrix<float>();
for (int y = 0; y < num_rows; ++y) {
for (int x = 0; x < num_columns; ++x) {
const float expected = uniform_dist(random);
ASSERT_EQ(expected, matrix(y, x));
}
}
}
}
TEST_F(MatrixToTensorCalculatorTest, RandomMatrixTranspose) {
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
const int num_rows = sizes[size_index][0];
const int num_columns = sizes[size_index][1];
// Run the calculator and verify that one output is generated.
runner_ = ::absl::make_unique<CalculatorRunner>(
"MatrixToTensorCalculator", kTransposeOptionsString, 1, 1, 0);
AddRandomMatrix(num_rows, num_columns, kSeed);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that the packet contains a 2D float tensor.
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(2, tensor.dims());
ASSERT_EQ(tf::DT_FLOAT, tensor.dtype());
// Verify that the data is correct.
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
const auto matrix = tensor.matrix<float>();
for (int y = 0; y < num_rows; ++y) {
for (int x = 0; x < num_columns; ++x) {
const float expected = uniform_dist(random);
ASSERT_EQ(expected, matrix(x, y));
}
}
}
}
TEST_F(MatrixToTensorCalculatorTest, RandomMatrixAddDimension) {
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
const int num_rows = sizes[size_index][0];
const int num_columns = sizes[size_index][1];
// Run the calculator and verify that one output is generated.
runner_ = ::absl::make_unique<CalculatorRunner>(
"MatrixToTensorCalculator", kAddDimensionOptionsString, 1, 1, 0);
AddRandomMatrix(num_rows, num_columns, kSeed);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
ASSERT_EQ(1, output_packets.size());
// Verify that the packet contains a 3D float tensor.
const tf::Tensor& tensor = output_packets[0].Get<tf::Tensor>();
ASSERT_EQ(3, tensor.dims());
ASSERT_EQ(tf::DT_FLOAT, tensor.dtype());
// Verify that the data is correct.
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
// const auto matrix = tensor.matrix<float>();
const float* tensor_data = tensor.flat<float>().data();
for (int y = 0; y < num_rows; ++y) {
for (int x = 0; x < num_columns; ++x) {
const float expected = uniform_dist(random);
ASSERT_EQ(expected, tensor_data[y * num_columns + x]);
}
}
}
}
} // namespace mediapipe

View File

@ -0,0 +1,239 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/util/tensor_to_detection.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace mediapipe {
class CalculatorOptions;
} // namespace mediapipe
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
const char kNumDetections[] = "NUM_DETECTIONS";
const char kBoxes[] = "BOXES";
const char kScores[] = "SCORES";
const char kClasses[] = "CLASSES";
const char kDetections[] = "DETECTIONS";
const char kKeypoints[] = "KEYPOINTS";
const char kMasks[] = "MASKS";
const char kLabelMap[] = "LABELMAP";
const int kNumCoordsPerBox = 4;
} // namespace
// Takes object detection results and converts them into MediaPipe Detections.
//
// Inputs are assumed to be tensors of the form:
// `num_detections` : float32 scalar tensor indicating the number of valid
// detections.
// `detection_boxes` : float32 tensor of the form [num_boxes, 4]. Format for
// coordinates is {ymin, xmin, ymax, xmax}.
// `detection_scores` : float32 tensor of the form [num_boxes].
// `detection_classes` : float32 tensor of the form [num_boxes].
// `detection_keypoints`: float32 tensor of the form
// [num_boxes, num_keypoints, 2].
// `detection_masks` : float32 tensor of the form
// [num_boxes, height, width].
//
// These are generated according to the Vale object detector model exporter,
// which may be found in
// image/understanding/object_detection/export_inference_graph.py
//
// By default, the output Detections store label ids (integers) for each
// detection. Optionally, a label map (of the form std::map<int, std::string>
// mapping label ids to label names as strings) can be made available as an
// input side packet, in which case the output Detections store
// labels as their associated std::string provided by the label map.
//
// Usage example:
// node {
// calculator: "ObjectDetectionTensorsToDetectionsCalculator"
// input_stream: "BOXES:detection_boxes_tensor"
// input_stream: "SCORES:detection_scores_tensor"
// input_stream: "CLASSES:detection_classes_tensor"
// input_stream: "NUM_DETECTIONS:num_detections_tensor"
// output_stream: "DETECTIONS:detections"
// options: {
// [mediapipe.ObjectDetectionsTensorToDetectionsCalculatorOptions.ext]: {
// tensor_dim_to_squeeze: 0
// }
// }
// }
class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase {
public:
ObjectDetectionTensorsToDetectionsCalculator() = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kBoxes).Set<tf::Tensor>();
cc->Inputs().Tag(kScores).Set<tf::Tensor>();
if (cc->Inputs().HasTag(kNumDetections)) {
cc->Inputs().Tag(kNumDetections).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kClasses)) {
cc->Inputs().Tag(kClasses).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kKeypoints)) {
cc->Inputs().Tag(kKeypoints).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kMasks)) {
cc->Inputs().Tag(kMasks).Set<tf::Tensor>();
const auto& calculator_options =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>();
float mask_threshold = calculator_options.mask_threshold();
if (!(mask_threshold >= 0.0 && mask_threshold <= 1.0)) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "mask_threshold must be in range [0.0, 1.0]";
}
}
cc->Outputs().Tag(kDetections).Set<std::vector<Detection>>();
if (cc->InputSidePackets().HasTag(kLabelMap)) {
cc->InputSidePackets()
.Tag(kLabelMap)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag(kLabelMap)) {
label_map_ = GetFromUniquePtr<std::map<int, std::string>>(
cc->InputSidePackets().Tag(kLabelMap));
}
const auto& tensor_dim_to_squeeze_field =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>()
.tensor_dim_to_squeeze();
tensor_dims_to_squeeze_ = std::vector<int32>(
tensor_dim_to_squeeze_field.begin(), tensor_dim_to_squeeze_field.end());
std::sort(tensor_dims_to_squeeze_.rbegin(), tensor_dims_to_squeeze_.rend());
cc->SetOffset(0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const auto& options =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>();
tf::Tensor input_num_detections_tensor =
tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0}));
if (cc->Inputs().HasTag(kClasses)) {
ASSIGN_OR_RETURN(
input_num_detections_tensor,
MaybeSqueezeDims(kNumDetections,
cc->Inputs().Tag(kNumDetections).Get<tf::Tensor>()));
}
if (input_num_detections_tensor.dtype() != tf::DT_INT32) {
RET_CHECK_EQ(input_num_detections_tensor.dtype(), tf::DT_FLOAT);
}
ASSIGN_OR_RETURN(
auto input_boxes_tensor,
MaybeSqueezeDims(kBoxes, cc->Inputs().Tag(kBoxes).Get<tf::Tensor>()));
RET_CHECK_EQ(input_boxes_tensor.dtype(), tf::DT_FLOAT);
ASSIGN_OR_RETURN(
auto input_scores_tensor,
MaybeSqueezeDims(kScores, cc->Inputs().Tag(kScores).Get<tf::Tensor>()));
RET_CHECK_EQ(input_scores_tensor.dtype(), tf::DT_FLOAT);
tf::Tensor input_classes_tensor =
tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0}));
if (cc->Inputs().HasTag(kClasses)) {
ASSIGN_OR_RETURN(
input_classes_tensor,
MaybeSqueezeDims(kClasses,
cc->Inputs().Tag(kClasses).Get<tf::Tensor>()));
}
RET_CHECK_EQ(input_classes_tensor.dtype(), tf::DT_FLOAT);
auto output_detections = absl::make_unique<std::vector<Detection>>();
const tf::Tensor& input_keypoints_tensor =
cc->Inputs().HasTag(kKeypoints)
? cc->Inputs().Tag(kKeypoints).Get<tf::Tensor>()
: tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0}));
const tf::Tensor& input_masks_tensor =
cc->Inputs().HasTag(kMasks)
? cc->Inputs().Tag(kMasks).Get<tf::Tensor>()
: tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0}));
RET_CHECK_EQ(input_masks_tensor.dtype(), tf::DT_FLOAT);
const std::map<int, std::string> label_map =
(label_map_ == nullptr) ? std::map<int, std::string>{} : *label_map_;
RET_CHECK_OK(TensorsToDetections(
input_num_detections_tensor, input_boxes_tensor, input_scores_tensor,
input_classes_tensor, input_keypoints_tensor, input_masks_tensor,
options.mask_threshold(), label_map, output_detections.get()));
cc->Outputs()
.Tag(kDetections)
.Add(output_detections.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
private:
std::map<int, std::string>* label_map_;
std::vector<int32> tensor_dims_to_squeeze_;
::mediapipe::StatusOr<tf::Tensor> MaybeSqueezeDims(
const std::string& tensor_tag, const tf::Tensor& input_tensor) {
if (tensor_dims_to_squeeze_.empty()) {
return input_tensor;
}
tf::TensorShape tensor_shape = input_tensor.shape();
for (const int dim : tensor_dims_to_squeeze_) {
RET_CHECK_GT(tensor_shape.dims(), dim)
<< "Dimension " << dim
<< " does not exist in input tensor with num dimensions "
<< input_tensor.dims() << " dims";
RET_CHECK_EQ(tensor_shape.dim_size(dim), 1)
<< "Cannot remove dimension " << dim << " with size "
<< tensor_shape.dim_size(dim);
tensor_shape.RemoveDim(dim);
}
tf::Tensor output_tensor;
RET_CHECK(output_tensor.CopyFrom(input_tensor, tensor_shape));
return std::move(output_tensor);
}
};
REGISTER_CALCULATOR(ObjectDetectionTensorsToDetectionsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,35 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The option proto for the ObjectDetectionsTensorToDetectionsCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message ObjectDetectionsTensorToDetectionsCalculatorOptions {
extend .mediapipe.CalculatorOptions {
optional ObjectDetectionsTensorToDetectionsCalculatorOptions ext =
192676232;
}
// The threshold used to compute the binary segmentation mask.
optional float mask_threshold = 1 [default = 0.0];
// The specific singleton dimensions to squeeze (remove). The calculator can
// only removes dimensions of size 1.
repeated int32 tensor_dim_to_squeeze = 2;
}

View File

@ -0,0 +1,355 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
const char kNumDetections[] = "NUM_DETECTIONS";
const char kBoxes[] = "BOXES";
const char kScores[] = "SCORES";
const char kClasses[] = "CLASSES";
const char kKeypoints[] = "KEYPOINTS";
const char kDetections[] = "DETECTIONS";
const int kNumBoxes = 3;
const int kNumClasses = 4;
const int kNumCoordsPerBox = 4;
const int kNumKeypointsPerBox = 2;
const int kNumCoordsPerKeypoint = 2;
class ObjectDetectionTensorsToDetectionsCalculatorTest
: public ::testing::Test {
protected:
void SetUp() override { SetUpInputs(); }
void SetUpInputs() {
input_num_detections_ = tf::test::AsTensor<float>({kNumBoxes}, {1});
// {ymin, xmin, ymax, xmax} format.
input_boxes_ =
tf::test::AsTensor<float>({0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.2f, 0.3f,
0.4f, 0.1f, 0.2f, 0.3f, 0.4f},
{kNumBoxes, kNumCoordsPerBox});
input_scores_ = tf::test::AsTensor<float>({0.1f, 0.5f, 1.0f}, {kNumBoxes});
input_scores_for_all_classes_ =
tf::test::AsTensor<float>({0.0f, 0.1f, 0.05f, 0.02f, 0.0f, 0.1f, 0.5f,
0.2f, 0.0f, 0.5f, 0.8f, 1.0f},
{kNumBoxes, kNumClasses});
input_classes_ = tf::test::AsTensor<float>({1.0, 2.0, 3.0}, {kNumBoxes});
input_keypoints_ = tf::test::AsTensor<float>(
{0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f, 0.5f, 0.6f,
0.5f},
{kNumBoxes, kNumKeypointsPerBox, kNumCoordsPerKeypoint});
}
void CreateNodeConfig(CalculatorGraphConfig::Node* node_config) const {
ASSERT_NE(nullptr, node_config);
*node_config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "ObjectDetectionTensorsToDetectionsCalculator"
input_stream: "NUM_DETECTIONS:num_detections"
input_stream: "BOXES:boxes"
input_stream: "SCORES:scores"
input_stream: "CLASSES:classes"
output_stream: "DETECTIONS:detections"
)");
}
void CreateNodeConfigRawTensors(
CalculatorGraphConfig::Node* node_config) const {
ASSERT_NE(nullptr, node_config);
*node_config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "ObjectDetectionTensorsToDetectionsCalculator"
input_stream: "BOXES:raw_detection_boxes"
input_stream: "SCORES:raw_detection_scores"
output_stream: "DETECTIONS:detections"
)");
}
void CreateNodeConfigWithKeypoints(
CalculatorGraphConfig::Node* node_config) const {
ASSERT_NE(nullptr, node_config);
*node_config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "ObjectDetectionTensorsToDetectionsCalculator"
input_stream: "NUM_DETECTIONS:num_detections"
input_stream: "BOXES:boxes"
input_stream: "SCORES:scores"
input_stream: "CLASSES:classes"
input_stream: "KEYPOINTS:keypoints"
output_stream: "DETECTIONS:detections"
)");
}
void SetUpCalculatorRunner() {
CalculatorGraphConfig::Node node_config;
CreateNodeConfig(&node_config);
runner_ = absl::make_unique<CalculatorRunner>(node_config);
}
void SetUpCalculatorRunnerRawTensors() {
CalculatorGraphConfig::Node node_config;
CreateNodeConfigRawTensors(&node_config);
runner_ = absl::make_unique<CalculatorRunner>(node_config);
}
void SetUpCalculatorRunnerWithKeypoints() {
CalculatorGraphConfig::Node node_config;
CreateNodeConfigWithKeypoints(&node_config);
runner_ = absl::make_unique<CalculatorRunner>(node_config);
}
void RunCalculator() {
SetUpCalculatorRunner();
runner_->MutableInputs()
->Tag(kNumDetections)
.packets.push_back(
PointToForeign(&input_num_detections_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kBoxes).packets.push_back(
PointToForeign(&input_boxes_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kScores).packets.push_back(
PointToForeign(&input_scores_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kClasses).packets.push_back(
PointToForeign(&input_classes_).At(Timestamp::PostStream()));
MEDIAPIPE_ASSERT_OK(runner_->Run());
ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size());
}
void RunCalculatorRawTensors() {
SetUpCalculatorRunnerRawTensors();
runner_->MutableInputs()->Tag(kBoxes).packets.push_back(
PointToForeign(&input_boxes_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kScores).packets.push_back(
PointToForeign(&input_scores_for_all_classes_)
.At(Timestamp::PostStream()));
MEDIAPIPE_ASSERT_OK(runner_->Run());
ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size());
}
void RunCalculatorWithKeypoints() {
SetUpCalculatorRunnerWithKeypoints();
runner_->MutableInputs()
->Tag(kNumDetections)
.packets.push_back(
PointToForeign(&input_num_detections_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kBoxes).packets.push_back(
PointToForeign(&input_boxes_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kScores).packets.push_back(
PointToForeign(&input_scores_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kClasses).packets.push_back(
PointToForeign(&input_classes_).At(Timestamp::PostStream()));
runner_->MutableInputs()
->Tag(kKeypoints)
.packets.push_back(
PointToForeign(&input_keypoints_).At(Timestamp::PostStream()));
MEDIAPIPE_ASSERT_OK(runner_->Run());
ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size());
}
void RunCalculatorWithTensorDimensionSqueezing() {
InsertExtraSingltonDim(&input_num_detections_);
InsertExtraSingltonDim(&input_boxes_);
InsertExtraSingltonDim(&input_scores_);
InsertExtraSingltonDim(&input_classes_);
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "ObjectDetectionTensorsToDetectionsCalculator"
input_stream: "NUM_DETECTIONS:num_detections"
input_stream: "BOXES:boxes"
input_stream: "SCORES:scores"
input_stream: "CLASSES:classes"
output_stream: "DETECTIONS:detections"
options: {
[mediapipe.ObjectDetectionsTensorToDetectionsCalculatorOptions
.ext]: { tensor_dim_to_squeeze: 0 }
}
)");
runner_ = absl::make_unique<CalculatorRunner>(node_config);
runner_->MutableInputs()
->Tag(kNumDetections)
.packets.push_back(
PointToForeign(&input_num_detections_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kBoxes).packets.push_back(
PointToForeign(&input_boxes_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kScores).packets.push_back(
PointToForeign(&input_scores_).At(Timestamp::PostStream()));
runner_->MutableInputs()->Tag(kClasses).packets.push_back(
PointToForeign(&input_classes_).At(Timestamp::PostStream()));
MEDIAPIPE_ASSERT_OK(runner_->Run());
ASSERT_EQ(1, runner_->Outputs().Tag(kDetections).packets.size());
}
void InsertExtraSingltonDim(tf::Tensor* tensor) {
tf::TensorShape new_shape(tensor->shape());
new_shape.InsertDim(0, 1);
ASSERT_TRUE(tensor->CopyFrom(*tensor, new_shape));
}
std::unique_ptr<CalculatorRunner> runner_;
tf::Tensor input_num_detections_;
tf::Tensor input_boxes_;
tf::Tensor input_scores_;
tf::Tensor input_scores_for_all_classes_;
tf::Tensor input_classes_;
tf::Tensor input_keypoints_;
};
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest, OutputsDetections) {
RunCalculator();
EXPECT_EQ(kNumBoxes, runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>()
.size());
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
OutputsDetectionsFromRawTensors) {
RunCalculatorRawTensors();
EXPECT_EQ(kNumBoxes, runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>()
.size());
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
OutputsDetectionsWithKeypoints) {
RunCalculatorWithKeypoints();
EXPECT_EQ(kNumBoxes, runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>()
.size());
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
OutputsDetectionsWithCorrectValues) {
RunCalculator();
const std::vector<Detection> detections = runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>();
EXPECT_EQ(kNumBoxes, detections.size());
for (const auto& detection : detections) {
LocationData::RelativeBoundingBox relative_bbox =
detection.location_data().relative_bounding_box();
EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin());
EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin());
EXPECT_FLOAT_EQ(0.2, relative_bbox.width());
EXPECT_FLOAT_EQ(0.2, relative_bbox.height());
}
EXPECT_FLOAT_EQ(0.1f, detections[0].score(0));
EXPECT_FLOAT_EQ(0.5f, detections[1].score(0));
EXPECT_FLOAT_EQ(1.0f, detections[2].score(0));
EXPECT_EQ(1, detections[0].label_id(0));
EXPECT_EQ(2, detections[1].label_id(0));
EXPECT_EQ(3, detections[2].label_id(0));
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
OutputsDetectionsFromRawTensorsWithCorrectValues) {
RunCalculatorRawTensors();
const std::vector<Detection> detections = runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>();
EXPECT_EQ(kNumBoxes, detections.size());
for (const auto& detection : detections) {
LocationData::RelativeBoundingBox relative_bbox =
detection.location_data().relative_bounding_box();
EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin());
EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin());
EXPECT_FLOAT_EQ(0.2, relative_bbox.width());
EXPECT_FLOAT_EQ(0.2, relative_bbox.height());
}
EXPECT_FLOAT_EQ(0.1f, detections[0].score(0));
EXPECT_FLOAT_EQ(0.5f, detections[1].score(0));
EXPECT_FLOAT_EQ(1.0f, detections[2].score(0));
EXPECT_EQ(1, detections[0].label_id(0));
EXPECT_EQ(2, detections[1].label_id(0));
EXPECT_EQ(3, detections[2].label_id(0));
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
OutputsDetectionsWithKeypointsAndCorrectValues) {
RunCalculatorWithKeypoints();
const std::vector<Detection> detections = runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>();
EXPECT_EQ(kNumBoxes, detections.size());
for (const auto& detection : detections) {
LocationData::RelativeBoundingBox relative_bbox =
detection.location_data().relative_bounding_box();
EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin());
EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin());
EXPECT_FLOAT_EQ(0.2, relative_bbox.width());
EXPECT_FLOAT_EQ(0.2, relative_bbox.height());
for (const auto& relative_keypoint :
detection.location_data().relative_keypoints()) {
EXPECT_FLOAT_EQ(0.5, relative_keypoint.x());
EXPECT_FLOAT_EQ(0.6, relative_keypoint.y());
}
}
EXPECT_FLOAT_EQ(0.1f, detections[0].score(0));
EXPECT_FLOAT_EQ(0.5f, detections[1].score(0));
EXPECT_FLOAT_EQ(1.0f, detections[2].score(0));
EXPECT_EQ(1, detections[0].label_id(0));
EXPECT_EQ(2, detections[1].label_id(0));
EXPECT_EQ(3, detections[2].label_id(0));
}
TEST_F(ObjectDetectionTensorsToDetectionsCalculatorTest,
SqueezesInputTensorDimensionAndOutputsDetectionsWithCorrectValues) {
RunCalculatorWithTensorDimensionSqueezing();
const std::vector<Detection> detections = runner_->Outputs()
.Tag(kDetections)
.packets[0]
.Get<std::vector<Detection>>();
EXPECT_EQ(kNumBoxes, detections.size());
for (const auto& detection : detections) {
LocationData::RelativeBoundingBox relative_bbox =
detection.location_data().relative_bounding_box();
EXPECT_FLOAT_EQ(0.2, relative_bbox.xmin());
EXPECT_FLOAT_EQ(0.1, relative_bbox.ymin());
EXPECT_FLOAT_EQ(0.2, relative_bbox.width());
EXPECT_FLOAT_EQ(0.2, relative_bbox.height());
}
EXPECT_FLOAT_EQ(0.1f, detections[0].score(0));
EXPECT_FLOAT_EQ(0.5f, detections[1].score(0));
EXPECT_FLOAT_EQ(1.0f, detections[2].score(0));
EXPECT_EQ(1, detections[0].label_id(0));
EXPECT_EQ(2, detections[1].label_id(0));
EXPECT_EQ(3, detections[2].label_id(0));
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,403 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/match.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "mediapipe/util/sequence/media_sequence_util.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
namespace mediapipe {
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
const char kImageTag[] = "IMAGE";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
const char kBBoxTag[] = "BBOX";
const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION";
namespace tf = ::tensorflow;
namespace mpms = ::mediapipe::mediasequence;
// Sink calculator to package streams into tf.SequenceExamples.
//
// The calculator takes a tf.SequenceExample as a side input and then adds
// the data from inputs to the SequenceExample with timestamps. Additional
// context features can be supplied verbatim in the calculator's options. The
// SequenceExample will conform to the description in media_sequence.h.
//
// The supported input stream tags are "IMAGE", which stores the encoded
// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which
// stores the encoded optical flow from the same calculator, "BBOX" which stores
// bounding boxes from vector<Detections>, and streams with the
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
// associated with the name ${NAME}. Audio streams (i.e. Matrix with a
// TimeSeriesHeader) are given extra packing and unpacking support and are named
// similar to floats with the pattern "AUDIO_${NAME}". "IMAGE_${NAME}" and
// "BBOX_${NAME}" will also store prefixed versions of each stream, which allows
// for multiple image streams to be included. However, the default names are
// suppored by more tools. "ENCODED_MEDIA" stores a video encoding for the clip
// directly. The last packet on this stream is stored, and can be unpacked with
// the metadata generator. Because the media decoder always starts with
// timestamp zero, the "ENCODED_MEDIA_START_TIMESTAMP" should be recorded as
// well. Use the FirstTimestampCalculator to determine this value.
//
// Example config:
// node {
// calculator: "PackMediaSequenceCalculator"
// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet"
// input_stream: "IMAGE:frames"
// input_stream: "FLOAT_FEATURE_FDENSE:fdense_vf"
// output_stream: "SEQUENCE_EXAMPLE:example_output_stream"
// options {
// [mediapipe.PackMediaSequenceCalculatorOptions.ext]: {
// context_feature_map {
// feature {
// key: "image/frames_per_second"
// value {
// float_list {
// value: 30.0
// }
// }
// }
// }
// }
// }
// }
namespace {
uint8 ConvertFloatToByte(const float float_value) {
float clamped_value = MathUtil::Clamp(0.0f, 1.0f, float_value);
return static_cast<uint8>(clamped_value * 255.0 + .5f);
}
} // namespace
class PackMediaSequenceCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag));
cc->InputSidePackets().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
cc->Inputs()
.Tag(kForwardFlowEncodedTag)
.Set<OpenCvImageEncoderCalculatorResults>();
}
if (cc->Inputs().HasTag(kSegmentationMaskTag)) {
cc->Inputs().Tag(kSegmentationMaskTag).Set<std::vector<Detection>>();
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) {
std::string key = "";
if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kImageTag)_?"
}
}
cc->Inputs().Tag(tag).Set<OpenCvImageEncoderCalculatorResults>();
}
if (absl::StartsWith(tag, kBBoxTag)) {
std::string key = "";
if (tag != kBBoxTag) {
int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kBBoxTag)_?"
}
}
cc->Inputs().Tag(tag).Set<std::vector<Detection>>();
}
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
cc->Inputs().Tag(tag).Set<std::vector<float>>();
}
}
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "Neither the output stream nor the output side packet is set to "
"output the sequence example.";
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
}
if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) {
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set<tf::SequenceExample>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
sequence_ = ::absl::make_unique<tf::SequenceExample>(
cc->InputSidePackets()
.Tag(kSequenceExampleTag)
.Get<tf::SequenceExample>());
const auto& context_features =
cc->Options<PackMediaSequenceCalculatorOptions>().context_feature_map();
for (const auto& feature : context_features.feature()) {
*mpms::MutableContext(feature.first, sequence_.get()) = feature.second;
}
for (const auto& tag : cc->Inputs().GetTags()) {
features_present_[tag] = false;
}
if (cc->Options()
.GetExtension(PackMediaSequenceCalculatorOptions::ext)
.replace_data_instead_of_append()) {
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) {
std::string key = "";
if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kImageTag)_?"
}
}
mpms::ClearImageEncoded(key, sequence_.get());
mpms::ClearImageTimestamp(key, sequence_.get());
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
mpms::ClearForwardFlowEncoded(sequence_.get());
mpms::ClearForwardFlowTimestamp(sequence_.get());
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
sizeof(*kFloatFeaturePrefixTag) -
1);
mpms::ClearFeatureFloats(key, sequence_.get());
mpms::ClearFeatureTimestamp(key, sequence_.get());
}
}
}
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs()
.Tag(kSequenceExampleTag)
.SetNextTimestampBound(Timestamp::Max());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status VerifySequence() {
std::string error_msg = "Missing features - ";
bool all_present = true;
for (auto iter : features_present_) {
if (!iter.second) {
all_present = false;
absl::StrAppend(&error_msg, iter.first, ", ");
}
}
if (all_present) {
return ::mediapipe::OkStatus();
} else {
return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) << error_msg;
}
}
::mediapipe::Status Close(CalculatorContext* cc) override {
auto& options =
cc->Options().GetExtension(PackMediaSequenceCalculatorOptions::ext);
if (options.reconcile_metadata()) {
RET_CHECK_OK(mpms::ReconcileMetadata(options.reconcile_bbox_annotations(),
sequence_.get()));
}
if (options.output_only_if_all_present()) {
::mediapipe::Status status = VerifySequence();
if (!status.ok()) {
cc->GetCounter(status.error_message())->Increment();
return status;
}
}
if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) {
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set(MakePacket<tensorflow::SequenceExample>(*sequence_));
}
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs()
.Tag(kSequenceExampleTag)
.Add(sequence_.release(), Timestamp::PostStream());
}
sequence_.reset();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = "";
if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kImageTag)_?"
}
}
const OpenCvImageEncoderCalculatorResults& image =
cc->Inputs().Tag(tag).Get<OpenCvImageEncoderCalculatorResults>();
if (!image.has_encoded_image()) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "No encoded image";
}
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get());
mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get());
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
!cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
const OpenCvImageEncoderCalculatorResults& forward_flow =
cc->Inputs()
.Tag(kForwardFlowEncodedTag)
.Get<OpenCvImageEncoderCalculatorResults>();
if (!forward_flow.has_encoded_image()) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "No encoded forward flow";
}
mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
sequence_.get());
mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
sequence_.get());
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
sizeof(*kFloatFeaturePrefixTag) -
1);
mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get());
mpms::AddFeatureFloats(key,
cc->Inputs().Tag(tag).Get<std::vector<float>>(),
sequence_.get());
}
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = "";
if (tag != kBBoxTag) {
int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kBBoxTag)_?"
}
}
std::vector<Location> predicted_locations;
std::vector<std::string> predicted_class_strings;
std::vector<int> predicted_label_ids;
for (auto& detection :
cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
if (detection.location_data().format() ==
LocationData::BOUNDING_BOX ||
detection.location_data().format() ==
LocationData::RELATIVE_BOUNDING_BOX) {
int height = mpms::GetImageHeight(*sequence_);
int width = mpms::GetImageWidth(*sequence_);
Location relative_bbox = Location::CreateRelativeBBoxLocation(
Location(detection.location_data())
.ConvertToRelativeBBox(width, height));
predicted_locations.push_back(relative_bbox);
if (detection.label_size() > 0) {
predicted_class_strings.push_back(detection.label(0));
}
if (detection.label_id_size() > 0) {
predicted_label_ids.push_back(detection.label_id(0));
}
}
}
if (!predicted_locations.empty()) {
mpms::AddBBox(key, predicted_locations, sequence_.get());
mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get());
if (!predicted_class_strings.empty()) {
mpms::AddBBoxClassString(key, predicted_class_strings,
sequence_.get());
}
if (!predicted_label_ids.empty()) {
mpms::AddBBoxClassIndex(key, predicted_label_ids, sequence_.get());
}
}
}
}
if (cc->Inputs().HasTag(kSegmentationMaskTag) &&
!cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) {
bool already_has_mask = false;
for (auto& detection : cc->Inputs()
.Tag(kSegmentationMaskTag)
.Get<std::vector<Detection>>()) {
if (detection.location_data().format() == LocationData::MASK) {
RET_CHECK(!already_has_mask)
<< "We currently only support adding one mask per timestamp. "
<< sequence_->DebugString();
auto mask_mat_ptr = Location(detection.location_data()).GetCvMask();
std::vector<uchar> bytes;
RET_CHECK(cv::imencode(".png", *mask_mat_ptr, bytes, {}));
std::string encoded_mask(bytes.begin(), bytes.end());
mpms::AddClassSegmentationEncoded(encoded_mask, sequence_.get());
mpms::AddClassSegmentationTimestamp(cc->InputTimestamp().Value(),
sequence_.get());
// SegmentationClassLabelString is a context feature for the entire
// sequence. The values in the last detection will be saved.
mpms::SetClassSegmentationClassLabelString({detection.label(0)},
sequence_.get());
already_has_mask = true;
} else {
return ::mediapipe::UnimplementedError(
"Global detections and empty detections are not supported.");
}
}
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (!cc->Inputs().Tag(tag).IsEmpty()) {
features_present_[tag] = true;
}
}
return ::mediapipe::OkStatus();
}
std::unique_ptr<tf::SequenceExample> sequence_;
std::map<std::string, bool> features_present_;
};
REGISTER_CALCULATOR(PackMediaSequenceCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,53 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "tensorflow/core/example/feature.proto";
message PackMediaSequenceCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional PackMediaSequenceCalculatorOptions ext = 243132285;
}
// Provide a map of strings to tf.Features to merge into the SequenceExample's
// context. Use this to add new metadata.
optional tensorflow.Features context_feature_map = 1;
// If true, update the context for the SequenceExample features.
// (e.g. fills in the image height, width, and number of frames.)
optional bool reconcile_metadata = 2 [default = true];
// If true, updates the metadata for sequences with bounding boxes. This will
// align each bounding box annotation with the nearest frame and insert empty
// annotations as needed to satisfy the frame rate.
// NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS.
// If two or more annotations are closest to the same frame, then only
// the closest annotation is saved. This matches the behavior of
// downsampling images in time.
optional bool reconcile_bbox_annotations = 5 [default = true];
// If true, the SequenceExample is output only if all input streams are
// present.
optional bool output_only_if_all_present = 3 [default = false];
// If true, will remove all data from a sequence example for a corresponding
// input stream. E.g. if images are already present and an IMAGE stream is
// present, the previous images and timestamps will be removed before adding
// the new images.
optional bool replace_data_instead_of_append = 4 [default = true];
}

View File

@ -0,0 +1,623 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
namespace mpms = ::mediapipe::mediasequence;
class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected:
void SetUpCalculator(const std::vector<std::string>& input_streams,
const tf::Features& features,
bool output_only_if_all_present,
bool replace_instead_of_append) {
CalculatorGraphConfig::Node config;
config.set_calculator("PackMediaSequenceCalculator");
config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence");
config.add_output_stream("SEQUENCE_EXAMPLE:output_sequence");
for (const std::string& stream : input_streams) {
config.add_input_stream(stream);
}
auto options = config.mutable_options()->MutableExtension(
PackMediaSequenceCalculatorOptions::ext);
*options->mutable_context_feature_map() = features;
options->set_output_only_if_all_present(output_only_if_all_present);
options->set_replace_data_instead_of_append(replace_instead_of_append);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
SetUpCalculator({"IMAGE:images"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_image_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(test_image_string);
encoded_image.set_width(2);
encoded_image.set_height(1);
int num_images = 2;
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(num_images, mpms::GetImageTimestampSize(output_sequence));
ASSERT_EQ(num_images, mpms::GetImageEncodedSize(output_sequence));
for (int i = 0; i < num_images; ++i) {
ASSERT_EQ(i, mpms::GetImageTimestampAt(output_sequence, i));
ASSERT_EQ(test_image_string, mpms::GetImageEncodedAt(output_sequence, i));
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
std::string prefix = "PREFIX";
SetUpCalculator({"IMAGE_PREFIX:images"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_image_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(test_image_string);
encoded_image.set_width(2);
encoded_image.set_height(1);
int num_images = 2;
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()
->Tag("IMAGE_PREFIX")
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(num_images, mpms::GetImageTimestampSize(prefix, output_sequence));
ASSERT_EQ(num_images, mpms::GetImageEncodedSize(prefix, output_sequence));
for (int i = 0; i < num_images; ++i) {
ASSERT_EQ(i, mpms::GetImageTimestampAt(prefix, output_sequence, i));
ASSERT_EQ(test_image_string,
mpms::GetImageEncodedAt(prefix, output_sequence, i));
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
SetUpCalculator({"FLOAT_FEATURE_TEST:test", "FLOAT_FEATURE_OTHER:test2"}, {},
false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST")
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_OTHER")
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(num_timesteps,
mpms::GetFeatureTimestampSize("TEST", output_sequence));
ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("TEST", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetFeatureTimestampSize("OTHER", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetFeatureFloatsSize("OTHER", output_sequence));
for (int i = 0; i < num_timesteps; ++i) {
ASSERT_EQ(i, mpms::GetFeatureTimestampAt("TEST", output_sequence, i));
ASSERT_THAT(mpms::GetFeatureFloatsAt("TEST", output_sequence, i),
::testing::ElementsAreArray(std::vector<float>(2, 2 << i)));
ASSERT_EQ(i, mpms::GetFeatureTimestampAt("OTHER", output_sequence, i));
ASSERT_THAT(mpms::GetFeatureFloatsAt("OTHER", output_sequence, i),
::testing::ElementsAreArray(std::vector<float>(2, 2 << i)));
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
tf::Features context;
(*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES");
(*context.mutable_feature())["OTHER"].mutable_bytes_list()->add_value("NO");
SetUpCalculator({"IMAGE:images"}, context, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_image_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(test_image_string);
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_TRUE(mpms::HasContext(output_sequence, "TEST"));
ASSERT_TRUE(mpms::HasContext(output_sequence, "OTHER"));
ASSERT_EQ(mpms::GetContext(output_sequence, "TEST").bytes_list().value(0),
"YES");
ASSERT_EQ(mpms::GetContext(output_sequence, "OTHER").bytes_list().value(0),
"NO");
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
SetUpCalculator({"FORWARD_FLOW_ENCODED:flow"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
encoded_flow.set_width(2);
encoded_flow.set_height(1);
int num_flows = 2;
for (int i = 0; i < num_flows; ++i) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(num_flows, mpms::GetForwardFlowTimestampSize(output_sequence));
ASSERT_EQ(num_flows, mpms::GetForwardFlowEncodedSize(output_sequence));
for (int i = 0; i < num_flows; ++i) {
ASSERT_EQ(i, mpms::GetForwardFlowTimestampAt(output_sequence, i));
ASSERT_EQ(test_flow_string,
mpms::GetForwardFlowEncodedAt(output_sequence, i));
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
SetUpCalculator({"BBOX_PREDICTED:detections"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
int height = 480;
int width = 640;
mpms::SetImageHeight(height, input_sequence.get());
mpms::SetImageWidth(width, input_sequence.get());
int num_vectors = 2;
for (int i = 0; i < num_vectors; ++i) {
auto detections = ::absl::make_unique<::std::vector<Detection>>();
Detection detection;
detection.add_label("absolute bbox");
detection.add_label_id(0);
detection.add_score(0.5);
Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
detection = Detection();
detection.add_label("relative bbox");
detection.add_label_id(1);
detection.add_score(0.75);
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
// The mask detection should be ignored in the output.
detection = Detection();
detection.add_label("mask");
detection.add_score(1.0);
cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0));
Location::CreateCvMaskLocation<uint8>(image).ConvertToProto(
detection.mutable_location_data());
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(height, mpms::GetImageHeight(output_sequence));
ASSERT_EQ(width, mpms::GetImageWidth(output_sequence));
ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxSize(output_sequence));
ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxTimestampSize(output_sequence));
ASSERT_EQ(0, mpms::GetClassSegmentationEncodedSize(output_sequence));
ASSERT_EQ(0, mpms::GetClassSegmentationTimestampSize(output_sequence));
for (int i = 0; i < num_vectors; ++i) {
ASSERT_EQ(i, mpms::GetPredictedBBoxTimestampAt(output_sequence, i));
auto bboxes = mpms::GetPredictedBBoxAt(output_sequence, i);
ASSERT_EQ(2, bboxes.size());
for (int j = 0; j < bboxes.size(); ++j) {
auto rect = bboxes[j].GetRelativeBBox();
ASSERT_NEAR(0, rect.xmin(), 0.001);
ASSERT_NEAR(0.5, rect.ymin(), 0.001);
ASSERT_NEAR(0.5, rect.xmax(), 0.001);
ASSERT_NEAR(1.0, rect.ymax(), 0.001);
}
auto class_strings =
mpms::GetPredictedBBoxClassStringAt(output_sequence, i);
ASSERT_EQ("absolute bbox", class_strings[0]);
ASSERT_EQ("relative bbox", class_strings[1]);
auto class_indices = mpms::GetPredictedBBoxClassIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]);
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
SetUpCalculator({"CLASS_SEGMENTATION:detections"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
int height = 480;
int width = 640;
mpms::SetImageHeight(height, input_sequence.get());
mpms::SetImageWidth(width, input_sequence.get());
int num_vectors = 2;
for (int i = 0; i < num_vectors; ++i) {
auto detections = ::absl::make_unique<::std::vector<Detection>>();
Detection detection;
detection = Detection();
detection.add_label("mask");
detection.add_score(1.0);
cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0));
Location::CreateCvMaskLocation<uint8>(image).ConvertToProto(
detection.mutable_location_data());
detections->push_back(detection);
runner_->MutableInputs()
->Tag("CLASS_SEGMENTATION")
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
LOG(INFO) << "output_sequence: \n" << output_sequence.DebugString();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(height, mpms::GetImageHeight(output_sequence));
ASSERT_EQ(width, mpms::GetImageWidth(output_sequence));
ASSERT_EQ(2, mpms::GetClassSegmentationEncodedSize(output_sequence));
ASSERT_EQ(2, mpms::GetClassSegmentationTimestampSize(output_sequence));
for (int i = 0; i < num_vectors; ++i) {
ASSERT_EQ(i, mpms::GetClassSegmentationTimestampAt(output_sequence, i));
}
ASSERT_THAT(mpms::GetClassSegmentationClassLabelString(output_sequence),
testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
}
TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
SetUpCalculator(
{"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {},
false, false);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
encoded_flow.set_width(2);
encoded_flow.set_height(1);
int num_flows = 2;
for (int i = 0; i < num_flows; ++i) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(num_flows, mpms::GetForwardFlowTimestampSize(output_sequence));
ASSERT_EQ(num_flows, mpms::GetForwardFlowEncodedSize(output_sequence));
for (int i = 0; i < num_flows; ++i) {
ASSERT_EQ(i, mpms::GetForwardFlowTimestampAt(output_sequence, i));
ASSERT_EQ(test_flow_string,
mpms::GetForwardFlowEncodedAt(output_sequence, i));
}
}
TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
SetUpCalculator(
{"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {}, true,
false);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
encoded_flow.set_width(2);
encoded_flow.set_height(1);
int num_flows = 2;
for (int i = 0; i < num_flows; ++i) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
::mediapipe::Status status = runner_->Run();
EXPECT_FALSE(status.ok());
}
TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) {
SetUpCalculator({"IMAGE:images"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
mpms::AddImageEncoded("one", input_sequence.get());
mpms::AddImageEncoded("two", input_sequence.get());
mpms::AddImageTimestamp(1, input_sequence.get());
mpms::AddImageTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(0, mpms::GetImageTimestampSize(output_sequence));
ASSERT_EQ(0, mpms::GetImageEncodedSize(output_sequence));
}
TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) {
SetUpCalculator({"FORWARD_FLOW_ENCODED:images"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
mpms::AddForwardFlowEncoded("one", input_sequence.get());
mpms::AddForwardFlowEncoded("two", input_sequence.get());
mpms::AddForwardFlowTimestamp(1, input_sequence.get());
mpms::AddForwardFlowTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(0, mpms::GetForwardFlowTimestampSize(output_sequence));
ASSERT_EQ(0, mpms::GetForwardFlowEncodedSize(output_sequence));
}
TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) {
SetUpCalculator({"FLOAT_FEATURE_TEST:test", "FLOAT_FEATURE_OTHER:test2"}, {},
false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
mpms::AddFeatureFloats("TEST", *vf_ptr, input_sequence.get());
mpms::AddFeatureTimestamp("TEST", i, input_sequence.get());
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
mpms::AddFeatureFloats("OTHER", *vf_ptr, input_sequence.get());
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
}
ASSERT_EQ(num_timesteps,
mpms::GetFeatureTimestampSize("TEST", *input_sequence));
ASSERT_EQ(num_timesteps, mpms::GetFeatureFloatsSize("TEST", *input_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(0, mpms::GetFeatureTimestampSize("TEST", output_sequence));
ASSERT_EQ(0, mpms::GetFeatureFloatsSize("TEST", output_sequence));
ASSERT_EQ(0, mpms::GetFeatureTimestampSize("OTHER", output_sequence));
ASSERT_EQ(0, mpms::GetFeatureFloatsSize("OTHER", output_sequence));
}
TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
SetUpCalculator({"IMAGE:images"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_image_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(test_image_string);
encoded_image.set_width(2);
encoded_image.set_height(1);
int num_images = 5; // Timestamps: 10, 20, 30, 40, 50
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
}
mpms::AddBBoxTimestamp(9, input_sequence.get());
mpms::AddBBoxTimestamp(21, input_sequence.get());
mpms::AddBBoxTimestamp(22, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(mpms::GetBBoxTimestampSize(output_sequence), 5);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 0), 10);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 1), 20);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 2), 30);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 3), 40);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 4), 50);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,99 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/example/example.pb.h"
// A calculator to serialize/deserialize tensorflow::SequenceExample protos
// to and from strings.
//
// Example converting to SequenceExample in Open():
// node {
// calculator: "StringToSequenceExampleCalculator"
// input_side_packet: "STRING:serialized_sequence_example"
// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
// }
//
// Example converting to std::string in Close():
// node {
// calculator: "StringToSequenceExampleCalculator"
// input_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
// output_side_packet: "STRING:serialized_sequence_example"
// }
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kString[] = "STRING";
constexpr char kSequenceExample[] = "SEQUENCE_EXAMPLE";
} // namespace
class StringToSequenceExampleCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(StringToSequenceExampleCalculator);
::mediapipe::Status StringToSequenceExampleCalculator::GetContract(
CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kString)) {
cc->InputSidePackets().Tag(kString).Set<std::string>();
cc->OutputSidePackets().Tag(kSequenceExample).Set<tf::SequenceExample>();
}
if (cc->InputSidePackets().HasTag(kSequenceExample)) {
cc->InputSidePackets().Tag(kSequenceExample).Set<tf::SequenceExample>();
cc->OutputSidePackets().Tag(kString).Set<std::string>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status StringToSequenceExampleCalculator::Open(
CalculatorContext* cc) {
if (cc->InputSidePackets().HasTag(kString)) {
auto string_value = cc->InputSidePackets().Tag(kString).Get<std::string>();
auto example = absl::make_unique<tf::SequenceExample>();
example->ParseFromString(string_value);
cc->OutputSidePackets()
.Tag(kSequenceExample)
.Set(::mediapipe::Adopt(example.release()));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status StringToSequenceExampleCalculator::Process(
CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
::mediapipe::Status StringToSequenceExampleCalculator::Close(
CalculatorContext* cc) {
if (cc->InputSidePackets().HasTag(kSequenceExample)) {
const auto& example =
cc->InputSidePackets().Tag(kSequenceExample).Get<tf::SequenceExample>();
auto string_value = absl::make_unique<std::string>();
example.SerializeToString(string_value.get());
cc->OutputSidePackets().Tag(kString).Set(
::mediapipe::Adopt(string_value.release()));
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,111 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
namespace mediapipe {
namespace tf = ::tensorflow;
// Given an input Tensor (example dimensions [1, 1024, 1, 5]), it squeezes all
// dimensions with size 1, or dimensions at specific indices, producing a tensor
// containing identical data (example output dimensions [1024, 5]).
class TensorSqueezeDimensionsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Need one input";
cc->Inputs().Index(0).Set<tf::Tensor>(
// Input Tensor
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Need one output";
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output Tensor Reduced Dimensions
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorSqueezeDimensionsCalculatorOptions>();
RET_CHECK(options_.squeeze_all_single_dims() ^ (options_.dim_size() > 0))
<< "Must specify dimensions to remove, or set squeeze_all_single_dims, "
"but not both. Received options: "
<< options_.DebugString();
if (options_.dim_size() > 0) {
remove_dims_ =
std::vector<int32>(options_.dim().begin(), options_.dim().end());
std::sort(remove_dims_.rbegin(), remove_dims_.rend());
remove_dims_initialized_ = true;
}
cc->SetOffset(0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const tf::Tensor& input_tensor = cc->Inputs().Index(0).Get<tf::Tensor>();
tf::TensorShape tensor_shape = input_tensor.shape();
if (!remove_dims_initialized_) {
// Happens iff options.squeeze_all_single_dims is set.
// Initialize remove_dims_ to all dimensions with size 1.
InitializeToRemoveAllSingletonDimensions(tensor_shape);
remove_dims_initialized_ = true;
}
for (const int dim : remove_dims_) {
RET_CHECK_GT(tensor_shape.dims(), dim)
<< "Dimension " << dim
<< " does not exist in input tensor with num dimensions "
<< input_tensor.dims();
RET_CHECK_EQ(tensor_shape.dim_size(dim), 1)
<< "Cannot remove dimension " << dim << " with size "
<< tensor_shape.dim_size(dim);
tensor_shape.RemoveDim(dim);
}
std::unique_ptr<tf::Tensor> output_tensor(new tf::Tensor);
RET_CHECK(output_tensor->CopyFrom(input_tensor, tensor_shape));
cc->Outputs().Index(0).Add(output_tensor.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
private:
TensorSqueezeDimensionsCalculatorOptions options_;
std::vector<int32> remove_dims_;
bool remove_dims_initialized_;
void InitializeToRemoveAllSingletonDimensions(
const tf::TensorShape& tensor_shape) {
const int dims = tensor_shape.dims();
for (int i = dims - 1; i >= 0; --i) {
if (tensor_shape.dim_size(i) == 1) {
remove_dims_.push_back(i);
}
}
if (remove_dims_.empty()) {
LOG(ERROR) << "TensorSqueezeDimensionsCalculator is squeezing input with "
"no single-dimensions. Calculator will be a no-op.";
LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape "
<< tensor_shape.DebugString();
}
}
};
REGISTER_CALCULATOR(TensorSqueezeDimensionsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
// Specifies options to TensorSqueezeDimensionsCalculator. Use this proto to
// configure which dimensions to squeeze (remove). It is only possible to remove
// dimensions of size 1.
// The fields 'squeeze_all_single_dims' and 'dim' are mutually exclusive.
message TensorSqueezeDimensionsCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorSqueezeDimensionsCalculatorOptions ext = 115619805;
}
// Remove all dimensions with size 1.
optional bool squeeze_all_single_dims = 1 [default = false];
// Remove specific singleton dimensions.
repeated int32 dim = 2;
}

View File

@ -0,0 +1,120 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class TensorSqueezeDimensionsCalculatorTest : public ::testing::Test {
protected:
TensorSqueezeDimensionsCalculatorTest() {
// Initialize tensor_ with deterministic values.
tensor_shape_ = tf::TensorShape(std::vector<tf::int64>({1, 3, 1, 3, 1}));
tensor_ = tf::Tensor(tf::DT_INT32, tensor_shape_);
auto tensor_values = tensor_.tensor<int32, 5>();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
tensor_values(0, i, 0, j, 0) = i * (j + 1);
}
}
}
std::unique_ptr<CalculatorRunner> runner_;
tf::TensorShape tensor_shape_;
tf::Tensor tensor_;
};
TEST_F(TensorSqueezeDimensionsCalculatorTest, CanSqueezeAllSingleDimensions) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorSqueezeDimensionsCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
CalculatorOptions options;
TensorSqueezeDimensionsCalculatorOptions* squeeze_options =
options.MutableExtension(TensorSqueezeDimensionsCalculatorOptions::ext);
squeeze_options->set_squeeze_all_single_dims(true);
*config.mutable_options() = options;
runner_.reset(new CalculatorRunner(config));
std::unique_ptr<tf::Tensor> tensor_copy(new tf::Tensor);
EXPECT_TRUE(tensor_copy->CopyFrom(tensor_, tensor_shape_));
const tf::int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor_copy.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
const tf::TensorShape expected_shape(std::vector<tf::int64>({3, 3}));
EXPECT_EQ(expected_shape.DebugString(), output_tensor.shape().DebugString());
const auto tensor_values = output_tensor.tensor<int32, 2>();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
const int expected_value = i * (j + 1);
EXPECT_EQ(expected_value, tensor_values(i, j));
}
}
}
TEST_F(TensorSqueezeDimensionsCalculatorTest, CanSqueezeSpecifiedDimensions) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorSqueezeDimensionsCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
CalculatorOptions options;
TensorSqueezeDimensionsCalculatorOptions* squeeze_options =
options.MutableExtension(TensorSqueezeDimensionsCalculatorOptions::ext);
squeeze_options->add_dim(0);
squeeze_options->add_dim(4);
*config.mutable_options() = options;
runner_.reset(new CalculatorRunner(config));
std::unique_ptr<tf::Tensor> tensor_copy(new tf::Tensor);
EXPECT_TRUE(tensor_copy->CopyFrom(tensor_, tensor_shape_));
const tf::int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor_copy.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
const tf::TensorShape expected_shape(std::vector<tf::int64>({3, 1, 3}));
EXPECT_EQ(expected_shape.DebugString(), output_tensor.shape().DebugString());
const auto tensor_values = output_tensor.tensor<int32, 3>();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
const int expected_value = i * (j + 1);
EXPECT_EQ(expected_value, tensor_values(i, 0, j));
}
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,124 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kImage[] = "IMAGE";
constexpr char kTensor[] = "TENSOR";
} // namespace
// Input:
// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The
// shape can be HxWx{3,1} or simply HxW.
//
// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255.
//
// Output:
// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8)
//
// Possible extensions: support other input ranges, maybe 4D tensors.
class TensorToImageFrameCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
float scale_factor_;
};
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
::mediapipe::Status TensorToImageFrameCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "One input stream must be provided.";
RET_CHECK(cc->Inputs().HasTag(kTensor))
<< "An input stream for tag: " << kTensor << " must be provided.";
cc->Inputs().Tag(kTensor).Set<tf::Tensor>(
// Input Tensor.
);
cc->Outputs().Tag(kImage).Set<ImageFrame>(
// Output ImageFrame.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
scale_factor_ =
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToImageFrameCalculator::Process(
CalculatorContext* cc) {
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
int32 depth = 1;
if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors.
CHECK(3 == input_tensor.dims())
<< "Only 2 or 3-D Tensors can be converted to frames. Instead got: "
<< input_tensor.dims();
depth = input_tensor.dim_size(2);
if (depth != 1) {
RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1.";
}
}
const int32 total_size =
input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth;
std::unique_ptr<uint8[]> buffer(new uint8[total_size]);
auto data = input_tensor.flat<float>().data();
for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i];
if (d < 0) d = 0;
if (d > 255) d = 255;
buffer[i] = d;
}
::std::unique_ptr<ImageFrame> output;
if (depth == 3) {
output = ::absl::make_unique<ImageFrame>(
ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0),
input_tensor.dim_size(1) * 3, buffer.release());
} else if (depth == 1) {
output = ::absl::make_unique<ImageFrame>(
ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0),
input_tensor.dim_size(1), buffer.release());
} else {
return ::mediapipe::InvalidArgumentError("Unrecognized image depth.");
}
cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorToImageFrameCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorToImageFrameCalculatorOptions ext = 142032475;
}
// Multiples floating point tensor outputs by this value before converting to
// uint8. This is useful for converting from range [0, 1] to [0, 255]
optional float scale_factor = 1 [default = 1.0];
}

View File

@ -0,0 +1,146 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kTensor[] = "TENSOR";
constexpr char kImage[] = "IMAGE";
} // namespace
class TensorToImageFrameCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner() {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToImageFrameCalculator");
config.add_input_stream("TENSOR:input_tensor");
config.add_output_stream("IMAGE:output_image");
runner_ = absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
SetUpRunner();
constexpr int kWidth = 16;
constexpr int kHeight = 8;
const tf::TensorShape tensor_shape(
std::vector<tf::int64>{kHeight, kWidth, 3});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->flat<float>().data();
// Writing sequence of integers as floats which we want back (as they were
// written).
for (int i = 0; i < kWidth * kHeight * 3; ++i) {
tensor_vec[i] = i % 255;
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height());
for (int i = 0; i < kWidth * kHeight * 3; ++i) {
const uint8 pixel_value = output_image.PixelData()[i];
EXPECT_EQ(i % 255, pixel_value);
}
}
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
SetUpRunner();
constexpr int kWidth = 16;
constexpr int kHeight = 8;
const tf::TensorShape tensor_shape(
std::vector<tf::int64>{kHeight, kWidth, 1});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->flat<float>().data();
// Writing sequence of integers as floats which we want back (as they were
// written).
for (int i = 0; i < kWidth * kHeight; ++i) {
tensor_vec[i] = i % 255;
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height());
for (int i = 0; i < kWidth * kHeight; ++i) {
const uint8 pixel_value = output_image.PixelData()[i];
EXPECT_EQ(i % 255, pixel_value);
}
}
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
SetUpRunner();
constexpr int kWidth = 16;
constexpr int kHeight = 8;
const tf::TensorShape tensor_shape(std::vector<tf::int64>{kHeight, kWidth});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->flat<float>().data();
// Writing sequence of integers as floats which we want back (as they were
// written).
for (int i = 0; i < kWidth * kHeight; ++i) {
tensor_vec[i] = i % 255;
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height());
for (int i = 0; i < kWidth * kHeight; ++i) {
const uint8 pixel_value = output_image.PixelData()[i];
EXPECT_EQ(i % 255, pixel_value);
}
}
} // namespace mediapipe

View File

@ -0,0 +1,227 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Calculator converts from one-dimensional Tensor of DT_FLOAT to Matrix
// OR from (batched) two-dimensional Tensor of DT_FLOAT to Matrix.
#include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kMatrix[] = "MATRIX";
constexpr char kTensor[] = "TENSOR";
constexpr char kReference[] = "REFERENCE";
::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
TimeSeriesHeader* header) {
CHECK(header);
if (header_packet.IsEmpty()) {
return ::mediapipe::UnknownError("No header found.");
}
if (!header_packet.ValidateAsType<TimeSeriesHeader>().ok()) {
return ::mediapipe::UnknownError(
"Packet does not contain TimeSeriesHeader.");
}
*header = header_packet.Get<TimeSeriesHeader>();
if (header->has_sample_rate() && header->sample_rate() >= 0 &&
header->has_num_channels() && header->num_channels() >= 0) {
return ::mediapipe::OkStatus();
} else {
std::string error_message =
"TimeSeriesHeader is missing necessary fields: "
"sample_rate or num_channels, or one of their values is negative. ";
#ifndef MEDIAPIPE_MOBILE
absl::StrAppend(&error_message, "Got header:\n",
header->ShortDebugString());
#endif
return ::mediapipe::InvalidArgumentError(error_message);
}
}
} // namespace
// Converts a 1-D or a 2-D Tensor into a 2 dimensional Matrix.
// Input:
// -- 1-D or 2-D Tensor
// Output:
// -- Matrix with the same values as the Tensor
// If input tensor is 1 dimensional, the ouput Matrix is of (1xn) shape.
// If input tensor is 2 dimensional (batched), the ouput Matrix is (mxn) shape.
//
// Example Config
// node: {
// calculator: "TensorToMatrixCalculator"
// input_stream: "TENSOR:tensor"
// output_stream: "MATRIX:matrix"
// }
//
//
// This calculator produces a TimeSeriesHeader header on its output stream iff
// an input stream is supplied with the REFERENCE tag and that stream has a
// header of type TimeSeriesHeader. This header is modified in two ways:
// - the sample_rate is set to the packet rate of the REFERENCE stream (which
// must have a packet_rate defined in its header). This is under the
// assumption that the packets on the reference stream, input stream, and
// output stream are in a 1:1 correspondence, and that the output packets are
// 1-D column vectors that represent a single sample of output.
// - the TimeSeriesHeader overrides specified in the calculator options are
// then applied, which can override the sample_rate field.
// If the REFERENCE stream is supplied, then the TimeSeriesHeader is verified on
// the input data when it arrives in Process(). In particular, if the header
// states that we produce a 1xD column vector, the input tensor must also be 1xD
//
// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and
// http://g/daredevil-project/VB26tcseUy8.
// Example Config
// node: {
// calculator: "TensorToMatrixCalculator"
// input_stream: "TENSOR:tensor"
// input_stream: "REFERENCE:reference_matrix"
// output_stream: "MATRIX:matrix"
// options {
// [mediapipe.TensorToMatrixCalculatorOptions.ext] {
// time_series_header_overrides {
// num_channels: 128
// }
// }
// }
// }
class TensorToMatrixCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
// Store header information so that we can verify the inputs in process().
TimeSeriesHeader header_;
};
REGISTER_CALCULATOR(TensorToMatrixCalculator);
::mediapipe::Status TensorToMatrixCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK_LE(cc->Inputs().NumEntries(), 2)
<< "Only one or two input streams are supported.";
RET_CHECK_GT(cc->Inputs().NumEntries(), 0)
<< "At least one input stream must be provided.";
RET_CHECK(cc->Inputs().HasTag(kTensor))
<< "An input stream for tag: " << kTensor << " must be provided.";
cc->Inputs().Tag(kTensor).Set<tf::Tensor>(
// Input Tensor.
);
if (cc->Inputs().NumEntries() == 2) {
RET_CHECK(cc->Inputs().HasTag(kReference))
<< "An input stream for tag: " << kReference
<< " must be provided when"
" providing two inputs.";
cc->Inputs()
.Tag(kReference)
.Set<Matrix>(
// A reference stream for the header.
);
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Tag(kMatrix).Set<Matrix>(
// Output Matrix.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
auto input_header = absl::make_unique<TimeSeriesHeader>();
::mediapipe::Status header_status;
if (cc->Inputs().HasTag(kReference)) {
header_status = FillTimeSeriesHeaderIfValid(
cc->Inputs().Tag(kReference).Header(), input_header.get());
}
if (header_status.ok()) {
if (cc->Options<TensorToMatrixCalculatorOptions>()
.has_time_series_header_overrides()) {
// From design discussions with Daredevil, we only want to support single
// sample per packet for now, so we hardcode the sample_rate based on the
// packet_rate of the REFERENCE and fail noisily if we cannot. An
// alternative would be to calculate the sample_rate from the reference
// sample_rate and the change in num_samples between the reference and
// override headers:
// sample_rate_output = sample_rate_reference /
// (num_samples_override / num_samples_reference)
const TimeSeriesHeader& override_header =
cc->Options<TensorToMatrixCalculatorOptions>()
.time_series_header_overrides();
input_header->MergeFrom(override_header);
CHECK(input_header->has_packet_rate())
<< "The TimeSeriesHeader.packet_rate must be set.";
if (!override_header.has_sample_rate()) {
CHECK_EQ(input_header->num_samples(), 1)
<< "Currently the time series can only output single samples.";
input_header->set_sample_rate(input_header->packet_rate());
}
}
header_ = *input_header;
cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release()));
}
cc->SetOffset(mediapipe::TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
// Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK
// failures. These are absolute conditions of the graph for the graph to be
// valid, and if it is violated by any input anywhere, the graph will be
// invalid for all inputs. A hard CHECK will enable faster debugging by
// immediately exiting and more prominently displaying error messages.
// Do not replace with RET_CHECKs.
// Verify that each reference stream packet corresponds to a tensor packet
// otherwise the header information is invalid. If we don't have a reference
// stream, Process() is only called when we have an input tensor and this is
// always True.
CHECK(cc->Inputs().HasTag(kTensor))
<< "Tensor stream not available at same timestamp as the reference "
"stream.";
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
<< "Only 1-D or 2-D Tensors can be converted to matrices.";
const int32 length = input_tensor.dim_size(input_tensor.dims() - 1);
const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0);
if (header_.has_num_channels()) {
CHECK_EQ(length, header_.num_channels())
<< "The number of channels at runtime does not match the header.";
}
if (header_.has_num_samples()) {
CHECK_EQ(width, header_.num_samples())
<< "The number of samples at runtime does not match the header.";
;
}
auto output = absl::make_unique<Matrix>(width, length);
*output =
Eigen::MatrixXf::Map(input_tensor.flat<float>().data(), length, width);
cc->Outputs().Tag(kMatrix).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,36 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/time_series_header.proto";
message TensorToMatrixCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorToMatrixCalculatorOptions ext = 136654056;
}
// Any values present in this TimeSeriesHeader override the values in the
// header from the reference stream if the reference stream is used.
// The most common fields to override are the num_channels field which
// typically correspond to the dimensionality of an output embedding and
// the num_samples field which is generally 1 for an embedding.
// If the sampling_rate is not specified in the time_series_header, then
// the packet_rate from the reference stream will be used as the sampling_rate
// which assumes that num_samples is 1.
optional TimeSeriesHeader time_series_header_overrides = 1;
}

View File

@ -0,0 +1,227 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kMatrix[] = "MATRIX";
constexpr char kTensor[] = "TENSOR";
} // namespace
class TensorToMatrixCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner() {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToMatrixCalculator");
config.add_input_stream("TENSOR:input_tensor");
config.add_output_stream("MATRIX:output_matrix");
runner_ = absl::make_unique<CalculatorRunner>(config);
}
// Creates a reference stream and sets num_channels and num_samples if > 0.
void SetUpRunnerWithReference(int channels, int samples,
int override_channels, bool include_rate) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToMatrixCalculator");
config.add_input_stream("TENSOR:input_tensor");
config.add_input_stream("REFERENCE:reference");
config.add_output_stream("MATRIX:output_matrix");
if (override_channels > 0) {
config.mutable_options()
->MutableExtension(TensorToMatrixCalculatorOptions::ext)
->mutable_time_series_header_overrides()
->set_num_channels(override_channels);
}
runner_ = absl::make_unique<CalculatorRunner>(config);
auto header = absl::make_unique<TimeSeriesHeader>();
header->set_sample_rate(1.0);
if (channels > 0) {
header->set_num_channels(channels);
}
if (samples > 0) {
header->set_num_samples(samples);
}
if (include_rate) {
header->set_packet_rate(1.0);
}
runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release());
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorToMatrixCalculatorTest, Converts1DTensorToMatrix) {
// This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1.
SetUpRunner();
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->vec<float>();
for (int i = 0; i < 5; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
tensor_vec(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kMatrix).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const Matrix& output_matrix = output_packets[0].Get<Matrix>();
EXPECT_EQ(5, output_matrix.rows());
for (int i = 0; i < 5; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_matrix(i, 0));
}
}
TEST_F(TensorToMatrixCalculatorTest, Converts2DTensorofWidthOneToMatrix) {
// This test converts a 2 Dimensional Tensor of shape 1xM to a Matrix of Mx1.
SetUpRunner();
const tf::TensorShape tensor_shape(std::vector<tf::int64>({1, 4}));
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto slice = tensor->Slice(0, 1).flat<float>();
for (int i = 0; i < 4; ++i) {
slice(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kMatrix).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const Matrix& output_matrix = output_packets[0].Get<Matrix>();
ASSERT_EQ(1, output_matrix.cols());
EXPECT_EQ(4, output_matrix.rows());
for (int i = 0; i < 4; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_matrix(i, 0));
}
}
TEST_F(TensorToMatrixCalculatorTest, Converts2DTensorToMatrix) {
// This test converts a 2 Dimensional Tensor of shape NxM to a Matrix of MxN.
SetUpRunner();
const tf::TensorShape tensor_shape(std::vector<tf::int64>({3, 4}));
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto slice = tensor->Slice(0, 1).flat<float>();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 4; ++j) {
slice(i * 4 + j) = static_cast<float>(i * j);
}
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kMatrix).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const Matrix& output_matrix = output_packets[0].Get<Matrix>();
ASSERT_EQ(3, output_matrix.cols());
EXPECT_EQ(4, output_matrix.rows());
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 3; ++j) {
const float expected = static_cast<float>(i * j);
EXPECT_EQ(expected, output_matrix(i, j));
}
}
}
TEST_F(TensorToMatrixCalculatorTest, ConvertsWithReferenceTimeSeriesHeader) {
// This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1.
SetUpRunnerWithReference(5, 1, -1, true);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->vec<float>();
for (int i = 0; i < 5; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
tensor_vec(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kMatrix).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const Matrix& output_matrix = output_packets[0].Get<Matrix>();
EXPECT_EQ(5, output_matrix.rows());
for (int i = 0; i < 5; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_matrix(i, 0));
}
const TimeSeriesHeader& output_header =
runner_->Outputs().Tag(kMatrix).header.Get<TimeSeriesHeader>();
EXPECT_EQ(output_header.num_channels(), 5);
}
TEST_F(TensorToMatrixCalculatorTest, TimeSeriesOverridesWork) {
// This test converts a 1 Dimensional Tensor of length M to a Matrix of Mx1.
SetUpRunnerWithReference(7, 1, 5, true);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->vec<float>();
for (int i = 0; i < 5; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
tensor_vec(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kMatrix).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const Matrix& output_matrix = output_packets[0].Get<Matrix>();
EXPECT_EQ(5, output_matrix.rows());
for (int i = 0; i < 5; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_matrix(i, 0));
}
const TimeSeriesHeader& output_header =
runner_->Outputs().Tag(kMatrix).header.Get<TimeSeriesHeader>();
EXPECT_EQ(output_header.num_channels(), 5);
}
} // namespace mediapipe

View File

@ -0,0 +1,109 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Calculator converts from one-dimensional Tensor of DT_FLOAT to vector<float>
// OR from (batched) two-dimensional Tensor of DT_FLOAT to vector<vector<float>.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = ::tensorflow;
class TensorToVectorFloatCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
TensorToVectorFloatCalculatorOptions options_;
};
REGISTER_CALCULATOR(TensorToVectorFloatCalculator);
::mediapipe::Status TensorToVectorFloatCalculator::GetContract(
CalculatorContract* cc) {
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<tf::Tensor>(
// Input Tensor
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
const auto& options = cc->Options<TensorToVectorFloatCalculatorOptions>();
if (options.tensor_is_2d()) {
RET_CHECK(!options.flatten_nd());
cc->Outputs().Index(0).Set<std::vector<std::vector<float>>>(
/* "Output vector<vector<float>>." */);
} else {
cc->Outputs().Index(0).Set<std::vector<float>>(
// Output vector<float>.
);
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<TensorToVectorFloatCalculatorOptions>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status TensorToVectorFloatCalculator::Process(
CalculatorContext* cc) {
const tf::Tensor& input_tensor =
cc->Inputs().Index(0).Value().Get<tf::Tensor>();
RET_CHECK(tf::DT_FLOAT == input_tensor.dtype())
<< "expected DT_FLOAT input but got "
<< tensorflow::DataTypeString(input_tensor.dtype());
if (options_.tensor_is_2d()) {
RET_CHECK(2 == input_tensor.dims())
<< "Expected 2-dimensional Tensor, but the tensor shape is: "
<< input_tensor.shape().DebugString();
auto output = absl::make_unique<std::vector<std::vector<float>>>(
input_tensor.dim_size(0), std::vector<float>(input_tensor.dim_size(1)));
for (int i = 0; i < input_tensor.dim_size(0); ++i) {
auto& instance_output = output->at(i);
const auto& slice = input_tensor.Slice(i, i + 1).unaligned_flat<float>();
for (int j = 0; j < input_tensor.dim_size(1); ++j) {
instance_output.at(j) = slice(j);
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else {
if (!options_.flatten_nd()) {
RET_CHECK(1 == input_tensor.dims())
<< "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the "
<< "tensor shape is: " << input_tensor.shape().DebugString();
}
auto output =
absl::make_unique<std::vector<float>>(input_tensor.NumElements());
const auto& tensor_values = input_tensor.flat<float>();
for (int i = 0; i < input_tensor.NumElements(); ++i) {
output->at(i) = tensor_values(i);
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorToVectorFloatCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorToVectorFloatCalculatorOptions ext = 120862965;
}
// If true, unpack a 2d tensor (matrix) into a vector<vector<float>>. If
// false, convert a 1d tensor (vector) into a vector<float>.
optional bool tensor_is_2d = 1 [default = false];
// If true, an N-D tensor will be flattened to a vector<float>. This is
// exclusive with tensor_is_2d.
optional bool flatten_nd = 2 [default = false];
}

View File

@ -0,0 +1,133 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class TensorToVectorFloatCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToVectorFloatCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
auto options = config.mutable_options()->MutableExtension(
TensorToVectorFloatCalculatorOptions::ext);
options->set_tensor_is_2d(tensor_is_2d);
options->set_flatten_nd(flatten_nd);
runner_ = absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorToVectorFloatCalculatorTest, ConvertsToVectorFloat) {
SetUpRunner(false, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto tensor_vec = tensor->vec<float>();
for (int i = 0; i < 5; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
tensor_vec(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<float>& output_vector =
output_packets[0].Get<std::vector<float>>();
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorFloatCalculatorTest, ConvertsBatchedToVectorVectorFloat) {
SetUpRunner(true, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{1, 5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto slice = tensor->Slice(0, 1).flat<float>();
for (int i = 0; i < 5; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
slice(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::vector<float>>& output_vectors =
output_packets[0].Get<std::vector<std::vector<float>>>();
ASSERT_EQ(1, output_vectors.size());
const std::vector<float>& output_vector = output_vectors[0];
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) {
SetUpRunner(false, true);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 2, 2});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
auto slice = tensor->flat<float>();
for (int i = 0; i < 2 * 2 * 2; ++i) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
slice(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<float>& output_vector =
output_packets[0].Get<std::vector<float>>();
EXPECT_EQ(2 * 2 * 2, output_vector.size());
for (int i = 0; i < 2 * 2 * 2; ++i) {
const float expected = static_cast<float>(1 << i);
EXPECT_EQ(expected, output_vector[i]);
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,540 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
namespace tf = ::tensorflow;
namespace mediapipe {
namespace {
// This is a simple implementation of a semaphore using standard C++ libraries.
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
// threads execute the graph (e.g. in a mapreduce type of job) but not to
// overload GPU/TPU/...
class SimpleSemaphore {
public:
explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {}
SimpleSemaphore(const SimpleSemaphore&) = delete;
SimpleSemaphore(SimpleSemaphore&&) = delete;
// Acquires the semaphore by certain amount.
void Acquire(uint32 amount) {
mutex_.Lock();
while (count_ < amount) {
cond_.Wait(&mutex_);
}
count_ -= amount;
mutex_.Unlock();
}
// Releases the semaphore by certain amount.
void Release(uint32 amount) {
mutex_.Lock();
count_ += amount;
cond_.SignalAll();
mutex_.Unlock();
}
private:
uint32 count_;
absl::Mutex mutex_;
absl::CondVar cond_;
};
} // namespace
// This calculator performs inference on a trained TensorFlow model.
//
// Additional documentation and examples at
// go/mediapipe/tensorflow_in_mediapipe.
//
// TensorFlow Sessions can be created from checkpoint paths, frozen models, or
// the SavedModel system (go/saved-model). See the TensorFlowSessionFrom*
// packet generators for details. Each of these methods defines a mapping
// between MediaPipe streams and TensorFlow tensors. All of this information is
// passed in as an input_side_packet.
//
// The input and output streams are TensorFlow tensors labeled by tags. The tags
// for the streams are matched to feeds and fetchs in a TensorFlow session using
// a named_signature.generic_signature in the ModelManifest. The
// generic_signature is used as key-value pairs between the MediaPipe tag and
// the TensorFlow tensor. The signature_name in the options proto determines
// which named_signature is used. The keys in the generic_signature must be
// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of
// the tensors corresponding to tags in the signature for input_streams are fed
// to the model and for output_streams the tensors are fetched from the model.
//
// Other calculators are used to convert data to and from tensors, this op only
// handles the TensorFlow session and batching. Batching occurs by concatenating
// input tensors along the 0th dimension across timestamps. If the 0th dimension
// is not a batch dimension, this calculator will add a 0th dimension by
// default. Setting add_batch_dim_to_tensors to false disables the dimension
// addition. Once batch_size inputs have been provided, the batch will be run
// and the output tensors sent out on the output streams with timestamps
// corresponding to the input stream packets. Setting the batch_size to 1
// completely disables batching, but is indepdent of add_batch_dim_to_tensors.
//
// The TensorFlowInferenceCalculator also support feeding states recurrently for
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
// recurrent tensors. Initializing the recurrent state can be handled by the
// GraphTensorsPacketGenerator.
//
// The calculator updates two Counters to report timing information:
// --<name>-TotalTimeUsecs = Total time spent running inference (in usecs),
// --<name>-TotalProcessedTimestamps = # of instances processed
// (approximately batches processed * batch_size),
// where <name> is replaced with CalculatorGraphConfig::Node::name() if it
// exists, or with TensorFlowInferenceCalculator if the name is not set. The
// name must be set for timing information to be instance-specific in graphs
// with multiple TensorFlowInferenceCalculators.
//
// Example config:
// packet_generator {
// packet_generator: "TensorFlowSessionFromSavedModelGenerator"
// output_side_packet: "tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: {
// saved_model_path: "/path/to/saved/model"
// signature_name: "mediapipe"
// }
// }
// }
// node {
// calculator: "TensorFlowInferenceCalculator"
// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag"
// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag"
// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag"
// input_side_packet: "SESSION:tensorflow_session"
// }
//
// Where the input and output streams are treated as Packet<tf::Tensor> and
// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and
// "LABELS" and their respective tensors exported to /path/to/bundle. For an
// example of how this model was exported, see
// tensorflow_inference_test_graph_generator.py
//
// It is possible to use a GraphDef proto that was not exported by exporter (i.e
// without MetaGraph with bindings). Such GraphDef could contain all of its
// parameters in-lined (for example, it can be the output of freeze_graph.py).
// To instantiate a TensorFlow model from a GraphDef file, replace the
// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator:
//
// packet_generator {
// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
// output_side_packet: "SESSION:tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
// graph_proto_path: "[PATH]"
// tag_to_tensor_names {
// key: "JPG_STRING"
// value: "input:0"
// }
// tag_to_tensor_names {
// key: "SOFTMAX"
// value: "softmax:0"
// }
// }
// }
// }
//
// It is also possible to use a GraphDef proto and checkpoint file that have not
// been frozen. This can be used to load graphs directly as they have been
// written from training. However, it is more brittle and you are encouraged to
// use a one of the more perminent formats described above. To instantiate a
// TensorFlow model from a GraphDef file and checkpoint, replace the
// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator:
//
// packet_generator {
// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator"
// output_side_packet: "SESSION:tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: {
// graph_proto_path: "[PATH]"
// model_options {
// checkpoint_path: "[PATH2]"
// }
// tag_to_tensor_names {
// key: "JPG_STRING"
// value: "input:0"
// }
// tag_to_tensor_names {
// key: "SOFTMAX"
// value: "softmax:0"
// }
// }
// }
// }
class TensorFlowInferenceCalculator : public CalculatorBase {
public:
// Counters for recording timing information. The actual names have the value
// of CalculatorGraphConfig::Node::name() prepended.
static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs";
static constexpr char kTotalProcessedTimestampsCounterSuffix[] =
"TotalProcessedTimestamps";
static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] =
"TotalSessionRunsTimeUsecs";
static constexpr char kTotalNumSessionRunsCounterSuffix[] =
"TotalNumSessionRuns";
TensorFlowInferenceCalculator() : session_(nullptr) {
clock_ = std::unique_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
}
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
for (const std::string& tag : cc->Inputs().GetTags()) {
// The tensorflow::Tensor with the tag equal to the graph node. May
// have a TimeSeriesHeader if all present TimeSeriesHeaders match.
cc->Inputs().Tag(tag).Set<tf::Tensor>();
}
RET_CHECK(!cc->Outputs().GetTags().empty());
for (const std::string& tag : cc->Outputs().GetTags()) {
// The tensorflow::Tensor with tag equal to the graph node to
// output. Any TimeSeriesHeader from the inputs will be forwarded
// with channels set to 0.
cc->Outputs().Tag(tag).Set<tf::Tensor>();
}
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
// For this calculator it must include a tag_to_tensor_map.
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
cc->InputSidePackets()
.Tag("RECURRENT_INIT_TENSORS")
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
session_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.session.get();
tag_to_tensor_map_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.tag_to_tensor_map;
// Validate and store the recurrent tags
RET_CHECK(options_.has_batch_size());
RET_CHECK(options_.batch_size() == 1 ||
options_.recurrent_tag_pair().empty())
<< "To use recurrent_tag_pairs, batch_size must be 1.";
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
RET_CHECK_EQ(tags.size(), 2)
<< "recurrent_tag_pair must be a colon "
"separated std::string with two components: "
<< tag_pair;
RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
<< "Can't find tag '" << tags[0] << "' in signature "
<< options_.signature_name();
RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
<< "Can't find tag '" << tags[1] << "' in signature "
<< options_.signature_name();
recurrent_feed_tags_.insert(tags[0]);
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
}
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
for (const auto& p : *init_tensor_map) {
input_tensor_batches_[p.first].emplace_back(p.second);
}
}
// Check that all tags are present in this signature bound to tensors.
for (const std::string& tag : cc->Inputs().GetTags()) {
RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
for (const std::string& tag : cc->Outputs().GetTags()) {
RET_CHECK(::mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
if (options_.batch_size() == 1) {
cc->SetOffset(0);
}
return ::mediapipe::OkStatus();
}
// Adds a batch dimension to the input tensor if specified in the calculator
// options.
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(input_tensor->shape());
new_shape.InsertDim(0, 1);
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
<< "Could not add 0th dimension to tensor without changing its shape."
<< " Current shape: " << input_tensor->shape().DebugString();
}
return ::mediapipe::OkStatus();
}
// Removes the batch dimension of the output tensor if specified in the
// calculator options.
::mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(output_tensor->shape());
new_shape.RemoveDim(0);
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
<< "Could not remove 0th dimension from tensor without changing its "
<< "shape. Current shape: " << output_tensor->shape().DebugString()
<< " (The expected first dimension is 1 for a batch element.)";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
std::map<std::string, tf::Tensor> input_tensors_by_tag;
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
// Recurrent tensors can be empty.
if (!::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
if (options_.skip_on_missing_features()) {
return ::mediapipe::OkStatus();
} else {
return ::mediapipe::InvalidArgumentError(absl::StrCat(
"Tag ", tag_as_node_name,
" not present at timestamp: ", cc->InputTimestamp().Value()));
}
}
} else {
tf::Tensor input_tensor(
cc->Inputs().Tag(tag_as_node_name).Get<tf::Tensor>());
RET_CHECK_OK(AddBatchDimension(&input_tensor));
if (::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
// If we receive an input on a recurrent tag, override the state.
// It's OK to override the global state because there is just one
// input stream allowed for recurrent tensors.
input_tensor_batches_[tag_as_node_name].clear();
}
input_tensors_by_tag.insert(
std::make_pair(tag_as_node_name, input_tensor));
}
}
batch_timestamps_.emplace_back(cc->InputTimestamp());
for (const auto& input_tensor_and_tag : input_tensors_by_tag) {
input_tensor_batches_[input_tensor_and_tag.first].emplace_back(
input_tensor_and_tag.second);
}
if (batch_timestamps_.size() == options_.batch_size()) {
RETURN_IF_ERROR(OutputBatch(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) override {
if (!batch_timestamps_.empty()) {
RETURN_IF_ERROR(OutputBatch(cc));
}
return ::mediapipe::OkStatus();
}
// When a batch of input tensors is ready to be run, runs TensorFlow and
// outputs the output tensors. The output tensors have timestamps matching
// the input tensor that formed that batch element. Any requested
// batch_dimension is added and removed. This code takes advantage of the fact
// that copying a tensor shares the same reference-counted, heap allocated
// memory buffer. Therefore, copies are cheap and should not cause the memory
// buffer to fall out of scope. In contrast, concat is only used where
// necessary.
::mediapipe::Status OutputBatch(CalculatorContext* cc) {
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
for (auto& keyed_tensors : input_tensor_batches_) {
if (options_.batch_size() == 1) {
// Short circuit to avoid the cost of deep copying tensors in concat.
if (!keyed_tensors.second.empty()) {
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
keyed_tensors.second[0]);
} else {
// The input buffer can be empty for recurrent tensors.
RET_CHECK(::mediapipe::ContainsKey(recurrent_feed_tags_,
keyed_tensors.first))
<< "A non-recurrent tensor does not have an input: "
<< keyed_tensors.first;
}
} else {
// Pad by replicating the first tens or, then ignore the values.
keyed_tensors.second.resize(options_.batch_size());
std::fill(keyed_tensors.second.begin() + batch_timestamps_.size(),
keyed_tensors.second.end(), keyed_tensors.second[0]);
tf::Tensor concated;
const tf::Status concat_status =
tf::tensor::Concat(keyed_tensors.second, &concated);
CHECK(concat_status.ok()) << concat_status.ToString();
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
concated);
}
}
input_tensor_batches_.clear();
std::vector<mediapipe::ProtoString> output_tensor_names;
std::vector<std::string> output_name_in_signature;
for (const std::string& tag : cc->Outputs().GetTags()) {
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
output_name_in_signature.emplace_back(tag);
}
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
// Ensure that we always fetch the recurrent state tensors.
if (std::find(output_name_in_signature.begin(),
output_name_in_signature.end(),
tag_pair.first) == output_name_in_signature.end()) {
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
output_name_in_signature.emplace_back(tag_pair.first);
}
}
std::vector<tf::Tensor> outputs;
SimpleSemaphore* session_run_throttle = nullptr;
if (options_.max_concurrent_session_runs() > 0) {
session_run_throttle =
get_session_run_throttle(options_.max_concurrent_session_runs());
session_run_throttle->Acquire(1);
}
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
const tf::Status tf_status =
session_->Run(input_tensors, output_tensor_names,
{} /* target_node_names */, &outputs);
if (session_run_throttle != nullptr) {
session_run_throttle->Release(1);
}
// RET_CHECK on the tf::Status object itself in order to print an
// informative error message.
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.error_message();
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
->IncrementBy(run_end_time - run_start_time);
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
// Feed back the recurrent state.
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
int pos = std::find(output_name_in_signature.begin(),
output_name_in_signature.end(), tag_pair.first) -
output_name_in_signature.begin();
input_tensor_batches_[tag_pair.second].emplace_back(outputs[pos]);
}
// Set that we want to split on each index of the 0th dimension.
std::vector<tf::int64> split_vector(options_.batch_size(), 1);
for (int i = 0; i < output_tensor_names.size(); ++i) {
if (options_.batch_size() == 1) {
if (cc->Outputs().HasTag(output_name_in_signature[i])) {
tf::Tensor output_tensor(outputs[i]);
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs()
.Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), batch_timestamps_[0]);
}
} else {
std::vector<tf::Tensor> split_tensors;
const tf::Status split_status =
tf::tensor::Split(outputs[i], split_vector, &split_tensors);
CHECK(split_status.ok()) << split_status.ToString();
// Loop over timestamps so that we don't copy the padding.
for (int j = 0; j < batch_timestamps_.size(); ++j) {
tf::Tensor output_tensor(split_tensors[j]);
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs()
.Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), batch_timestamps_[j]);
}
}
}
// Get end time and report.
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalUsecsCounterSuffix)
->IncrementBy(end_time - start_time);
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
->IncrementBy(batch_timestamps_.size());
batch_timestamps_.clear();
return ::mediapipe::OkStatus();
}
private:
// The Session object is provided by a packet factory and is owned by the
// MediaPipe framework. Individual calls are thread-safe, but session state
// may be shared across threads.
tf::Session* session_;
// A mapping between stream tags and the tensor names they are bound to.
std::map<std::string, std::string> tag_to_tensor_map_;
// A mapping between stream tags and the tensors we are collecting as a batch.
std::map<std::string, std::vector<tf::Tensor>> input_tensor_batches_;
// The timestamps that go into a batch.
std::vector<Timestamp> batch_timestamps_;
// The options for the calculator.
TensorFlowInferenceCalculatorOptions options_;
// Store the feed and fetch tags for feed/fetch recurrent networks.
std::set<std::string> recurrent_feed_tags_;
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
// Clock used to measure the computation time in OutputBatch().
std::unique_ptr<mediapipe::Clock> clock_;
// The static singleton semaphore to throttle concurrent session runs.
static SimpleSemaphore* get_session_run_throttle(
int32 max_concurrent_session_runs) {
static SimpleSemaphore* session_run_throttle =
new SimpleSemaphore(max_concurrent_session_runs);
return session_run_throttle;
}
};
REGISTER_CALCULATOR(TensorFlowInferenceCalculator);
constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];
constexpr char
TensorFlowInferenceCalculator::kTotalProcessedTimestampsCounterSuffix[];
constexpr char
TensorFlowInferenceCalculator::kTotalSessionRunsTimeUsecsCounterSuffix[];
constexpr char
TensorFlowInferenceCalculator::kTotalNumSessionRunsCounterSuffix[];
} // namespace mediapipe

View File

@ -0,0 +1,79 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorFlowInferenceCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorFlowInferenceCalculatorOptions ext = 113766539;
}
// The signature_name specifies the mapping between stream tags and TensorFlow
// session tensors. The mapping between tags and tensors is encoded as a
// ModelManifest signature. The signatures are keyed by name and the
// named_signature matching signature_name is used by the calculator to
// match stream tags to tensors. The named_signature must be a
// ModelManifest.generic_signature with map keys that are valid tags (i.e.
// [A-Z0-9]*).
optional string signature_name = 1;
// How many elements to batch together and feed into the graph.
// Setting the batch_size to 1 disables batching entirely. You still may or
// may not need to add the batch dimension via the option below depending on
// the input data shape and the model's expectations.
optional int32 batch_size = 2;
// Whether to add a 0th dimension to the input tensors for batching.
// If the 0th dimension is the batch dimension, then the tensors are
// concatenated on that dimension. If the 0th is a data dimension, then a 0th
// dimension is added before concatenating. If added, the extra dimension is
// removed before outputing the tensor. Examples of each case: If you want
// to batch spectra of audio over time for an LSTM, a time-frequency
// representation has a 0th dimension as the batch dimension. If you want to
// batch frames of video that are [width, height, channels], the batch
// dimension needs to be added.
optional bool add_batch_dim_to_tensors = 3 [default = true];
// These pairs represent feed and fetch tensors for handling recurrent state.
// Each entry is a colon separated pair of strings. The first half of each
// string is the signature tag for the feed tensor for recurrent state. The
// second half of the string is the signature tag for the fetch tensor for the
// recurrent state. More than two colon separated strings is an error. During
// inference, The fetch tensor is fetched at every timestep and will be output
// if there is a corresponding output stream. The behavior of the feed tensor
// is determined by the following conditions in order: If the MediaPipe input
// stream with the matching tag has a packet available, then the input
// packet's tensor is passed in. If no input packet is available and we have
// fetched a tensor from the previous time step, we will feed the tensor from
// the previous timestep back in. If neither tensor is available, no tensor
// will be fed into the model.
// If this flag is set, batch_size must be 1. Do not list recurrent_tag_pair
// tags as initial_state_tags because those are only fed once.
repeated string recurrent_tag_pair = 4;
// If set to true, skips input for which any of the features are missing.
// If set to false, requires that all input features to be available. If not,
// it will report an error for the calculator.
optional bool skip_on_missing_features = 5 [default = false];
// Maximum allowed concurrent Tensorflow session run calls in the calculator
// to avoid overloading local compute hardware such as TPU. Note that this
// only works in the local process, not "globally" across multiple processes
// or replicas (if any). Default to 0, i.e. no limit.
optional int32 max_concurrent_session_runs = 6 [default = 0];
}

View File

@ -0,0 +1,512 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#ifdef __APPLE__
#include <CoreFoundation/CoreFoundation.h>
#endif // defined(__APPLE__)
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
std::string GetGraphDefPath() {
#ifdef __APPLE__
char path[1024];
CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle());
CFURLGetFileSystemRepresentation(
bundle_url, true, reinterpret_cast<UInt8*>(path), sizeof(path));
CFRelease(bundle_url);
return ::mediapipe::file::JoinPath(path, "testdata/frozen_graph_def.pb");
#elif defined(__ANDROID__)
char path[1024];
getcwd(path, sizeof(path));
return ::mediapipe::file::JoinPath(path,
"mediapipe/calculators/tensorflow/"
"testdata/frozen_graph_def.pb");
#else
return ::mediapipe::file::JoinPath(
"./",
// This should match the path of the output files
// of the genrule() that generates test model files.
"mediapipe/calculators/tensorflow/testdata/", "frozen_graph_def.pb");
#endif // defined(__APPLE__)
}
} // namespace
class TensorflowInferenceCalculatorTest : public ::testing::Test {
protected:
// Add the input side packet.
void AddSessionInputSidePacket() {
PacketGeneratorOptions extendable_options;
TensorFlowSessionFromFrozenGraphGeneratorOptions* generator_options;
generator_options = extendable_options.MutableExtension(
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
generator_options->set_graph_proto_path(GetGraphDefPath());
(*generator_options->mutable_tag_to_tensor_names())["MULTIPLIED"] =
"multiplied:0";
(*generator_options->mutable_tag_to_tensor_names())["A"] = "a:0";
(*generator_options->mutable_tag_to_tensor_names())["B"] = "b:0";
(*generator_options->mutable_tag_to_tensor_names())["EXPENSIVE"] =
"expensive:0";
PacketSet input_side_packets({});
PacketSet output_side_packets({"SESSION"});
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
input_side_packets, &output_side_packets));
runner_->MutableSidePackets()->Tag("SESSION") =
output_side_packets.Tag("SESSION");
}
// Create tensor from Vector and add as a Packet to the provided tag as input.
void AddVectorToInputsAsTensor(const std::vector<int32>& input,
const std::string& tag, int64 time) {
tf::TensorShape tensor_shape;
tensor_shape.AddDim(input.size());
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_INT32, tensor_shape);
for (int i = 0; i < input.size(); ++i) {
tensor->vec<int32>()(i) = input[i];
}
runner_->MutableInputs()->Tag(tag).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_in");
config.add_output_stream("B:tensor_out");
config.add_output_stream("MULTIPLIED:tensor_multiplied");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(false);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({0, 0, 0}, "A", 0);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_b =
runner_->Outputs().Tag("B").packets;
ASSERT_EQ(output_packets_b.size(), 1);
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({1, 3});
auto expected_tensor = tf::test::AsTensor<int32>({3, 2, 1}, expected_shape);
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_mult);
EXPECT_EQ(1, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(false);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3});
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}, expected_shape);
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_mult);
// Add only one of the two expected tensors at the next timestamp, expect
// useful failure message.
AddVectorToInputsAsTensor({1, 2, 3}, "A", 1);
auto run_status = runner_->Run();
ASSERT_FALSE(run_status.ok());
EXPECT_THAT(run_status.ToString(),
testing::HasSubstr("TensorFlowInferenceCalculator"));
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B"));
}
TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("BAD:tensor_in"); // This one is bad.
config.add_output_stream("B:tensor_out");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
ASSERT_FALSE(runner_->Run().ok());
}
TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(3);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->add_recurrent_tag_pair("A:MULTIPLIED");
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0;
auto expected_tensor = tf::test::AsTensor<int32>({3, 8, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 32, 75});
LOG(INFO) << "timestamp: " << 1;
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->add_recurrent_tag_pair("A:MULTIPLIED");
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({1, 1, 1}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({1, 1, 1}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0;
auto expected_tensor = tf::test::AsTensor<int32>({3, 4, 5});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({3, 4, 5});
LOG(INFO) << "timestamp: " << 1;
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
// TODO: Investigate this test failure.
TEST_F(TensorflowInferenceCalculatorTest, DISABLED_CheckTiming) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_in");
config.add_output_stream("EXPENSIVE:tensor_expensive");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(false);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({0, 0, 0}, "A", 0);
MEDIAPIPE_ASSERT_OK(runner_->Run());
EXPECT_EQ(1, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
// We only test the timing counter here because we are requesting an
// expensive tensor output. Because the precision on android is
// sometimes closer to milliseconds, we need to request a large tensor
// to be sure this will be greater than zero.
EXPECT_GT(runner_->GetCounter("TensorFlowInferenceCalculator-TotalTimeUsecs")
->Get(),
0);
}
TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_skip_on_missing_features(false);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
ASSERT_FALSE(runner_->Run().ok());
}
TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_skip_on_missing_features(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(0, output_packets_mult.size());
}
TEST_F(TensorflowInferenceCalculatorTest,
MissingInputFeature_SkipCheckInternalState) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_skip_on_missing_features(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
EXPECT_EQ(1, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
} // namespace mediapipe

View File

@ -0,0 +1,35 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_
#define MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_
#include <memory>
#include "tensorflow/core/public/session.h"
namespace mediapipe {
struct TensorFlowSession {
// TensorFlow session wrapper to get around the RTTI issue.
std::unique_ptr<tensorflow::Session> session;
// Store an optional mapping to the between MediaPipe tags and TensorFlow
// tensor names. Creating this mapping when the session is loaded allows more
// flexible definition of mapping tags to tensors across platforms.
std::map<std::string, std::string> tag_to_tensor_map;
};
} // namespace mediapipe
#endif // MEDIAPIPE_TENSORFLOW_CALCULATORS_TENSORFLOW_SESSION_H_

View File

@ -0,0 +1,131 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Reads serialized GraphDef proto. There are three ways to load a model:
// 1. Specify the path to a graph.pb in the calculator options.
// 2. Specify the path to the graph.pb through the
// input_side_packet:STRING_MODEL_FILE_PATH
// 3. Provide a serialized GraphDef through input_side_packet:STRING_MODEL,
// typically provided by EmbeddingFilePacketFactory.
//
// See tensorflow_session_bundle_from_graph_generator.proto for options.
// Produces a SessionBundle that TensorFlowInferenceCalculator can use.
#include <string>
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/public/session_options.h"
namespace mediapipe {
namespace tf = ::tensorflow;
class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
public:
static ::mediapipe::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
RET_CHECK(extendable_options.HasExtension(
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext));
const auto& options = extendable_options.GetExtension( // NOLINT
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
bool has_exactly_one_model =
!options.graph_proto_path().empty()
? !(input_side_packets->HasTag("STRING_MODEL") |
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"))
: (input_side_packets->HasTag("STRING_MODEL") ^
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"));
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of graph_proto_path in options or "
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
if (input_side_packets->HasTag("STRING_MODEL")) {
input_side_packets->Tag("STRING_MODEL")
.Set<std::string>(
// String model from embedded path
);
} else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) {
input_side_packets->Tag("STRING_MODEL_FILE_PATH")
.Set<std::string>(
// Filename of std::string model.
);
}
output_side_packets->Tag("SESSION").Set<TensorFlowSession>(
// A TensorFlow model loaded and ready for use along with
// a map from tags to tensor names.
);
RET_CHECK_GT(options.tag_to_tensor_names().size(), 0);
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status Generate(
const PacketGeneratorOptions& packet_generator_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) {
const TensorFlowSessionFromFrozenGraphGeneratorOptions& options =
packet_generator_options.GetExtension(
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
// Output bundle packet.
auto session = ::absl::make_unique<TensorFlowSession>();
tf::SessionOptions session_options;
session_options.config.CopyFrom(options.config());
std::vector<mediapipe::ProtoString> initialization_op_names;
initialization_op_names.reserve(options.initialization_op_names_size());
for (int i = 0; i < options.initialization_op_names_size(); ++i) {
initialization_op_names.emplace_back(options.initialization_op_names(i));
}
session->session.reset(tf::NewSession(session_options));
std::string graph_def_serialized;
if (input_side_packets.HasTag("STRING_MODEL")) {
graph_def_serialized =
input_side_packets.Tag("STRING_MODEL").Get<std::string>();
} else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) {
const std::string& frozen_graph =
input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get<std::string>();
RET_CHECK_OK(
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
} else {
RET_CHECK_OK(mediapipe::file::GetContents(options.graph_proto_path(),
&graph_def_serialized));
}
tensorflow::GraphDef graph_def;
RET_CHECK(graph_def.ParseFromString(graph_def_serialized));
const tf::Status tf_status = session->session->Create(graph_def);
RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.error_message();
for (const auto& key_value : options.tag_to_tensor_names()) {
session->tag_to_tensor_map[key_value.first] = key_value.second;
}
if (!initialization_op_names.empty()) {
const tf::Status tf_status =
session->session->Run({}, {}, initialization_op_names, {});
// RET_CHECK on the tf::Status object itself in order to print an
// informative error message.
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.error_message();
}
output_side_packets->Tag("SESSION") = Adopt(session.release());
return ::mediapipe::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(TensorFlowSessionFromFrozenGraphGenerator);
} // namespace mediapipe

View File

@ -0,0 +1,72 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
import "tensorflow/core/protobuf/config.proto";
message TensorFlowSessionFromFrozenGraphGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional TensorFlowSessionFromFrozenGraphGeneratorOptions ext = 160666123;
}
// Path to file containing serialized proto of type tensorflow::GraphDef.
optional string graph_proto_path = 1;
// To run inference with MediaPipe inputs MediaPipe streams need to be mapped
// to TensorFlow tensors. This map defines the which streams are fed into
// which tensors in the model. The MediaPipe tag of the stream is the map key.
// Tags must be capitalized, matching regex [A-Z0-9_]+. Examples: "JPG_STRING"
// and "SOFTMAX". Then, those tags can be used as the MediaPipe tags of
// input_stream or output_stream of the TensorflowInferenceCalculator
// consuming the packet produced by this generator. The tensor names must
// match the tensor names in the graph that you want to feed or fetch into or
// out of. Examples: "DecodeJpeg/contents:0" or "softmax:0". For example, a
// mediapipe graph can include the nodes:
//
// packet_generator {
// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
// output_side_packet: "SESSION:session"
// options {
// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
// graph_proto_path: "[PATH]"
// tag_to_tensor_names {
// key: "JPG_STRING"
// value: "input:0"
// }
// tag_to_tensor_names {
// key: "SOFTMAX"
// value: "softmax:0"
// }
// }
// }
// }
// node {
// calculator: "TensorflowInferenceCalculator"
// input_side_packet: "SESSION:graph_with_bindings"
// input_stream: "JPG_STRING:jpg_string_tensor"
// output_stream: "SOFTMAX:softmax_tensor"
// }
map<string, string> tag_to_tensor_names = 2;
// Tensorflow session config options.
optional tensorflow.ConfigProto config = 3;
// Graph nodes to run to initialize the model. Any output of these ops is
// ignored.
repeated string initialization_op_names = 4;
}

View File

@ -0,0 +1,288 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/tag_map_helper.h"
#include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
std::string GetGraphDefPath() {
return mediapipe::file::JoinPath("./",
"mediapipe/calculators/tensorflow/"
"testdata/frozen_graph_def.pb");
}
// Helper function that creates Tensor INT32 matrix with size 1x3.
tf::Tensor TensorMatrix1x3(const int v1, const int v2, const int v3) {
tf::Tensor tensor(tf::DT_INT32,
tf::TensorShape(std::vector<tf::int64>({1, 3})));
auto matrix = tensor.matrix<int32>();
matrix(0, 0) = v1;
matrix(0, 1) = v2;
matrix(0, 2) = v3;
return tensor;
}
class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
protected:
void SetUp() override {
extendable_options_.Clear();
generator_options_ = extendable_options_.MutableExtension(
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
generator_options_->set_graph_proto_path(GetGraphDefPath());
(*generator_options_->mutable_tag_to_tensor_names())["MULTIPLIED"] =
"multiplied:0";
(*generator_options_->mutable_tag_to_tensor_names())["A"] = "a:0";
(*generator_options_->mutable_tag_to_tensor_names())["B"] = "b:0";
generator_options_->mutable_config()->set_intra_op_parallelism_threads(1);
generator_options_->mutable_config()->set_inter_op_parallelism_threads(2);
}
void VerifySignatureMap(PacketSet* output_side_packets) {
const TensorFlowSession& session =
output_side_packets->Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
// Bindings are inserted.
EXPECT_EQ(session.tag_to_tensor_map.size(), 3);
// For some reason, EXPECT_EQ and EXPECT_NE are not working with iterators.
EXPECT_FALSE(session.tag_to_tensor_map.find("A") ==
session.tag_to_tensor_map.end());
EXPECT_FALSE(session.tag_to_tensor_map.find("B") ==
session.tag_to_tensor_map.end());
EXPECT_FALSE(session.tag_to_tensor_map.find("MULTIPLIED") ==
session.tag_to_tensor_map.end());
// Sanity: find() actually returns a reference to end() if element not
// found.
EXPECT_TRUE(session.tag_to_tensor_map.find("Z") ==
session.tag_to_tensor_map.end());
EXPECT_EQ(session.tag_to_tensor_map.at("A"), "a:0");
EXPECT_EQ(session.tag_to_tensor_map.at("B"), "b:0");
EXPECT_EQ(session.tag_to_tensor_map.at("MULTIPLIED"), "multiplied:0");
}
PacketGeneratorOptions extendable_options_;
TensorFlowSessionFromFrozenGraphGeneratorOptions* generator_options_;
};
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CreatesPacketWithGraphAndBindings) {
PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MEDIAPIPE_EXPECT_OK(run_status) << run_status.message();
VerifySignatureMap(&output_side_packets);
}
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
// consumes the Packet emitted by this generator.
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
ProducesPacketUsableByTensorFlowInferenceCalculator) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
node {
calculator: "TensorFlowInferenceCalculator"
input_side_packet: "SESSION:tf_model"
input_stream: "A:a_tensor"
output_stream: "MULTIPLIED:multiplied_tensor"
options {
[mediapipe.TensorFlowInferenceCalculatorOptions.ext] {
batch_size: 5
add_batch_dim_to_tensors: false
}
}
}
packet_generator {
packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
$0
}
}
}
input_stream: "a_tensor"
)",
generator_options_->DebugString()));
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(config));
StatusOrPoller status_or_poller =
graph.AddOutputStreamPoller("multiplied_tensor");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie());
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
"a_tensor",
Adopt(new auto(TensorMatrix1x3(1, -1, 10))).At(Timestamp(0))));
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("a_tensor"));
Packet packet;
ASSERT_TRUE(poller.Next(&packet));
// input tensor gets multiplied by [[3, 2, 1]]. Expected output:
tf::Tensor expected_multiplication = TensorMatrix1x3(3, -2, 10);
EXPECT_EQ(expected_multiplication.DebugString(),
packet.Get<tf::Tensor>().DebugString());
ASSERT_FALSE(poller.Next(&packet));
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CreatesPacketWithGraphAndBindingsFromInputSidePacket) {
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_MODEL:model"}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
std::string serialized_graph_contents;
MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL") =
Adopt(new std::string(serialized_graph_contents));
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MEDIAPIPE_EXPECT_OK(run_status) << run_status.message();
VerifySignatureMap(&output_side_packets);
}
TEST_F(
TensorFlowSessionFromFrozenGraphGeneratorTest,
CreatesPacketWithGraphAndBindingsFromInputSidePacketStringModelFilePath) {
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
Adopt(new std::string(GetGraphDefPath()));
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MEDIAPIPE_EXPECT_OK(run_status) << run_status.message();
VerifySignatureMap(&output_side_packets);
}
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CheckFailureForOptionsAndInputsProvideGraphDefProto) {
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
Adopt(new std::string(GetGraphDefPath()));
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal);
EXPECT_THAT(
run_status.message(),
::testing::HasSubstr("Must have exactly one of graph_proto_path"));
}
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CheckFailureForAllInputsProvideGraphDefProto) {
PacketSet input_side_packets(
tool::CreateTagMap(
{"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"})
.ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
std::string serialized_graph_contents;
MEDIAPIPE_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") =
Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
Adopt(new std::string(GetGraphDefPath()));
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal);
EXPECT_THAT(
run_status.message(),
::testing::HasSubstr("Must have exactly one of graph_proto_path"));
}
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CheckFailureForOnlyBothInputSidePacketsProvideGraphDefProto) {
PacketSet input_side_packets(
tool::CreateTagMap(
{"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"})
.ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
std::string serialized_graph_contents;
EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") =
Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
Adopt(new std::string(GetGraphDefPath()));
generator_options_->clear_graph_proto_path();
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
EXPECT_EQ(run_status.code(), ::mediapipe::StatusCode::kInternal);
EXPECT_THAT(
run_status.message(),
::testing::HasSubstr("Must have exactly one of graph_proto_path"));
}
TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
CheckInitializationOpName) {
PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).ValueOrDie());
generator_options_->add_initialization_op_names("multiplied:0");
::mediapipe::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MEDIAPIPE_EXPECT_OK(run_status);
VerifySignatureMap(&output_side_packets);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,176 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#if defined(MEDIAPIPE_TPU_SUPPORT)
#include "learning/brain/google/xla/global_tpu_init.h"
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
#endif
#if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h"
#endif
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
namespace mediapipe {
namespace {
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models
// in subdirectories, replaces path with the alphabetically last subdirectory.
::mediapipe::Status GetLatestDirectory(std::string* path) {
#if defined(__ANDROID__)
return ::mediapipe::UnimplementedError(
"GetLatestDirectory is not implemented on Android");
#else
std::vector<std::string> saved_models;
RET_CHECK_OK(file::MatchInTopSubdirectories(
*path, tensorflow::kSavedModelFilenamePb, &saved_models));
RET_CHECK_GT(saved_models.size(), 0)
<< "No exported bundles found in " << path;
::std::sort(saved_models.begin(), saved_models.end());
*path = std::string(file::Dirname(saved_models.back()));
return ::mediapipe::OkStatus();
#endif
}
// If options.convert_signature_to_tags() will convert letters to uppercase
// and replace /'s with _'s. If set, this enables the standard SavedModel
// classification, regression, and prediction signatures to be used as
// uppercase INPUTS and OUTPUTS tags for streams.
const std::string MaybeConvertSignatureToTag(
const std::string& name,
const TensorFlowSessionFromSavedModelCalculatorOptions& options) {
if (options.convert_signature_to_tags()) {
std::string output;
output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); });
output = absl::Substitute(output, "/", "_");
return output;
} else {
return name;
}
}
} // namespace
// TensorFlowSessionFromSavedModelCalculator is a MediaPipe packet calculator
// that loads a trained TensorFlow model exported via SavedModel's exporter (see
// go/savedmodel) and returns a Packet containing a unique_ptr to a
// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session
// ready for execution and a map between tags and tensor names.
//
// Example usage:
// node {
// calculator: "TensorFlowSessionFromSavedModelCalculator"
// output_side_packet: "SESSION:vod_session"
// options {
// [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
// signature_name: "serving_default"
// saved_model_path: "path/to/model"
// }
// }
// }
class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
const auto& options =
cc->Options<TensorFlowSessionFromSavedModelCalculatorOptions>();
const bool has_exactly_one_model =
options.saved_model_path().empty() ==
cc->InputSidePackets().HasTag(kStringSavedModelPath);
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of saved model filepath in options or "
"input_side_packets STRING_MODEL_FILE_PATH";
// Path of savedmodel.
if (cc->InputSidePackets().HasTag(kStringSavedModelPath)) {
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const auto& options =
cc->Options<TensorFlowSessionFromSavedModelCalculatorOptions>();
std::string path = cc->InputSidePackets().HasTag(kStringSavedModelPath)
? cc->InputSidePackets()
.Tag(kStringSavedModelPath)
.Get<std::string>()
: options.saved_model_path();
if (options.load_latest_model()) {
RET_CHECK_OK(GetLatestDirectory(&path));
}
// Set user specified tags properly.
// If no tags specified will use tensorflow::kSavedModelTagServe by default.
std::unordered_set<std::string> tags_set;
for (std::string tag : options.saved_model_tag()) {
tags_set.insert(tag);
}
if (tags_set.empty()) {
tags_set.insert(tensorflow::kSavedModelTagServe);
}
tensorflow::RunOptions run_options;
// In the future, could construct session options from the options proto.
tensorflow::SessionOptions session_options;
auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
::tensorflow::Status status = tensorflow::LoadSavedModel(
session_options, run_options, path, tags_set, saved_model.get());
if (!status.ok()) {
return ::mediapipe::Status(
static_cast<::mediapipe::StatusCode>(status.code()),
status.error_message());
}
auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session);
RET_CHECK(!options.signature_name().empty());
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
const auto& signature_def = signature_def_map.at(options.signature_name());
for (const auto& input_signature : signature_def.inputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
input_signature.first, options)] = input_signature.second.name();
}
for (const auto& output_signature : signature_def.outputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
output_signature.first, options)] = output_signature.second.name();
}
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(TensorFlowSessionFromSavedModelCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,58 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorFlowSessionFromSavedModelCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorFlowSessionFromSavedModelCalculatorOptions ext = 244429915;
}
// TODO: SessionBundles provided global step versioning of models
// that let you load the latest model. If there's a similar solution for
// SavedModels, include a flag to load the most recent model.
// Path to a directory containing a trained TensorFlow model as prepared
// by SavedModel (go/saved-model).
optional string saved_model_path = 1;
// The name of the generic signature to load into the mapping from tags to
// tensor names.
optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase and switch /'s to
// _'s, which enables standard signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
// latest checkpoint) model is loaded. Note that saved models are not exported
// in numbered directories by default. If you want to use this feature, you
// need to arrange your directories by global_step or some other order when
// you save your models.
optional bool load_latest_model = 4;
// [DEPRECATED] If true, this calculator will try to initialize local Tensor
// Processing Unit (TPU) hardware so that the Tensorflow session loaded from
// this saved model may benefit from TPU speedups. If you want to use this
// feature, you need to make sure that the calculator runs on a machine that
// has TPU hardware installed. The saved model should have correct device
// placements in the graph (have the ops already placed on TPU), typically if
// the saved model was exported through TPUEstimator then device placement is
// automatically taken care of.
optional bool use_tpu = 5 [deprecated = true];
// User specified tags in a saved model.
// If no tag is specified, then use "serve" as the default. Note that in order
// to use TPU accelerator hardware, the tag "tpu" needs to be specified.
repeated string saved_model_tag = 6;
}

View File

@ -0,0 +1,208 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/tag_map_helper.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
std::string GetSavedModelDir() {
std::string out_path =
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
"tensorflow_saved_model/00000000");
return out_path;
}
// Helper function that creates Tensor INT32 matrix with size 1x3.
tf::Tensor TensorMatrix1x3(const int v1, const int v2, const int v3) {
tf::Tensor tensor(tf::DT_INT32,
tf::TensorShape(std::vector<tf::int64>({1, 3})));
auto matrix = tensor.matrix<int32>();
matrix(0, 0) = v1;
matrix(0, 1) = v2;
matrix(0, 2) = v3;
return tensor;
}
class TensorFlowSessionFromSavedModelCalculatorTest : public ::testing::Test {
protected:
void SetUp() override {
extendable_options_.Clear();
options_ = extendable_options_.MutableExtension(
TensorFlowSessionFromSavedModelCalculatorOptions::ext);
options_->set_saved_model_path(GetSavedModelDir());
}
CalculatorOptions extendable_options_;
TensorFlowSessionFromSavedModelCalculatorOptions* options_;
};
TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
CreatesPacketWithGraphAndBindings) {
CalculatorRunner runner(absl::Substitute(R"(
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
$0
}
})",
options_->DebugString()));
MEDIAPIPE_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
// Bindings are inserted.
EXPECT_EQ(session.tag_to_tensor_map.size(), 4);
// For some reason, EXPECT_EQ and EXPECT_NE are not working with iterators.
EXPECT_FALSE(session.tag_to_tensor_map.find("A") ==
session.tag_to_tensor_map.end());
EXPECT_FALSE(session.tag_to_tensor_map.find("B") ==
session.tag_to_tensor_map.end());
EXPECT_FALSE(session.tag_to_tensor_map.find("MULTIPLIED") ==
session.tag_to_tensor_map.end());
EXPECT_FALSE(session.tag_to_tensor_map.find("EXPENSIVE") ==
session.tag_to_tensor_map.end());
// Sanity: find() actually returns a reference to end() if element not
// found.
EXPECT_TRUE(session.tag_to_tensor_map.find("Z") ==
session.tag_to_tensor_map.end());
EXPECT_EQ(session.tag_to_tensor_map.at("A"), "a:0");
EXPECT_EQ(session.tag_to_tensor_map.at("B"), "b:0");
EXPECT_EQ(session.tag_to_tensor_map.at("MULTIPLIED"), "multiplied:0");
EXPECT_EQ(session.tag_to_tensor_map.at("EXPENSIVE"), "expensive:0");
}
TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
CreateSessionFromSidePacket) {
options_->clear_saved_model_path();
CalculatorRunner runner(absl::Substitute(R"(
calculator: "TensorFlowSessionFromSavedModelCalculator"
input_side_packet: "STRING_SAVED_MODEL_PATH:saved_model_dir"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
$0
}
})",
options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") =
MakePacket<std::string>(GetSavedModelDir());
MEDIAPIPE_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
// consumes the Packet emitted by this factory.
TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
ProducesPacketUsableByTensorFlowInferenceCalculator) {
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
node {
calculator: "TensorFlowInferenceCalculator"
input_side_packet: "SESSION:tf_model"
input_stream: "A:a_tensor"
output_stream: "MULTIPLIED:multiplied_tensor"
options {
[mediapipe.TensorFlowInferenceCalculatorOptions.ext] {
batch_size: 5
add_batch_dim_to_tensors: false
}
}
}
node {
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
$0
}
}
}
input_stream: "a_tensor"
)",
options_->DebugString()));
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config));
StatusOrPoller status_or_poller =
graph.AddOutputStreamPoller("multiplied_tensor");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie());
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
"a_tensor",
Adopt(new auto(TensorMatrix1x3(1, -1, 10))).At(Timestamp(0))));
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("a_tensor"));
Packet packet;
ASSERT_TRUE(poller.Next(&packet));
// input tensor gets multiplied by [[3, 2, 1]]. Expected output:
tf::Tensor expected_multiplication = TensorMatrix1x3(3, -2, 10);
EXPECT_EQ(expected_multiplication.DebugString(),
packet.Get<tf::Tensor>().DebugString());
ASSERT_FALSE(poller.Next(&packet));
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
GetsBundleGivenParentDirectory) {
options_->set_saved_model_path(
std::string(file::SplitPath(GetSavedModelDir()).first));
options_->set_load_latest_model(true);
CalculatorRunner runner(absl::Substitute(R"(
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:tf_model"
options {
[mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
$0
}
})",
options_->DebugString()));
MEDIAPIPE_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,166 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#if defined(MEDIAPIPE_TPU_SUPPORT)
#include "learning/brain/google/xla/global_tpu_init.h"
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
#endif
#if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h"
#endif
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
namespace mediapipe {
namespace {
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models
// in subdirectories, replaces path with the alphabetically last subdirectory.
::mediapipe::Status GetLatestDirectory(std::string* path) {
#if defined(__ANDROID__)
return ::mediapipe::UnimplementedError(
"GetLatestDirectory is not implemented on Android");
#else
std::vector<std::string> saved_models;
RET_CHECK_OK(file::MatchInTopSubdirectories(
*path, tensorflow::kSavedModelFilenamePb, &saved_models));
RET_CHECK_GT(saved_models.size(), 0)
<< "No exported bundles found in " << path;
::std::sort(saved_models.begin(), saved_models.end());
*path = std::string(file::Dirname(saved_models.back()));
return ::mediapipe::OkStatus();
#endif
}
// If options.convert_signature_to_tags() will convert letters to uppercase
// and replace /'s with _'s. If set, this enables the standard SavedModel
// classification, regression, and prediction signatures to be used as
// uppercase INPUTS and OUTPUTS tags for streams.
const std::string MaybeConvertSignatureToTag(
const std::string& name,
const TensorFlowSessionFromSavedModelGeneratorOptions& options) {
if (options.convert_signature_to_tags()) {
std::string output;
output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); });
output = absl::Substitute(output, "/", "_");
return output;
} else {
return name;
}
}
} // namespace
// TensorFlowSessionFromSavedModelGenerator is a MediaPipe packet generator
// that loads a trained TensorFlow model exported via SavedModel's exporter (see
// go/savedmodel) and returns a Packet containing a unique_ptr to a
// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session
// ready for execution and a map between tags and tensor names.
class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
public:
static ::mediapipe::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
const TensorFlowSessionFromSavedModelGeneratorOptions& options =
extendable_options.GetExtension(
TensorFlowSessionFromSavedModelGeneratorOptions::ext);
const bool has_exactly_one_model =
options.saved_model_path().empty() ==
input_side_packets->HasTag(kStringSavedModelPath);
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of saved model filepath in options or "
"input_side_packets STRING_MODEL_FILE_PATH";
// Path of savedmodel.
if (input_side_packets->HasTag(kStringSavedModelPath)) {
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
output_side_packets->Tag("SESSION").Set<TensorFlowSession>();
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status Generate(
const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) {
const TensorFlowSessionFromSavedModelGeneratorOptions& options =
extendable_options.GetExtension(
TensorFlowSessionFromSavedModelGeneratorOptions::ext);
std::string path =
input_side_packets.HasTag(kStringSavedModelPath)
? input_side_packets.Tag(kStringSavedModelPath).Get<std::string>()
: options.saved_model_path();
if (options.load_latest_model()) {
RET_CHECK_OK(GetLatestDirectory(&path));
}
// Set user specified tags properly.
// If no tags specified will use tensorflow::kSavedModelTagServe by default.
std::unordered_set<std::string> tags_set;
for (std::string tag : options.saved_model_tag()) {
tags_set.insert(tag);
}
if (tags_set.empty()) {
tags_set.insert(tensorflow::kSavedModelTagServe);
}
tensorflow::RunOptions run_options;
// In the future, could construct session options from the options proto.
tensorflow::SessionOptions session_options;
auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
::tensorflow::Status status = tensorflow::LoadSavedModel(
session_options, run_options, path, tags_set, saved_model.get());
if (!status.ok()) {
return ::mediapipe::Status(
static_cast<::mediapipe::StatusCode>(status.code()),
status.error_message());
}
auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session);
RET_CHECK(!options.signature_name().empty());
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
const auto& signature_def = signature_def_map.at(options.signature_name());
for (const auto& input_signature : signature_def.inputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
input_signature.first, options)] = input_signature.second.name();
}
for (const auto& output_signature : signature_def.outputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
output_signature.first, options)] = output_signature.second.name();
}
output_side_packets->Tag("SESSION") = Adopt(session.release());
return ::mediapipe::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(TensorFlowSessionFromSavedModelGenerator);
} // namespace mediapipe

View File

@ -0,0 +1,58 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
message TensorFlowSessionFromSavedModelGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional TensorFlowSessionFromSavedModelGeneratorOptions ext = 151486368;
}
// TODO: SessionBundles provided global step versioning of models
// that let you load the latest model. If there's a similar solution for
// SavedModels, include a flag to load the most recent model.
// Path to a directory containing a trained TensorFlow model as prepared
// by SavedModel (go/saved-model).
optional string saved_model_path = 1;
// The name of the generic signature to load into the mapping from tags to
// tensor names.
optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase and switch /'s to
// _'s, which enables standard signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
// latest checkpoint) model is loaded. Note that saved models are not exported
// in numbered directories by default. If you want to use this feature, you
// need to arrange your directories by global_step or some other order when
// you save your models.
optional bool load_latest_model = 4;
// [DEPRECATED] If true, this calculator will try to initialize local Tensor
// Processing Unit (TPU) hardware so that the Tensorflow session loaded from
// this saved model may benefit from TPU speedups. If you want to use this
// feature, you need to make sure that the calculator runs on a machine that
// has TPU hardware installed. The saved model should have correct device
// placements in the graph (have the ops already placed on TPU), typically if
// the saved model was exported through TPUEstimator then device placement is
// automatically taken care of.
optional bool use_tpu = 5 [deprecated = true];
// User specified tags in a saved model.
// If no tag is specified, then use "serve" as the default. Note that in order
// to use TPU accelerator hardware, the tag "tpu" needs to be specified.
repeated string saved_model_tag = 6;
}

Some files were not shown because too many files have changed in this diff Show More