Project import generated by Copybara.

GitOrigin-RevId: 612e50bb8db2ec3dc1c30049372d87a80c3848db
This commit is contained in:
MediaPipe Team 2020-08-29 23:41:10 -04:00 committed by chuoling
parent a7225b938a
commit c0124fb83c
248 changed files with 5225 additions and 1914 deletions

View File

@ -1,6 +1,11 @@
global-exclude .git* global-exclude .git*
global-exclude *_test.py global-exclude *_test.py
recursive-include mediapipe/models *.tflite *.txt include CONTRIBUTING.md
include LICENSE
include MANIFEST.in
include README.md
include requirements.txt
recursive-include mediapipe/modules *.tflite *.txt recursive-include mediapipe/modules *.tflite *.txt
recursive-include mediapipe/graphs *.binarypb recursive-include mediapipe/graphs *.binarypb

View File

@ -22,27 +22,28 @@ desktop/cloud, web and IoT devices.
## ML solutions in MediaPipe ## ML solutions in MediaPipe
Face Detection | Face Mesh | Iris 🆕 | Hands | Pose 🆕 Face Detection | Face Mesh | Iris | Hands | Pose | Hair Segmentation
:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :----: :----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :---------------:
[![face_detection](docs/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](docs/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](docs/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](docs/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](docs/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) [![face_detection](docs/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](docs/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](docs/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](docs/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](docs/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](docs/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation)
Hair Segmentation | Object Detection | Box Tracking | Objectron | KNIFT Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT
:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---:
[![hair_segmentation](docs/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](docs/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](docs/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![objectron](docs/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](docs/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) [![object_detection](docs/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](docs/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](docs/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](docs/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](docs/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift)
<!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. --> <!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. -->
<!-- Whenever this table is updated, paste a copy to solutions/solutions.md. --> <!-- Whenever this table is updated, paste a copy to solutions/solutions.md. -->
[]() | Android | iOS | Desktop | Python | Web | Coral []() | Android | iOS | Desktop | Python | Web | Coral
:---------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: :---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---:
[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅
[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | |
[Iris](https://google.github.io/mediapipe/solutions/iris) 🆕 | ✅ | ✅ | ✅ | | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ |
[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ |
[Pose](https://google.github.io/mediapipe/solutions/pose) 🆕 | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ |
[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ |
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
@ -88,6 +89,8 @@ run code search using
## Publications ## Publications
* [Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog)
in Google Developers Blog
* [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) * [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
in Google AI Blog in Google AI Blog
* [MediaPipe Iris: Real-time Eye Tracking and Depth Estimation](https://ai.googleblog.com/2020/08/mediapipe-iris-real-time-iris-tracking.html) * [MediaPipe Iris: Real-time Eye Tracking and Depth Estimation](https://ai.googleblog.com/2020/08/mediapipe-iris-real-time-iris-tracking.html)

View File

@ -36,6 +36,19 @@ http_archive(
urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"], urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
) )
http_archive(
name = "rules_foreign_cc",
strip_prefix = "rules_foreign_cc-master",
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/master.zip",
)
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
rules_foreign_cc_dependencies()
# This is used to select all contents of the archives for CMake-based packages to give CMake access to them.
all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])"""
# GoogleTest/GoogleMock framework. Used by most unit-tests. # GoogleTest/GoogleMock framework. Used by most unit-tests.
# Last updated 2020-06-30. # Last updated 2020-06-30.
http_archive( http_archive(
@ -68,14 +81,23 @@ http_archive(
url = "https://github.com/gflags/gflags/archive/v2.2.2.zip", url = "https://github.com/gflags/gflags/archive/v2.2.2.zip",
) )
# glog v0.3.5 # 2020-08-21
# TODO: Migrate MediaPipe to use com_github_glog_glog on all platforms.
http_archive( http_archive(
name = "com_github_glog_glog_v_0_3_5", name = "com_github_glog_glog",
url = "https://github.com/google/glog/archive/v0.3.5.zip", strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8", sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
strip_prefix = "glog-0.3.5", urls = [
build_file = "@//third_party:glog.BUILD", "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
],
)
http_archive(
name = "com_github_glog_glog_no_gflags",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
build_file = "@//third_party:glog_no_gflags.BUILD",
urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
],
patches = [ patches = [
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff" "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff"
], ],
@ -84,16 +106,6 @@ http_archive(
], ],
) )
# 2020-02-16
http_archive(
name = "com_github_glog_glog",
strip_prefix = "glog-3ba8976592274bc1f907c402ce22558011d6fc5e",
sha256 = "feca3c7e29a693cab7887409756d89d342d4a992d54d7c5599bebeae8f7b50be",
urls = [
"https://github.com/google/glog/archive/3ba8976592274bc1f907c402ce22558011d6fc5e.zip",
],
)
# easyexif # easyexif
http_archive( http_archive(
name = "easyexif", name = "easyexif",
@ -169,6 +181,13 @@ http_archive(
sha256 = "5ba6d0db4e784621fda44a50c58bb23b0892684692f0c623e2063f9c19f192f1" sha256 = "5ba6d0db4e784621fda44a50c58bb23b0892684692f0c623e2063f9c19f192f1"
) )
http_archive(
name = "opencv",
build_file_content = all_content,
strip_prefix = "opencv-3.4.10",
urls = ["https://github.com/opencv/opencv/archive/3.4.10.tar.gz"],
)
new_local_repository( new_local_repository(
name = "linux_opencv", name = "linux_opencv",
build_file = "@//third_party:opencv_linux.BUILD", build_file = "@//third_party:opencv_linux.BUILD",
@ -184,13 +203,13 @@ new_local_repository(
new_local_repository( new_local_repository(
name = "macos_opencv", name = "macos_opencv",
build_file = "@//third_party:opencv_macos.BUILD", build_file = "@//third_party:opencv_macos.BUILD",
path = "/usr", path = "/usr/local/opt/opencv@3",
) )
new_local_repository( new_local_repository(
name = "macos_ffmpeg", name = "macos_ffmpeg",
build_file = "@//third_party:ffmpeg_macos.BUILD", build_file = "@//third_party:ffmpeg_macos.BUILD",
path = "/usr", path = "/usr/local/opt/ffmpeg",
) )
new_local_repository( new_local_repository(
@ -301,9 +320,6 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
maven_install( maven_install(
name = "maven", name = "maven",
artifacts = [ artifacts = [
"junit:junit:4.12",
"androidx.test.espresso:espresso-core:3.1.1",
"org.hamcrest:hamcrest-library:1.3",
"androidx.concurrent:concurrent-futures:1.0.0-alpha03", "androidx.concurrent:concurrent-futures:1.0.0-alpha03",
"androidx.lifecycle:lifecycle-common:2.2.0", "androidx.lifecycle:lifecycle-common:2.2.0",
"androidx.annotation:annotation:aar:1.1.0", "androidx.annotation:annotation:aar:1.1.0",
@ -314,11 +330,15 @@ maven_install(
"androidx.core:core:aar:1.1.0-rc03", "androidx.core:core:aar:1.1.0-rc03",
"androidx.legacy:legacy-support-v4:aar:1.0.0", "androidx.legacy:legacy-support-v4:aar:1.0.0",
"androidx.recyclerview:recyclerview:aar:1.1.0-beta02", "androidx.recyclerview:recyclerview:aar:1.1.0-beta02",
"androidx.test.espresso:espresso-core:3.1.1",
"com.github.bumptech.glide:glide:4.11.0",
"com.google.android.material:material:aar:1.0.0-rc01", "com.google.android.material:material:aar:1.0.0-rc01",
"com.google.code.findbugs:jsr305:3.0.2", "com.google.code.findbugs:jsr305:3.0.2",
"com.google.flogger:flogger-system-backend:0.3.1", "com.google.flogger:flogger-system-backend:0.3.1",
"com.google.flogger:flogger:0.3.1", "com.google.flogger:flogger:0.3.1",
"com.google.guava:guava:27.0.1-android", "com.google.guava:guava:27.0.1-android",
"junit:junit:4.12",
"org.hamcrest:hamcrest-library:1.3",
], ],
repositories = [ repositories = [
"https://jcenter.bintray.com", "https://jcenter.bintray.com",

View File

@ -22,8 +22,8 @@ aux_links:
# Footer content appears at the bottom of every page's main content # Footer content appears at the bottom of every page's main content
footer_content: "&copy; 2020 GOOGLE LLC | <a href=\"https://policies.google.com/privacy\">PRIVACY POLICY</a> | <a href=\"https://policies.google.com/terms\">TERMS OF SERVICE</a>" footer_content: "&copy; 2020 GOOGLE LLC | <a href=\"https://policies.google.com/privacy\">PRIVACY POLICY</a> | <a href=\"https://policies.google.com/terms\">TERMS OF SERVICE</a>"
# Color scheme currently only supports "dark" or nil (default) # Color scheme currently only supports "dark", "light"/nil (default), or a custom scheme that you define
color_scheme: nil color_scheme: mediapipe
# Google Analytics Tracking (optional) # Google Analytics Tracking (optional)
ga_tracking: UA-140696581-2 ga_tracking: UA-140696581-2

View File

@ -0,0 +1 @@
$link-color: #0097A7;

View File

@ -425,7 +425,47 @@ Note: This currently works only on Linux, and please first follow
## Python ## Python
### Prerequisite MediaPipe Python package is available on
[PyPI](https://pypi.org/project/mediapipe/), and can be installed simply by `pip
install mediapipe` on Linux and macOS, as described below in
[Run in python interpreter](#run-in-python-interpreter) and in this
[colab](https://mediapipe.page.link/mp-py-colab).
### Run in Python interpreter
Using [MediaPipe Pose](../solutions/pose.md) as an example:
```bash
# Activate a Python virtual environment.
$ python3 -m venv mp_env && source mp_env/bin/activate
# Install MediaPipe Python package
(mp_env)$ pip install mediapipe
# Run in Python interpreter
(mp_env)$ python3
>>> import mediapipe as mp
>>> pose_tracker = mp.examples.UpperBodyPoseTracker()
# For image input
>>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file')
>>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file')
# For live camera input
# (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.)
>>> pose_tracker.run_live()
# Close the tracker.
>>> pose_tracker.close()
```
Tip: Use command `deactivate` to exit the Python virtual environment.
### Building Python package from source
Follow these steps only if you have local changes and need to build the Python
package from source. Otherwise, we strongly encourage our users to simply run
`pip install mediapipe`, more convenient and much faster.
1. Make sure that Bazel and OpenCV are correctly installed and configured for 1. Make sure that Bazel and OpenCV are correctly installed and configured for
MediaPipe. Please see [Installation](./install.md) for how to setup Bazel MediaPipe. Please see [Installation](./install.md) for how to setup Bazel
@ -445,50 +485,23 @@ Note: This currently works only on Linux, and please first follow
$ brew install protobuf $ brew install protobuf
``` ```
### Set up Python virtual environment. 3. Activate a Python virtual environment.
1. Activate a Python virtual environment.
```bash ```bash
$ python3 -m venv mp_env && source mp_env/bin/activate $ python3 -m venv mp_env && source mp_env/bin/activate
``` ```
2. In the virtual environment, go to the MediaPipe repo directory. 4. In the virtual environment, go to the MediaPipe repo directory.
3. Install the required Python packages. 5. Install the required Python packages.
```bash ```bash
(mp_env)mediapipe$ pip3 install -r requirements.txt (mp_env)mediapipe$ pip3 install -r requirements.txt
``` ```
4. Generate and install MediaPipe package. 6. Generate and install MediaPipe package.
```bash ```bash
(mp_env)mediapipe$ python3 setup.py gen_protos (mp_env)mediapipe$ python3 setup.py gen_protos
(mp_env)mediapipe$ python3 setup.py install (mp_env)mediapipe$ python3 setup.py install --link-opencv
``` ```
### Run in Python interpreter
Make sure you are not in the MediaPipe repo directory.
Using [MediaPipe Pose](../solutions/pose.md) as an example:
```bash
(mp_env)$ python3
>>> import mediapipe as mp
>>> pose_tracker = mp.examples.UpperBodyPoseTracker()
# For image input
>>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file')
>>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file')
# For live camera input
# (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.)
>>> pose_tracker.run_live()
# Close the tracker.
>>> pose_tracker.close()
```
Tip: Use command `deactivate` to exit the Python virtual environment.

Binary file not shown.

After

Width:  |  Height:  |  Size: 925 KiB

View File

@ -22,27 +22,28 @@ desktop/cloud, web and IoT devices.
## ML solutions in MediaPipe ## ML solutions in MediaPipe
Face Detection | Face Mesh | Iris 🆕 | Hands | Pose 🆕 Face Detection | Face Mesh | Iris | Hands | Pose | Hair Segmentation
:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :----: :----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :---------------:
[![face_detection](images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) [![face_detection](images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation)
Hair Segmentation | Object Detection | Box Tracking | Objectron | KNIFT Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT
:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---:
[![hair_segmentation](images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![objectron](images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) [![object_detection](images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift)
<!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. --> <!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. -->
<!-- Whenever this table is updated, paste a copy to solutions/solutions.md. --> <!-- Whenever this table is updated, paste a copy to solutions/solutions.md. -->
[]() | Android | iOS | Desktop | Python | Web | Coral []() | Android | iOS | Desktop | Python | Web | Coral
:---------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: :---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---:
[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅
[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | |
[Iris](https://google.github.io/mediapipe/solutions/iris) 🆕 | ✅ | ✅ | ✅ | | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ |
[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ |
[Pose](https://google.github.io/mediapipe/solutions/pose) 🆕 | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ |
[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ |
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
@ -88,6 +89,8 @@ run code search using
## Publications ## Publications
* [Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog)
in Google Developers Blog
* [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) * [BlazePose - On-device Real-time Body Pose Tracking](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
in Google AI Blog in Google AI Blog
* [MediaPipe Iris: Real-time Eye Tracking and Depth Estimation](https://ai.googleblog.com/2020/08/mediapipe-iris-real-time-iris-tracking.html) * [MediaPipe Iris: Real-time Eye Tracking and Depth Estimation](https://ai.googleblog.com/2020/08/mediapipe-iris-real-time-iris-tracking.html)

View File

@ -2,7 +2,7 @@
layout: default layout: default
title: AutoFlip (Saliency-aware Video Cropping) title: AutoFlip (Saliency-aware Video Cropping)
parent: Solutions parent: Solutions
nav_order: 11 nav_order: 12
--- ---
# AutoFlip: Saliency-aware Video Cropping # AutoFlip: Saliency-aware Video Cropping

View File

@ -0,0 +1,122 @@
---
layout: default
title: Instant Motion Tracking
parent: Solutions
nav_order: 9
---
# MediaPipe Instant Motion Tracking
{: .no_toc }
1. TOC
{:toc}
---
## Overview
Augmented Reality (AR) technology creates fun, engaging, and immersive user
experiences. The ability to perform AR tracking across devices and platforms,
without initialization, remains important to power AR applications at scale.
MediaPipe Instant Motion Tracking provides AR tracking across devices and
platforms without initialization or calibration. It is built upon the
[MediaPipe Box Tracking](./box_tracking.md) solution. With Instant Motion
Tracking, you can easily place virtual 2D and 3D content on static or moving
surfaces, allowing them to seamlessly interact with the real-world environment.
![instant_motion_tracking_android_small](../images/mobile/instant_motion_tracking_android_small.gif) |
:-----------------------------------------------------------------------: |
*Fig 1. Instant Motion Tracking is used to augment the world with a 3D sticker.* |
## Pipeline
The Instant Motion Tracking pipeline is implemented as a MediaPipe
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt),
which internally utilizes a
[RegionTrackingSubgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/subgraphs/region_tracking.pbtxt)
in order to perform anchor tracking for each individual 3D sticker.
We first use a
[StickerManagerCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/calculators/sticker_manager_calculator.cc)
to prepare the individual sticker data for the rest of the application. This
information is then sent to the
[RegionTrackingSubgraph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/subgraphs/region_tracking.pbtxt)
that performs 3D region tracking for sticker placement and rendering. Once
acquired, our tracked sticker regions are sent with user transformations (i.e.
gestures from the user to rotate and zoom the sticker) and IMU data to the
[MatricesManagerCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc),
which turns all our sticker transformation data into a set of model matrices.
This data is handled directly by our
[GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc)
as an input stream, which will render the provided texture and object file using
our matrix specifications. The output of
[GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc)
is a video stream depicting the virtual 3D content rendered on top of the real
world, creating immersive AR experiences for users.
## Using Instant Motion Tracking
With the Instant Motion Tracking MediaPipe [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt),
an application can create an interactive and realistic AR experience by
specifying the required input streams, side packets, and output streams.
The input streams are the following:
* Input Video (GpuBuffer): Video frames to render augmented stickers onto.
* Rotation Matrix (9-element Float Array): The 3x3 row-major rotation
matrix from the device IMU to determine proper orientation of the device.
* Sticker Proto String (String): A string representing the
serialized [sticker buffer protobuf message](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/calculators/sticker_buffer.proto),
containing a list of all stickers and their attributes.
* Each sticker in the Protobuffer has a unique ID to find associated
anchors and transforms, an initial anchor placement in a normalized [0.0, 1.0]
3D space, a user rotation and user scaling transform on the sticker,
and an integer indicating which type of objects to render for the
sticker (e.g. 3D asset or GIF).
* Sticker Sentinel (Integer): When an anchor must be initially placed or
repositioned, this value must be changed to the ID of the anchor to reset from
the sticker buffer protobuf message. If no valid ID is provided, the system
will simply maintain tracking.
Side packets are also an integral part of the Instant Motion Tracking solution
to provide device-specific information for the rendering system:
* Field of View (Float): The field of view of the camera in radians.
* Aspect Ratio (Float): The aspect ratio (width / height) of the camera frames
(this ratio corresponds to the image frames themselves, not necessarily the
screen bounds).
* Object Asset (String): The
[GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc)
must be provided with an associated asset file name pointing to the 3D model
to render in the viewfinder.
* (Optional) Texture (ImageFrame on Android, GpuBuffer on iOS): Textures for
the
[GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc)
can be provided either via an input stream (dynamic texturing) or as a side
packet (unchanging texture).
The rendering system for the Instant Motion Tracking is powered by OpenGL. For
more information regarding the structure of model matrices and OpenGL rendering,
please visit [OpenGL Wiki](https://www.khronos.org/opengl/wiki/). With the
specifications above, the Instant Motion Tracking capabilities can be adapted to
any device that is able to run the MediaPipe framework with a working IMU system
and connected camera.
## Example Apps
Please first see general instructions for
[Android](../getting_started/building_examples.md#android) on how to build
MediaPipe examples.
* Graph: [mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/instant_motion_tracking/instant_motion_tracking.pbtxt)
* Android target (or download prebuilt [ARM64 APK](https://drive.google.com/file/d/1KnaBBoKpCHR73nOBJ4fL_YdWVTAcwe6L/view?usp=sharing)):
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking:instantmotiontracking`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD)
## Resources
* Google Developers Blog:
[Instant Motion Tracking With MediaPipe](https://mediapipe.page.link/instant-motion-tracking-blog)
* Google AI Blog:
[The Instant Motion Tracking Behind Motion Stills AR](https://ai.googleblog.com/2018/02/the-instant-motion-tracking-behind.html)
* Paper:
[Instant Motion Tracking and Its Applications to Augmented Reality](https://arxiv.org/abs/1907.06796)

View File

@ -55,7 +55,7 @@ that uses a
from the from the
[face landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark), [face landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark),
an an
[iris landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_tracking/iris_landmark_left_and_right_gpu.pbtxt) [iris landmark subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_landmark/iris_landmark_left_and_right_gpu.pbtxt)
from the from the
[iris landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_landmark), [iris landmark module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/iris_landmark),
and renders using a dedicated and renders using a dedicated
@ -72,6 +72,11 @@ Note: To visualize a graph, copy the graph and paste it into
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../tools/visualizer.md). [visualizer documentation](../tools/visualizer.md).
The output of the pipeline is a set of 478 3D landmarks, including 468 face
landmarks from [MediaPipe Face Mesh](./face_mesh.md), with those around the eyes
further refined (see Fig 2), and 10 additional iris landmarks appended at the
end (5 for each eye, and see Fig 2 also).
## Models ## Models
### Face Detection Model ### Face Detection Model

View File

@ -2,7 +2,7 @@
layout: default layout: default
title: KNIFT (Template-based Feature Matching) title: KNIFT (Template-based Feature Matching)
parent: Solutions parent: Solutions
nav_order: 10 nav_order: 11
--- ---
# MediaPipe KNIFT # MediaPipe KNIFT

View File

@ -2,7 +2,7 @@
layout: default layout: default
title: Dataset Preparation with MediaSequence title: Dataset Preparation with MediaSequence
parent: Solutions parent: Solutions
nav_order: 12 nav_order: 13
--- ---
# Dataset Preparation with MediaSequence # Dataset Preparation with MediaSequence

View File

@ -2,7 +2,7 @@
layout: default layout: default
title: Objectron (3D Object Detection) title: Objectron (3D Object Detection)
parent: Solutions parent: Solutions
nav_order: 9 nav_order: 10
--- ---
# MediaPipe Objectron # MediaPipe Objectron
@ -161,7 +161,7 @@ to visualize its associated subgraphs, please see
### Objectron for Shoes ### Objectron for Shoes
* Graph: * Graph:
[`mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt) [`mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/shoe_classic_occlusion_tracking.pbtxt)
* Android target: * Android target:
[(or download prebuilt ARM64 APK)](https://drive.google.com/open?id=1S0K4hbWt3o31FfQ4QU3Rz7IHrvOUMx1d) [(or download prebuilt ARM64 APK)](https://drive.google.com/open?id=1S0K4hbWt3o31FfQ4QU3Rz7IHrvOUMx1d)
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD) [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d:objectdetection3d`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD)

View File

@ -142,10 +142,21 @@ MediaPipe examples.
### Python ### Python
Please first see general instructions for MediaPipe Python package is available on
[Python](../getting_started/building_examples.md#python) examples. [PyPI](https://pypi.org/project/mediapipe/), and can be installed simply by `pip
install mediapipe` on Linux and macOS, as described below and in this
[colab](https://mediapipe.page.link/mp-py-colab). If you do need to build the
Python package from source, see
[additional instructions](../getting_started/building_examples.md#python).
```bash ```bash
# Activate a Python virtual environment.
$ python3 -m venv mp_env && source mp_env/bin/activate
# Install MediaPipe Python package
(mp_env)$ pip install mediapipe
# Run in Python interpreter
(mp_env)$ python3 (mp_env)$ python3
>>> import mediapipe as mp >>> import mediapipe as mp
>>> pose_tracker = mp.examples.UpperBodyPoseTracker() >>> pose_tracker = mp.examples.UpperBodyPoseTracker()
@ -153,6 +164,9 @@ Please first see general instructions for
# For image input # For image input
>>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file') >>> pose_landmarks, _ = pose_tracker.run(input_file='/path/to/input/file', output_file='/path/to/output/file')
>>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file') >>> pose_landmarks, annotated_image = pose_tracker.run(input_file='/path/to/file')
# To print out the pose landmarks, you can simply do "print(pose_landmarks)".
# However, the data points can be more accessible with the following approach.
>>> [print('x is', data_point.x, 'y is', data_point.y, 'z is', data_point.z, 'visibility is', data_point.visibility) for data_point in pose_landmarks.landmark]
# For live camera input # For live camera input
# (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.) # (Press Esc within the output image window to stop the run or let it self terminate after 30 seconds.)
@ -162,6 +176,8 @@ Please first see general instructions for
>>> pose_tracker.close() >>> pose_tracker.close()
``` ```
Tip: Use command `deactivate` to exit the Python virtual environment.
### Web ### Web
Please refer to [these instructions](../index.md#mediapipe-on-the-web). Please refer to [these instructions](../index.md#mediapipe-on-the-web).

View File

@ -17,15 +17,16 @@ has_toc: false
<!-- Whenever this table is updated, paste a copy to ../external_index.md. --> <!-- Whenever this table is updated, paste a copy to ../external_index.md. -->
[]() | Android | iOS | Desktop | Python | Web | Coral []() | Android | iOS | Desktop | Python | Web | Coral
:---------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---: :---------------------------------------------------------------------------------------- | :-----: | :-: | :-----: | :----: | :-: | :---:
[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅ [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | ✅ | ✅
[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | | [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | | |
[Iris](https://google.github.io/mediapipe/solutions/iris) 🆕 | ✅ | ✅ | ✅ | | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | ✅ |
[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | | ✅ |
[Pose](https://google.github.io/mediapipe/solutions/pose) 🆕 | ✅ | ✅ | ✅ | ✅ | ✅ | [Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ |
[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ | [Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | ✅ |
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |

View File

@ -2,7 +2,7 @@
layout: default layout: default
title: YouTube-8M Feature Extraction and Model Inference title: YouTube-8M Feature Extraction and Model Inference
parent: Solutions parent: Solutions
nav_order: 13 nav_order: 14
--- ---
# YouTube-8M Feature Extraction and Model Inference # YouTube-8M Feature Extraction and Model Inference

View File

@ -144,10 +144,13 @@ we record ten intervals of half a second each. This can be overridden by adding
```bash ```bash
profiler_config { profiler_config {
trace_enabled: true trace_enabled: true
trace_log_path: "/sdcard/profiles" trace_log_path: "/sdcard/profiles/"
} }
``` ```
Note: The forward slash at the end of the `trace_log_path` is necessary for
indicating that `profiles` is a directory (that *should* exist).
* Download the trace files from the device. * Download the trace files from the device.
```bash ```bash

View File

@ -12,7 +12,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import mediapipe.examples.python as examples
from mediapipe.python import *
import mediapipe.util as util

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# #
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
filegroup( filegroup(
name = "test_audios", name = "test_audios",

View File

@ -15,7 +15,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
@ -290,7 +290,9 @@ cc_library(
deps = [ deps = [
":concatenate_vector_calculator_cc_proto", ":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
@ -1119,6 +1121,7 @@ cc_library(
":constant_side_packet_calculator_cc_proto", ":constant_side_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],

View File

@ -16,7 +16,9 @@
#include <vector> #include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
@ -45,6 +47,9 @@ REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator);
typedef ConcatenateVectorCalculator<int32> ConcatenateInt32VectorCalculator; typedef ConcatenateVectorCalculator<int32> ConcatenateInt32VectorCalculator;
REGISTER_CALCULATOR(ConcatenateInt32VectorCalculator); REGISTER_CALCULATOR(ConcatenateInt32VectorCalculator);
typedef ConcatenateVectorCalculator<uint64> ConcatenateUInt64VectorCalculator;
REGISTER_CALCULATOR(ConcatenateUInt64VectorCalculator);
// Example config: // Example config:
// node { // node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator" // calculator: "ConcatenateTfLiteTensorVectorCalculator"
@ -60,6 +65,14 @@ typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
ConcatenateLandmarkVectorCalculator; ConcatenateLandmarkVectorCalculator;
REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator); REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList>
ConcatenateLandmarListVectorCalculator;
REGISTER_CALCULATOR(ConcatenateLandmarListVectorCalculator);
typedef ConcatenateVectorCalculator<mediapipe::ClassificationList>
ConcatenateClassificationListVectorCalculator;
REGISTER_CALCULATOR(ConcatenateClassificationListVectorCalculator);
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer> typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer>
ConcatenateGlBufferVectorCalculator; ConcatenateGlBufferVectorCalculator;

View File

@ -15,6 +15,7 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ #ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
@ -26,10 +27,10 @@
namespace mediapipe { namespace mediapipe {
// Concatenates several std::vector<T> following stream index order. This class // Concatenates several objects of type T or std::vector<T> following stream
// assumes that every input stream contains the vector<T> type. To use this // index order. This class assumes that every input stream contains either T or
// class for a particular type T, regisiter a calculator using // vector<T> type. To use this class for a particular type T, regisiter a
// ConcatenateVectorCalculator<T>. // calculator using ConcatenateVectorCalculator<T>.
template <typename T> template <typename T>
class ConcatenateVectorCalculator : public CalculatorBase { class ConcatenateVectorCalculator : public CalculatorBase {
public: public:
@ -38,7 +39,8 @@ class ConcatenateVectorCalculator : public CalculatorBase {
RET_CHECK(cc->Outputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() == 1);
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<std::vector<T>>(); // Actual type T or vector<T> will be validated in Process().
cc->Inputs().Index(i).SetAny();
} }
cc->Outputs().Index(0).Set<std::vector<T>>(); cc->Outputs().Index(0).Set<std::vector<T>>();
@ -69,9 +71,19 @@ class ConcatenateVectorCalculator : public CalculatorBase {
CalculatorContext* cc) { CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>(); auto output = absl::make_unique<std::vector<U>>();
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue; auto& input = cc->Inputs().Index(i);
const std::vector<U>& input = cc->Inputs().Index(i).Get<std::vector<U>>();
output->insert(output->end(), input.begin(), input.end()); if (input.IsEmpty()) continue;
if (input.Value().ValidateAsType<U>().ok()) {
const U& value = input.Get<U>();
output->push_back(value);
} else if (input.Value().ValidateAsType<std::vector<U>>().ok()) {
const std::vector<U>& value = input.Get<std::vector<U>>();
output->insert(output->end(), value.begin(), value.end());
} else {
return ::mediapipe::InvalidArgumentError("Invalid input stream type.");
}
} }
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -88,17 +100,32 @@ class ConcatenateVectorCalculator : public CalculatorBase {
CalculatorContext* cc) { CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>(); auto output = absl::make_unique<std::vector<U>>();
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue; auto& input = cc->Inputs().Index(i);
::mediapipe::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
cc->Inputs().Index(i).Value().Consume<std::vector<U>>(); if (input.IsEmpty()) continue;
if (input_status.ok()) {
std::unique_ptr<std::vector<U>> input_vector = if (input.Value().ValidateAsType<U>().ok()) {
std::move(input_status).ValueOrDie(); ::mediapipe::StatusOr<std::unique_ptr<U>> value_status =
output->insert(output->end(), input.Value().Consume<U>();
std::make_move_iterator(input_vector->begin()), if (value_status.ok()) {
std::make_move_iterator(input_vector->end())); std::unique_ptr<U> value = std::move(value_status).ValueOrDie();
output->push_back(std::move(*value));
} else { } else {
return input_status.status(); return value_status.status();
}
} else if (input.Value().ValidateAsType<std::vector<U>>().ok()) {
::mediapipe::StatusOr<std::unique_ptr<std::vector<U>>> value_status =
input.Value().Consume<std::vector<U>>();
if (value_status.ok()) {
std::unique_ptr<std::vector<U>> value =
std::move(value_status).ValueOrDie();
output->insert(output->end(), std::make_move_iterator(value->begin()),
std::make_move_iterator(value->end()));
} else {
return value_status.status();
}
} else {
return ::mediapipe::InvalidArgumentError("Invalid input stream type.");
} }
} }
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
@ -109,7 +136,7 @@ class ConcatenateVectorCalculator : public CalculatorBase {
::mediapipe::Status ConsumeAndConcatenateVectors(std::false_type, ::mediapipe::Status ConsumeAndConcatenateVectors(std::false_type,
CalculatorContext* cc) { CalculatorContext* cc) {
return ::mediapipe::InternalError( return ::mediapipe::InternalError(
"Cannot copy or move input vectors to concatenate them"); "Cannot copy or move inputs to concatenate them");
} }
private: private:

View File

@ -30,11 +30,29 @@ namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator; typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator); REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator);
void AddInputVector(int index, const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}
void AddInputVectors(const std::vector<std::vector<int>>& inputs, void AddInputVectors(const std::vector<std::vector<int>>& inputs,
int64 timestamp, CalculatorRunner* runner) { int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back( AddInputVector(i, inputs[i], timestamp, runner);
MakePacket<std::vector<int>>(inputs[i]).At(Timestamp(timestamp))); }
}
void AddInputItem(int index, int input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<int>(input).At(Timestamp(timestamp)));
}
void AddInputItems(const std::vector<int>& inputs, int64 timestamp,
CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
AddInputItem(i, inputs[i], timestamp, runner);
} }
} }
@ -131,6 +149,135 @@ TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size()); EXPECT_EQ(0, outputs.size());
} }
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneTimestamp) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<int> inputs = {1, 2, 3};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsTwoInputsAtTwoTimestamps) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<int> inputs = {1, 2, 3};
AddInputItems(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<int> inputs = {4, 5, 6};
AddInputItems(inputs, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(3, outputs[0].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
{
EXPECT_EQ(3, outputs[1].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<int> expected_vector = {4, 5, 6};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
}
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamStillOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
// No third input item.
std::vector<int> inputs = {1, 2};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamNoOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
// No third input item.
std::vector<int> inputs = {1, 2};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItems) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/4,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<int> vector_0 = {1, 2};
std::vector<int> vector_1 = {3, 4, 5};
int item_0 = 6;
int item_1 = 7;
AddInputVector(/*index*/ 0, vector_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 1, vector_1, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 2, item_0, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItemsAnother) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/4,
/*num_outputs=*/1, /*num_side_packets=*/0);
int item_0 = 1;
std::vector<int> vector_0 = {2, 3};
std::vector<int> vector_1 = {4, 5, 6};
int item_1 = 7;
AddInputItem(/*index*/ 0, item_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 1, vector_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 2, vector_1, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
void AddInputVectors(const std::vector<std::vector<float>>& inputs, void AddInputVectors(const std::vector<std::vector<float>>& inputs,
int64 timestamp, CalculatorRunner* runner) { int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -71,6 +72,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
packet.Set<bool>(); packet.Set<bool>();
} else if (packet_options.has_string_value()) { } else if (packet_options.has_string_value()) {
packet.Set<std::string>(); packet.Set<std::string>();
} else if (packet_options.has_uint64_value()) {
packet.Set<uint64>();
} else { } else {
return ::mediapipe::InvalidArgumentError( return ::mediapipe::InvalidArgumentError(
"None of supported values were specified in options."); "None of supported values were specified in options.");
@ -95,6 +98,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
packet.Set(MakePacket<bool>(packet_options.bool_value())); packet.Set(MakePacket<bool>(packet_options.bool_value()));
} else if (packet_options.has_string_value()) { } else if (packet_options.has_string_value()) {
packet.Set(MakePacket<std::string>(packet_options.string_value())); packet.Set(MakePacket<std::string>(packet_options.string_value()));
} else if (packet_options.has_uint64_value()) {
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
} else { } else {
return ::mediapipe::InvalidArgumentError( return ::mediapipe::InvalidArgumentError(
"None of supported values were specified in options."); "None of supported values were specified in options.");

View File

@ -29,6 +29,7 @@ message ConstantSidePacketCalculatorOptions {
float float_value = 2; float float_value = 2;
bool bool_value = 3; bool bool_value = 3;
string string_value = 4; string string_value = 4;
uint64 uint64_value = 5;
} }
} }

View File

@ -14,7 +14,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# #
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
filegroup( filegroup(
name = "test_images", name = "test_images",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")

View File

@ -15,7 +15,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
@ -427,6 +427,10 @@ cc_library(
deps = [ deps = [
":tensorflow_session", ":tensorflow_session",
":tensorflow_inference_calculator_cc_proto", ":tensorflow_inference_calculator_cc_proto",
"//mediapipe/framework:timestamp",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -434,6 +438,8 @@ cc_library(
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework:packet",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",

View File

@ -93,7 +93,7 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
cc->Inputs().Index(0).Set<tf::Tensor>( cc->Inputs().Index(0).Set<tf::Tensor>(
// tensorflow::Tensor stream. // tensorflow::Tensor stream.
); );
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported."; << "Only one output stream is supported.";
if (cc->InputSidePackets().HasTag(kBufferSize)) { if (cc->InputSidePackets().HasTag(kBufferSize)) {

View File

@ -19,16 +19,22 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h" #include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
@ -77,6 +83,17 @@ class SimpleSemaphore {
absl::Mutex mutex_; absl::Mutex mutex_;
absl::CondVar cond_; absl::CondVar cond_;
}; };
class InferenceState {
public:
InferenceState() : input_tensor_batches_(), batch_timestamps_() {}
// A mapping between stream tags and the tensors we are collecting as a
// batch.
std::map<std::string, std::vector<tf::Tensor>> input_tensor_batches_;
// The timestamps that go into a batch.
std::vector<Timestamp> batch_timestamps_;
};
} // namespace } // namespace
// This calculator performs inference on a trained TensorFlow model. // This calculator performs inference on a trained TensorFlow model.
@ -218,11 +235,16 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
} }
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
const auto& options = cc->Options<TensorFlowInferenceCalculatorOptions>();
RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Inputs().GetTags().empty());
for (const std::string& tag : cc->Inputs().GetTags()) { for (const std::string& tag : cc->Inputs().GetTags()) {
// The tensorflow::Tensor with the tag equal to the graph node. May // The tensorflow::Tensor with the tag equal to the graph node. May
// have a TimeSeriesHeader if all present TimeSeriesHeaders match. // have a TimeSeriesHeader if all present TimeSeriesHeaders match.
if (!options.batched_input()) {
cc->Inputs().Tag(tag).Set<tf::Tensor>(); cc->Inputs().Tag(tag).Set<tf::Tensor>();
} else {
cc->Inputs().Tag(tag).Set<std::vector<mediapipe::Packet>>();
}
} }
RET_CHECK(!cc->Outputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty());
for (const std::string& tag : cc->Outputs().GetTags()) { for (const std::string& tag : cc->Outputs().GetTags()) {
@ -242,6 +264,22 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
std::unique_ptr<InferenceState> inference_state =
absl::make_unique<InferenceState>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
for (const auto& p : *init_tensor_map) {
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
}
}
return inference_state;
}
::mediapipe::Status Open(CalculatorContext* cc) override { ::mediapipe::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>(); options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
@ -275,15 +313,6 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
recurrent_feed_tags_.insert(tags[0]); recurrent_feed_tags_.insert(tags[0]);
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
} }
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
for (const auto& p : *init_tensor_map) {
input_tensor_batches_[p.first].emplace_back(p.second);
}
}
// Check that all tags are present in this signature bound to tensors. // Check that all tags are present in this signature bound to tensors.
for (const std::string& tag : cc->Inputs().GetTags()) { for (const std::string& tag : cc->Inputs().GetTags()) {
@ -297,9 +326,15 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
<< options_.signature_name(); << options_.signature_name();
} }
if (options_.batch_size() == 1) { {
absl::WriterMutexLock l(&mutex_);
inference_state_ = std::unique_ptr<InferenceState>();
}
if (options_.batch_size() == 1 || options_.batched_input()) {
cc->SetOffset(0); cc->SetOffset(0);
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -316,6 +351,24 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status AggregateTensorPacket(
const std::string& tag_name, const Packet& packet,
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
input_tensors_by_tag_by_timestamp,
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
RET_CHECK_OK(AddBatchDimension(&input_tensor));
if (::mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
// If we receive an input on a recurrent tag, override the state.
// It's OK to override the global state because there is just one
// input stream allowed for recurrent tensors.
inference_state_->input_tensor_batches_[tag_name].clear();
}
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
std::make_pair(tag_name, input_tensor));
return ::mediapipe::OkStatus();
}
// Removes the batch dimension of the output tensor if specified in the // Removes the batch dimension of the output tensor if specified in the
// calculator options. // calculator options.
::mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) { ::mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
@ -331,11 +384,19 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
} }
::mediapipe::Status Process(CalculatorContext* cc) override { ::mediapipe::Status Process(CalculatorContext* cc) override {
std::map<std::string, tf::Tensor> input_tensors_by_tag; std::unique_ptr<InferenceState> inference_state_to_process;
{
absl::WriterMutexLock l(&mutex_);
if (inference_state_ == nullptr) {
inference_state_ = CreateInferenceState(cc);
}
std::map<Timestamp, std::map<std::string, tf::Tensor>>
input_tensors_by_tag_by_timestamp;
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
// Recurrent tensors can be empty. // Recurrent tensors can be empty.
if (!::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { if (!::mediapipe::ContainsKey(recurrent_feed_tags_,
tag_as_node_name)) {
if (options_.skip_on_missing_features()) { if (options_.skip_on_missing_features()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} else { } else {
@ -344,35 +405,64 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
" not present at timestamp: ", cc->InputTimestamp().Value())); " not present at timestamp: ", cc->InputTimestamp().Value()));
} }
} }
} else if (options_.batched_input()) {
const auto& tensor_packets =
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
if (tensor_packets.size() > options_.batch_size()) {
return ::mediapipe::InvalidArgumentError(absl::StrCat(
"Batch for tag ", tag_as_node_name,
" has more packets than batch capacity. batch_size: ",
options_.batch_size(), " packets: ", tensor_packets.size()));
}
for (const auto& packet : tensor_packets) {
RET_CHECK_OK(AggregateTensorPacket(
tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp,
inference_state_.get()));
}
} else { } else {
tf::Tensor input_tensor( RET_CHECK_OK(AggregateTensorPacket(
cc->Inputs().Tag(tag_as_node_name).Get<tf::Tensor>()); tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
RET_CHECK_OK(AddBatchDimension(&input_tensor)); &input_tensors_by_tag_by_timestamp, inference_state_.get()));
if (::mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
// If we receive an input on a recurrent tag, override the state.
// It's OK to override the global state because there is just one
// input stream allowed for recurrent tensors.
input_tensor_batches_[tag_as_node_name].clear();
}
input_tensors_by_tag.insert(
std::make_pair(tag_as_node_name, input_tensor));
} }
} }
batch_timestamps_.emplace_back(cc->InputTimestamp()); for (const auto& timestamp_and_input_tensors_by_tag :
for (const auto& input_tensor_and_tag : input_tensors_by_tag) { input_tensors_by_tag_by_timestamp) {
input_tensor_batches_[input_tensor_and_tag.first].emplace_back( inference_state_->batch_timestamps_.emplace_back(
input_tensor_and_tag.second); timestamp_and_input_tensors_by_tag.first);
for (const auto& input_tensor_and_tag :
timestamp_and_input_tensors_by_tag.second) {
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
.emplace_back(input_tensor_and_tag.second);
}
}
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
options_.batched_input()) {
inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
} }
if (batch_timestamps_.size() == options_.batch_size()) { if (inference_state_to_process) {
MP_RETURN_IF_ERROR(OutputBatch(cc)); MP_RETURN_IF_ERROR(
OutputBatch(cc, std::move(inference_state_to_process)));
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status Close(CalculatorContext* cc) override { ::mediapipe::Status Close(CalculatorContext* cc) override {
if (!batch_timestamps_.empty()) { std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
MP_RETURN_IF_ERROR(OutputBatch(cc)); {
absl::WriterMutexLock l(&mutex_);
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
!inference_state_->batch_timestamps_.empty()) {
inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
}
if (inference_state_to_process) {
MP_RETURN_IF_ERROR(
OutputBatch(cc, std::move(inference_state_to_process)));
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -385,10 +475,12 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
// memory buffer. Therefore, copies are cheap and should not cause the memory // memory buffer. Therefore, copies are cheap and should not cause the memory
// buffer to fall out of scope. In contrast, concat is only used where // buffer to fall out of scope. In contrast, concat is only used where
// necessary. // necessary.
::mediapipe::Status OutputBatch(CalculatorContext* cc) { ::mediapipe::Status OutputBatch(
CalculatorContext* cc, std::unique_ptr<InferenceState> inference_state) {
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors; std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
for (auto& keyed_tensors : input_tensor_batches_) {
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
if (options_.batch_size() == 1) { if (options_.batch_size() == 1) {
// Short circuit to avoid the cost of deep copying tensors in concat. // Short circuit to avoid the cost of deep copying tensors in concat.
if (!keyed_tensors.second.empty()) { if (!keyed_tensors.second.empty()) {
@ -404,7 +496,8 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
} else { } else {
// Pad by replicating the first tens or, then ignore the values. // Pad by replicating the first tens or, then ignore the values.
keyed_tensors.second.resize(options_.batch_size()); keyed_tensors.second.resize(options_.batch_size());
std::fill(keyed_tensors.second.begin() + batch_timestamps_.size(), std::fill(keyed_tensors.second.begin() +
inference_state->batch_timestamps_.size(),
keyed_tensors.second.end(), keyed_tensors.second[0]); keyed_tensors.second.end(), keyed_tensors.second[0]);
tf::Tensor concated; tf::Tensor concated;
const tf::Status concat_status = const tf::Status concat_status =
@ -414,7 +507,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
concated); concated);
} }
} }
input_tensor_batches_.clear(); inference_state->input_tensor_batches_.clear();
std::vector<mediapipe::ProtoString> output_tensor_names; std::vector<mediapipe::ProtoString> output_tensor_names;
std::vector<std::string> output_name_in_signature; std::vector<std::string> output_name_in_signature;
for (const std::string& tag : cc->Outputs().GetTags()) { for (const std::string& tag : cc->Outputs().GetTags()) {
@ -466,9 +559,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
int pos = std::find(output_name_in_signature.begin(), int pos = std::find(output_name_in_signature.begin(),
output_name_in_signature.end(), tag_pair.first) - output_name_in_signature.end(), tag_pair.first) -
output_name_in_signature.begin(); output_name_in_signature.begin();
input_tensor_batches_[tag_pair.second].emplace_back(outputs[pos]); inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
outputs[pos]);
} }
absl::WriterMutexLock l(&mutex_);
// Set that we want to split on each index of the 0th dimension. // Set that we want to split on each index of the 0th dimension.
std::vector<tf::int64> split_vector(options_.batch_size(), 1); std::vector<tf::int64> split_vector(options_.batch_size(), 1);
for (int i = 0; i < output_tensor_names.size(); ++i) { for (int i = 0; i < output_tensor_names.size(); ++i) {
@ -478,7 +573,8 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs() cc->Outputs()
.Tag(output_name_in_signature[i]) .Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), batch_timestamps_[0]); .Add(new tf::Tensor(output_tensor),
inference_state->batch_timestamps_[0]);
} }
} else { } else {
std::vector<tf::Tensor> split_tensors; std::vector<tf::Tensor> split_tensors;
@ -486,22 +582,30 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
tf::tensor::Split(outputs[i], split_vector, &split_tensors); tf::tensor::Split(outputs[i], split_vector, &split_tensors);
CHECK(split_status.ok()) << split_status.ToString(); CHECK(split_status.ok()) << split_status.ToString();
// Loop over timestamps so that we don't copy the padding. // Loop over timestamps so that we don't copy the padding.
for (int j = 0; j < batch_timestamps_.size(); ++j) { for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
tf::Tensor output_tensor(split_tensors[j]); tf::Tensor output_tensor(split_tensors[j]);
RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs() cc->Outputs()
.Tag(output_name_in_signature[i]) .Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), batch_timestamps_[j]); .Add(new tf::Tensor(output_tensor),
inference_state->batch_timestamps_[j]);
} }
} }
} }
// Get end time and report. // Get end time and report.
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalUsecsCounterSuffix) cc->GetCounter(kTotalUsecsCounterSuffix)
->IncrementBy(end_time - start_time); ->IncrementBy(end_time - start_time);
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
->IncrementBy(batch_timestamps_.size()); ->IncrementBy(inference_state->batch_timestamps_.size());
batch_timestamps_.clear();
// Make sure we hold on to the recursive state.
if (!options_.recurrent_tag_pair().empty()) {
inference_state_ = std::move(inference_state);
inference_state_->batch_timestamps_.clear();
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -514,11 +618,8 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
// A mapping between stream tags and the tensor names they are bound to. // A mapping between stream tags and the tensor names they are bound to.
std::map<std::string, std::string> tag_to_tensor_map_; std::map<std::string, std::string> tag_to_tensor_map_;
// A mapping between stream tags and the tensors we are collecting as a batch. absl::Mutex mutex_;
std::map<std::string, std::vector<tf::Tensor>> input_tensor_batches_; std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
// The timestamps that go into a batch.
std::vector<Timestamp> batch_timestamps_;
// The options for the calculator. // The options for the calculator.
TensorFlowInferenceCalculatorOptions options_; TensorFlowInferenceCalculatorOptions options_;

View File

@ -76,4 +76,13 @@ message TensorFlowInferenceCalculatorOptions {
// only works in the local process, not "globally" across multiple processes // only works in the local process, not "globally" across multiple processes
// or replicas (if any). Default to 0, i.e. no limit. // or replicas (if any). Default to 0, i.e. no limit.
optional int32 max_concurrent_session_runs = 6 [default = 0]; optional int32 max_concurrent_session_runs = 6 [default = 0];
// If turned on, the Calculator expects a vector of batched packages as input.
// This will make sure that you can turn on max_in_flight for batch_size
// greater than 1. Otherwise it results in problems of none-monotonically
// increasing timestamps.
// Use BatchSequentialCalculator to create the batches. The batch_size
// should agree for both calculators. All the data in a batch is processed
// together. The BatchSequentialCalculator can't run with max_in_flight.
optional bool batched_input = 7;
} }

View File

@ -89,17 +89,31 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
output_side_packets.Tag("SESSION"); output_side_packets.Tag("SESSION");
} }
// Create tensor from Vector and add as a Packet to the provided tag as input. Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
void AddVectorToInputsAsTensor(const std::vector<int32>& input,
const std::string& tag, int64 time) {
tf::TensorShape tensor_shape; tf::TensorShape tensor_shape;
tensor_shape.AddDim(input.size()); tensor_shape.AddDim(input.size());
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_INT32, tensor_shape); auto tensor = absl::make_unique<tf::Tensor>(tf::DT_INT32, tensor_shape);
for (int i = 0; i < input.size(); ++i) { for (int i = 0; i < input.size(); ++i) {
tensor->vec<int32>()(i) = input[i]; tensor->vec<int32>()(i) = input[i];
} }
return Adopt(tensor.release()).At(Timestamp(time));
}
// Create tensor from Vector and add as a Packet to the provided tag as input.
void AddVectorToInputsAsTensor(const std::vector<int32>& input,
const std::string& tag, int64 time) {
runner_->MutableInputs()->Tag(tag).packets.push_back( runner_->MutableInputs()->Tag(tag).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time))); CreateTensorPacket(input, time));
}
// Create tensor from Vector and add as a Packet to the provided tag as input.
void AddVectorToInputsAsPacket(const std::vector<Packet>& packets,
const std::string& tag) {
CHECK(!packets.empty())
<< "Please specify at least some data in the packet";
auto packets_ptr = absl::make_unique<std::vector<Packet>>(packets);
runner_->MutableInputs()->Tag(tag).packets.push_back(
Adopt(packets_ptr.release()).At(packets.begin()->Timestamp()));
} }
std::unique_ptr<CalculatorRunner> runner_; std::unique_ptr<CalculatorRunner> runner_;
@ -183,6 +197,45 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B"));
} }
TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
config.set_max_in_flight(2);
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(false);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3});
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}, expected_shape);
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_mult);
// Add only one of the two expected tensors at the next timestamp, expect
// useful failure message.
AddVectorToInputsAsTensor({1, 2, 3}, "A", 1);
auto run_status = runner_->Run();
ASSERT_FALSE(run_status.ok());
EXPECT_THAT(run_status.ToString(),
testing::HasSubstr("TensorFlowInferenceCalculator"));
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B"));
}
TEST_F(TensorflowInferenceCalculatorTest, BadTag) { TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator"); config.set_calculator("TensorFlowInferenceCalculator");
@ -235,6 +288,86 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
->Get()); ->Get());
} }
TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
config.set_max_in_flight(2);
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest,
GetMultiBatchComputed_MoreThanMaxInFlight) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
config.set_max_in_flight(2);
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(1);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsTensor({2, 2, 2}, "A", 0);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 0);
AddVectorToInputsAsTensor({3, 3, 3}, "A", 1);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 1);
AddVectorToInputsAsTensor({4, 4, 4}, "A", 2);
AddVectorToInputsAsTensor({3, 4, 5}, "B", 2);
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(3, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
const tf::Tensor& tensor_mult2 = output_packets_mult[2].Get<tf::Tensor>();
auto expected_tensor2 = tf::test::AsTensor<int32>({12, 16, 20});
tf::test::ExpectTensorEqual<int32>(tensor_mult2, expected_tensor2);
EXPECT_EQ(3, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) { TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator"); config.set_calculator("TensorFlowInferenceCalculator");
@ -311,6 +444,66 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
->Get()); ->Get());
} }
TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
config.set_max_in_flight(2);
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batched_input(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsPacket(
{CreateTensorPacket({2, 2, 2}, 0), CreateTensorPacket({3, 3, 3}, 1)},
"A");
AddVectorToInputsAsPacket(
{CreateTensorPacket({3, 4, 5}, 0), CreateTensorPacket({3, 4, 5}, 1)},
"B");
AddVectorToInputsAsPacket(
{CreateTensorPacket({4, 4, 4}, 2), CreateTensorPacket({5, 5, 5}, 3)},
"A");
AddVectorToInputsAsPacket(
{CreateTensorPacket({3, 4, 5}, 2), CreateTensorPacket({3, 4, 5}, 3)},
"B");
AddVectorToInputsAsPacket({CreateTensorPacket({6, 6, 6}, 4)}, "A");
AddVectorToInputsAsPacket({CreateTensorPacket({3, 4, 5}, 4)}, "B");
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
ASSERT_EQ(5, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
tf::test::ExpectTensorEqual<int32>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32>({9, 12, 15});
tf::test::ExpectTensorEqual<int32>(tensor_mult1, expected_tensor1);
const tf::Tensor& tensor_mult2 = output_packets_mult[2].Get<tf::Tensor>();
auto expected_tensor2 = tf::test::AsTensor<int32>({12, 16, 20});
tf::test::ExpectTensorEqual<int32>(tensor_mult2, expected_tensor2);
const tf::Tensor& tensor_mult3 = output_packets_mult[3].Get<tf::Tensor>();
auto expected_tensor3 = tf::test::AsTensor<int32>({15, 20, 25});
tf::test::ExpectTensorEqual<int32>(tensor_mult3, expected_tensor3);
const tf::Tensor& tensor_mult4 = output_packets_mult[4].Get<tf::Tensor>();
auto expected_tensor4 = tf::test::AsTensor<int32>({18, 24, 30});
tf::test::ExpectTensorEqual<int32>(tensor_mult4, expected_tensor4);
EXPECT_EQ(5, runner_
->GetCounter(
"TensorFlowInferenceCalculator-TotalProcessedTimestamps")
->Get());
}
TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) { TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator"); config.set_calculator("TensorFlowInferenceCalculator");
@ -509,4 +702,40 @@ TEST_F(TensorflowInferenceCalculatorTest,
->Get()); ->Get());
} }
TEST_F(TensorflowInferenceCalculatorTest, BatchedInputTooBigBatch) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorFlowInferenceCalculator");
config.add_input_stream("A:tensor_a");
config.add_input_stream("B:tensor_b");
config.add_output_stream("MULTIPLIED:tensor_o1");
config.add_input_side_packet("SESSION:session");
config.set_max_in_flight(2);
CalculatorOptions options;
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batch_size(2);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_add_batch_dim_to_tensors(true);
options.MutableExtension(TensorFlowInferenceCalculatorOptions::ext)
->set_batched_input(true);
*config.mutable_options() = options;
runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket();
AddVectorToInputsAsPacket(
{CreateTensorPacket({2, 2, 2}, 0), CreateTensorPacket({3, 3, 3}, 1),
CreateTensorPacket({4, 4, 4}, 2)},
"A");
AddVectorToInputsAsPacket(
{CreateTensorPacket({3, 4, 5}, 0), CreateTensorPacket({3, 4, 5}, 1),
CreateTensorPacket({3, 4, 5}, 2)},
"B");
auto status = runner_->Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
::testing::HasSubstr(
"has more packets than batch capacity. batch_size: 2 packets: 3"));
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -29,6 +29,7 @@ namespace mediapipe {
// Streams: // Streams:
const char kBBoxTag[] = "BBOX"; const char kBBoxTag[] = "BBOX";
const char kImageTag[] = "IMAGE"; const char kImageTag[] = "IMAGE";
const char kKeypointsTag[] = "KEYPOINTS";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kForwardFlowImageTag[] = "FORWARD_FLOW_ENCODED"; const char kForwardFlowImageTag[] = "FORWARD_FLOW_ENCODED";
@ -150,7 +151,6 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
<< "or" << kAudioDecoderOptions; << "or" << kAudioDecoderOptions;
} }
// Optional streams.
if (cc->Outputs().HasTag(kForwardFlowImageTag)) { if (cc->Outputs().HasTag(kForwardFlowImageTag)) {
cc->Outputs().Tag(kForwardFlowImageTag).Set<std::string>(); cc->Outputs().Tag(kForwardFlowImageTag).Set<std::string>();
} }
@ -244,6 +244,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
const auto& sequence = cc->InputSidePackets() const auto& sequence = cc->InputSidePackets()
.Tag(kSequenceExampleTag) .Tag(kSequenceExampleTag)
.Get<tensorflow::SequenceExample>(); .Get<tensorflow::SequenceExample>();
if (cc->Outputs().HasTag(kKeypointsTag)) {
keypoint_names_ = absl::StrSplit(options.keypoint_names(), ',');
default_keypoint_location_ = options.default_keypoint_location();
}
if (cc->OutputSidePackets().HasTag(kDataPath)) { if (cc->OutputSidePackets().HasTag(kDataPath)) {
std::string root_directory = ""; std::string root_directory = "";
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) { if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
@ -357,7 +361,6 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
end_timestamp = end_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1]; timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
} }
for (const auto& map_kv : timestamps_) { for (const auto& map_kv : timestamps_) {
for (int i = 0; i < map_kv.second.size(); ++i) { for (int i = 0; i < map_kv.second.size(); ++i) {
if (map_kv.second[i] >= start_timestamp && if (map_kv.second[i] >= start_timestamp &&
@ -454,6 +457,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
int current_timestamp_index_; int current_timestamp_index_;
// Store the very first timestamp, so we output everything on the first frame. // Store the very first timestamp, so we output everything on the first frame.
int64 first_timestamp_seen_; int64 first_timestamp_seen_;
// List of keypoint names.
std::vector<std::string> keypoint_names_;
// Default keypoint location when missing.
float default_keypoint_location_;
}; };
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator); REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,7 +16,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
@ -257,6 +257,7 @@ cc_library(
}) + select({ }) + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [ "//mediapipe:android": [
"//mediapipe/util/android/file/base",
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate", "@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
], ],
}) + select({ }) + select({

View File

@ -33,6 +33,12 @@
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#if defined(MEDIAPIPE_ANDROID)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h"
#include "mediapipe/util/android/file/base/helpers.h"
#endif // ANDROID
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
@ -219,6 +225,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override; ::mediapipe::Status Close(CalculatorContext* cc) override;
private: private:
::mediapipe::Status ReadKernelsFromFile();
::mediapipe::Status WriteKernelsToFile();
::mediapipe::Status LoadModel(CalculatorContext* cc); ::mediapipe::Status LoadModel(CalculatorContext* cc);
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc); ::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
::mediapipe::Status LoadDelegate(CalculatorContext* cc); ::mediapipe::Status LoadDelegate(CalculatorContext* cc);
@ -273,6 +281,9 @@ class TfLiteInferenceCalculator : public CalculatorBase {
bool use_quantized_tensors_ = false; bool use_quantized_tensors_ = false;
bool use_advanced_gpu_api_ = false; bool use_advanced_gpu_api_ = false;
bool use_kernel_caching_ = false;
std::string cached_kernel_filename_;
}; };
REGISTER_CALCULATOR(TfLiteInferenceCalculator); REGISTER_CALCULATOR(TfLiteInferenceCalculator);
@ -354,6 +365,17 @@ bool ShouldUseGpu(CC* cc) {
options.has_delegate() && options.has_delegate() &&
options.delegate().has_gpu() && options.delegate().has_gpu() &&
options.delegate().gpu().use_advanced_gpu_api(); options.delegate().gpu().use_advanced_gpu_api();
use_kernel_caching_ =
use_advanced_gpu_api_ && options.delegate().gpu().use_kernel_caching();
if (use_kernel_caching_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID)
cached_kernel_filename_ =
"/sdcard/" + mediapipe::File::Basename(options.model_path()) + ".ker";
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID
}
if (use_advanced_gpu_api_ && !gpu_input_) { if (use_advanced_gpu_api_ && !gpu_input_) {
LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers." LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers."
"Falling back to the default TFLite API."; "Falling back to the default TFLite API.";
@ -423,7 +445,23 @@ bool ShouldUseGpu(CC* cc) {
}); });
} }
::mediapipe::Status TfLiteInferenceCalculator::WriteKernelsToFile() {
#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID)
if (use_kernel_caching_) {
// Save kernel file.
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>(
tflite_gpu_runner_->GetSerializedBinaryCache());
std::string cache_str(kernel_cache->begin(), kernel_cache->end());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(WriteKernelsToFile());
return RunInContextIfNeeded([this]() -> ::mediapipe::Status { return RunInContextIfNeeded([this]() -> ::mediapipe::Status {
if (delegate_) { if (delegate_) {
interpreter_ = nullptr; interpreter_ = nullptr;
@ -635,6 +673,22 @@ bool ShouldUseGpu(CC* cc) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteInferenceCalculator::ReadKernelsFromFile() {
#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID)
if (use_kernel_caching_) {
// Load pre-compiled kernel file.
if (mediapipe::File::Exists(cached_kernel_filename_)) {
std::string cache_str;
MP_RETURN_IF_ERROR(
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
}
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( ::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
CalculatorContext* cc) { CalculatorContext* cc) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
@ -692,6 +746,9 @@ bool ShouldUseGpu(CC* cc) {
::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
} }
MP_RETURN_IF_ERROR(ReadKernelsFromFile());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE

View File

@ -48,6 +48,10 @@ message TfLiteInferenceCalculatorOptions {
// example: // example:
// delegate: { gpu { use_advanced_gpu_api: true } } // delegate: { gpu { use_advanced_gpu_api: true } }
optional bool use_advanced_gpu_api = 1 [default = false]; optional bool use_advanced_gpu_api = 1 [default = false];
// Load pre-compiled serialized binary cache to accelerate init process.
// Only available for OpenCL delegate on Android.
optional bool use_kernel_caching = 2 [default = false];
} }
// Android only. // Android only.
message Nnapi {} message Nnapi {}

View File

@ -14,7 +14,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
@ -783,6 +783,7 @@ mediapipe_cc_proto_library(
cc_library( cc_library(
name = "landmarks_to_render_data_calculator", name = "landmarks_to_render_data_calculator",
srcs = ["landmarks_to_render_data_calculator.cc"], srcs = ["landmarks_to_render_data_calculator.cc"],
hdrs = ["landmarks_to_render_data_calculator.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":landmarks_to_render_data_calculator_cc_proto", ":landmarks_to_render_data_calculator_cc_proto",

View File

@ -389,8 +389,6 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
// Upload render target to GPU. // Upload render target to GPU.
{ {
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
glBindTexture(GL_TEXTURE_2D, image_mat_tex_); glBindTexture(GL_TEXTURE_2D, image_mat_tex_);
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_canvas_, height_canvas_, glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_canvas_, height_canvas_,
GL_RGB, GL_UNSIGNED_BYTE, overlay_image); GL_RGB, GL_UNSIGNED_BYTE, overlay_image);

View File

@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -34,8 +35,6 @@ constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kLandmarkLabel[] = "KEYPOINT"; constexpr char kLandmarkLabel[] = "KEYPOINT";
constexpr int kMaxLandmarkThickness = 18; constexpr int kMaxLandmarkThickness = 18;
using ::mediapipe::RenderAnnotation_Point;
inline void SetColor(RenderAnnotation* annotation, const Color& color) { inline void SetColor(RenderAnnotation* annotation, const Color& color) {
annotation->mutable_color()->set_r(color.r()); annotation->mutable_color()->set_r(color.r());
annotation->mutable_color()->set_g(color.g()); annotation->mutable_color()->set_g(color.g());
@ -162,45 +161,6 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color,
} // namespace } // namespace
// A calculator that converts Landmark proto to RenderData proto for
// visualization. The input should be LandmarkList proto. It is also possible
// to specify the connections between landmarks.
//
// Example config:
// node {
// calculator: "LandmarksToRenderDataCalculator"
// input_stream: "NORM_LANDMARKS:landmarks"
// output_stream: "RENDER_DATA:render_data"
// options {
// [LandmarksToRenderDataCalculatorOptions.ext] {
// landmark_connections: [0, 1, 1, 2]
// landmark_color { r: 0 g: 255 b: 0 }
// connection_color { r: 0 g: 255 b: 0 }
// thickness: 4.0
// }
// }
// }
class LandmarksToRenderDataCalculator : public CalculatorBase {
public:
LandmarksToRenderDataCalculator() {}
~LandmarksToRenderDataCalculator() override {}
LandmarksToRenderDataCalculator(const LandmarksToRenderDataCalculator&) =
delete;
LandmarksToRenderDataCalculator& operator=(
const LandmarksToRenderDataCalculator&) = delete;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
LandmarksToRenderDataCalculatorOptions options_;
std::vector<int> landmark_connections_;
};
REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
::mediapipe::Status LandmarksToRenderDataCalculator::GetContract( ::mediapipe::Status LandmarksToRenderDataCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) || RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) ||
@ -354,4 +314,5 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,69 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_TO_RENDER_DATA_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_TO_RENDER_DATA_CALCULATOR_H_
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
// A calculator that converts Landmark proto to RenderData proto for
// visualization. The input should be LandmarkList proto. It is also possible
// to specify the connections between landmarks.
//
// Example config:
// node {
// calculator: "LandmarksToRenderDataCalculator"
// input_stream: "NORM_LANDMARKS:landmarks"
// output_stream: "RENDER_DATA:render_data"
// options {
// [LandmarksToRenderDataCalculatorOptions.ext] {
// landmark_connections: [0, 1, 1, 2]
// landmark_color { r: 0 g: 255 b: 0 }
// connection_color { r: 0 g: 255 b: 0 }
// thickness: 4.0
// }
// }
// }
class LandmarksToRenderDataCalculator : public CalculatorBase {
public:
LandmarksToRenderDataCalculator() {}
~LandmarksToRenderDataCalculator() override {}
LandmarksToRenderDataCalculator(const LandmarksToRenderDataCalculator&) =
delete;
LandmarksToRenderDataCalculator& operator=(
const LandmarksToRenderDataCalculator&) = delete;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
protected:
::mediapipe::LandmarksToRenderDataCalculatorOptions options_;
std::vector<int> landmark_connections_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_TO_RENDER_DATA_CALCULATOR_H_

View File

@ -19,7 +19,7 @@ load(
"mediapipe_binary_graph", "mediapipe_binary_graph",
) )
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -15,7 +15,7 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"]) package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
# Basic library common across example apps. # Basic library common across example apps.
android_library( android_library(

View File

@ -80,7 +80,7 @@ public class MainActivity extends AppCompatActivity {
@Override @Override
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main); setContentView(getContentViewLayoutResId());
try { try {
applicationInfo = applicationInfo =
@ -112,6 +112,12 @@ public class MainActivity extends AppCompatActivity {
PermissionHelper.checkAndRequestCameraPermissions(this); PermissionHelper.checkAndRequestCameraPermissions(this);
} }
// Used to obtain the content view for this application. If you are extending this class, and
// have a custom layout, override this method and return the custom layout.
protected int getContentViewLayoutResId() {
return R.layout.activity_main;
}
@Override @Override
protected void onResume() { protected void onResume() {
super.onResume(); super.onResume();

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -0,0 +1,99 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
java_lite_proto_library(
name = "sticker_buffer_java_proto_lite",
deps = ["//mediapipe/graphs/instant_motion_tracking/calculators:sticker_buffer_proto"],
)
android_library(
name = "instantmotiontracking_lib",
srcs = glob(["*.java"]),
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
resource_files = glob([
"res/layout/**",
"res/drawable/**",
]),
visibility = ["//visibility:public"],
deps = [
":sticker_buffer_java_proto_lite",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib",
"//mediapipe/java/com/google/mediapipe/components:android_components",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//third_party:androidx_appcompat",
"//third_party:androidx_core",
"//third_party:opencv",
"@maven//:androidx_concurrent_concurrent_futures",
"@maven//:com_github_bumptech_glide_glide",
"@maven//:com_google_guava_guava",
],
)
# Include all calculators specific to this project defined by BUILD in graphs
cc_binary(
name = "libmediapipe_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/graphs/instant_motion_tracking:instant_motion_tracking_deps",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
],
)
# Converts the .so cc_binary into a cc_library, to be consumed in an android_binary.
cc_library(
name = "mediapipe_jni_lib",
srcs = [":libmediapipe_jni.so"],
alwayslink = 1,
)
genrule(
name = "asset3d",
srcs = ["//mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/assets:robot/robot.obj.uuu.zip"],
outs = ["robot/robot.obj.uuu"],
cmd = "unzip -p $< > $@",
)
android_binary(
name = "instantmotiontracking",
assets = [
":asset3d",
"//mediapipe/graphs/instant_motion_tracking:instant_motion_tracking.binarypb",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/assets:gif/gif.obj.uuu",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/assets:gif/default_gif_texture.jpg",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/assets:robot/robot_texture.jpg",
],
assets_dir = "",
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.apps.instantmotiontracking",
"appName": "Instant Motion Tracking",
"mainActivity": ".MainActivity",
"cameraFacingFront": "False",
"binaryGraphName": "instant_motion_tracking.binarypb",
"inputVideoStreamName": "input_video",
"outputVideoStreamName": "output_video",
"flipFramesVertically": "True",
},
multidex = "native",
deps = [
":instantmotiontracking_lib",
":mediapipe_jni_lib",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
],
)

View File

@ -0,0 +1,103 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.instantmotiontracking;
import android.content.ClipDescription;
import android.content.Context;
import android.net.Uri;
import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText;
import android.util.AttributeSet;
import android.util.Log;
import android.view.inputmethod.EditorInfo;
import android.view.inputmethod.InputConnection;
import androidx.core.view.inputmethod.EditorInfoCompat;
import androidx.core.view.inputmethod.InputConnectionCompat;
import androidx.core.view.inputmethod.InputContentInfoCompat;
// import android.support.v13.view.inputmethod.EditorInfoCompat;
// import android.support.v13.view.inputmethod.InputConnectionCompat;
// import android.support.v13.view.inputmethod.InputContentInfoCompat;
/**
* This custom EditText implementation uses the existing EditText framework in
* order to develop a GIFEditText input box which is capable of accepting GIF
* animations from the Android system keyboard and return the GIF location with
* a content URI.
*/
public class GIFEditText extends AppCompatEditText {
private GIFCommitListener gifCommitListener;
public GIFEditText(Context context) {
super(context);
}
public GIFEditText(Context context, AttributeSet attrs) {
super(context, attrs);
}
/**
* onGIFCommit is called once content is pushed to the EditText via the
* Android keyboard.
*/
public interface GIFCommitListener {
void onGIFCommit(Uri contentUri, ClipDescription description);
}
/**
* Used to set the gifCommitListener for this GIFEditText.
*
* @param gifCommitListener handles response to new content pushed to EditText
*/
public void setGIFCommitListener(GIFCommitListener gifCommitListener) {
this.gifCommitListener = gifCommitListener;
}
@Override
public InputConnection onCreateInputConnection(EditorInfo editorInfo) {
final InputConnection inputConnection = super.onCreateInputConnection(editorInfo);
EditorInfoCompat.setContentMimeTypes(editorInfo, new String[] {"image/gif"});
return InputConnectionCompat.createWrapper(
inputConnection,
editorInfo,
new InputConnectionCompat.OnCommitContentListener() {
@Override
public boolean onCommitContent(
final InputContentInfoCompat inputContentInfo, int flags, Bundle opts) {
try {
if (gifCommitListener != null) {
Runnable runnable =
new Runnable() {
@Override
public void run() {
inputContentInfo.requestPermission();
gifCommitListener.onGIFCommit(
inputContentInfo.getContentUri(), inputContentInfo.getDescription());
inputContentInfo.releasePermission();
}
};
new Thread(runnable).start();
}
} catch (RuntimeException e) {
Log.e("GIFEditText", "Input connection to GIF selection failed");
e.printStackTrace();
return false;
}
return true;
}
});
}
}

View File

@ -0,0 +1,633 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.instantmotiontracking;
import static java.lang.Math.max;
import android.content.ClipDescription;
import android.content.Context;
import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.drawable.Drawable;
import android.hardware.Sensor;
import android.hardware.SensorEvent;
import android.hardware.SensorEventListener;
import android.hardware.SensorManager;
import android.net.Uri;
import android.os.Bundle;
import android.util.Log;
import android.util.Size;
import android.view.MotionEvent;
import android.view.SurfaceHolder;
import android.view.View;
import android.view.ViewGroup;
import android.view.inputmethod.InputMethodManager;
import android.widget.ImageButton;
import android.widget.ImageView;
import android.widget.LinearLayout;
import com.bumptech.glide.Glide;
import com.bumptech.glide.load.resource.gif.GifDrawable;
import com.bumptech.glide.request.target.CustomTarget;
import com.bumptech.glide.request.transition.Transition;
import com.google.mediapipe.components.FrameProcessor;
import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Packet;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* This is the MainActivity that handles camera input, IMU sensor data acquisition
* and sticker management for the InstantMotionTracking MediaPipe project.
*/
public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
private static final String TAG = "InstantMotionTrackingMainActivity";
// Allows for automated packet transmission to graph
private MediaPipePacketManager mediaPipePacketManager;
private static final int TARGET_CAMERA_WIDTH = 960;
private static final int TARGET_CAMERA_HEIGHT = 1280;
private static final float TARGET_CAMERA_ASPECT_RATIO =
(float) TARGET_CAMERA_WIDTH / (float) TARGET_CAMERA_HEIGHT;
// Bounds for a single click (sticker anchor reset)
private static final long CLICK_DURATION = 300; // ms
private long clickStartMillis = 0;
private ViewGroup viewGroup;
// Contains dynamic layout of sticker data controller
private LinearLayout buttonLayout;
private ArrayList<StickerManager> stickerArrayList;
// Current sticker being edited by user
private StickerManager currentSticker;
// Trip value used to determine sticker re-anchoring
private static final String STICKER_SENTINEL_TAG = "sticker_sentinel";
private int stickerSentinel = -1;
// Define parameters for 'reactivity' of object
private static final float ROTATION_SPEED = 5.0f;
private static final float SCALING_FACTOR = 0.025f;
// Parameters of device visual field for rendering system
// (68 degrees, 4:3 for Pixel 4)
// TODO : Make acquisition of this information automated
private static final float VERTICAL_FOV_RADIANS = (float) Math.toRadians(68.0);
private static final String FOV_SIDE_PACKET_TAG = "vertical_fov_radians";
private static final String ASPECT_RATIO_SIDE_PACKET_TAG = "aspect_ratio";
private static final String IMU_MATRIX_TAG = "imu_rotation_matrix";
private static final int SENSOR_SAMPLE_DELAY = SensorManager.SENSOR_DELAY_FASTEST;
private final float[] rotationMatrix = new float[9];
private static final String STICKER_PROTO_TAG = "sticker_proto_string";
// Assets for object rendering
// All animation assets and tags for the first asset (1)
private Bitmap asset3dTexture = null;
private static final String ASSET_3D_TEXTURE = "robot/robot_texture.jpg";
private static final String ASSET_3D_FILE = "robot/robot.obj.uuu";
private static final String ASSET_3D_TEXTURE_TAG = "texture_3d";
private static final String ASSET_3D_TAG = "asset_3d";
// All GIF animation assets and tags
private GIFEditText editText;
private ArrayList<Bitmap> gifBitmaps = new ArrayList<>();
private int gifCurrentIndex = 0;
private Bitmap defaultGIFTexture = null; // Texture sent if no gif available
// last time the GIF was updated
private long gifLastFrameUpdateMS = System.currentTimeMillis();
private static final int GIF_FRAME_RATE = 20; // 20 FPS
private static final String GIF_ASPECT_RATIO_TAG = "gif_aspect_ratio";
private static final String DEFAULT_GIF_TEXTURE = "gif/default_gif_texture.jpg";
private static final String GIF_FILE = "gif/gif.obj.uuu";
private static final String GIF_TEXTURE_TAG = "gif_texture";
private static final String GIF_ASSET_TAG = "gif_asset_name";
private int cameraWidth = TARGET_CAMERA_WIDTH;
private int cameraHeight = TARGET_CAMERA_HEIGHT;
@Override
protected Size cameraTargetResolution() {
// Camera size is in landscape, so here we have (height, width)
return new Size(TARGET_CAMERA_HEIGHT, TARGET_CAMERA_WIDTH);
}
@Override
protected Size computeViewSize(int width, int height) {
// Try to force aspect ratio of view size to match our target aspect ratio
return new Size(height, (int) (height * TARGET_CAMERA_ASPECT_RATIO));
}
@Override
protected void onPreviewDisplaySurfaceChanged(
SurfaceHolder holder, int format, int width, int height) {
super.onPreviewDisplaySurfaceChanged(holder, format, width, height);
boolean isCameraRotated = cameraHelper.isCameraRotated();
// cameraImageSize computation logic duplicated from base MainActivity
Size viewSize = computeViewSize(width, height);
Size cameraImageSize = cameraHelper.computeDisplaySizeFromViewSize(viewSize);
cameraWidth =
isCameraRotated ? cameraImageSize.getHeight() : cameraImageSize.getWidth();
cameraHeight =
isCameraRotated ? cameraImageSize.getWidth() : cameraImageSize.getHeight();
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
editText = findViewById(R.id.gif_edit_text);
editText.setGIFCommitListener(
new GIFEditText.GIFCommitListener() {
@Override
public void onGIFCommit(Uri contentUri, ClipDescription description) {
// The application must have permission to access the GIF content
grantUriPermission(
"com.google.mediapipe.apps.instantmotiontracking",
contentUri,
Intent.FLAG_GRANT_READ_URI_PERMISSION);
// Set GIF frames from content URI
setGIFBitmaps(contentUri.toString());
// Close the keyboard upon GIF acquisition
closeKeyboard();
}
});
// Send loaded 3d render assets as side packets to graph
prepareDemoAssets();
AndroidPacketCreator packetCreator = processor.getPacketCreator();
Map<String, Packet> inputSidePackets = new HashMap<>();
inputSidePackets.put(ASSET_3D_TEXTURE_TAG,
packetCreator.createRgbaImageFrame(asset3dTexture));
inputSidePackets.put(ASSET_3D_TAG,
packetCreator.createString(ASSET_3D_FILE));
inputSidePackets.put(GIF_ASSET_TAG,
packetCreator.createString(GIF_FILE));
processor.setInputSidePackets(inputSidePackets);
// Add frame listener to PacketManagement system
mediaPipePacketManager = new MediaPipePacketManager();
processor.setOnWillAddFrameListener(mediaPipePacketManager);
// Send device properties to render objects via OpenGL
Map<String, Packet> devicePropertiesSidePackets = new HashMap<>();
// TODO: Note that if our actual camera stream resolution does not match the
// requested aspect ratio, then we will need to update the value used for
// this packet, or else tracking results will be off.
devicePropertiesSidePackets.put(
ASPECT_RATIO_SIDE_PACKET_TAG, packetCreator.createFloat32(TARGET_CAMERA_ASPECT_RATIO));
devicePropertiesSidePackets.put(
FOV_SIDE_PACKET_TAG, packetCreator.createFloat32(VERTICAL_FOV_RADIANS));
processor.setInputSidePackets(devicePropertiesSidePackets);
// Begin with 0 stickers in dataset
stickerArrayList = new ArrayList<>();
currentSticker = null;
SensorManager sensorManager = (SensorManager) getSystemService(SENSOR_SERVICE);
List<Sensor> sensorList = sensorManager.getSensorList(Sensor.TYPE_ROTATION_VECTOR);
sensorManager.registerListener(
new SensorEventListener() {
private final float[] rotMatFromVec = new float[9];
@Override
public void onAccuracyChanged(Sensor sensor, int accuracy) {}
// Update procedure on sensor adjustment (phone changes orientation)
@Override
public void onSensorChanged(SensorEvent event) {
// Get the Rotation Matrix from the Rotation Vector
SensorManager.getRotationMatrixFromVector(rotMatFromVec, event.values);
// AXIS_MINUS_X is used to remap the rotation matrix for left hand
// rules in the MediaPipe graph
SensorManager.remapCoordinateSystem(
rotMatFromVec, SensorManager.AXIS_MINUS_X, SensorManager.AXIS_Y, rotationMatrix);
}
},
(Sensor) sensorList.get(0),
SENSOR_SAMPLE_DELAY);
// Mechanisms for zoom, pinch, rotation, tap gestures
buttonLayout = (LinearLayout) findViewById(R.id.button_layout);
viewGroup = findViewById(R.id.preview_display_layout);
viewGroup.setOnTouchListener(
new View.OnTouchListener() {
@Override
public boolean onTouch(View v, MotionEvent event) {
return manageUiTouch(event);
}
});
refreshUi();
}
// Obtain our custom activity_main layout for InstantMotionTracking
@Override
protected int getContentViewLayoutResId() {
return R.layout.instant_motion_tracking_activity_main;
}
// Manages a touch event in order to perform placement/rotation/scaling gestures
// on virtual sticker objects.
private boolean manageUiTouch(MotionEvent event) {
if (currentSticker != null) {
switch (event.getAction()) {
// Detecting a single click for object re-anchoring
case (MotionEvent.ACTION_DOWN):
clickStartMillis = System.currentTimeMillis();
break;
case (MotionEvent.ACTION_UP):
if (System.currentTimeMillis() - clickStartMillis <= CLICK_DURATION) {
recordClick(event);
}
break;
case (MotionEvent.ACTION_MOVE):
// Rotation and Scaling are independent events and can occur simulataneously
if (event.getPointerCount() == 2) {
if (event.getHistorySize() > 1) {
// Calculate user scaling of sticker
float newScaleFactor = getNewScaleFactor(event, currentSticker.getScaleFactor());
currentSticker.setScaleFactor(newScaleFactor);
// calculate rotation (radians) for dynamic y-axis rotations
float rotationIncrement = calculateRotationRadians(event);
currentSticker.setRotation(currentSticker.getRotation() + rotationIncrement);
}
}
break;
default:
// fall out
}
}
return true;
}
// Returns a float value that is equal to the radians of rotation from a two-finger
// MotionEvent recorded by the OnTouchListener.
private static float calculateRotationRadians(MotionEvent event) {
float tangentA =
(float) Math.atan2(event.getY(1) - event.getY(0), event.getX(1) - event.getX(0));
float tangentB =
(float)
Math.atan2(
event.getHistoricalY(1, 0) - event.getHistoricalY(0, 0),
event.getHistoricalX(1, 0) - event.getHistoricalX(0, 0));
float angle = ((float) Math.toDegrees(tangentA - tangentB)) % 360f;
angle += ((angle < -180f) ? +360f : ((angle > 180f) ? -360f : 0.0f));
float rotationIncrement = (float) (Math.PI * ((angle * ROTATION_SPEED) / 180));
return rotationIncrement;
}
// Returns a float value that is equal to the translation distance between
// two-fingers that move in a pinch/spreading direction.
private static float getNewScaleFactor(MotionEvent event, float currentScaleFactor) {
double newDistance = getDistance(event.getX(0), event.getY(0), event.getX(1), event.getY(1));
double oldDistance =
getDistance(
event.getHistoricalX(0, 0),
event.getHistoricalY(0, 0),
event.getHistoricalX(1, 0),
event.getHistoricalY(1, 0));
float signFloat =
(newDistance < oldDistance)
? -SCALING_FACTOR
: SCALING_FACTOR; // Are they moving towards each other?
currentScaleFactor *= (1f + signFloat);
return currentScaleFactor;
}
// Called if a single touch event is recorded on the screen and used to set the
// new anchor position for the current sticker in focus.
private void recordClick(MotionEvent event) {
// First normalize our click position w.r.t. to the view display
float x = (event.getX() / viewGroup.getWidth());
float y = (event.getY() / viewGroup.getHeight());
// MediaPipe can automatically crop our camera stream when displaying it to
// our surface, which can throw off our touch point calulations. So we need
// to replicate that logic here. See FrameScaleMode::kFillAndCrop usage in
// gl_quad_renderer.cc for more details.
float widthRatio = (float) viewGroup.getWidth() / (float) cameraWidth;
float heightRatio = (float) viewGroup.getHeight() / (float) cameraHeight;
float maxRatio = max(widthRatio, heightRatio);
widthRatio /= maxRatio;
heightRatio /= maxRatio;
// Now we scale by the scale factors, and then reposition (since cropping
// is always centered)
x *= widthRatio;
x += 0.5f * (1.0f - widthRatio);
y *= heightRatio;
y += 0.5f * (1.0f - heightRatio);
// Finally, we can pass our adjusted x and y points to the StickerManager
currentSticker.setAnchorCoordinate(x, y);
stickerSentinel = currentSticker.getstickerId();
}
// Provided the X and Y coordinates of two points, the distance between them
// will be returned.
private static double getDistance(double x1, double y1, double x2, double y2) {
return Math.hypot((y2 - y1), (x2 - x1));
}
// Called upon each button click, and used to populate the buttonLayout with the
// current sticker data in addition to sticker controls (delete, remove, back).
private void refreshUi() {
if (currentSticker != null) { // No sticker in view
buttonLayout.removeAllViews();
ImageButton deleteSticker = new ImageButton(this);
setControlButtonDesign(deleteSticker, R.drawable.baseline_clear_24);
deleteSticker.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (currentSticker != null) {
stickerArrayList.remove(currentSticker);
currentSticker = null;
refreshUi();
}
}
});
// Go to home sticker menu
ImageButton goBack = new ImageButton(this);
setControlButtonDesign(goBack, R.drawable.baseline_arrow_back_24);
goBack.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
currentSticker = null;
refreshUi();
}
});
// Change sticker to next possible render
ImageButton loopRender = new ImageButton(this);
setControlButtonDesign(loopRender, R.drawable.baseline_loop_24);
loopRender.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
currentSticker.setRender(currentSticker.getRender().iterate());
refreshUi();
}
});
buttonLayout.addView(deleteSticker);
buttonLayout.addView(goBack);
buttonLayout.addView(loopRender);
// Add the GIF search option if current sticker is GIF
if (currentSticker.getRender() == StickerManager.Render.GIF) {
ImageButton gifSearch = new ImageButton(this);
setControlButtonDesign(gifSearch, R.drawable.baseline_search_24);
gifSearch.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
// Clear the text field to prevent text artifacts in GIF selection
editText.setText("");
// Open the Keyboard to allow user input
openKeyboard();
}
});
buttonLayout.addView(gifSearch);
}
} else {
buttonLayout.removeAllViews();
// Display stickers
for (final StickerManager sticker : stickerArrayList) {
final ImageButton stickerButton = new ImageButton(this);
stickerButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
currentSticker = sticker;
refreshUi();
}
});
if (sticker.getRender() == StickerManager.Render.GIF) {
setControlButtonDesign(stickerButton, R.drawable.asset_gif_preview);
} else if (sticker.getRender() == StickerManager.Render.ASSET_3D) {
setStickerButtonDesign(stickerButton, R.drawable.asset_3d_preview);
}
buttonLayout.addView(stickerButton);
}
ImageButton addSticker = new ImageButton(this);
setControlButtonDesign(addSticker, R.drawable.baseline_add_24);
addSticker.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
StickerManager newSticker = new StickerManager();
stickerArrayList.add(newSticker);
currentSticker = newSticker;
refreshUi();
}
});
ImageButton clearStickers = new ImageButton(this);
setControlButtonDesign(clearStickers, R.drawable.baseline_clear_all_24);
clearStickers.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
stickerArrayList.clear();
refreshUi();
}
});
buttonLayout.addView(addSticker);
buttonLayout.addView(clearStickers);
}
}
// Sets ImageButton UI for Control Buttons.
private void setControlButtonDesign(ImageButton btn, int imageDrawable) {
// btn.setImageDrawable(getResources().getDrawable(imageDrawable));
btn.setImageDrawable(getDrawable(imageDrawable));
btn.setBackgroundColor(Color.parseColor("#00ffffff"));
btn.setColorFilter(Color.parseColor("#0494a4"));
btn.setLayoutParams(new LinearLayout.LayoutParams(200, 200));
btn.setPadding(25, 25, 25, 25);
btn.setScaleType(ImageView.ScaleType.FIT_XY);
}
// Sets ImageButton UI for Sticker Buttons.
private void setStickerButtonDesign(ImageButton btn, int imageDrawable) {
btn.setImageDrawable(getDrawable(imageDrawable));
btn.setBackground(getDrawable(R.drawable.circle_button));
btn.setLayoutParams(new LinearLayout.LayoutParams(250, 250));
btn.setPadding(25, 25, 25, 25);
btn.setScaleType(ImageView.ScaleType.CENTER_INSIDE);
}
// Used to set ArrayList of Bitmap frames
private void setGIFBitmaps(String gifUrl) {
gifBitmaps = new ArrayList<>(); // Empty the bitmap array
Glide.with(this)
.asGif()
.load(gifUrl)
.into(
new CustomTarget<GifDrawable>() {
@Override
public void onLoadCleared(Drawable placeholder) {}
@Override
public void onResourceReady(
GifDrawable resource, Transition<? super GifDrawable> transition) {
try {
Object startConstant = resource.getConstantState();
Field frameManager = startConstant.getClass().getDeclaredField("frameLoader");
frameManager.setAccessible(true);
Object frameLoader = frameManager.get(startConstant);
Field decoder = frameLoader.getClass().getDeclaredField("gifDecoder");
decoder.setAccessible(true);
Object frameObject = (decoder.get(frameLoader));
for (int i = 0; i < resource.getFrameCount(); i++) {
frameObject.getClass().getMethod("advance").invoke(frameObject);
Bitmap bmp =
(Bitmap)
frameObject.getClass().getMethod("getNextFrame").invoke(frameObject);
gifBitmaps.add(flipHorizontal(bmp));
}
} catch (Exception e) {
Log.e(TAG, "", e);
}
}
});
}
// Bitmaps must be flipped due to native acquisition of frames from Android OS
private static Bitmap flipHorizontal(Bitmap bmp) {
Matrix matrix = new Matrix();
// Flip Bitmap frames horizontally
matrix.preScale(-1.0f, 1.0f);
return Bitmap.createBitmap(bmp, 0, 0, bmp.getWidth(), bmp.getHeight(), matrix, true);
}
// Function that is continuously called in order to time GIF frame updates
private void updateGIFFrame() {
long millisPerFrame = 1000 / GIF_FRAME_RATE;
if (System.currentTimeMillis() - gifLastFrameUpdateMS >= millisPerFrame) {
// Update GIF timestamp
gifLastFrameUpdateMS = System.currentTimeMillis();
// Cycle through every possible frame and avoid a divide by 0
gifCurrentIndex = gifBitmaps.isEmpty() ? 1 : (gifCurrentIndex + 1) % gifBitmaps.size();
}
}
// Called once to popup the Keyboard via Android OS with focus set to editText
private void openKeyboard() {
editText.requestFocus();
InputMethodManager imm = (InputMethodManager) getSystemService(Context.INPUT_METHOD_SERVICE);
imm.showSoftInput(editText, InputMethodManager.SHOW_IMPLICIT);
}
// Called once to close the Keyboard via Android OS
private void closeKeyboard() {
View view = this.getCurrentFocus();
if (view != null) {
InputMethodManager imm = (InputMethodManager) getSystemService(Context.INPUT_METHOD_SERVICE);
imm.hideSoftInputFromWindow(view.getWindowToken(), 0);
}
}
private void prepareDemoAssets() {
// We render from raw data with openGL, so disable decoding preprocessing
BitmapFactory.Options decodeOptions = new BitmapFactory.Options();
decodeOptions.inScaled = false;
decodeOptions.inDither = false;
decodeOptions.inPremultiplied = false;
try {
InputStream inputStream = getAssets().open(DEFAULT_GIF_TEXTURE);
defaultGIFTexture =
flipHorizontal(
BitmapFactory.decodeStream(inputStream, null /*outPadding*/, decodeOptions));
inputStream.close();
} catch (Exception e) {
Log.e(TAG, "Error parsing object texture; error: ", e);
throw new IllegalStateException(e);
}
try {
InputStream inputStream = getAssets().open(ASSET_3D_TEXTURE);
asset3dTexture = BitmapFactory.decodeStream(inputStream, null /*outPadding*/, decodeOptions);
inputStream.close();
} catch (Exception e) {
Log.e(TAG, "Error parsing object texture; error: ", e);
throw new IllegalStateException(e);
}
}
private class MediaPipePacketManager implements FrameProcessor.OnWillAddFrameListener {
@Override
public void onWillAddFrame(long timestamp) {
// set current GIF bitmap as default texture
Bitmap currentGIFBitmap = defaultGIFTexture;
// If current index is in bounds, display current frame
if (gifCurrentIndex <= gifBitmaps.size() - 1) {
currentGIFBitmap = gifBitmaps.get(gifCurrentIndex);
}
// Update to next GIF frame based on timing and frame rate
updateGIFFrame();
// Calculate and set the aspect ratio of the GIF
float gifAspectRatio =
(float) currentGIFBitmap.getWidth() / (float) currentGIFBitmap.getHeight();
Packet stickerSentinelPacket = processor.getPacketCreator().createInt32(stickerSentinel);
// Sticker sentinel value must be reset for next graph iteration
stickerSentinel = -1;
// Initialize sticker data protobufferpacket information
Packet stickerProtoDataPacket =
processor
.getPacketCreator()
.createSerializedProto(StickerManager.getMessageLiteData(stickerArrayList));
// Define and set the IMU sensory information float array
Packet imuDataPacket = processor.getPacketCreator().createFloat32Array(rotationMatrix);
// Communicate GIF textures (dynamic texturing) to graph
Packet gifTexturePacket = processor.getPacketCreator().createRgbaImageFrame(currentGIFBitmap);
Packet gifAspectRatioPacket = processor.getPacketCreator().createFloat32(gifAspectRatio);
processor
.getGraph()
.addConsumablePacketToInputStream(STICKER_SENTINEL_TAG, stickerSentinelPacket, timestamp);
processor
.getGraph()
.addConsumablePacketToInputStream(STICKER_PROTO_TAG, stickerProtoDataPacket, timestamp);
processor
.getGraph()
.addConsumablePacketToInputStream(IMU_MATRIX_TAG, imuDataPacket, timestamp);
processor
.getGraph()
.addConsumablePacketToInputStream(GIF_TEXTURE_TAG, gifTexturePacket, timestamp);
processor
.getGraph()
.addConsumablePacketToInputStream(GIF_ASPECT_RATIO_TAG, gifAspectRatioPacket, timestamp);
stickerSentinelPacket.release();
stickerProtoDataPacket.release();
imuDataPacket.release();
gifTexturePacket.release();
gifAspectRatioPacket.release();
}
}
}

View File

@ -0,0 +1,191 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.instantmotiontracking;
import com.google.mediapipe.graphs.instantmotiontracking.StickerBufferProto.Sticker;
import com.google.mediapipe.graphs.instantmotiontracking.StickerBufferProto.StickerRoll;
import java.util.ArrayList;
/**
* This class represents a single sticker object placed in the
* instantmotiontracking system. StickerManagers represent a unique object to render
* and manipulate in an AR scene.
* <p>A sticker has a sticker_id (a unique integer identifying a sticker object
* to render), x and y normalized anchor coordinates [0.0-1.0], user inputs for
* rotation in radians, scaling, and a renderID (another unique integer which
* determines what object model to render for this unique sticker).
*/
public class StickerManager {
/** All types of possible objects to render for our application. */
public enum Render {
// Every possible render for a sticker object
GIF,
ASSET_3D;
/**
* Once called, will set the value of the current render to the next
* possible Render available. If all possible Renders have been iterated
* through, the function will loop and set to the first available Render.
*/
public Render iterate() {
int newEnumIdx = (this.ordinal() + 1) % Render.values().length;
return Render.values()[newEnumIdx];
}
}
// Current render of the sticker object
private Render currentRender;
// Normalized X and Y coordinates of anchor
// (0,0) lies at top-left corner of screen
// (1.0,1.0) lies at bottom-right corner of screen
private float anchorX;
private float anchorY;
// Rotation in radians from user
private float userRotation = 0f;
// Scaling factor as defined by user (defaults to 1.0)
private float userScalingFactor = 1f;
// Unique sticker integer ID
private final int stickerId;
// Used to determine next stickerId
private static int globalIDLimit = 1;
/**
* Used to create a StickerManager object with a newly generated stickerId and a
* default Render of the first possible render in our Render enum.
*/
public StickerManager() {
// Every sticker will have a default render of the first 3D asset
this.currentRender = Render.values()[1];
// StickerManager will render out of view by default
this.setAnchorCoordinate(2.0f, 2.0f);
// Set the global sticker ID limit for the next sticker
stickerId = StickerManager.globalIDLimit++;
}
/**
* Used to create a StickerManager object with a newly generated stickerId.
*
* @param render initial Render of the new StickerManager object
*/
public StickerManager(Render render) {
this.currentRender = render;
// StickerManager will render out of view by default
this.setAnchorCoordinate(2.0f, 2.0f);
// Set the global sticker ID limit for the next sticker
stickerId = StickerManager.globalIDLimit++;
}
/**
* Used to get the sticker ID of the object.
*
* @return integer of the unique sticker ID
*/
public int getstickerId() {
return this.stickerId;
}
/**
* Used to update or reset the anchor positions in normalized [0.0-1.0]
* coordinate space for the sticker object.
*
* @param normalizedX normalized X coordinate for the new anchor position
* @param normalizedY normalized Y coordinate for the new anchor position
*/
public void setAnchorCoordinate(float normalizedX, float normalizedY) {
this.anchorX = normalizedX;
this.anchorY = normalizedY;
}
/** Returns the normalized X anchor coordinate of the sticker object. */
public float getAnchorX() {
return anchorX;
}
/** Returns the normalized Y anchor coordinate of the sticker object. */
public float getAnchorY() {
return anchorY;
}
/** Returns current asset to be rendered for this sticker object. */
public Render getRender() {
return currentRender;
}
/** Set render for this sticker object */
public void setRender(Render render) {
this.currentRender = render;
}
/**
* Sets new user value of rotation radians. This rotation is not cumulative,
* and must be set to an absolute value of rotation applied to the object.
*
* @param radians specified radians to rotate the sticker object by
*/
public void setRotation(float radians) {
this.userRotation = radians;
}
/** Returns current user radian rotation setting. */
public float getRotation() {
return this.userRotation;
}
/**
* Sets new user scale factor. This factor will be proportional to the scale
* of the sticker object.
*
* @param scaling scale factor to be applied
*/
public void setScaleFactor(float scaling) {
this.userScalingFactor = scaling;
}
/** Returns current user scale factor setting. */
public float getScaleFactor() {
return this.userScalingFactor;
}
/**
* This method converts an ArrayList of stickers to a MessageLite object
* which can be passed directly to the MediaPipe graph.
*
* @param stickerArrayList ArrayList of StickerManager objects to convert to data string
* @return MessageLite protobuffer of all sticker data
*/
public static StickerRoll getMessageLiteData(
ArrayList<StickerManager> stickerArrayList) {
StickerRoll.Builder stickerRollBuilder
= StickerRoll.newBuilder();
for (final StickerManager sticker : stickerArrayList) {
Sticker protoSticker =
Sticker.newBuilder()
.setId(sticker.getstickerId())
.setX(sticker.getAnchorX())
.setY(sticker.getAnchorY())
.setRotation(sticker.getRotation())
.setScale(sticker.getScaleFactor())
.setRenderId(sticker.getRender().ordinal())
.build();
stickerRollBuilder.addSticker(protoSticker);
}
return stickerRollBuilder.build();
}
}

View File

@ -0,0 +1,21 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
exports_files(
srcs = glob(["**"]),
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

View File

@ -0,0 +1,16 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M11.5,9h1.5v6h-1.5z"/>
<path
android:fillColor="@android:color/white"
android:pathData="M9,9H6c-0.6,0 -1,0.5 -1,1v4c0,0.5 0.4,1 1,1h3c0.6,0 1,-0.5 1,-1v-2H8.5v1.5h-2v-3H10V10C10,9.5 9.6,9 9,9z"/>
<path
android:fillColor="@android:color/white"
android:pathData="M19,10.5l0,-1.5l-4.5,0l0,6l1.5,0l0,-2l2,0l0,-1.5l-2,0l0,-1z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M19,13h-6v6h-2v-6H5v-2h6V5h2v6h6v2z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M20,11H7.83l5.59,-5.59L12,4l-8,8 8,8 1.41,-1.41L7.83,13H20v-2z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M19,6.41L17.59,5 12,10.59 6.41,5 5,6.41 10.59,12 5,17.59 6.41,19 12,13.41 17.59,19 19,17.59 13.41,12z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M5,13h14v-2L5,11v2zM3,17h14v-2L3,15v2zM7,7v2h14L21,7L7,7z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M12,4L12,1L8,5l4,4L12,6c3.31,0 6,2.69 6,6 0,1.01 -0.25,1.97 -0.7,2.8l1.46,1.46C19.54,15.03 20,13.57 20,12c0,-4.42 -3.58,-8 -8,-8zM12,18c-3.31,0 -6,-2.69 -6,-6 0,-1.01 0.25,-1.97 0.7,-2.8L5.24,7.74C4.46,8.97 4,10.43 4,12c0,4.42 3.58,8 8,8v3l4,-4 -4,-4v3z"/>
</vector>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M15.5,14h-0.79l-0.28,-0.27C15.41,12.59 16,11.11 16,9.5 16,5.91 13.09,3 9.5,3S3,5.91 3,9.5 5.91,16 9.5,16c1.61,0 3.09,-0.59 4.23,-1.57l0.27,0.28v0.79l5,4.99L20.49,19l-4.99,-5zM9.5,14C7.01,14 5,11.99 5,9.5S7.01,5 9.5,5 14,7.01 14,9.5 11.99,14 9.5,14z"/>
</vector>

View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<selector xmlns:android="http://schemas.android.com/apk/res/android">
<item>
<shape android:shape="oval">
</shape>
</item>
</selector>

View File

@ -0,0 +1,62 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:id="@+id/constraint_layout">
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:layout_weight="1">
<TextView
android:id="@+id/no_camera_access_view"
android:layout_height="fill_parent"
android:layout_width="fill_parent"
android:gravity="center"
android:text="@string/no_camera_access" />
</FrameLayout>
<LinearLayout
android:layout_width="wrap_content"
android:layout_height="fill_parent"
android:gravity="top"
android:orientation="vertical"
app:layout_constraintBottom_toTopOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintTop_toTopOf="parent">
<LinearLayout
android:id="@+id/button_layout"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:gravity="top"
android:orientation="vertical"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintTop_toTopOf="parent"/>
<com.google.mediapipe.apps.instantmotiontracking.GIFEditText
android:id="@+id/gif_edit_text"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_margin="0dp"
android:background="@null"
android:gravity="center"
android:hint=""
android:text=""
android:padding="0dp"
android:cursorVisible="false"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
@ -57,6 +57,7 @@ android_binary(
deps = [ deps = [
":mediapipe_jni_lib", ":mediapipe_jni_lib",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
], ],
) )

View File

@ -15,7 +15,13 @@
package com.google.mediapipe.apps.iristrackinggpu; package com.google.mediapipe.apps.iristrackinggpu;
import android.graphics.SurfaceTexture; import android.graphics.SurfaceTexture;
import android.os.Bundle;
import android.util.Log;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -24,6 +30,7 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
private static final String TAG = "MainActivity"; private static final String TAG = "MainActivity";
private static final String FOCAL_LENGTH_STREAM_NAME = "focal_length_pixel"; private static final String FOCAL_LENGTH_STREAM_NAME = "focal_length_pixel";
private static final String OUTPUT_LANDMARKS_STREAM_NAME = "face_landmarks_with_iris";
@Override @Override
protected void onCameraStarted(SurfaceTexture surfaceTexture) { protected void onCameraStarted(SurfaceTexture surfaceTexture) {
@ -37,4 +44,55 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
processor.setInputSidePackets(inputSidePackets); processor.setInputSidePackets(inputSidePackets);
} }
} }
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
// To show verbose logging, run:
// adb shell setprop log.tag.MainActivity VERBOSE
if (Log.isLoggable(TAG, Log.VERBOSE)) {
processor.addPacketCallback(
OUTPUT_LANDMARKS_STREAM_NAME,
(packet) -> {
byte[] landmarksRaw = PacketGetter.getProtoBytes(packet);
try {
NormalizedLandmarkList landmarks = NormalizedLandmarkList.parseFrom(landmarksRaw);
if (landmarks == null) {
Log.v(TAG, "[TS:" + packet.getTimestamp() + "] No landmarks.");
return;
}
Log.v(
TAG,
"[TS:"
+ packet.getTimestamp()
+ "] #Landmarks for face (including iris): "
+ landmarks.getLandmarkCount());
Log.v(TAG, getLandmarksDebugString(landmarks));
} catch (InvalidProtocolBufferException e) {
Log.e(TAG, "Couldn't Exception received - " + e);
return;
}
});
}
}
private static String getLandmarksDebugString(NormalizedLandmarkList landmarks) {
int landmarkIndex = 0;
String landmarksString = "";
for (NormalizedLandmark landmark : landmarks.getLandmarkList()) {
landmarksString +=
"\t\tLandmark["
+ landmarkIndex
+ "]: ("
+ landmark.getX()
+ ", "
+ landmark.getY()
+ ", "
+ landmark.getZ()
+ ")\n";
++landmarkIndex;
}
return landmarksString;
}
} }

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
@ -72,11 +72,11 @@ android_binary(
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:model.obj.uuu", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:model.obj.uuu",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:texture.bmp", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/sneaker:texture.jpg",
], ],
":use_chair_model": [ ":use_chair_model": [
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:model.obj.uuu", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:model.obj.uuu",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:texture.bmp", "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/assets/chair:texture.jpg",
], ],
}), }),
assets_dir = "", assets_dir = "",

View File

@ -31,7 +31,7 @@ import java.util.Map;
public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
private static final String TAG = "MainActivity"; private static final String TAG = "MainActivity";
private static final String OBJ_TEXTURE = "texture.bmp"; private static final String OBJ_TEXTURE = "texture.jpg";
private static final String OBJ_FILE = "model.obj.uuu"; private static final String OBJ_FILE = "model.obj.uuu";
private static final String BOX_TEXTURE = "classic_colors.png"; private static final String BOX_TEXTURE = "classic_colors.png";
private static final String BOX_FILE = "box.obj.uuu"; private static final String BOX_FILE = "box.obj.uuu";

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])

Binary file not shown.

After

Width:  |  Height:  |  Size: 420 KiB

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = [ package(default_visibility = [
"//visibility:public", "//visibility:public",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = [ package(default_visibility = [
"//visibility:public", "//visibility:public",

View File

@ -14,7 +14,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library"
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//mediapipe/examples:__subpackages__"]) package(default_visibility = ["//mediapipe/examples:__subpackages__"])

View File

@ -14,7 +14,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library"
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = [ package(default_visibility = [
"//mediapipe/examples:__subpackages__", "//mediapipe/examples:__subpackages__",

View File

@ -173,6 +173,7 @@ TEST(ContentZoomingCalculatorTest, PanConfig) {
auto* options = config.mutable_options()->MutableExtension( auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext); ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_pan()->set_update_rate_seconds(2);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);
@ -191,6 +192,7 @@ TEST(ContentZoomingCalculatorTest, TiltConfig) {
ContentZoomingCalculatorOptions::ext); ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_tilt()->set_update_rate_seconds(2);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get()); AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
@ -209,6 +211,7 @@ TEST(ContentZoomingCalculatorTest, ZoomConfig) {
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0); options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2);
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get()); AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get()); AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
@ -345,8 +348,13 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
} }
TEST(ContentZoomingCalculatorTest, ZoomTestNearOutsideBorder) { TEST(ContentZoomingCalculatorTest, ZoomTestNearOutsideBorder) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD)); auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_update_rate_seconds(2);
options->mutable_kinematic_options_tilt()->set_update_rate_seconds(2);
options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.95, .95, .05, .05), 0, runner.get()); AddDetection(cv::Rect_<float>(.95, .95, .05, .05), 0, runner.get());
AddDetection(cv::Rect_<float>(.9, .9, .1, .1), 1000000, runner.get()); AddDetection(cv::Rect_<float>(.9, .9, .1, .1), 1000000, runner.get());
MP_ASSERT_OK(runner->Run()); MP_ASSERT_OK(runner->Run());
@ -357,8 +365,13 @@ TEST(ContentZoomingCalculatorTest, ZoomTestNearOutsideBorder) {
} }
TEST(ContentZoomingCalculatorTest, ZoomTestNearInsideBorder) { TEST(ContentZoomingCalculatorTest, ZoomTestNearInsideBorder) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD)); auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_update_rate_seconds(2);
options->mutable_kinematic_options_tilt()->set_update_rate_seconds(2);
options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(0, 0, .05, .05), 0, runner.get()); AddDetection(cv::Rect_<float>(0, 0, .05, .05), 0, runner.get());
AddDetection(cv::Rect_<float>(0, 0, .1, .1), 1000000, runner.get()); AddDetection(cv::Rect_<float>(0, 0, .1, .1), 1000000, runner.get());
MP_ASSERT_OK(runner->Run()); MP_ASSERT_OK(runner->Run());

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
filegroup( filegroup(
name = "test_images", name = "test_images",

View File

@ -14,7 +14,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library"
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//mediapipe/examples:__subpackages__"]) package(default_visibility = ["//mediapipe/examples:__subpackages__"])

View File

@ -12,6 +12,8 @@ namespace autoflip {
current_velocity_deg_per_s_ = 0; current_velocity_deg_per_s_ = 0;
RET_CHECK_GT(pixels_per_degree_, 0) RET_CHECK_GT(pixels_per_degree_, 0)
<< "pixels_per_degree must be larger than 0."; << "pixels_per_degree must be larger than 0.";
RET_CHECK_GE(options_.update_rate_seconds(), 0)
<< "update_rate_seconds must be greater than 0.";
RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window())
<< "Reframe window cannot exceed min_motion_to_reframe."; << "Reframe window cannot exceed min_motion_to_reframe.";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -41,9 +43,10 @@ namespace autoflip {
// Observed velocity and then weighted update of this velocity. // Observed velocity and then weighted update of this velocity.
double observed_velocity = delta_degs / delta_t; double observed_velocity = delta_degs / delta_t;
double updated_velocity = double update_rate = std::min(delta_t / options_.update_rate_seconds(),
current_velocity_deg_per_s_ * (1 - options_.update_rate()) + options_.max_update_rate());
observed_velocity * options_.update_rate(); double updated_velocity = current_velocity_deg_per_s_ * (1 - update_rate) +
observed_velocity * update_rate;
// Limited current velocity. // Limited current velocity.
current_velocity_deg_per_s_ = current_velocity_deg_per_s_ =
updated_velocity > 0 ? fmin(updated_velocity, options_.max_velocity()) updated_velocity > 0 ? fmin(updated_velocity, options_.max_velocity())

View File

@ -5,7 +5,7 @@ package mediapipe.autoflip;
message KinematicOptions { message KinematicOptions {
// Weighted update of new camera velocity (measurement) vs current state // Weighted update of new camera velocity (measurement) vs current state
// (prediction). // (prediction).
optional double update_rate = 1 [default = 0.5]; optional double update_rate = 1 [default = 0.5, deprecated = true];
// Max velocity (degrees per second) that the camera can move. // Max velocity (degrees per second) that the camera can move.
optional double max_velocity = 2 [default = 18]; optional double max_velocity = 2 [default = 18];
// Min motion (in degrees) to react in pixels. // Min motion (in degrees) to react in pixels.
@ -15,4 +15,9 @@ message KinematicOptions {
// total reframe distance on average. Value cannot exceed // total reframe distance on average. Value cannot exceed
// min_motion_to_reframe value. // min_motion_to_reframe value.
optional float reframe_window = 4 [default = 0]; optional float reframe_window = 4 [default = 0];
// Calculation of internal velocity state is:
// min((delta_time_s / update_rate_seconds), max_update_rate)
// where delta_time_s is the time since the last frame.
optional double update_rate_seconds = 5 [default = 0.20];
optional double max_update_rate = 6 [default = 0.8];
} }

View File

@ -85,7 +85,8 @@ TEST(KinematicPathSolverTest, PassEnoughMotionLargeImg) {
KinematicOptions options; KinematicOptions options;
// Set min motion to 1deg // Set min motion to 1deg
options.set_min_motion_to_reframe(1.0); options.set_min_motion_to_reframe(1.0);
options.set_update_rate(1); options.set_update_rate_seconds(.0000001);
options.set_max_update_rate(1.0);
options.set_max_velocity(1000); options.set_max_velocity(1000);
// Set degrees / pixel to 16.6 // Set degrees / pixel to 16.6
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
@ -102,7 +103,8 @@ TEST(KinematicPathSolverTest, PassEnoughMotionSmallImg) {
KinematicOptions options; KinematicOptions options;
// Set min motion to 2deg // Set min motion to 2deg
options.set_min_motion_to_reframe(1.0); options.set_min_motion_to_reframe(1.0);
options.set_update_rate(1); options.set_update_rate_seconds(.0000001);
options.set_max_update_rate(1.0);
options.set_max_velocity(18); options.set_max_velocity(18);
// Set degrees / pixel to 8.3 // Set degrees / pixel to 8.3
KinematicPathSolver solver(options, 0, 500, 500.0 / kWidthFieldOfView); KinematicPathSolver solver(options, 0, 500, 500.0 / kWidthFieldOfView);
@ -132,7 +134,8 @@ TEST(KinematicPathSolverTest, PassReframeWindow) {
KinematicOptions options; KinematicOptions options;
// Set min motion to 1deg // Set min motion to 1deg
options.set_min_motion_to_reframe(1.0); options.set_min_motion_to_reframe(1.0);
options.set_update_rate(1); options.set_update_rate_seconds(.0000001);
options.set_max_update_rate(1.0);
options.set_max_velocity(1000); options.set_max_velocity(1000);
// Set reframe window size to .75 for test. // Set reframe window size to .75 for test.
options.set_reframe_window(0.75); options.set_reframe_window(0.75);
@ -147,10 +150,41 @@ TEST(KinematicPathSolverTest, PassReframeWindow) {
EXPECT_EQ(state, 507); EXPECT_EQ(state, 507);
} }
TEST(KinematicPathSolverTest, PassUpdateRate30FPS) {
KinematicOptions options;
options.set_min_motion_to_reframe(1.0);
options.set_update_rate_seconds(.25);
options.set_max_update_rate(0.8);
options.set_max_velocity(18);
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
int state;
MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0));
MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1 / 30));
MP_ASSERT_OK(solver.GetState(&state));
// (0.033 / .25) * 20 =
EXPECT_EQ(state, 503);
}
TEST(KinematicPathSolverTest, PassUpdateRate10FPS) {
KinematicOptions options;
options.set_min_motion_to_reframe(1.0);
options.set_update_rate_seconds(.25);
options.set_max_update_rate(0.8);
options.set_max_velocity(18);
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
int state;
MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0));
MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1 / 10));
MP_ASSERT_OK(solver.GetState(&state));
// (0.1 / .25) * 20 =
EXPECT_EQ(state, 508);
}
TEST(KinematicPathSolverTest, PassUpdateRate) { TEST(KinematicPathSolverTest, PassUpdateRate) {
KinematicOptions options; KinematicOptions options;
options.set_min_motion_to_reframe(1.0); options.set_min_motion_to_reframe(1.0);
options.set_update_rate(0.25); options.set_update_rate_seconds(4);
options.set_max_update_rate(1.0);
options.set_max_velocity(18); options.set_max_velocity(18);
KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView);
int state; int state;

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])

Some files were not shown because too many files have changed in this diff Show More