Project import generated by Copybara.

GitOrigin-RevId: 27c70b5fe62ab71189d358ca122ee4b19c817a8f
This commit is contained in:
MediaPipe Team 2021-07-23 17:09:32 -07:00 committed by chuoling
parent 374f5e2e7e
commit 50c92c6623
158 changed files with 4704 additions and 621 deletions

View File

@ -45,7 +45,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
@ -79,6 +79,13 @@ run code search using
## Publications ## Publications
* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html)
in Google Developers Blog
* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html)
in Google Developers Blog
* [SignAll SDK: Sign language interface using MediaPipe is now available for
developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html)
in Google Developers Blog
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)

View File

@ -53,19 +53,12 @@ rules_foreign_cc_dependencies()
all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])"""
# GoogleTest/GoogleMock framework. Used by most unit-tests. # GoogleTest/GoogleMock framework. Used by most unit-tests.
# Last updated 2020-06-30. # Last updated 2021-07-02.
http_archive( http_archive(
name = "com_google_googletest", name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"], urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"],
patches = [ strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3",
# fix for https://github.com/google/googletest/issues/2817 sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049",
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
],
patch_args = [
"-p1",
],
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
) )
# Google Benchmark library. # Google Benchmark library.
@ -353,9 +346,9 @@ maven_install(
"com.google.android.material:material:aar:1.0.0-rc01", "com.google.android.material:material:aar:1.0.0-rc01",
"com.google.auto.value:auto-value:1.8.1", "com.google.auto.value:auto-value:1.8.1",
"com.google.auto.value:auto-value-annotations:1.8.1", "com.google.auto.value:auto-value-annotations:1.8.1",
"com.google.code.findbugs:jsr305:3.0.2", "com.google.code.findbugs:jsr305:latest.release",
"com.google.flogger:flogger-system-backend:0.3.1", "com.google.flogger:flogger-system-backend:latest.release",
"com.google.flogger:flogger:0.3.1", "com.google.flogger:flogger:latest.release",
"com.google.guava:guava:27.0.1-android", "com.google.guava:guava:27.0.1-android",
"com.google.guava:listenablefuture:1.0", "com.google.guava:listenablefuture:1.0",
"junit:junit:4.12", "junit:junit:4.12",

View File

@ -113,9 +113,9 @@ each project.
androidTestImplementation 'androidx.test.ext:junit:1.1.0' androidTestImplementation 'androidx.test.ext:junit:1.1.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
// MediaPipe deps // MediaPipe deps
implementation 'com.google.flogger:flogger:0.3.1' implementation 'com.google.flogger:flogger:latest.release'
implementation 'com.google.flogger:flogger-system-backend:0.3.1' implementation 'com.google.flogger:flogger-system-backend:latest.release'
implementation 'com.google.code.findbugs:jsr305:3.0.2' implementation 'com.google.code.findbugs:jsr305:latest.release'
implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4' implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library // CameraX core library

View File

@ -31,8 +31,8 @@ stream on an Android device.
## Setup ## Setup
1. Install MediaPipe on your system, see [MediaPipe installation guide] for 1. Install MediaPipe on your system, see
details. [MediaPipe installation guide](./install.md) for details.
2. Install Android Development SDK and Android NDK. See how to do so also in 2. Install Android Development SDK and Android NDK. See how to do so also in
[MediaPipe installation guide]. [MediaPipe installation guide].
3. Enable [developer options] on your Android device. 3. Enable [developer options] on your Android device.
@ -770,7 +770,6 @@ If you ran into any issues, please see the full code of the tutorial
[`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java
[`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout [`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout
[`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java [`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java
[MediaPipe installation guide]:./install.md
[`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java [`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java
[`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html [`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html
[`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView [`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView

View File

@ -31,8 +31,8 @@ stream on an iOS device.
## Setup ## Setup
1. Install MediaPipe on your system, see [MediaPipe installation guide] for 1. Install MediaPipe on your system, see
details. [MediaPipe installation guide](./install.md) for details.
2. Setup your iOS device for development. 2. Setup your iOS device for development.
3. Setup [Bazel] on your system to build and deploy the iOS app. 3. Setup [Bazel] on your system to build and deploy the iOS app.
@ -560,6 +560,5 @@ appropriate `BUILD` file dependencies for the edge detection graph.
[Bazel]:https://bazel.build/ [Bazel]:https://bazel.build/
[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt
[MediaPipe installation guide]:./install.md
[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common) [common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common)
[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld) [helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld)

View File

@ -22,6 +22,7 @@ Solution | NPM Package | Example
[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] [Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo]
[Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] [Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo]
[Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] [Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo]
[Objectron][Ob-pg] | [@mediapipe/objectron][Ob-npm] | [mediapipe.dev/demo/objectron][Ob-demo]
[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] [Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo]
[Selfie Segmentation][S-pg] | [@mediapipe/selfie_segmentation][S-npm] | [mediapipe.dev/demo/selfie_segmentation][S-demo] [Selfie Segmentation][S-pg] | [@mediapipe/selfie_segmentation][S-npm] | [mediapipe.dev/demo/selfie_segmentation][S-demo]
@ -67,33 +68,24 @@ affecting your work, restrict your request to a `<minor>` number. e.g.,
[F-pg]: ../solutions/face_mesh#javascript-solution-api [F-pg]: ../solutions/face_mesh#javascript-solution-api
[Fd-pg]: ../solutions/face_detection#javascript-solution-api [Fd-pg]: ../solutions/face_detection#javascript-solution-api
[H-pg]: ../solutions/hands#javascript-solution-api [H-pg]: ../solutions/hands#javascript-solution-api
[Ob-pg]: ../solutions/objectron#javascript-solution-api
[P-pg]: ../solutions/pose#javascript-solution-api [P-pg]: ../solutions/pose#javascript-solution-api
[S-pg]: ../solutions/selfie_segmentation#javascript-solution-api [S-pg]: ../solutions/selfie_segmentation#javascript-solution-api
[Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic [Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic
[F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh [F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh
[Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection [Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection
[H-npm]: https://www.npmjs.com/package/@mediapipe/hands [H-npm]: https://www.npmjs.com/package/@mediapipe/hands
[Ob-npm]: https://www.npmjs.com/package/@mediapipe/objectron
[P-npm]: https://www.npmjs.com/package/@mediapipe/pose [P-npm]: https://www.npmjs.com/package/@mediapipe/pose
[S-npm]: https://www.npmjs.com/package/@mediapipe/selfie_segmentation [S-npm]: https://www.npmjs.com/package/@mediapipe/selfie_segmentation
[draw-npm]: https://www.npmjs.com/package/@mediapipe/drawing_utils [draw-npm]: https://www.npmjs.com/package/@mediapipe/drawing_utils
[cam-npm]: https://www.npmjs.com/package/@mediapipe/camera_utils [cam-npm]: https://www.npmjs.com/package/@mediapipe/camera_utils
[ctrl-npm]: https://www.npmjs.com/package/@mediapipe/control_utils [ctrl-npm]: https://www.npmjs.com/package/@mediapipe/control_utils
[Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic
[F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh
[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection
[H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands
[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose
[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/selfie_segmentation
[Ho-pen]: https://code.mediapipe.dev/codepen/holistic
[F-pen]: https://code.mediapipe.dev/codepen/face_mesh
[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection
[H-pen]: https://code.mediapipe.dev/codepen/hands
[P-pen]: https://code.mediapipe.dev/codepen/pose
[S-pen]: https://code.mediapipe.dev/codepen/selfie_segmentation
[Ho-demo]: https://mediapipe.dev/demo/holistic [Ho-demo]: https://mediapipe.dev/demo/holistic
[F-demo]: https://mediapipe.dev/demo/face_mesh [F-demo]: https://mediapipe.dev/demo/face_mesh
[Fd-demo]: https://mediapipe.dev/demo/face_detection [Fd-demo]: https://mediapipe.dev/demo/face_detection
[H-demo]: https://mediapipe.dev/demo/hands [H-demo]: https://mediapipe.dev/demo/hands
[Ob-demo]: https://mediapipe.dev/demo/objectron
[P-demo]: https://mediapipe.dev/demo/pose [P-demo]: https://mediapipe.dev/demo/pose
[S-demo]: https://mediapipe.dev/demo/selfie_segmentation [S-demo]: https://mediapipe.dev/demo/selfie_segmentation
[npm]: https://www.npmjs.com/package/@mediapipe [npm]: https://www.npmjs.com/package/@mediapipe

View File

@ -74,7 +74,7 @@ Mapping\[str, Packet\] | std::map<std::string, Packet> | create_st
np.ndarray<br>(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;format=ImageFormat.SRGB,<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;data=mat) | get_image_frame(packet) np.ndarray<br>(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;format=ImageFormat.SRGB,<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;data=mat) | get_image_frame(packet)
np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet) np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet)
Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet) Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet)
List\[Proto\] | std::vector\<Proto\> | create_proto_vector(proto_list) | get_proto_list(packet) List\[Proto\] | std::vector\<Proto\> | n/a | get_proto_list(packet)
It's not uncommon that users create custom C++ classes and and send those into It's not uncommon that users create custom C++ classes and and send those into
the graphs and calculators. To allow the custom classes to be used in Python the graphs and calculators. To allow the custom classes to be used in Python

View File

@ -45,7 +45,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
@ -79,6 +79,13 @@ run code search using
## Publications ## Publications
* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html)
in Google Developers Blog
* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html)
in Google Developers Blog
* [SignAll SDK: Sign language interface using MediaPipe is now available for
developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html)
in Google Developers Blog
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) * [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)

View File

@ -224,29 +224,33 @@ where object detection simply runs on every image. Default to `0.99`.
#### model_name #### model_name
Name of the model to use for predicting 3D bounding box landmarks. Currently supports Name of the model to use for predicting 3D bounding box landmarks. Currently
`{'Shoe', 'Chair', 'Cup', 'Camera'}`. supports `{'Shoe', 'Chair', 'Cup', 'Camera'}`. Default to `Shoe`.
#### focal_length #### focal_length
Camera focal length `(fx, fy)`, by default is defined in By default, camera focal length defined in [NDC space](#ndc-space), i.e., `(fx,
[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in fy)`. Default to `(1.0, 1.0)`. To specify focal length in
[pixel space](#pixel-space), users should provide `image_size` = `(image_width, [pixel space](#pixel-space) instead, i.e., `(fx_pixel, fy_pixel)`, users should
image_height)` to enable conversions inside the API. For further details about provide [`image_size`](#image_size) = `(image_width, image_height)` to enable
NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). conversions inside the API. For further details about NDC and pixel space,
please see [Coordinate Systems](#coordinate-systems).
#### principal_point #### principal_point
Camera principal point `(px, py)`, by default is defined in By default, camera principal point defined in [NDC space](#ndc-space), i.e.,
[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in `(px, py)`. Default to `(0.0, 0.0)`. To specify principal point in
[pixel space](#pixel-space), users should provide `image_size` = `(image_width, [pixel space](#pixel-space), i.e.,`(px_pixel, py_pixel)`, users should provide
image_height)` to enable conversions inside the API. For further details about [`image_size`](#image_size) = `(image_width, image_height)` to enable
NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). conversions inside the API. For further details about NDC and pixel space,
please see [Coordinate Systems](#coordinate-systems).
#### image_size #### image_size
(**Optional**) size `(image_width, image_height)` of the input image, **ONLY** **Specify only when [`focal_length`](#focal_length) and
needed when use `focal_length` and `principal_point` in pixel space. [`principal_point`](#principal_point) are specified in pixel space.**
Size of the input image, i.e., `(image_width, image_height)`.
### Output ### Output
@ -356,6 +360,89 @@ with mp_objectron.Objectron(static_image_mode=False,
cap.release() cap.release()
``` ```
## JavaScript Solution API
Please first see general [introduction](../getting_started/javascript.md) on
MediaPipe in JavaScript, then learn more in the companion [web demo](#resources)
and the following usage example.
Supported configuration options:
* [staticImageMode](#static_image_mode)
* [maxNumObjects](#max_num_objects)
* [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence)
* [modelName](#model_name)
* [focalLength](#focal_length)
* [principalPoint](#principal_point)
* [imageSize](#image_size)
```html
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/camera_utils/camera_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/control_utils/control_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/drawing_utils/control_utils_3d.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/drawing_utils/drawing_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/objectron/objectron.js" crossorigin="anonymous"></script>
</head>
<body>
<div class="container">
<video class="input_video"></video>
<canvas class="output_canvas" width="1280px" height="720px"></canvas>
</div>
</body>
</html>
```
```javascript
<script type="module">
const videoElement = document.getElementsByClassName('input_video')[0];
const canvasElement = document.getElementsByClassName('output_canvas')[0];
const canvasCtx = canvasElement.getContext('2d');
function onResults(results) {
canvasCtx.save();
canvasCtx.drawImage(
results.image, 0, 0, canvasElement.width, canvasElement.height);
if (!!results.objectDetections) {
for (const detectedObject of results.objectDetections) {
// Reformat keypoint information as landmarks, for easy drawing.
const landmarks: mpObjectron.Point2D[] =
detectedObject.keypoints.map(x => x.point2d);
// Draw bounding box.
drawingUtils.drawConnectors(canvasCtx, landmarks,
mpObjectron.BOX_CONNECTIONS, {color: '#FF0000'});
// Draw centroid.
drawingUtils.drawLandmarks(canvasCtx, [landmarks[0]], {color: '#FFFFFF'});
}
}
canvasCtx.restore();
}
const objectron = new Objectron({locateFile: (file) => {
return `https://cdn.jsdelivr.net/npm/@mediapipe/objectron/${file}`;
}});
objectron.setOptions({
modelName: 'Chair',
maxNumObjects: 3,
});
objectron.onResults(onResults);
const camera = new Camera(videoElement, {
onFrame: async () => {
await objectron.send({image: videoElement});
},
width: 1280,
height: 720
});
camera.start();
</script>
```
## Example Apps ## Example Apps
Please first see general instructions for Please first see general instructions for
@ -561,11 +648,15 @@ py = -py_pixel * 2.0 / image_height + 1.0
[Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html)
* Google AI Blog: * Google AI Blog:
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021 * Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the
Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in
CVPR 2021
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
Shape Supervision](https://arxiv.org/abs/2003.03522) Shape Supervision](https://arxiv.org/abs/2003.03522)
* Paper: * Paper:
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020 ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth
Workshop on Computer Vision for AR/VR, CVPR 2020
* [Models and model cards](./models.md#objectron) * [Models and model cards](./models.md#objectron)
* [Web demo](https://code.mediapipe.dev/codepen/objectron)
* [Python Colab](https://mediapipe.page.link/objectron_py_colab) * [Python Colab](https://mediapipe.page.link/objectron_py_colab)

View File

@ -96,6 +96,7 @@ Supported configuration options:
```python ```python
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation mp_selfie_segmentation = mp.solutions.selfie_segmentation

View File

@ -29,7 +29,7 @@ has_toc: false
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |

View File

@ -140,6 +140,16 @@ mediapipe_proto_library(
], ],
) )
mediapipe_proto_library(
name = "graph_profile_calculator_proto",
srcs = ["graph_profile_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library( cc_library(
name = "add_header_calculator", name = "add_header_calculator",
srcs = ["add_header_calculator.cc"], srcs = ["add_header_calculator.cc"],
@ -1200,3 +1210,45 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "graph_profile_calculator",
srcs = ["graph_profile_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":graph_profile_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "graph_profile_calculator_test",
srcs = ["graph_profile_calculator_test.cc"],
deps = [
":graph_profile_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:test_calculators",
"//mediapipe/framework/deps:clock",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/tool:simulation_clock_executor",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,70 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/core/graph_profile_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// This calculator periodically copies the GraphProfile from
// mediapipe::GraphProfiler::CaptureProfile to the "PROFILE" output stream.
//
// Example config:
// node {
// calculator: "GraphProfileCalculator"
// output_stream: "FRAME:any_frame"
// output_stream: "PROFILE:graph_profile"
// }
//
class GraphProfileCalculator : public Node {
public:
static constexpr Input<AnyType>::Multiple kFrameIn{"FRAME"};
static constexpr Output<GraphProfile> kProfileOut{"PROFILE"};
MEDIAPIPE_NODE_CONTRACT(kFrameIn, kProfileOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
auto options = cc->Options<::mediapipe::GraphProfileCalculatorOptions>();
if (prev_profile_ts_ == Timestamp::Unset() ||
cc->InputTimestamp() - prev_profile_ts_ >= options.profile_interval()) {
prev_profile_ts_ = cc->InputTimestamp();
GraphProfile result;
MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile(&result));
kProfileOut(cc).Send(result);
}
return absl::OkStatus();
}
private:
Timestamp prev_profile_ts_;
};
MEDIAPIPE_REGISTER_NODE(GraphProfileCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message GraphProfileCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional GraphProfileCalculatorOptions ext = 367481815;
}
// The interval in microseconds between successive reported GraphProfiles.
optional int64 profile_interval = 1 [default = 1000000];
}

View File

@ -0,0 +1,207 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/port/threadpool.h"
#include "mediapipe/framework/tool/simulation_clock_executor.h"
// Tests for GraphProfileCalculator.
using testing::ElementsAre;
namespace mediapipe {
namespace {
using mediapipe::Clock;
// A Calculator with a fixed Process call latency.
class SleepCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
clock_->Sleep(absl::Milliseconds(5));
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
std::shared_ptr<::mediapipe::Clock> clock_ = nullptr;
};
REGISTER_CALCULATOR(SleepCalculator);
// Tests showing GraphProfileCalculator reporting GraphProfile output packets.
class GraphProfileCalculatorTest : public ::testing::Test {
protected:
void SetUpProfileGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: 'SleepCalculator'
input_side_packet: 'CLOCK:sync_clock'
input_stream: 'input_packets_0'
output_stream: 'output_packets_1'
}
node {
calculator: "GraphProfileCalculator"
options: {
[mediapipe.GraphProfileCalculatorOptions.ext]: {
profile_interval: 25000
}
}
input_stream: "FRAME:output_packets_1"
output_stream: "PROFILE:output_packets_0"
}
)",
&graph_config_));
}
static Packet PacketAt(int64 ts) {
return Adopt(new int64(999)).At(Timestamp(ts));
}
static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); }
static bool IsNone(const Packet& packet) {
return packet.Timestamp() == Timestamp::OneOverPostStream();
}
// Return the values of the timestamps of a vector of Packets.
static std::vector<int64> TimestampValues(
const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& p : packets) {
result.push_back(p.Timestamp().Value());
}
return result;
}
// Runs a CalculatorGraph with a series of packet sets.
// Returns a vector of packets from each graph output stream.
void RunGraph(const std::vector<std::vector<Packet>>& input_sets,
std::vector<Packet>* output_packets) {
// Register output packet observers.
tool::AddVectorSink("output_packets_0", &graph_config_, output_packets);
// Start running the graph.
std::shared_ptr<SimulationClockExecutor> executor(
new SimulationClockExecutor(3 /*num_threads*/));
CalculatorGraph graph;
MP_ASSERT_OK(graph.SetExecutor("", executor));
graph.profiler()->SetClock(executor->GetClock());
MP_ASSERT_OK(graph.Initialize(graph_config_));
executor->GetClock()->ThreadStart();
MP_ASSERT_OK(graph.StartRun({
{"sync_clock",
Adopt(new std::shared_ptr<::mediapipe::Clock>(executor->GetClock()))},
}));
// Send each packet to the graph in the specified order.
for (int t = 0; t < input_sets.size(); t++) {
const std::vector<Packet>& input_set = input_sets[t];
for (int i = 0; i < input_set.size(); i++) {
const Packet& packet = input_set[i];
if (!IsNone(packet)) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
absl::StrCat("input_packets_", i), packet));
}
executor->GetClock()->Sleep(absl::Milliseconds(10));
}
}
MP_ASSERT_OK(graph.CloseAllInputStreams());
executor->GetClock()->Sleep(absl::Milliseconds(100));
executor->GetClock()->ThreadFinish();
MP_ASSERT_OK(graph.WaitUntilDone());
}
CalculatorGraphConfig graph_config_;
};
TEST_F(GraphProfileCalculatorTest, GraphProfile) {
SetUpProfileGraph();
auto profiler_config = graph_config_.mutable_profiler_config();
profiler_config->set_enable_profiler(true);
profiler_config->set_trace_enabled(false);
profiler_config->set_trace_log_disabled(true);
profiler_config->set_enable_stream_latency(true);
profiler_config->set_calculator_filter(".*Calculator");
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000)}, //
{PacketAt(20000)}, //
{PacketAt(30000)}, //
{PacketAt(40000)},
};
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Validate the output packets.
EXPECT_THAT(TimestampValues(output_packets), //
ElementsAre(10000, 40000));
GraphProfile expected_profile =
mediapipe::ParseTextProtoOrDie<GraphProfile>(R"pb(
calculator_profiles {
name: "GraphProfileCalculator"
open_runtime: 0
process_runtime { total: 0 count: 3 }
process_input_latency { total: 15000 count: 3 }
process_output_latency { total: 15000 count: 3 }
input_stream_profiles {
name: "output_packets_1"
back_edge: false
latency { total: 0 count: 3 }
}
}
calculator_profiles {
name: "SleepCalculator"
open_runtime: 0
process_runtime { total: 15000 count: 3 }
process_input_latency { total: 0 count: 3 }
process_output_latency { total: 15000 count: 3 }
input_stream_profiles {
name: "input_packets_0"
back_edge: false
latency { total: 0 count: 3 }
}
})pb");
EXPECT_THAT(output_packets[1].Get<GraphProfile>(),
mediapipe::EqualsProto(expected_profile));
}
} // namespace
} // namespace mediapipe

View File

@ -240,7 +240,7 @@ absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) {
auto input_mat = mediapipe::formats::MatView(&input_frame); auto input_mat = mediapipe::formats::MatView(&input_frame);
// Only 1 or 3 channel images supported by OpenCV. // Only 1 or 3 channel images supported by OpenCV.
if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { if (!(input_mat.channels() == 1 || input_mat.channels() == 3)) {
return absl::InternalError( return absl::InternalError(
"CPU filtering supports only 1 or 3 channel input images."); "CPU filtering supports only 1 or 3 channel input images.");
} }

View File

@ -36,7 +36,7 @@ using GpuBuffer = mediapipe::GpuBuffer;
// stored on the target storage (CPU vs GPU) specified in the calculator option. // stored on the target storage (CPU vs GPU) specified in the calculator option.
// //
// The clone shares ownership of the input pixel data on the existing storage. // The clone shares ownership of the input pixel data on the existing storage.
// If the target storage is diffrent from the existing one, then the data is // If the target storage is different from the existing one, then the data is
// further copied there. // further copied there.
// //
// Example usage: // Example usage:

View File

@ -33,7 +33,7 @@ class InferenceCalculatorSelectorImpl
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
const CalculatorGraphConfig::Node& subgraph_node) { const CalculatorGraphConfig::Node& subgraph_node) {
const auto& options = const auto& options =
Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>( Subgraph::GetOptions<mediapipe::InferenceCalculatorOptions>(
subgraph_node); subgraph_node);
std::vector<absl::string_view> impls; std::vector<absl::string_view> impls;
const bool should_use_gpu = const bool should_use_gpu =

View File

@ -99,8 +99,13 @@ class InferenceCalculator : public NodeIntf {
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"}; static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"}; static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr SideInput<std::string>::Optional kNnApiDelegateCacheDir{
"NNAPI_CACHE_DIR"};
static constexpr SideInput<std::string>::Optional kNnApiDelegateModelToken{
"NNAPI_MODEL_TOKEN"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
kOutTensors); kOutTensors, kNnApiDelegateCacheDir,
kNnApiDelegateModelToken);
protected: protected:
using TfLiteDelegatePtr = using TfLiteDelegatePtr =

View File

@ -67,9 +67,32 @@ message InferenceCalculatorOptions {
// Only available for OpenCL delegate on Android. // Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set. // Kernel caching will only be enabled if this path is set.
optional string cached_kernel_path = 2; optional string cached_kernel_path = 2;
// Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage {
UNSPECIFIED = 0;
// InferenceRunner will be used only once. Therefore, it is important to
// minimize bootstrap time as well.
FAST_SINGLE_ANSWER = 1;
// Prefer maximizing the throughput. Same inference runner will be used
// repeatedly on different inputs.
SUSTAINED_SPEED = 2;
} }
optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED];
}
// Android only. // Android only.
message Nnapi {} message Nnapi {
// Directory to store compilation cache. If unspecified, NNAPI will not
// try caching the compilation.
optional string cache_dir = 1;
// Unique token identifying the model. It is the caller's responsibility
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation.
optional string model_token = 2;
}
message Xnnpack { message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries // Number of threads for XNNPACK delegate. (By default, calculator tries
// to choose optimal number of threads depending on the device.) // to choose optimal number of threads depending on the device.)

View File

@ -181,9 +181,21 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
// Attempt to use NNAPI. // Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used. // If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1); interpreter_->SetAllowFp16PrecisionForFp32(1);
delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { tflite::StatefulNnApiDelegate::Options options;
// No need to free according to tflite::NnApiDelegate() documentation. const auto& nnapi = calculator_opts.delegate().nnapi();
}); // Set up cache_dir and model_token for NNAPI compilation cache.
options.cache_dir =
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
if (!kNnApiDelegateCacheDir(cc).IsEmpty()) {
options.cache_dir = kNnApiDelegateCacheDir(cc).Get().c_str();
}
options.model_token =
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
if (!kNnApiDelegateModelToken(cc).IsEmpty()) {
options.model_token = kNnApiDelegateModelToken(cc).Get().c_str();
}
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk); kTfLiteOk);
return absl::OkStatus(); return absl::OkStatus();

View File

@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
@ -65,6 +66,8 @@ class InferenceCalculatorGlImpl
bool allow_precision_loss_ = false; bool allow_precision_loss_ = false;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_; tflite_gpu_runner_api_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED #if MEDIAPIPE_TFLITE_GPU_SUPPORTED
@ -96,6 +99,7 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
options.delegate().gpu().use_advanced_gpu_api(); options.delegate().gpu().use_advanced_gpu_api();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
tflite_gpu_runner_api_ = options.delegate().gpu().api(); tflite_gpu_runner_api_ = options.delegate().gpu().api();
tflite_gpu_runner_usage_ = options.delegate().gpu().usage();
use_kernel_caching_ = use_advanced_gpu_api_ && use_kernel_caching_ = use_advanced_gpu_api_ &&
options.delegate().gpu().has_cached_kernel_path(); options.delegate().gpu().has_cached_kernel_path();
use_gpu_delegate_ = !use_advanced_gpu_api_; use_gpu_delegate_ = !use_advanced_gpu_api_;
@ -253,9 +257,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
: tflite::gpu::InferencePriority::MAX_PRECISION; : tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO;
switch (tflite_gpu_runner_usage_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options); tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) { switch (tflite_gpu_runner_api_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
tflite_gpu_runner_->ForceOpenGL(); tflite_gpu_runner_->ForceOpenGL();
break; break;
@ -264,10 +286,6 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
tflite_gpu_runner_->ForceOpenCL(); tflite_gpu_runner_->ForceOpenCL();
break; break;
} }
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
} }
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true)); model, op_resolver, /*allow_quant_ops=*/true));

View File

@ -864,6 +864,7 @@ cc_test(
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",

View File

@ -243,11 +243,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
} }
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs()
.Tag(kSequenceExampleTag)
.SetNextTimestampBound(Timestamp::Max());
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -305,7 +300,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (cc->Outputs().HasTag(kSequenceExampleTag)) { if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs() cc->Outputs()
.Tag(kSequenceExampleTag) .Tag(kSequenceExampleTag)
.Add(sequence_.release(), Timestamp::PostStream()); .Add(sequence_.release(), options.output_as_zero_timestamp()
? Timestamp(0ll)
: Timestamp::PostStream());
} }
sequence_.reset(); sequence_.reset();

View File

@ -65,4 +65,7 @@ message PackMediaSequenceCalculatorOptions {
// If true, will return an error status if an output sequence would be too // If true, will return an error status if an output sequence would be too
// many bytes to serialize. // many bytes to serialize.
optional bool skip_large_sequences = 7 [default = true]; optional bool skip_large_sequences = 7 [default = true];
// If true/false, outputs the SequenceExample at timestamp 0/PostStream.
optional bool output_as_zero_timestamp = 8 [default = false];
} }

View File

@ -29,6 +29,7 @@
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/example/feature.pb.h"
@ -43,8 +44,9 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpCalculator(const std::vector<std::string>& input_streams, void SetUpCalculator(const std::vector<std::string>& input_streams,
const tf::Features& features, const tf::Features& features,
bool output_only_if_all_present, const bool output_only_if_all_present,
bool replace_instead_of_append) { const bool replace_instead_of_append,
const bool output_as_zero_timestamp = false) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("PackMediaSequenceCalculator"); config.set_calculator("PackMediaSequenceCalculator");
config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence");
@ -57,6 +59,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
*options->mutable_context_feature_map() = features; *options->mutable_context_feature_map() = features;
options->set_output_only_if_all_present(output_only_if_all_present); options->set_output_only_if_all_present(output_only_if_all_present);
options->set_replace_data_instead_of_append(replace_instead_of_append); options->set_replace_data_instead_of_append(replace_instead_of_append);
options->set_output_as_zero_timestamp(output_as_zero_timestamp);
runner_ = ::absl::make_unique<CalculatorRunner>(config); runner_ = ::absl::make_unique<CalculatorRunner>(config);
} }
@ -194,6 +197,29 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
} }
} }
TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) {
SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST")
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(output_packets[0].Timestamp().Value(), 0ll);
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
SetUpCalculator( SetUpCalculator(
{"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, {"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"},

View File

@ -292,6 +292,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
bool allow_precision_loss_ = false; bool allow_precision_loss_ = false;
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_; tflite_gpu_runner_api_;
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
std::string cached_kernel_filename_; std::string cached_kernel_filename_;
@ -377,6 +379,7 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
options.delegate().gpu().use_advanced_gpu_api(); options.delegate().gpu().use_advanced_gpu_api();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
tflite_gpu_runner_api_ = options.delegate().gpu().api(); tflite_gpu_runner_api_ = options.delegate().gpu().api();
tflite_gpu_runner_usage_ = options.delegate().gpu().usage();
use_kernel_caching_ = use_advanced_gpu_api_ && use_kernel_caching_ = use_advanced_gpu_api_ &&
options.delegate().gpu().has_cached_kernel_path(); options.delegate().gpu().has_cached_kernel_path();
@ -733,7 +736,23 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
: tflite::gpu::InferencePriority::MAX_PRECISION; : tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO; options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO; options.priority3 = tflite::gpu::InferencePriority::AUTO;
switch (tflite_gpu_runner_usage_) {
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options); tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) { switch (tflite_gpu_runner_api_) {
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: { case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
@ -878,11 +897,15 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
// Attempt to use NNAPI. // Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used. // If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1); interpreter_->SetAllowFp16PrecisionForFp32(1);
delegate_ = tflite::StatefulNnApiDelegate::Options options;
TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { const auto& nnapi = calculator_opts.delegate().nnapi();
// No need to free according to tflite::NnApiDelegate() // Set up cache_dir and model_token for NNAPI compilation cache.
// documentation. if (nnapi.has_cache_dir() && nnapi.has_model_token()) {
}); options.cache_dir = nnapi.cache_dir().c_str();
options.model_token = nnapi.model_token().c_str();
}
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk); kTfLiteOk);
return absl::OkStatus(); return absl::OkStatus();

View File

@ -67,9 +67,31 @@ message TfLiteInferenceCalculatorOptions {
// Only available for OpenCL delegate on Android. // Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set. // Kernel caching will only be enabled if this path is set.
optional string cached_kernel_path = 2; optional string cached_kernel_path = 2;
// Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage {
UNSPECIFIED = 0;
// InferenceRunner will be used only once. Therefore, it is important to
// minimize bootstrap time as well.
FAST_SINGLE_ANSWER = 1;
// Prefer maximizing the throughput. Same inference runner will be used
// repeatedly on different inputs.
SUSTAINED_SPEED = 2;
}
optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED];
} }
// Android only. // Android only.
message Nnapi {} message Nnapi {
// Directory to store compilation cache. If unspecified, NNAPI will not
// try caching the compilation.
optional string cache_dir = 1;
// Unique token identifying the model. It is the caller's responsibility
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation.
optional string model_token = 2;
}
message Xnnpack { message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries // Number of threads for XNNPACK delegate. (By default, calculator tries
// to choose optimal number of threads depending on the device.) // to choose optimal number of threads depending on the device.)

View File

@ -0,0 +1,21 @@
# Copyright 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
filegroup(
name = "resource_files",
srcs = glob(["res/**"]),
visibility = ["//mediapipe/examples/android/solutions:__subpackages__"],
)

View File

@ -0,0 +1,24 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
repositories {
google()
mavenCentral()
}
dependencies {
classpath "com.android.tools.build:gradle:4.2.0"
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
}
allprojects {
repositories {
google()
mavenCentral()
}
}
task clean(type: Delete) {
delete rootProject.buildDir
}

View File

@ -0,0 +1,17 @@
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app"s APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true

View File

@ -0,0 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

185
mediapipe/examples/android/solutions/gradlew vendored Executable file
View File

@ -0,0 +1,185 @@
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"

View File

@ -0,0 +1,89 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -0,0 +1,50 @@
plugins {
id 'com.android.application'
}
android {
compileSdkVersion 30
buildToolsVersion "30.0.3"
defaultConfig {
applicationId "com.google.mediapipe.apps.hands"
minSdkVersion 21
targetSdkVersion 30
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'com.google.android.material:material:1.3.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
// MediaPipe hands solution API and solution-core.
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:latest.release'
implementation 'com.google.flogger:flogger-system-backend:latest.release'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}

View File

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -0,0 +1,31 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.examples.hands">
<uses-sdk
android:minSdkVersion="21"
android:targetSdkVersion="30" />
<!-- For loading images from gallery -->
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="MediaPipe Hands"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,40 @@
# Copyright 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
android_binary(
name = "hands",
srcs = glob(["**/*.java"]),
custom_package = "com.google.mediapipe.examples.hands",
manifest = "AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.examples.hands",
},
multidex = "native",
resource_files = ["//mediapipe/examples/android/solutions:resource_files"],
deps = [
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/solutioncore:camera_input",
"//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib",
"//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering",
"//mediapipe/java/com/google/mediapipe/solutions/hands",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"@maven//:androidx_concurrent_concurrent_futures",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,129 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.opengl.GLES20;
import android.opengl.Matrix;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutioncore.ResultGlBoundary;
import com.google.mediapipe.solutioncore.ResultGlRenderer;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.List;
/** A custom implementation of {@link ResultGlRenderer} to render MediaPope Hands results. */
public class HandsResultGlRenderer implements ResultGlRenderer<HandsResult> {
private static final String TAG = "HandsResultGlRenderer";
private static final float CONNECTION_THICKNESS = 20.0f;
private static final String VERTEX_SHADER =
"uniform mat4 uTransformMatrix;\n"
+ "attribute vec4 vPosition;\n"
+ "void main() {\n"
+ " gl_Position = uTransformMatrix * vPosition;\n"
+ "}";
private static final String FRAGMENT_SHADER =
"precision mediump float;\n"
+ "void main() {\n"
+ " gl_FragColor = vec4(0, 1, 0, 1);\n"
+ "}";
private int program;
private int positionHandle;
private int transformMatrixHandle;
private final float[] transformMatrix = new float[16];
private FloatBuffer vertexBuffer;
private int loadShader(int type, String shaderCode) {
int shader = GLES20.glCreateShader(type);
GLES20.glShaderSource(shader, shaderCode);
GLES20.glCompileShader(shader);
return shader;
}
@Override
public void setupRendering() {
program = GLES20.glCreateProgram();
int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER);
int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER);
GLES20.glAttachShader(program, vertexShader);
GLES20.glAttachShader(program, fragmentShader);
GLES20.glLinkProgram(program);
positionHandle = GLES20.glGetAttribLocation(program, "vPosition");
transformMatrixHandle = GLES20.glGetUniformLocation(program, "uTransformMatrix");
}
@Override
public void renderResult(HandsResult result, ResultGlBoundary boundary) {
if (result == null) {
return;
}
GLES20.glUseProgram(program);
// Sets the transform matrix to align the result rendering with the scaled output texture.
Matrix.setIdentityM(transformMatrix, 0);
Matrix.scaleM(
transformMatrix,
0,
2 / (boundary.right() - boundary.left()),
2 / (boundary.top() - boundary.bottom()),
1.0f);
GLES20.glUniformMatrix4fv(transformMatrixHandle, 1, false, transformMatrix, 0);
GLES20.glLineWidth(CONNECTION_THICKNESS);
int numHands = result.multiHandLandmarks().size();
for (int i = 0; i < numHands; ++i) {
drawLandmarks(result.multiHandLandmarks().get(i).getLandmarkList());
}
}
/**
* Calls this to delete the shader program.
*
* <p>This is only necessary if one wants to release the program while keeping the context around.
*/
public void release() {
GLES20.glDeleteProgram(program);
}
// TODO: Better hand landmark and hand connection drawing.
private void drawLandmarks(List<NormalizedLandmark> handLandmarkList) {
for (Hands.Connection c : Hands.HAND_CONNECTIONS) {
float[] vertex = new float[4];
NormalizedLandmark start = handLandmarkList.get(c.start());
vertex[0] = normalizedLandmarkValue(start.getX());
vertex[1] = normalizedLandmarkValue(start.getY());
NormalizedLandmark end = handLandmarkList.get(c.end());
vertex[2] = normalizedLandmarkValue(end.getX());
vertex[3] = normalizedLandmarkValue(end.getY());
vertexBuffer =
ByteBuffer.allocateDirect(vertex.length * 4)
.order(ByteOrder.nativeOrder())
.asFloatBuffer()
.put(vertex);
vertexBuffer.position(0);
GLES20.glEnableVertexAttribArray(positionHandle);
GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer);
GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2);
}
}
// Normalizes the value from the landmark value range:[0, 1] to the standard OpenGL coordinate
// value range: [-1, 1].
private float normalizedLandmarkValue(float value) {
return value * 2 - 1;
}
}

View File

@ -0,0 +1,95 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.widget.ImageView;
import com.google.mediapipe.formats.proto.LandmarkProto;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.util.List;
/** An ImageView implementation for displaying MediaPipe Hands results. */
public class HandsResultImageView extends ImageView {
private static final String TAG = "HandsResultImageView";
private static final int LANDMARK_COLOR = Color.RED;
private static final int LANDMARK_RADIUS = 15;
private static final int CONNECTION_COLOR = Color.GREEN;
private static final int CONNECTION_THICKNESS = 10;
public HandsResultImageView(Context context) {
super(context);
setScaleType(ImageView.ScaleType.FIT_CENTER);
}
/**
* Sets a {@link HandsResult} to render.
*
* @param result a {@link HandsResult} object that contains the solution outputs and the input
* {@link Bitmap}.
*/
public void setHandsResult(HandsResult result) {
if (result == null) {
return;
}
Bitmap bmInput = result.inputBitmap();
int width = bmInput.getWidth();
int height = bmInput.getHeight();
Bitmap bmOutput = Bitmap.createBitmap(width, height, bmInput.getConfig());
Canvas canvas = new Canvas(bmOutput);
canvas.drawBitmap(bmInput, new Matrix(), null);
int numHands = result.multiHandLandmarks().size();
for (int i = 0; i < numHands; ++i) {
drawLandmarksOnCanvas(
result.multiHandLandmarks().get(i).getLandmarkList(), canvas, width, height);
}
postInvalidate();
setImageBitmap(bmOutput);
}
// TODO: Better hand landmark and hand connection drawing.
private void drawLandmarksOnCanvas(
List<NormalizedLandmark> handLandmarkList, Canvas canvas, int width, int height) {
// Draw connections.
for (Hands.Connection c : Hands.HAND_CONNECTIONS) {
Paint connectionPaint = new Paint();
connectionPaint.setColor(CONNECTION_COLOR);
connectionPaint.setStrokeWidth(CONNECTION_THICKNESS);
NormalizedLandmark start = handLandmarkList.get(c.start());
NormalizedLandmark end = handLandmarkList.get(c.end());
canvas.drawLine(
start.getX() * width,
start.getY() * height,
end.getX() * width,
end.getY() * height,
connectionPaint);
}
Paint landmarkPaint = new Paint();
landmarkPaint.setColor(LANDMARK_COLOR);
// Draw landmarks.
for (LandmarkProto.NormalizedLandmark landmark : handLandmarkList) {
canvas.drawCircle(
landmark.getX() * width, landmark.getY() * height, LANDMARK_RADIUS, landmarkPaint);
}
}
}

View File

@ -0,0 +1,241 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.content.Intent;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.FrameLayout;
import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutioncore.CameraInput;
import com.google.mediapipe.solutioncore.SolutionGlSurfaceView;
import com.google.mediapipe.solutions.hands.HandLandmark;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsOptions;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.io.IOException;
/** Main activity of MediaPipe Hands app. */
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private Hands hands;
private int mode = HandsOptions.STATIC_IMAGE_MODE;
// Image demo UI and image loader components.
private Button loadImageButton;
private ActivityResultLauncher<Intent> imageGetter;
private HandsResultImageView imageView;
// Live camera demo UI and camera components.
private Button startCameraButton;
private CameraInput cameraInput;
private SolutionGlSurfaceView<HandsResult> glSurfaceView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
setupStaticImageDemoUiComponents();
setupLiveDemoUiComponents();
}
@Override
protected void onResume() {
super.onResume();
if (mode == HandsOptions.STREAMING_MODE) {
// Restarts the camera and the opengl surface rendering.
cameraInput = new CameraInput(this);
cameraInput.setCameraNewFrameListener(textureFrame -> hands.send(textureFrame));
glSurfaceView.post(this::startCamera);
glSurfaceView.setVisibility(View.VISIBLE);
}
}
@Override
protected void onPause() {
super.onPause();
if (mode == HandsOptions.STREAMING_MODE) {
stopLiveDemo();
}
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData());
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
if (bitmap != null) {
hands.send(bitmap);
}
}
}
});
loadImageButton = (Button) findViewById(R.id.button_load_picture);
loadImageButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (mode == HandsOptions.STREAMING_MODE) {
stopLiveDemo();
}
if (hands == null || mode != HandsOptions.STATIC_IMAGE_MODE) {
setupStaticImageModePipeline();
}
// Reads images from gallery.
Intent gallery =
new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI);
imageGetter.launch(gallery);
}
});
imageView = new HandsResultImageView(this);
}
/** The core MediaPipe Hands setup workflow for its static image mode. */
private void setupStaticImageModePipeline() {
// Initializes a new MediaPipe Hands instance in the static image mode.
mode = HandsOptions.STATIC_IMAGE_MODE;
if (hands != null) {
hands.close();
}
hands = new Hands(this, HandsOptions.builder().setMode(mode).build());
// Connects MediaPipe Hands to the user-defined HandsResultImageView.
hands.setResultListener(
handsResult -> {
logWristLandmark(handsResult, /*showPixelValues=*/ true);
runOnUiThread(() -> imageView.setHandsResult(handsResult));
});
hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe hands error:" + message));
// Updates the preview layout.
FrameLayout frameLayout = (FrameLayout) findViewById(R.id.preview_display_layout);
frameLayout.removeAllViewsInLayout();
imageView.setImageDrawable(null);
frameLayout.addView(imageView);
imageView.setVisibility(View.VISIBLE);
}
/** Sets up the UI components for the live demo with camera input. */
private void setupLiveDemoUiComponents() {
startCameraButton = (Button) findViewById(R.id.button_start_camera);
startCameraButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (hands == null || mode != HandsOptions.STREAMING_MODE) {
setupStreamingModePipeline();
}
}
});
}
/** The core MediaPipe Hands setup workflow for its streaming mode. */
private void setupStreamingModePipeline() {
// Initializes a new MediaPipe Hands instance in the streaming mode.
mode = HandsOptions.STREAMING_MODE;
if (hands != null) {
hands.close();
}
hands = new Hands(this, HandsOptions.builder().setMode(mode).build());
hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe hands error:" + message));
// Initializes a new CameraInput instance and connects it to MediaPipe Hands.
cameraInput = new CameraInput(this);
cameraInput.setCameraNewFrameListener(textureFrame -> hands.send(textureFrame));
// Initalizes a new Gl surface view with a user-defined HandsResultGlRenderer.
glSurfaceView =
new SolutionGlSurfaceView<>(this, hands.getGlContext(), hands.getGlMajorVersion());
glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer());
glSurfaceView.setRenderInputImage(true);
hands.setResultListener(
handsResult -> {
logWristLandmark(handsResult, /*showPixelValues=*/ false);
glSurfaceView.setRenderData(handsResult);
glSurfaceView.requestRender();
});
// The runnable to start camera after the gl surface view is attached.
glSurfaceView.post(this::startCamera);
// Updates the preview layout.
FrameLayout frameLayout = (FrameLayout) findViewById(R.id.preview_display_layout);
imageView.setVisibility(View.GONE);
frameLayout.removeAllViewsInLayout();
frameLayout.addView(glSurfaceView);
glSurfaceView.setVisibility(View.VISIBLE);
frameLayout.requestLayout();
}
private void startCamera() {
cameraInput.start(
this,
hands.getGlContext(),
CameraInput.CameraFacing.FRONT,
glSurfaceView.getWidth(),
glSurfaceView.getHeight());
}
private void stopLiveDemo() {
if (cameraInput != null) {
cameraInput.stop();
}
if (glSurfaceView != null) {
glSurfaceView.setVisibility(View.GONE);
}
}
private void logWristLandmark(HandsResult result, boolean showPixelValues) {
NormalizedLandmark wristLandmark = Hands.getHandLandmark(result, 0, HandLandmark.WRIST);
// For Bitmaps, show the pixel values. For texture inputs, show the normoralized cooridanates.
if (showPixelValues) {
int width = result.inputBitmap().getWidth();
int height = result.inputBitmap().getHeight();
Log.i(
TAG,
"MediaPipe Hand wrist coordinates (pixel values): x= "
+ wristLandmark.getX() * width
+ " y="
+ wristLandmark.getY() * height);
} else {
Log.i(
TAG,
"MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x= "
+ wristLandmark.getX()
+ " y="
+ wristLandmark.getY());
}
}
}

View File

@ -0,0 +1 @@
../../../res

View File

@ -0,0 +1,34 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportHeight="108"
android:viewportWidth="108">
<path
android:fillType="evenOdd"
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
android:strokeColor="#00000000"
android:strokeWidth="1">
<aapt:attr name="android:fillColor">
<gradient
android:endX="78.5885"
android:endY="90.9159"
android:startX="48.7653"
android:startY="61.0927"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
android:strokeColor="#00000000"
android:strokeWidth="1" />
</vector>

View File

@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<vector
android:height="108dp"
android:width="108dp"
android:viewportHeight="108"
android:viewportWidth="108"
xmlns:android="http://schemas.android.com/apk/res/android">
<path android:fillColor="#26A69A"
android:pathData="M0,0h108v108h-108z"/>
<path android:fillColor="#00000000" android:pathData="M9,0L9,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,0L19,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,0L29,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,0L39,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,0L49,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,0L59,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,0L69,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,0L79,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M89,0L89,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M99,0L99,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,9L108,9"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,19L108,19"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,29L108,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,39L108,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,49L108,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,59L108,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,69L108,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,79L108,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,89L108,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,99L108,99"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,29L89,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,39L89,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,49L89,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,59L89,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,69L89,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,79L89,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,19L29,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,19L39,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,19L49,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,19L59,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,19L69,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,19L79,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
</vector>

View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:id="@+id/buttons"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:orientation="horizontal">
<Button
android:id="@+id/button_load_picture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Load Picture" />
<Button
android:id="@+id/button_start_camera"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Start Camera" />
</LinearLayout>
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/no_view"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:gravity="center"
android:text="Please press any button above to start" />
</FrameLayout>
</LinearLayout>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 959 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 900 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -0,0 +1,3 @@
<resources>
<string name="no_camera_access" translatable="false">Please grant camera permissions.</string>
</resources>

View File

@ -0,0 +1,11 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -0,0 +1,2 @@
rootProject.name = "mediapipe-solutions-examples"
include ':hands'

View File

@ -291,7 +291,6 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
CheckBorder(static_features, 1000, 1000, 495, 395); CheckBorder(static_features, 1000, 1000, 495, 395);
} }
#if 0
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) { TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD)); ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
@ -727,8 +726,8 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeZoomingWithCache) {
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket< runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache); mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
1000, runner.get()); runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 2000000, 500, 500, AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 2000000, 500, 500,
runner.get()); runner.get());
MP_ASSERT_OK(runner->Run()); MP_ASSERT_OK(runner->Run());
@ -752,7 +751,6 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
CheckCropRect(500, 500, 916, 916, 0, CheckCropRect(500, 500, 916, 916, 0,
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
#endif
TEST(ContentZoomingCalculatorTest, MaxZoomValueOverride) { TEST(ContentZoomingCalculatorTest, MaxZoomValueOverride) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigF); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigF);
@ -781,7 +779,6 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValueOverride) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
#if 0
TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) { TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension( auto* options = config.mutable_options()->MutableExtension(
@ -969,7 +966,6 @@ TEST(ContentZoomingCalculatorTest, ProvidesConstantFirstRect) {
EXPECT_EQ(first_rect.height(), rect.height()); EXPECT_EQ(first_rect.height(), rect.height());
} }
} }
#endif
} // namespace } // namespace
} // namespace autoflip } // namespace autoflip

View File

@ -71,7 +71,10 @@ message FaceBoxAdjusterCalculatorOptions {
optional int32 max_facesize_history_us = 9 [default = 300000000]; optional int32 max_facesize_history_us = 9 [default = 300000000];
// Scale factor of face width to shift based on pan look angle. // Scale factor of face width to shift based on pan look angle.
optional float pan_position_shift_scale = 15 [default = 0.5]; optional float pan_position_shift_scale = 15
[default = 1.0, deprecated = true];
// Scale factor of face height to shift based on tilt look angle. // Scale factor of face height to shift based on tilt look angle.
optional float tilt_position_shift_scale = 16 [default = 0.5]; optional float tilt_position_shift_scale = 16 [default = 0.25];
// Scale factor of face width to shift based on roll look angle.
optional float roll_position_shift_scale = 17 [default = 0.7];
} }

View File

@ -29,6 +29,7 @@ import os
import plistlib import plistlib
import re import re
import subprocess import subprocess
from typing import Optional
import uuid import uuid
# This script is meant to be located in the MediaPipe iOS examples directory # This script is meant to be located in the MediaPipe iOS examples directory
@ -79,7 +80,7 @@ def configure_bundle_id_prefix(
return bundle_id_prefix return bundle_id_prefix
def get_app_id(profile_path) -> str: def get_app_id(profile_path) -> Optional[str]:
try: try:
plist = subprocess.check_output( plist = subprocess.check_output(
["security", "cms", "-D", "-i", profile_path]) ["security", "cms", "-D", "-i", profile_path])

View File

@ -222,10 +222,10 @@ cc_library(
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto",
"//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
@ -299,7 +299,6 @@ cc_library(
":graph_service", ":graph_service",
":graph_service_manager", ":graph_service_manager",
":input_stream_manager", ":input_stream_manager",
":input_stream_shard",
":output_side_packet_impl", ":output_side_packet_impl",
":output_stream", ":output_stream",
":output_stream_manager", ":output_stream_manager",
@ -317,8 +316,6 @@ cc_library(
":timestamp", ":timestamp",
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:packet_factory_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework:thread_pool_executor_cc_proto",
@ -327,6 +324,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -336,8 +334,8 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location", "//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/tool:fill_packet_set", "//mediapipe/framework/tool:fill_packet_set",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate", "//mediapipe/framework/tool:validate",
@ -345,10 +343,7 @@ cc_library(
"//mediapipe/gpu:graph_support", "//mediapipe/gpu:graph_support",
"//mediapipe/util:cpu_util", "//mediapipe/util:cpu_util",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"],
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/gpu:gpu_service",
],
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
}), }),
) )
@ -389,13 +384,11 @@ cc_library(
":input_side_packet_handler", ":input_side_packet_handler",
":input_stream_handler", ":input_stream_handler",
":input_stream_manager", ":input_stream_manager",
":input_stream_shard",
":legacy_calculator_support", ":legacy_calculator_support",
":mediapipe_profiling", ":mediapipe_profiling",
":output_side_packet_impl", ":output_side_packet_impl",
":output_stream_handler", ":output_stream_handler",
":output_stream_manager", ":output_stream_manager",
":output_stream_shard",
":packet", ":packet",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
@ -404,14 +397,12 @@ cc_library(
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location", "//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/stream_handler:default_input_stream_handler", "//mediapipe/framework/stream_handler:default_input_stream_handler",
"//mediapipe/framework/stream_handler:in_order_output_stream_handler", "//mediapipe/framework/stream_handler:in_order_output_stream_handler",
"//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:name_util",
@ -421,6 +412,7 @@ cc_library(
"//mediapipe/gpu:graph_support", "//mediapipe/gpu:graph_support",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
@ -1588,6 +1580,7 @@ cc_test(
":packet", ":packet",
":packet_test_cc_proto", ":packet_test_cc_proto",
":type_map", ":type_map",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -212,6 +212,9 @@ message ProfilerConfig {
// False specifies an event for each calculator invocation. // False specifies an event for each calculator invocation.
// True specifies a separate event for each start and finish time. // True specifies a separate event for each start and finish time.
bool trace_log_instant_events = 17; bool trace_log_instant_events = 17;
// Limits calculator-profile histograms to a subset of calculators.
string calculator_filter = 18;
} }
// Describes the topology and function of a MediaPipe Graph. The graph of // Describes the topology and function of a MediaPipe Graph. The graph of

View File

@ -14,16 +14,40 @@
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/packet_generator_wrapper_calculator.pb.h"
#include "mediapipe/framework/tool/tag_map.h" #include "mediapipe/framework/tool/tag_map.h"
namespace mediapipe { namespace mediapipe {
namespace {
CalculatorGraphConfig::Node MakePacketGeneratorWrapperConfig(
const PacketGeneratorConfig& node, const std::string& package) {
CalculatorGraphConfig::Node wrapper_node;
wrapper_node.set_calculator("PacketGeneratorWrapperCalculator");
*wrapper_node.mutable_input_side_packet() = node.input_side_packet();
*wrapper_node.mutable_output_side_packet() = node.output_side_packet();
auto* wrapper_options = wrapper_node.mutable_options()->MutableExtension(
mediapipe::PacketGeneratorWrapperCalculatorOptions::ext);
wrapper_options->set_packet_generator(node.packet_generator());
wrapper_options->set_package(package);
if (node.has_options()) {
*wrapper_options->mutable_options() = node.options();
}
return wrapper_node;
}
} // anonymous namespace
absl::Status CalculatorContract::Initialize( absl::Status CalculatorContract::Initialize(
const CalculatorGraphConfig::Node& node) { const CalculatorGraphConfig::Node& node) {
std::vector<absl::Status> statuses; std::vector<absl::Status> statuses;
@ -74,7 +98,8 @@ absl::Status CalculatorContract::Initialize(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node) { absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node,
const std::string& package) {
std::vector<absl::Status> statuses; std::vector<absl::Status> statuses;
auto input_side_packet_statusor = auto input_side_packet_statusor =
@ -101,6 +126,11 @@ absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node) {
return std::move(builder); return std::move(builder);
} }
wrapper_config_ = std::make_unique<CalculatorGraphConfig::Node>(
MakePacketGeneratorWrapperConfig(node, package));
options_.Initialize(*wrapper_config_);
inputs_ = absl::make_unique<PacketTypeSet>(0);
outputs_ = absl::make_unique<PacketTypeSet>(0);
input_side_packets_ = absl::make_unique<PacketTypeSet>( input_side_packets_ = absl::make_unique<PacketTypeSet>(
std::move(input_side_packet_statusor).value()); std::move(input_side_packet_statusor).value());
output_side_packets_ = absl::make_unique<PacketTypeSet>( output_side_packets_ = absl::make_unique<PacketTypeSet>(

View File

@ -48,7 +48,8 @@ namespace mediapipe {
class CalculatorContract { class CalculatorContract {
public: public:
absl::Status Initialize(const CalculatorGraphConfig::Node& node); absl::Status Initialize(const CalculatorGraphConfig::Node& node);
absl::Status Initialize(const PacketGeneratorConfig& node); absl::Status Initialize(const PacketGeneratorConfig& node,
const std::string& package);
absl::Status Initialize(const StatusHandlerConfig& node); absl::Status Initialize(const StatusHandlerConfig& node);
void SetNodeName(const std::string& node_name) { node_name_ = node_name; } void SetNodeName(const std::string& node_name) { node_name_ = node_name; }
@ -163,7 +164,14 @@ class CalculatorContract {
template <class T> template <class T>
void GetNodeOptions(T* result) const; void GetNodeOptions(T* result) const;
// When creating a contract for a PacketGenerator, we define a configuration
// for a wrapper calculator, for use by CalculatorNode.
const CalculatorGraphConfig::Node& GetWrapperConfig() const {
return *wrapper_config_;
}
const CalculatorGraphConfig::Node* node_config_ = nullptr; const CalculatorGraphConfig::Node* node_config_ = nullptr;
std::unique_ptr<CalculatorGraphConfig::Node> wrapper_config_;
tool::OptionsMap options_; tool::OptionsMap options_;
std::unique_ptr<PacketTypeSet> inputs_; std::unique_ptr<PacketTypeSet> inputs_;
std::unique_ptr<PacketTypeSet> outputs_; std::unique_ptr<PacketTypeSet> outputs_;
@ -175,6 +183,8 @@ class CalculatorContract {
std::map<std::string, GraphServiceRequest> service_requests_; std::map<std::string, GraphServiceRequest> service_requests_;
bool process_timestamps_ = false; bool process_timestamps_ = false;
TimestampDiff timestamp_offset_ = TimestampDiff::Unset(); TimestampDiff timestamp_offset_ = TimestampDiff::Unset();
friend class CalculatorNode;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -80,7 +80,7 @@ TEST(CalculatorContractTest, PacketGenerator) {
output_side_packet: "content_fingerprint" output_side_packet: "content_fingerprint"
)pb"); )pb");
CalculatorContract contract; CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node)); MP_EXPECT_OK(contract.Initialize(node, ""));
EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1); EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1);
EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 4); EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 4);
} }

View File

@ -26,6 +26,7 @@
#include "absl/container/fixed_array.h" #include "absl/container/fixed_array.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -84,9 +85,9 @@ void CalculatorGraph::ScheduleAllOpenableNodes() {
// node->ReadyForOpen() only before any node or graph input stream has // node->ReadyForOpen() only before any node or graph input stream has
// propagated header packets or generated output side packets, either of // propagated header packets or generated output side packets, either of
// which may cause a downstream node to be scheduled for OpenNode(). // which may cause a downstream node to be scheduled for OpenNode().
for (CalculatorNode& node : *nodes_) { for (auto& node : nodes_) {
if (node.ReadyForOpen()) { if (node->ReadyForOpen()) {
scheduler_.ScheduleNodeForOpen(&node); scheduler_.ScheduleNodeForOpen(node.get());
} }
} }
} }
@ -234,15 +235,15 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
std::vector<absl::Status> errors; std::vector<absl::Status> errors;
// Create and initialize all the nodes in the graph. // Create and initialize all the nodes in the graph.
nodes_ = absl::make_unique<absl::FixedArray<CalculatorNode>>(
validated_graph_->CalculatorInfos().size());
for (int node_id = 0; node_id < validated_graph_->CalculatorInfos().size(); for (int node_id = 0; node_id < validated_graph_->CalculatorInfos().size();
++node_id) { ++node_id) {
// buffer_size_hint will be positive if one was specified in // buffer_size_hint will be positive if one was specified in
// the graph proto. // the graph proto.
int buffer_size_hint = 0; int buffer_size_hint = 0;
const absl::Status result = (*nodes_)[node_id].Initialize( NodeTypeInfo::NodeRef node_ref(NodeTypeInfo::NodeType::CALCULATOR, node_id);
validated_graph_.get(), node_id, input_stream_managers_.get(), nodes_.push_back(absl::make_unique<CalculatorNode>());
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(), output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_); &buffer_size_hint, profiler_);
if (buffer_size_hint > 0) { if (buffer_size_hint > 0) {
@ -263,6 +264,38 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
const std::vector<int>& non_scheduled_generators) {
// Do not add wrapper nodes again if we are running the graph multiple times.
if (packet_generator_nodes_added_) return absl::OkStatus();
packet_generator_nodes_added_ = true;
// Use a local variable to avoid needing to lock errors_.
std::vector<absl::Status> errors;
for (int index : non_scheduled_generators) {
// This is never used by the packet generator wrapper.
int buffer_size_hint = 0;
NodeTypeInfo::NodeRef node_ref(NodeTypeInfo::NodeType::PACKET_GENERATOR,
index);
nodes_.push_back(absl::make_unique<CalculatorNode>());
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_);
if (!result.ok()) {
// Collect as many errors as we can before failing.
errors.push_back(result);
}
}
if (!errors.empty()) {
return tool::CombinedStatus(
"CalculatorGraph::InitializePacketGeneratorNodes failed: ", errors);
}
return absl::OkStatus();
}
absl::Status CalculatorGraph::InitializeProfiler() { absl::Status CalculatorGraph::InitializeProfiler() {
profiler_->Initialize(*validated_graph_); profiler_->Initialize(*validated_graph_);
return absl::OkStatus(); return absl::OkStatus();
@ -528,8 +561,8 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
std::map<std::string, Packet> additional_side_packets; std::map<std::string, Packet> additional_side_packets;
bool update_sp = false; bool update_sp = false;
bool uses_gpu = false; bool uses_gpu = false;
for (const auto& node : *nodes_) { for (const auto& node : nodes_) {
if (node.UsesGpu()) { if (node->UsesGpu()) {
uses_gpu = true; uses_gpu = true;
break; break;
} }
@ -571,9 +604,9 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
} }
// Set up executors. // Set up executors.
for (auto& node : *nodes_) { for (auto& node : nodes_) {
if (node.UsesGpu()) { if (node->UsesGpu()) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(&node)); MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
} }
} }
for (const auto& name_executor : gpu_resources->GetGpuExecutors()) { for (const auto& name_executor : gpu_resources->GetGpuExecutors()) {
@ -616,8 +649,10 @@ absl::Status CalculatorGraph::PrepareForRun(
} }
current_run_side_packets_.clear(); current_run_side_packets_.clear();
std::vector<int> non_scheduled_generators;
absl::Status generator_status = packet_generator_graph_.RunGraphSetup( absl::Status generator_status = packet_generator_graph_.RunGraphSetup(
*input_side_packets, &current_run_side_packets_); *input_side_packets, &current_run_side_packets_,
&non_scheduled_generators);
CallStatusHandlers(GraphRunState::PRE_RUN, generator_status); CallStatusHandlers(GraphRunState::PRE_RUN, generator_status);
@ -650,6 +685,8 @@ absl::Status CalculatorGraph::PrepareForRun(
} }
scheduler_.Reset(); scheduler_.Reset();
MP_RETURN_IF_ERROR(InitializePacketGeneratorNodes(non_scheduled_generators));
{ {
absl::MutexLock lock(&full_input_streams_mutex_); absl::MutexLock lock(&full_input_streams_mutex_);
// Initialize a count per source node to store the number of input streams // Initialize a count per source node to store the number of input streams
@ -671,22 +708,22 @@ absl::Status CalculatorGraph::PrepareForRun(
output_side_packets_[index].PrepareForRun( output_side_packets_[index].PrepareForRun(
std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1)); std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1));
} }
for (CalculatorNode& node : *nodes_) { for (auto& node : nodes_) {
InputStreamManager::QueueSizeCallback queue_size_callback = InputStreamManager::QueueSizeCallback queue_size_callback =
std::bind(&CalculatorGraph::UpdateThrottledNodes, this, std::bind(&CalculatorGraph::UpdateThrottledNodes, this,
std::placeholders::_1, std::placeholders::_2); std::placeholders::_1, std::placeholders::_2);
node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); node->SetQueueSizeCallbacks(queue_size_callback, queue_size_callback);
scheduler_.AssignNodeToSchedulerQueue(&node); scheduler_.AssignNodeToSchedulerQueue(node.get());
// TODO: update calculator node to use GraphServiceManager // TODO: update calculator node to use GraphServiceManager
// instead of service packets? // instead of service packets?
const absl::Status result = node.PrepareForRun( const absl::Status result = node->PrepareForRun(
current_run_side_packets_, service_manager_.ServicePackets(), current_run_side_packets_, service_manager_.ServicePackets(),
std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_, std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_,
&node), node.get()),
std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_, std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_,
&node), node.get()),
std::bind(&internal::Scheduler::ScheduleNodeIfNotThrottled, &scheduler_, std::bind(&internal::Scheduler::ScheduleNodeIfNotThrottled, &scheduler_,
&node, std::placeholders::_1), node.get(), std::placeholders::_1),
std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1), std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1),
counter_factory_.get()); counter_factory_.get());
if (!result.ok()) { if (!result.ok()) {
@ -714,8 +751,8 @@ absl::Status CalculatorGraph::PrepareForRun(
// Ensure that the latest value of max queue size is passed to all input // Ensure that the latest value of max queue size is passed to all input
// streams. // streams.
for (auto& node : *nodes_) { for (auto& node : nodes_) {
node.SetMaxInputStreamQueueSize(max_queue_size_); node->SetMaxInputStreamQueueSize(max_queue_size_);
} }
// Allow graph input streams to override the global max queue size. // Allow graph input streams to override the global max queue size.
@ -729,9 +766,9 @@ absl::Status CalculatorGraph::PrepareForRun(
(*stream)->SetMaxQueueSize(name_max.second); (*stream)->SetMaxQueueSize(name_max.second);
} }
for (CalculatorNode& node : *nodes_) { for (auto& node : nodes_) {
if (node.IsSource()) { if (node->IsSource()) {
scheduler_.AddUnopenedSourceNode(&node); scheduler_.AddUnopenedSourceNode(node.get());
has_sources_ = true; has_sources_ = true;
} }
} }
@ -1077,7 +1114,7 @@ void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream,
} }
} else { } else {
if (!is_throttled) { if (!is_throttled) {
CalculatorNode& node = (*nodes_)[node_id]; CalculatorNode& node = *nodes_[node_id];
// Add this node to the scheduler queue if possible. // Add this node to the scheduler queue if possible.
if (node.Active() && !node.Closed()) { if (node.Active() && !node.Closed()) {
nodes_to_schedule.emplace_back(&node); nodes_to_schedule.emplace_back(&node);
@ -1244,8 +1281,8 @@ void CalculatorGraph::CleanupAfterRun(absl::Status* status) {
MEDIAPIPE_CHECK_OK(*status); MEDIAPIPE_CHECK_OK(*status);
} }
for (CalculatorNode& node : *nodes_) { for (auto& node : nodes_) {
node.CleanupAfterRun(*status); node->CleanupAfterRun(*status);
} }
for (auto& graph_output_stream : graph_output_streams_) { for (auto& graph_output_stream : graph_output_streams_) {

View File

@ -486,6 +486,8 @@ class CalculatorGraph {
absl::Status InitializeStreams(); absl::Status InitializeStreams();
absl::Status InitializeProfiler(); absl::Status InitializeProfiler();
absl::Status InitializeCalculatorNodes(); absl::Status InitializeCalculatorNodes();
absl::Status InitializePacketGeneratorNodes(
const std::vector<int>& non_scheduled_generators);
// Iterates through all nodes and schedules any that can be opened. // Iterates through all nodes and schedules any that can be opened.
void ScheduleAllOpenableNodes(); void ScheduleAllOpenableNodes();
@ -556,7 +558,8 @@ class CalculatorGraph {
std::unique_ptr<InputStreamManager[]> input_stream_managers_; std::unique_ptr<InputStreamManager[]> input_stream_managers_;
std::unique_ptr<OutputStreamManager[]> output_stream_managers_; std::unique_ptr<OutputStreamManager[]> output_stream_managers_;
std::unique_ptr<OutputSidePacketImpl[]> output_side_packets_; std::unique_ptr<OutputSidePacketImpl[]> output_side_packets_;
std::unique_ptr<absl::FixedArray<CalculatorNode>> nodes_; std::vector<std::unique_ptr<CalculatorNode>> nodes_;
bool packet_generator_nodes_added_ = false;
// The graph output streams. // The graph output streams.
std::vector<std::shared_ptr<internal::GraphOutputStream>> std::vector<std::shared_ptr<internal::GraphOutputStream>>

View File

@ -52,6 +52,25 @@ class OutputSidePacketInProcessCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator);
// Takes an input side packet and passes it as an output side packet.
class OutputSidePacketInOpenCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(
&cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
};
REGISTER_CALCULATOR(OutputSidePacketInOpenCalculator);
// Takes an input stream packet and counts the number of the packets it // Takes an input stream packet and counts the number of the packets it
// receives. Outputs the total number of packets as a side packet in Close. // receives. Outputs the total number of packets as a side packet in Close.
class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase {
@ -802,5 +821,80 @@ TEST(CalculatorGraph, OutputSidePacketCached) {
} }
} }
TEST(CalculatorGraph, GeneratorAfterCalculatorOpen) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_side_packet: "offset"
node {
calculator: "OutputSidePacketInOpenCalculator"
input_side_packet: "offset"
output_side_packet: "offset1"
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'offset1'
output_side_packet: 'offset_out'
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "offset_out"
output_stream: "output"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({{"offset", MakePacket<TimestampDiff>(100)}}));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(100, output_packets[0].Get<TimestampDiff>().Value());
}
TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "offset"
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "offset"
output_side_packet: "offset"
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'offset'
output_side_packet: 'offset_out'
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "offset_out"
output_stream: "output"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return absl::OkStatus();
}));
// Run twice to verify that we don't duplicate wrapper nodes.
for (int run = 0; run < 2; ++run) {
output_packets.clear();
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(100).At(Timestamp(0))));
MP_ASSERT_OK(graph.CloseInputStream("offset"));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(100, output_packets[0].Get<TimestampDiff>().Value());
}
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -1133,24 +1133,6 @@ class CheckInputTimestamp2SinkCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator); REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator);
// Takes an input stream packet and passes it (with timestamp removed) as an
// output side packet.
class OutputSidePacketInProcessCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(
cc->Inputs().Index(0).Value().At(Timestamp::Unset()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator);
// A calculator checks if either of two input streams contains a packet and // A calculator checks if either of two input streams contains a packet and
// sends the packet to the single output stream with the same timestamp. // sends the packet to the single output stream with the same timestamp.
class SimpleMuxCalculator : public CalculatorBase { class SimpleMuxCalculator : public CalculatorBase {

View File

@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -119,79 +120,92 @@ Timestamp CalculatorNode::SourceProcessOrder(
} }
absl::Status CalculatorNode::Initialize( absl::Status CalculatorNode::Initialize(
const ValidatedGraphConfig* validated_graph, int node_id, const ValidatedGraphConfig* validated_graph, NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers, InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers, OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, int* buffer_size_hint, OutputSidePacketImpl* output_side_packets, int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context) { std::shared_ptr<ProfilingContext> profiling_context) {
RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL"; RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL";
node_id_ = node_id;
validated_graph_ = validated_graph; validated_graph_ = validated_graph;
profiling_context_ = profiling_context; profiling_context_ = profiling_context;
const CalculatorGraphConfig::Node& node_config = const CalculatorGraphConfig::Node* node_config;
validated_graph_->Config().node(node_id_); if (node_ref.type == NodeTypeInfo::NodeType::CALCULATOR) {
name_ = tool::CanonicalNodeName(validated_graph_->Config(), node_id_); node_config = &validated_graph_->Config().node(node_ref.index);
name_ = tool::CanonicalNodeName(validated_graph_->Config(), node_ref.index);
max_in_flight_ = node_config.max_in_flight(); node_type_info_ = &validated_graph_->CalculatorInfos()[node_ref.index];
max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1; } else if (node_ref.type == NodeTypeInfo::NodeType::PACKET_GENERATOR) {
if (!node_config.executor().empty()) { const PacketGeneratorConfig& pg_config =
executor_ = node_config.executor(); validated_graph_->Config().packet_generator(node_ref.index);
name_ = absl::StrCat("__pg_", node_ref.index, "_",
pg_config.packet_generator());
node_type_info_ = &validated_graph_->GeneratorInfos()[node_ref.index];
node_config = &node_type_info_->Contract().GetWrapperConfig();
} else {
return absl::InvalidArgumentError(
"node_ref is not a calculator or packet generator");
} }
source_layer_ = node_config.source_layer();
const NodeTypeInfo& node_type_info = max_in_flight_ = node_config->max_in_flight();
validated_graph_->CalculatorInfos()[node_id_]; max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1;
const CalculatorContract& contract = node_type_info.Contract(); if (!node_config->executor().empty()) {
executor_ = node_config->executor();
}
source_layer_ = node_config->source_layer();
const CalculatorContract& contract = node_type_info_->Contract();
uses_gpu_ = uses_gpu_ =
node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) || node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
ContainsKey(node_type_info.Contract().ServiceRequests(), kGpuService.key); ContainsKey(node_type_info_->Contract().ServiceRequests(),
kGpuService.key);
// TODO Propagate types between calculators when SetAny is used. // TODO Propagate types between calculators when SetAny is used.
MP_RETURN_IF_ERROR(InitializeOutputSidePackets( MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
node_type_info.OutputSidePacketTypes(), output_side_packets)); node_type_info_->OutputSidePacketTypes(), output_side_packets));
MP_RETURN_IF_ERROR(InitializeInputSidePackets(output_side_packets)); MP_RETURN_IF_ERROR(InitializeInputSidePackets(output_side_packets));
MP_RETURN_IF_ERROR(InitializeOutputStreamHandler( MP_RETURN_IF_ERROR(
node_config.output_stream_handler(), node_type_info.OutputStreamTypes())); InitializeOutputStreamHandler(node_config->output_stream_handler(),
node_type_info_->OutputStreamTypes()));
MP_RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers)); MP_RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers));
calculator_state_ = absl::make_unique<CalculatorState>( calculator_state_ = absl::make_unique<CalculatorState>(
name_, node_id_, node_config.calculator(), node_config, name_, node_ref.index, node_config->calculator(), *node_config,
profiling_context_); profiling_context_);
// Inform the scheduler that this node has buffering behavior and that the // Inform the scheduler that this node has buffering behavior and that the
// maximum input queue size should be adjusted accordingly. // maximum input queue size should be adjusted accordingly.
*buffer_size_hint = node_config.buffer_size_hint(); *buffer_size_hint = node_config->buffer_size_hint();
calculator_context_manager_.Initialize( calculator_context_manager_.Initialize(
calculator_state_.get(), node_type_info.InputStreamTypes().TagMap(), calculator_state_.get(), node_type_info_->InputStreamTypes().TagMap(),
node_type_info.OutputStreamTypes().TagMap(), node_type_info_->OutputStreamTypes().TagMap(),
/*calculator_run_in_parallel=*/max_in_flight_ > 1); /*calculator_run_in_parallel=*/max_in_flight_ > 1);
// The graph specified InputStreamHandler takes priority. // The graph specified InputStreamHandler takes priority.
const bool graph_specified = const bool graph_specified =
node_config.input_stream_handler().has_input_stream_handler(); node_config->input_stream_handler().has_input_stream_handler();
const bool calc_specified = !(node_type_info.GetInputStreamHandler().empty()); const bool calc_specified =
!(node_type_info_->GetInputStreamHandler().empty());
// Only use calculator ISH if available, and if the graph ISH is not set. // Only use calculator ISH if available, and if the graph ISH is not set.
InputStreamHandlerConfig handler_config; InputStreamHandlerConfig handler_config;
const bool use_calc_specified = calc_specified && !graph_specified; const bool use_calc_specified = calc_specified && !graph_specified;
if (use_calc_specified) { if (use_calc_specified) {
*(handler_config.mutable_input_stream_handler()) = *(handler_config.mutable_input_stream_handler()) =
node_type_info.GetInputStreamHandler(); node_type_info_->GetInputStreamHandler();
*(handler_config.mutable_options()) = *(handler_config.mutable_options()) =
node_type_info.GetInputStreamHandlerOptions(); node_type_info_->GetInputStreamHandlerOptions();
} }
// Use calculator or graph specified InputStreamHandler, or the default ISH // Use calculator or graph specified InputStreamHandler, or the default ISH
// already set from graph. // already set from graph.
MP_RETURN_IF_ERROR(InitializeInputStreamHandler( MP_RETURN_IF_ERROR(InitializeInputStreamHandler(
use_calc_specified ? handler_config : node_config.input_stream_handler(), use_calc_specified ? handler_config : node_config->input_stream_handler(),
node_type_info.InputStreamTypes())); node_type_info_->InputStreamTypes()));
for (auto& stream : output_stream_handler_->OutputStreams()) { for (auto& stream : output_stream_handler_->OutputStreams()) {
stream->Spec()->offset_enabled = stream->Spec()->offset_enabled =
@ -209,9 +223,7 @@ absl::Status CalculatorNode::InitializeOutputSidePackets(
OutputSidePacketImpl* output_side_packets) { OutputSidePacketImpl* output_side_packets) {
output_side_packets_ = output_side_packets_ =
absl::make_unique<OutputSidePacketSet>(output_side_packet_types.TagMap()); absl::make_unique<OutputSidePacketSet>(output_side_packet_types.TagMap());
const NodeTypeInfo& node_type_info = int base_index = node_type_info_->OutputSidePacketBaseIndex();
validated_graph_->CalculatorInfos()[node_id_];
int base_index = node_type_info.OutputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index); RET_CHECK_LE(0, base_index);
for (CollectionItemId id = output_side_packets_->BeginId(); for (CollectionItemId id = output_side_packets_->BeginId();
id < output_side_packets_->EndId(); ++id) { id < output_side_packets_->EndId(); ++id) {
@ -223,13 +235,11 @@ absl::Status CalculatorNode::InitializeOutputSidePackets(
absl::Status CalculatorNode::InitializeInputSidePackets( absl::Status CalculatorNode::InitializeInputSidePackets(
OutputSidePacketImpl* output_side_packets) { OutputSidePacketImpl* output_side_packets) {
const NodeTypeInfo& node_type_info = int base_index = node_type_info_->InputSidePacketBaseIndex();
validated_graph_->CalculatorInfos()[node_id_];
int base_index = node_type_info.InputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index); RET_CHECK_LE(0, base_index);
// Set all the mirrors. // Set all the mirrors.
for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId(); for (CollectionItemId id = node_type_info_->InputSidePacketTypes().BeginId();
id < node_type_info.InputSidePacketTypes().EndId(); ++id) { id < node_type_info_->InputSidePacketTypes().EndId(); ++id) {
int output_side_packet_index = int output_side_packet_index =
validated_graph_->InputSidePacketInfos()[base_index + id.value()] validated_graph_->InputSidePacketInfos()[base_index + id.value()]
.upstream; .upstream;
@ -252,11 +262,9 @@ absl::Status CalculatorNode::InitializeInputSidePackets(
absl::Status CalculatorNode::InitializeOutputStreams( absl::Status CalculatorNode::InitializeOutputStreams(
OutputStreamManager* output_stream_managers) { OutputStreamManager* output_stream_managers) {
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL"; RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
const NodeTypeInfo& node_type_info = RET_CHECK_LE(0, node_type_info_->OutputStreamBaseIndex());
validated_graph_->CalculatorInfos()[node_id_];
RET_CHECK_LE(0, node_type_info.OutputStreamBaseIndex());
OutputStreamManager* current_output_stream_managers = OutputStreamManager* current_output_stream_managers =
&output_stream_managers[node_type_info.OutputStreamBaseIndex()]; &output_stream_managers[node_type_info_->OutputStreamBaseIndex()];
return output_stream_handler_->InitializeOutputStreamManagers( return output_stream_handler_->InitializeOutputStreamManagers(
current_output_stream_managers); current_output_stream_managers);
} }
@ -266,20 +274,18 @@ absl::Status CalculatorNode::InitializeInputStreams(
OutputStreamManager* output_stream_managers) { OutputStreamManager* output_stream_managers) {
RET_CHECK(input_stream_managers) << "input_stream_managers is NULL"; RET_CHECK(input_stream_managers) << "input_stream_managers is NULL";
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL"; RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
const NodeTypeInfo& node_type_info = RET_CHECK_LE(0, node_type_info_->InputStreamBaseIndex());
validated_graph_->CalculatorInfos()[node_id_];
RET_CHECK_LE(0, node_type_info.InputStreamBaseIndex());
InputStreamManager* current_input_stream_managers = InputStreamManager* current_input_stream_managers =
&input_stream_managers[node_type_info.InputStreamBaseIndex()]; &input_stream_managers[node_type_info_->InputStreamBaseIndex()];
MP_RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers( MP_RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers(
current_input_stream_managers)); current_input_stream_managers));
// Set all the mirrors. // Set all the mirrors.
for (CollectionItemId id = node_type_info.InputStreamTypes().BeginId(); for (CollectionItemId id = node_type_info_->InputStreamTypes().BeginId();
id < node_type_info.InputStreamTypes().EndId(); ++id) { id < node_type_info_->InputStreamTypes().EndId(); ++id) {
int output_stream_index = int output_stream_index =
validated_graph_ validated_graph_
->InputStreamInfos()[node_type_info.InputStreamBaseIndex() + ->InputStreamInfos()[node_type_info_->InputStreamBaseIndex() +
id.value()] id.value()]
.upstream; .upstream;
RET_CHECK_LE(0, output_stream_index); RET_CHECK_LE(0, output_stream_index);
@ -287,7 +293,7 @@ absl::Status CalculatorNode::InitializeInputStreams(
&output_stream_managers[output_stream_index]; &output_stream_managers[output_stream_index];
VLOG(2) << "Adding mirror for input stream with id " << id.value() VLOG(2) << "Adding mirror for input stream with id " << id.value()
<< " and flat index " << " and flat index "
<< node_type_info.InputStreamBaseIndex() + id.value() << node_type_info_->InputStreamBaseIndex() + id.value()
<< " which will be connected to output stream with flat index " << " which will be connected to output stream with flat index "
<< output_stream_index; << output_stream_index;
origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id); origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id);
@ -391,10 +397,9 @@ absl::Status CalculatorNode::PrepareForRun(
std::move(schedule_callback), error_callback); std::move(schedule_callback), error_callback);
output_stream_handler_->PrepareForRun(error_callback); output_stream_handler_->PrepareForRun(error_callback);
const PacketTypeSet* packet_types = const auto& contract = node_type_info_->Contract();
&validated_graph_->CalculatorInfos()[node_id_].InputSidePacketTypes();
input_side_packet_types_ = RemoveOmittedPacketTypes( input_side_packet_types_ = RemoveOmittedPacketTypes(
*packet_types, all_side_packets, validated_graph_); contract.InputSidePackets(), all_side_packets, validated_graph_);
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun( MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(
input_side_packet_types_.get(), all_side_packets, input_side_packet_types_.get(), all_side_packets,
[this]() { CalculatorNode::InputSidePacketsReady(); }, [this]() { CalculatorNode::InputSidePacketsReady(); },
@ -404,8 +409,6 @@ absl::Status CalculatorNode::PrepareForRun(
calculator_state_->SetOutputSidePackets(output_side_packets_.get()); calculator_state_->SetOutputSidePackets(output_side_packets_.get());
calculator_state_->SetCounterFactory(counter_factory); calculator_state_->SetCounterFactory(counter_factory);
const auto& contract =
validated_graph_->CalculatorInfos()[node_id_].Contract();
for (const auto& svc_req : contract.ServiceRequests()) { for (const auto& svc_req : contract.ServiceRequests()) {
const auto& req = svc_req.second; const auto& req = svc_req.second;
auto it = service_packets.find(req.Service().key); auto it = service_packets.find(req.Service().key);

View File

@ -70,7 +70,9 @@ class CalculatorNode {
CalculatorNode(); CalculatorNode();
CalculatorNode(const CalculatorNode&) = delete; CalculatorNode(const CalculatorNode&) = delete;
CalculatorNode& operator=(const CalculatorNode&) = delete; CalculatorNode& operator=(const CalculatorNode&) = delete;
int Id() const { return node_id_; } int Id() const {
return node_type_info_ ? node_type_info_->Node().index : -1;
}
// Returns a value according to which the scheduler queue determines the // Returns a value according to which the scheduler queue determines the
// relative priority between runnable source nodes; a smaller value means // relative priority between runnable source nodes; a smaller value means
@ -106,7 +108,7 @@ class CalculatorNode {
// OutputSidePacketImpls corresponding to the output side packet indexes in // OutputSidePacketImpls corresponding to the output side packet indexes in
// validated_graph. // validated_graph.
absl::Status Initialize(const ValidatedGraphConfig* validated_graph, absl::Status Initialize(const ValidatedGraphConfig* validated_graph,
int node_id, NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers, InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers, OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, OutputSidePacketImpl* output_side_packets,
@ -287,7 +289,6 @@ class CalculatorNode {
// Keeps data which a Calculator subclass needs access to. // Keeps data which a Calculator subclass needs access to.
std::unique_ptr<CalculatorState> calculator_state_; std::unique_ptr<CalculatorState> calculator_state_;
int node_id_ = -1;
std::string name_; // Optional user-defined name std::string name_; // Optional user-defined name
// Name of the executor which the node will execute on. If empty, the node // Name of the executor which the node will execute on. If empty, the node
// will execute on the default executor. // will execute on the default executor.
@ -372,6 +373,8 @@ class CalculatorNode {
internal::SchedulerQueue* scheduler_queue_ = nullptr; internal::SchedulerQueue* scheduler_queue_ = nullptr;
const ValidatedGraphConfig* validated_graph_ = nullptr; const ValidatedGraphConfig* validated_graph_ = nullptr;
const NodeTypeInfo* node_type_info_ = nullptr;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -158,11 +158,11 @@ class CalculatorNodeTest : public ::testing::Test {
input_side_packets_.emplace("input_a", Adopt(new int(42))); input_side_packets_.emplace("input_a", Adopt(new int(42)));
input_side_packets_.emplace("input_b", Adopt(new int(42))); input_side_packets_.emplace("input_b", Adopt(new int(42)));
node_.reset(new CalculatorNode()); node_ = absl::make_unique<CalculatorNode>();
MP_ASSERT_OK(node_->Initialize( MP_ASSERT_OK(node_->Initialize(
&validated_graph_, 2, input_stream_managers_.get(), &validated_graph_, {NodeTypeInfo::NodeType::CALCULATOR, 2},
output_stream_managers_.get(), output_side_packets_.get(), input_stream_managers_.get(), output_stream_managers_.get(),
&buffer_size_hint_, graph_profiler_)); output_side_packets_.get(), &buffer_size_hint_, graph_profiler_));
} }
absl::Status PrepareNodeForRun() { absl::Status PrepareNodeForRun() {

View File

@ -30,6 +30,14 @@ bzl_library(
visibility = ["//mediapipe/framework:__subpackages__"], visibility = ["//mediapipe/framework:__subpackages__"],
) )
bzl_library(
name = "descriptor_set_bzl",
srcs = [
"descriptor_set.bzl",
],
visibility = ["//mediapipe/framework:__subpackages__"],
)
proto_library( proto_library(
name = "proto_descriptor_proto", name = "proto_descriptor_proto",
srcs = ["proto_descriptor.proto"], srcs = ["proto_descriptor.proto"],
@ -281,7 +289,8 @@ cc_library(
# Use this library through "mediapipe/framework/port:gtest_main". # Use this library through "mediapipe/framework/port:gtest_main".
visibility = ["//mediapipe/framework/port:__pkg__"], visibility = ["//mediapipe/framework/port:__pkg__"],
deps = [ deps = [
":status", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )

View File

@ -0,0 +1,139 @@
"""Outputs a FileDescriptorSet with all transitive dependencies.
Copied from tools/build_defs/proto/descriptor_set.bzl.
"""
TransitiveDescriptorInfo = provider(
"The transitive descriptors from a set of protos.",
fields = ["descriptors"],
)
DirectDescriptorInfo = provider(
"The direct descriptors from a set of protos.",
fields = ["descriptors"],
)
def calculate_transitive_descriptor_set(actions, deps, output):
"""Calculates the transitive dependencies of the deps.
Args:
actions: the actions (typically ctx.actions) used to run commands
deps: the deps to get the transitive dependencies of
output: the output file the data will be written to
Returns:
The same output file passed as the input arg, for convenience.
"""
# Join all proto descriptors in a single file.
transitive_descriptor_sets = depset(transitive = [
dep[ProtoInfo].transitive_descriptor_sets if ProtoInfo in dep else dep[TransitiveDescriptorInfo].descriptors
for dep in deps
])
args = actions.args()
args.use_param_file(param_file_arg = "--arg-file=%s")
args.add_all(transitive_descriptor_sets)
# Because `xargs` must take its arguments before the command to execute,
# we cannot simply put a reference to the argument list at the end, as in the
# case of param file spooling, since the entire argument list will get
# replaced by "--arg-file=bazel-out/..." which needs to be an `xargs`
# argument rather than a `cat` argument.
#
# We look to see if the first argument begins with a '--arg-file=' and
# selectively choose xargs vs. just supplying the arguments to `cat`.
actions.run_shell(
outputs = [output],
inputs = transitive_descriptor_sets,
progress_message = "Joining descriptors.",
command = ("if [[ \"$1\" =~ ^--arg-file=.* ]]; then xargs \"$1\" cat; " +
"else cat \"$@\"; fi >{output}".format(output = output.path)),
arguments = [args],
)
return output
def _transitive_descriptor_set_impl(ctx):
"""Combine descriptors for all transitive proto dependencies into one file.
Warning: Concatenating all of the descriptor files with a single `cat` command
could exceed system limits (1MB+). For example, a dependency on gwslog.proto
will trigger this edge case.
When writing new code, prefer to accept a list of descriptor files instead of
just one so that this limitation won't impact you.
"""
output = ctx.actions.declare_file(ctx.attr.name + "-transitive-descriptor-set.proto.bin")
calculate_transitive_descriptor_set(ctx.actions, ctx.attr.deps, output)
return DefaultInfo(
files = depset([output]),
runfiles = ctx.runfiles(files = [output]),
)
# transitive_descriptor_set outputs a single file containing a binary
# FileDescriptorSet with all transitive dependencies of the given proto
# dependencies.
#
# Example usage:
#
# transitive_descriptor_set(
# name = "my_descriptors",
# deps = [":my_proto"],
# )
transitive_descriptor_set = rule(
attrs = {
"deps": attr.label_list(providers = [[ProtoInfo], [TransitiveDescriptorInfo]]),
},
outputs = {
"out": "%{name}-transitive-descriptor-set.proto.bin",
},
implementation = _transitive_descriptor_set_impl,
)
def calculate_direct_descriptor_set(actions, deps, output):
"""Calculates the direct dependencies of the deps.
Args:
actions: the actions (typically ctx.actions) used to run commands
deps: the deps to get the direct dependencies of
output: the output file the data will be written to
Returns:
The same output file passed as the input arg, for convenience.
"""
descriptor_set = depset(
[dep[ProtoInfo].direct_descriptor_set for dep in deps if ProtoInfo in dep],
transitive = [dep[DirectDescriptorInfo].descriptors for dep in deps if ProtoInfo not in dep],
)
actions.run_shell(
outputs = [output],
inputs = descriptor_set,
progress_message = "Joining direct descriptors.",
command = ("cat %s > %s") % (
" ".join([d.path for d in descriptor_set.to_list()]),
output.path,
),
)
return output
def _direct_descriptor_set_impl(ctx):
calculate_direct_descriptor_set(ctx.actions, ctx.attr.deps, ctx.outputs.out)
# direct_descriptor_set outputs a single file containing a binary
# FileDescriptorSet with all direct, non transitive dependencies of
# the given proto dependencies.
#
# Example usage:
#
# direct_descriptor_set(
# name = "my_direct_descriptors",
# deps = [":my_proto"],
# )
direct_descriptor_set = rule(
attrs = {
"deps": attr.label_list(providers = [[ProtoInfo], [DirectDescriptorInfo]]),
},
outputs = {
"out": "%{name}-direct-descriptor-set.proto.bin",
},
implementation = _direct_descriptor_set_impl,
)

View File

@ -15,48 +15,74 @@
#ifndef MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_ #ifndef MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_
#define MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_ #define MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_
#include <memory>
#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
namespace mediapipe { namespace mediapipe {
namespace internal { class ProtoMatcher {
public:
using is_gtest_matcher = void;
using MessageType = proto_ns::MessageLite;
explicit ProtoMatcher(const MessageType& message)
: message_(CloneMessage(message)) {}
bool MatchAndExplain(const MessageType& m,
testing::MatchResultListener*) const {
return EqualsMessage(*message_, m);
}
bool MatchAndExplain(const MessageType* m,
testing::MatchResultListener*) const {
return EqualsMessage(*message_, *m);
}
void DescribeTo(std::ostream* os) const {
*os << "has the same serialization as " << ExpectedMessageDescription();
}
void DescribeNegationTo(std::ostream* os) const {
*os << "does not have the same serialization as "
<< ExpectedMessageDescription();
}
private:
std::unique_ptr<MessageType> CloneMessage(const MessageType& message) {
std::unique_ptr<MessageType> clone(message.New());
clone->CheckTypeAndMergeFrom(message);
return clone;
}
bool EqualsMessage(const proto_ns::MessageLite& m_1, bool EqualsMessage(const proto_ns::MessageLite& m_1,
const proto_ns::MessageLite& m_2) { const proto_ns::MessageLite& m_2) const {
std::string s_1, s_2; std::string s_1, s_2;
m_1.SerializeToString(&s_1); m_1.SerializeToString(&s_1);
m_2.SerializeToString(&s_2); m_2.SerializeToString(&s_2);
return s_1 == s_2; return s_1 == s_2;
} }
} // namespace internal
template <typename MessageType> std::string ExpectedMessageDescription() const {
class ProtoMatcher : public testing::MatcherInterface<MessageType> {
using MatchResultListener = testing::MatchResultListener;
public:
explicit ProtoMatcher(const MessageType& message) : message_(message) {}
virtual bool MatchAndExplain(MessageType m, MatchResultListener*) const {
return internal::EqualsMessage(message_, m);
}
virtual void DescribeTo(::std::ostream* os) const {
#if defined(MEDIAPIPE_PROTO_LITE) #if defined(MEDIAPIPE_PROTO_LITE)
*os << "Protobuf messages have identical serializations."; return "the expected message";
#else #else
*os << message_.DebugString(); return message_->DebugString();
#endif #endif
} }
private: const std::shared_ptr<MessageType> message_;
const MessageType message_;
}; };
template <typename MessageType> inline ProtoMatcher EqualsProto(const proto_ns::MessageLite& message) {
inline testing::PolymorphicMatcher<ProtoMatcher<MessageType>> EqualsProto( return ProtoMatcher(message);
const MessageType& message) { }
return testing::PolymorphicMatcher<ProtoMatcher<MessageType>>(
ProtoMatcher<MessageType>(message)); // for Pointwise
MATCHER(EqualsProto, "") {
const auto& a = ::testing::get<0>(arg);
const auto& b = ::testing::get<1>(arg);
return ::testing::ExplainMatchResult(EqualsProto(b), a, result_listener);
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,24 +15,102 @@
#ifndef MEDIAPIPE_DEPS_STATUS_MATCHERS_H_ #ifndef MEDIAPIPE_DEPS_STATUS_MATCHERS_H_
#define MEDIAPIPE_DEPS_STATUS_MATCHERS_H_ #define MEDIAPIPE_DEPS_STATUS_MATCHERS_H_
#include "absl/status/status.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "mediapipe/framework/deps/status.h" #include "mediapipe/framework/port/statusor.h"
namespace mediapipe { namespace mediapipe {
inline const ::absl::Status& GetStatus(const ::absl::Status& status) {
return status;
}
template <typename T>
inline const ::absl::Status& GetStatus(const ::absl::StatusOr<T>& status) {
return status.status();
}
// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
// reference to StatusOr<T>.
template <typename StatusOrType>
class IsOkAndHoldsMatcherImpl
: public ::testing::MatcherInterface<StatusOrType> {
public:
typedef
typename std::remove_reference<StatusOrType>::type::value_type value_type;
template <typename InnerMatcher>
explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
: inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
std::forward<InnerMatcher>(inner_matcher))) {}
void DescribeTo(std::ostream* os) const override {
*os << "is OK and has a value that ";
inner_matcher_.DescribeTo(os);
}
void DescribeNegationTo(std::ostream* os) const override {
*os << "isn't OK or has a value that ";
inner_matcher_.DescribeNegationTo(os);
}
bool MatchAndExplain(
StatusOrType actual_value,
::testing::MatchResultListener* result_listener) const override {
if (!actual_value.ok()) {
*result_listener << "which has status " << actual_value.status();
return false;
}
::testing::StringMatchResultListener inner_listener;
const bool matches =
inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
const std::string inner_explanation = inner_listener.str();
if (!inner_explanation.empty()) {
*result_listener << "which contains value "
<< ::testing::PrintToString(*actual_value) << ", "
<< inner_explanation;
}
return matches;
}
private:
const ::testing::Matcher<const value_type&> inner_matcher_;
};
// Implements IsOkAndHolds(m) as a polymorphic matcher.
template <typename InnerMatcher>
class IsOkAndHoldsMatcher {
public:
explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
: inner_matcher_(std::move(inner_matcher)) {}
// Converts this polymorphic matcher to a monomorphic matcher of the
// given type. StatusOrType can be either StatusOr<T> or a
// reference to StatusOr<T>.
template <typename StatusOrType>
operator ::testing::Matcher<StatusOrType>() const { // NOLINT
return ::testing::Matcher<StatusOrType>(
new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
}
private:
const InnerMatcher inner_matcher_;
};
// Monomorphic implementation of matcher IsOk() for a given type T. // Monomorphic implementation of matcher IsOk() for a given type T.
// T can be Status, StatusOr<>, or a reference to either of them. // T can be Status, StatusOr<>, or a reference to either of them.
template <typename T> template <typename T>
class MonoIsOkMatcherImpl : public testing::MatcherInterface<T> { class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
public: public:
void DescribeTo(std::ostream* os) const override { *os << "is OK"; } void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
void DescribeNegationTo(std::ostream* os) const override { void DescribeNegationTo(std::ostream* os) const override {
*os << "is not OK"; *os << "is not OK";
} }
bool MatchAndExplain(T actual_value, bool MatchAndExplain(T actual_value,
testing::MatchResultListener*) const override { ::testing::MatchResultListener*) const override {
return actual_value.ok(); return GetStatus(actual_value).ok();
} }
}; };
@ -40,11 +118,20 @@ class MonoIsOkMatcherImpl : public testing::MatcherInterface<T> {
class IsOkMatcher { class IsOkMatcher {
public: public:
template <typename T> template <typename T>
operator testing::Matcher<T>() const { // NOLINT operator ::testing::Matcher<T>() const { // NOLINT
return testing::Matcher<T>(new MonoIsOkMatcherImpl<T>()); return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>());
} }
}; };
// Returns a gMock matcher that matches a StatusOr<> whose status is
// OK and whose value matches the inner matcher.
template <typename InnerMatcher>
IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds(
InnerMatcher&& inner_matcher) {
return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>(
std::forward<InnerMatcher>(inner_matcher));
}
// Returns a gMock matcher that matches a Status or StatusOr<> which is OK. // Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
inline IsOkMatcher IsOk() { return IsOkMatcher(); } inline IsOkMatcher IsOk() { return IsOkMatcher(); }

View File

@ -54,6 +54,7 @@ mediapipe_register_type(
types = [ types = [
"::mediapipe::Classification", "::mediapipe::Classification",
"::mediapipe::ClassificationList", "::mediapipe::ClassificationList",
"::mediapipe::ClassificationListCollection",
"::std::vector<::mediapipe::Classification>", "::std::vector<::mediapipe::Classification>",
"::std::vector<::mediapipe::ClassificationList>", "::std::vector<::mediapipe::ClassificationList>",
], ],
@ -262,8 +263,10 @@ mediapipe_register_type(
types = [ types = [
"::mediapipe::Landmark", "::mediapipe::Landmark",
"::mediapipe::LandmarkList", "::mediapipe::LandmarkList",
"::mediapipe::LandmarkListCollection",
"::mediapipe::NormalizedLandmark", "::mediapipe::NormalizedLandmark",
"::mediapipe::NormalizedLandmarkList", "::mediapipe::NormalizedLandmarkList",
"::mediapipe::NormalizedLandmarkListCollection",
"::std::vector<::mediapipe::Landmark>", "::std::vector<::mediapipe::Landmark>",
"::std::vector<::mediapipe::LandmarkList>", "::std::vector<::mediapipe::LandmarkList>",
"::std::vector<::mediapipe::NormalizedLandmark>", "::std::vector<::mediapipe::NormalizedLandmark>",

View File

@ -39,3 +39,8 @@ message Classification {
message ClassificationList { message ClassificationList {
repeated Classification classification = 1; repeated Classification classification = 1;
} }
// Group of ClassificationList protos.
message ClassificationListCollection {
repeated ClassificationList classification_list = 1;
}

View File

@ -47,6 +47,11 @@ message LandmarkList {
repeated Landmark landmark = 1; repeated Landmark landmark = 1;
} }
// Group of LandmarkList protos.
message LandmarkListCollection {
repeated LandmarkList landmark_list = 1;
}
// A normalized version of above Landmark proto. All coordinates should be // A normalized version of above Landmark proto. All coordinates should be
// within [0, 1]. // within [0, 1].
message NormalizedLandmark { message NormalizedLandmark {
@ -61,3 +66,8 @@ message NormalizedLandmark {
message NormalizedLandmarkList { message NormalizedLandmarkList {
repeated NormalizedLandmark landmark = 1; repeated NormalizedLandmark landmark = 1;
} }
// Group of NormalizedLandmarkList protos.
message NormalizedLandmarkListCollection {
repeated NormalizedLandmarkList landmark_list = 1;
}

View File

@ -127,7 +127,7 @@ const proto_ns::MessageLite& Packet::GetProtoMessageLite() const {
} }
StatusOr<std::vector<const proto_ns::MessageLite*>> StatusOr<std::vector<const proto_ns::MessageLite*>>
Packet::GetVectorOfProtoMessageLitePtrs() { Packet::GetVectorOfProtoMessageLitePtrs() const {
if (holder_ == nullptr) { if (holder_ == nullptr) {
return absl::InternalError("Packet is empty."); return absl::InternalError("Packet is empty.");
} }

View File

@ -175,7 +175,7 @@ class Packet {
// Note: This function is meant to be used internally within the MediaPipe // Note: This function is meant to be used internally within the MediaPipe
// framework only. // framework only.
StatusOr<std::vector<const proto_ns::MessageLite*>> StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLitePtrs(); GetVectorOfProtoMessageLitePtrs() const;
// Returns an error if the packet does not contain data of type T. // Returns an error if the packet does not contain data of type T.
template <typename T> template <typename T>
@ -391,7 +391,7 @@ class HolderBase {
// underlying object is a vector of protocol buffer objects, otherwise, // underlying object is a vector of protocol buffer objects, otherwise,
// returns an error. // returns an error.
virtual StatusOr<std::vector<const proto_ns::MessageLite*>> virtual StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLite() = 0; GetVectorOfProtoMessageLite() const = 0;
private: private:
size_t type_id_; size_t type_id_;
@ -563,7 +563,7 @@ class Holder : public HolderBase {
// underlying object is a vector of protocol buffer objects, otherwise, // underlying object is a vector of protocol buffer objects, otherwise,
// returns an error. // returns an error.
StatusOr<std::vector<const proto_ns::MessageLite*>> StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLite() override { GetVectorOfProtoMessageLite() const override {
return ConvertToVectorOfProtoMessageLitePtrs(ptr_, is_proto_vector<T>()); return ConvertToVectorOfProtoMessageLitePtrs(ptr_, is_proto_vector<T>());
} }

View File

@ -370,7 +370,8 @@ absl::Status PacketGeneratorGraph::Initialize(
absl::Status PacketGeneratorGraph::RunGraphSetup( absl::Status PacketGeneratorGraph::RunGraphSetup(
const std::map<std::string, Packet>& input_side_packets, const std::map<std::string, Packet>& input_side_packets,
std::map<std::string, Packet>* output_side_packets) const { std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators) const {
*output_side_packets = base_packets_; *output_side_packets = base_packets_;
for (const std::pair<const std::string, Packet>& item : input_side_packets) { for (const std::pair<const std::string, Packet>& item : input_side_packets) {
auto iter = output_side_packets->find(item.first); auto iter = output_side_packets->find(item.first);
@ -380,7 +381,9 @@ absl::Status PacketGeneratorGraph::RunGraphSetup(
} }
output_side_packets->insert(iter, item); output_side_packets->insert(iter, item);
} }
std::vector<int> non_scheduled_generators; std::vector<int> non_scheduled_generators_local;
if (!non_scheduled_generators)
non_scheduled_generators = &non_scheduled_generators_local;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
validated_graph_->CanAcceptSidePackets(input_side_packets)); validated_graph_->CanAcceptSidePackets(input_side_packets));
@ -389,11 +392,7 @@ absl::Status PacketGeneratorGraph::RunGraphSetup(
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
validated_graph_->ValidateRequiredSidePackets(*output_side_packets)); validated_graph_->ValidateRequiredSidePackets(*output_side_packets));
MP_RETURN_IF_ERROR(ExecuteGenerators( MP_RETURN_IF_ERROR(ExecuteGenerators(
output_side_packets, &non_scheduled_generators, /*initial=*/false)); output_side_packets, non_scheduled_generators, /*initial=*/false));
RET_CHECK(non_scheduled_generators.empty())
<< "Some Generators were unrunnable (validation should have failed).\n"
"Generator indexes: "
<< absl::StrJoin(non_scheduled_generators, ", ");
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -76,7 +76,8 @@ class PacketGeneratorGraph {
// must now be runnable) to produce output_side_packets. // must now be runnable) to produce output_side_packets.
virtual absl::Status RunGraphSetup( virtual absl::Status RunGraphSetup(
const std::map<std::string, Packet>& input_side_packets, const std::map<std::string, Packet>& input_side_packets,
std::map<std::string, Packet>* output_side_packets) const; std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators = nullptr) const;
// Get the base packets: the packets which are produced when Initialize // Get the base packets: the packets which are produced when Initialize
// is called. // is called.

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/packet_test.pb.h" #include "mediapipe/framework/packet_test.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -214,6 +215,21 @@ TEST(PacketTest, ValidateAsProtoMessageLite) {
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
} }
TEST(PacketTest, GetVectorOfProtos) {
std::vector<mediapipe::PacketTestProto> protos(2);
protos[0].add_x(123);
protos[1].add_x(456);
// Normally we'd move here, but we copy to use the protos for comparison.
const Packet packet =
MakePacket<std::vector<mediapipe::PacketTestProto>>(protos);
auto maybe_proto_ptrs = packet.GetVectorOfProtoMessageLitePtrs();
EXPECT_THAT(maybe_proto_ptrs,
IsOkAndHolds(testing::Pointwise(EqualsProto(), protos)));
const Packet wrong = MakePacket<int>(1);
EXPECT_THAT(wrong.GetVectorOfProtoMessageLitePtrs(), testing::Not(IsOk()));
}
TEST(PacketTest, SyncedPacket) { TEST(PacketTest, SyncedPacket) {
Packet synced_packet = AdoptAsSyncedPacket(new int(100)); Packet synced_packet = AdoptAsSyncedPacket(new int(100));
Packet value_packet = Packet value_packet =

View File

@ -4,6 +4,7 @@
""".bzl file for mediapipe open source build configs.""" """.bzl file for mediapipe open source build configs."""
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
def provided_args(**kwargs): def provided_args(**kwargs):
"""Returns the keyword arguments omitting None arguments.""" """Returns the keyword arguments omitting None arguments."""
@ -47,6 +48,7 @@ def mediapipe_proto_library(
def_objc_proto = True, def_objc_proto = True,
def_java_proto = True, def_java_proto = True,
def_jspb_proto = True, def_jspb_proto = True,
def_options_lib = True,
portable_deps = None): portable_deps = None):
"""Defines the proto_library targets needed for all mediapipe platforms. """Defines the proto_library targets needed for all mediapipe platforms.
@ -67,6 +69,7 @@ def mediapipe_proto_library(
def_objc_proto: define the objc_proto_library target def_objc_proto: define the objc_proto_library target
def_java_proto: define the java_proto_library target def_java_proto: define the java_proto_library target
def_jspb_proto: define the jspb_proto_library target def_jspb_proto: define the jspb_proto_library target
def_options_lib: define the mediapipe_options_library target
""" """
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps] _ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps]
@ -116,6 +119,17 @@ def mediapipe_proto_library(
compatible_with = compatible_with, compatible_with = compatible_with,
)) ))
if def_options_lib:
cc_deps = replace_deps(deps, "_proto", "_cc_proto")
mediapipe_options_library(**provided_args(
name = replace_suffix(name, "_proto", "_options_lib"),
proto_lib = name,
deps = cc_deps,
visibility = visibility,
testonly = testonly,
compatible_with = compatible_with,
))
def mediapipe_py_proto_library( def mediapipe_py_proto_library(
name, name,
srcs, srcs,

View File

@ -113,6 +113,7 @@ cc_library(
"//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/port:advanced_proto_lite",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:re2",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:name_util",

View File

@ -24,6 +24,7 @@
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/re2.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/profiler/profiler_resource_util.h" #include "mediapipe/framework/profiler/profiler_resource_util.h"
@ -71,6 +72,8 @@ bool IsTracerEnabled(const ProfilerConfig& profiler_config) {
} }
// Returns true if trace events are written to a log file. // Returns true if trace events are written to a log file.
// Note that for now, file output is only for graph-trace and not for
// calculator-profile.
bool IsTraceLogEnabled(const ProfilerConfig& profiler_config) { bool IsTraceLogEnabled(const ProfilerConfig& profiler_config) {
return IsTracerEnabled(profiler_config) && return IsTracerEnabled(profiler_config) &&
!profiler_config.trace_log_disabled(); !profiler_config.trace_log_disabled();
@ -117,6 +120,39 @@ PacketInfo* GetPacketInfo(PacketInfoMap* map, const PacketId& packet_id) {
} // namespace } // namespace
// Builds GraphProfile records from profiler timing data.
class GraphProfiler::GraphProfileBuilder {
public:
GraphProfileBuilder(GraphProfiler* profiler)
: profiler_(profiler), calculator_regex_(".*") {
auto& filter = profiler_->profiler_config().calculator_filter();
calculator_regex_ = filter.empty() ? calculator_regex_ : RE2(filter);
}
bool ProfileIncluded(const CalculatorProfile& p) {
return RE2::FullMatch(p.name(), calculator_regex_);
}
private:
GraphProfiler* profiler_;
RE2 calculator_regex_;
};
GraphProfiler::GraphProfiler()
: is_initialized_(false),
is_profiling_(false),
calculator_profiles_(1000),
packets_info_(1000),
is_running_(false),
previous_log_end_time_(absl::InfinitePast()),
previous_log_index_(-1),
validated_graph_(nullptr) {
clock_ = std::shared_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
}
GraphProfiler::~GraphProfiler() {}
void GraphProfiler::Initialize( void GraphProfiler::Initialize(
const ValidatedGraphConfig& validated_graph_config) { const ValidatedGraphConfig& validated_graph_config) {
absl::WriterMutexLock lock(&profiler_mutex_); absl::WriterMutexLock lock(&profiler_mutex_);
@ -156,6 +192,7 @@ void GraphProfiler::Initialize(
CHECK(iter.second) << absl::Substitute( CHECK(iter.second) << absl::Substitute(
"Calculator \"$0\" has already been added.", node_name); "Calculator \"$0\" has already been added.", node_name);
} }
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
is_initialized_ = true; is_initialized_ = true;
} }
@ -554,17 +591,45 @@ class OstreamStream : public proto_ns::io::ZeroCopyOutputStream {
}; };
// Sets the canonical node name in each CalculatorGraphConfig::Node // Sets the canonical node name in each CalculatorGraphConfig::Node
// and also in GraphTrace. // and also in the GraphTrace if present.
void AssignNodeNames(GraphProfile* profile) { void AssignNodeNames(GraphProfile* profile) {
CalculatorGraphConfig* graph_config = profile->mutable_config(); CalculatorGraphConfig* graph_config = profile->mutable_config();
GraphTrace* graph_trace = profile->mutable_graph_trace(0); GraphTrace* graph_trace = profile->graph_trace_size() > 0
? profile->mutable_graph_trace(0)
: nullptr;
if (graph_trace) {
graph_trace->clear_calculator_name(); graph_trace->clear_calculator_name();
}
for (int i = 0; i < graph_config->node().size(); ++i) { for (int i = 0; i < graph_config->node().size(); ++i) {
std::string node_name = CanonicalNodeName(*graph_config, i); std::string node_name = CanonicalNodeName(*graph_config, i);
graph_config->mutable_node(i)->set_name(node_name); graph_config->mutable_node(i)->set_name(node_name);
if (graph_trace) {
graph_trace->add_calculator_name(node_name); graph_trace->add_calculator_name(node_name);
} }
} }
}
// Clears fields containing their default values.
void CleanTimeHistogram(TimeHistogram* histogram) {
if (histogram->num_intervals() == 1) {
histogram->clear_num_intervals();
}
if (histogram->interval_size_usec() == 1000000) {
histogram->clear_interval_size_usec();
}
}
// Clears fields containing their default values.
void CleanCalculatorProfiles(GraphProfile* profile) {
for (CalculatorProfile& p : *profile->mutable_calculator_profiles()) {
CleanTimeHistogram(p.mutable_process_runtime());
CleanTimeHistogram(p.mutable_process_input_latency());
CleanTimeHistogram(p.mutable_process_output_latency());
for (StreamProfile& s : *p.mutable_input_stream_profiles()) {
CleanTimeHistogram(s.mutable_latency());
}
}
}
absl::StatusOr<std::string> GraphProfiler::GetTraceLogPath() { absl::StatusOr<std::string> GraphProfiler::GetTraceLogPath() {
if (!IsTraceLogEnabled(profiler_config_)) { if (!IsTraceLogEnabled(profiler_config_)) {
@ -588,12 +653,14 @@ absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) {
absl::Time end_time = absl::Time end_time =
clock_->TimeNow() - clock_->TimeNow() -
absl::Microseconds(profiler_config_.trace_log_margin_usec()); absl::Microseconds(profiler_config_.trace_log_margin_usec());
if (tracer()) {
GraphTrace* trace = result->add_graph_trace(); GraphTrace* trace = result->add_graph_trace();
if (!profiler_config_.trace_log_instant_events()) { if (!profiler_config_.trace_log_instant_events()) {
tracer()->GetTrace(previous_log_end_time_, end_time, trace); tracer()->GetTrace(previous_log_end_time_, end_time, trace);
} else { } else {
tracer()->GetLog(previous_log_end_time_, end_time, trace); tracer()->GetLog(previous_log_end_time_, end_time, trace);
} }
}
previous_log_end_time_ = end_time; previous_log_end_time_ = end_time;
// Record the latest CalculatorProfiles. // Record the latest CalculatorProfiles.
@ -601,9 +668,12 @@ absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) {
std::vector<CalculatorProfile> profiles; std::vector<CalculatorProfile> profiles;
status.Update(GetCalculatorProfiles(&profiles)); status.Update(GetCalculatorProfiles(&profiles));
for (CalculatorProfile& p : profiles) { for (CalculatorProfile& p : profiles) {
if (profile_builder_->ProfileIncluded(p)) {
*result->mutable_calculator_profiles()->Add() = std::move(p); *result->mutable_calculator_profiles()->Add() = std::move(p);
} }
}
this->Reset(); this->Reset();
CleanCalculatorProfiles(result);
return status; return status;
} }

View File

@ -97,18 +97,8 @@ class GraphProfilerTestPeer;
// The client can overwrite this by calling SetClock(). // The client can overwrite this by calling SetClock().
class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> { class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
public: public:
GraphProfiler() GraphProfiler();
: is_initialized_(false), ~GraphProfiler();
is_profiling_(false),
calculator_profiles_(1000),
packets_info_(1000),
is_running_(false),
previous_log_end_time_(absl::InfinitePast()),
previous_log_index_(-1),
validated_graph_(nullptr) {
clock_ = std::shared_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
}
// Not copyable or movable. // Not copyable or movable.
GraphProfiler(const GraphProfiler&) = delete; GraphProfiler(const GraphProfiler&) = delete;
@ -230,6 +220,8 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
int64 start_time_usec_; int64 start_time_usec_;
}; };
const ProfilerConfig& profiler_config() { return profiler_config_; }
private: private:
// This can be used to add packet info for the input streams to the graph. // This can be used to add packet info for the input streams to the graph.
// It treats the stream defined by |stream_name| as a stream produced by a // It treats the stream defined by |stream_name| as a stream produced by a
@ -303,6 +295,7 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
// Helper method to get the clock time in microsecond. // Helper method to get the clock time in microsecond.
int64 TimeNowUsec() { return ToUnixMicros(clock_->TimeNow()); } int64 TimeNowUsec() { return ToUnixMicros(clock_->TimeNow()); }
private:
// The settings for this tracer. // The settings for this tracer.
ProfilerConfig profiler_config_; ProfilerConfig profiler_config_;
@ -345,6 +338,10 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
// The configuration for the graph being profiled. // The configuration for the graph being profiled.
const ValidatedGraphConfig* validated_graph_; const ValidatedGraphConfig* validated_graph_;
// A private resource for creating GraphProfiles.
class GraphProfileBuilder;
std::unique_ptr<GraphProfileBuilder> profile_builder_;
// For testing. // For testing.
friend GraphProfilerTestPeer; friend GraphProfilerTestPeer;
}; };

View File

@ -1205,5 +1205,68 @@ TEST(GraphProfilerTest, ParallelReads) {
EXPECT_EQ(1001, out_1_packets.size()); EXPECT_EQ(1001, out_1_packets.size());
} }
// Returns the set of calculator names in a GraphProfile captured from
// CalculatorGraph initialized from a certain CalculatorGraphConfig.
std::set<std::string> GetCalculatorNames(const CalculatorGraphConfig& config) {
std::set<std::string> result;
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config));
GraphProfile profile;
MP_EXPECT_OK(graph.profiler()->CaptureProfile(&profile));
for (auto& p : profile.calculator_profiles()) {
result.insert(p.name());
}
return result;
}
TEST(GraphProfilerTest, CalculatorProfileFilter) {
CalculatorGraphConfig config;
QCHECK(proto2::TextFormat::ParseFromString(R"(
profiler_config {
enable_profiler: true
}
node {
calculator: "RangeCalculator"
input_side_packet: "range_step"
output_stream: "out"
output_stream: "sum"
output_stream: "mean"
}
node {
calculator: "PassThroughCalculator"
input_stream: "out"
input_stream: "sum"
input_stream: "mean"
output_stream: "out_1"
output_stream: "sum_1"
output_stream: "mean_1"
}
output_stream: "OUT:0:the_integers"
)",
&config));
std::set<std::string> expected_names;
expected_names = {"RangeCalculator", "PassThroughCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() =
"RangeCalculator";
expected_names = {"RangeCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() = "Range.*";
expected_names = {"RangeCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() =
".*Calculator";
expected_names = {"RangeCalculator", "PassThroughCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() = ".*Clock.*";
expected_names = {};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,63 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"]) licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") mediapipe_proto_library(
proto_library(
name = "sky_light_calculator_proto", name = "sky_light_calculator_proto",
srcs = ["sky_light_calculator.proto"], srcs = ["sky_light_calculator.proto"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "sky_light_calculator_cc_proto",
srcs = ["sky_light_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":sky_light_calculator_proto"],
) )
proto_library( mediapipe_proto_library(
name = "night_light_calculator_proto", name = "night_light_calculator_proto",
srcs = ["night_light_calculator.proto"], srcs = ["night_light_calculator.proto"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "night_light_calculator_cc_proto",
srcs = ["night_light_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":night_light_calculator_proto"], deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
) )
proto_library( mediapipe_proto_library(
name = "zoo_mutator_proto", name = "zoo_mutator_proto",
srcs = ["zoo_mutator.proto"], srcs = ["zoo_mutator.proto"],
deps = ["@com_google_protobuf//:any_proto"], deps = ["@com_google_protobuf//:any_proto"],
) )
mediapipe_cc_proto_library( mediapipe_proto_library(
name = "zoo_mutator_cc_proto",
srcs = ["zoo_mutator.proto"],
cc_deps = ["@com_google_protobuf//:cc_wkt_protos"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":zoo_mutator_proto"],
)
proto_library(
name = "zoo_mutation_calculator_proto", name = "zoo_mutation_calculator_proto",
srcs = ["zoo_mutation_calculator.proto"], srcs = ["zoo_mutation_calculator.proto"],
features = ["-proto_dynamic_mode_static_link"],
visibility = ["//mediapipe:__subpackages__"], visibility = ["//mediapipe:__subpackages__"],
deps = [ deps = [
":zoo_mutator_proto", ":zoo_mutator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:packet_factory_proto",
"//mediapipe/framework:packet_generator_proto", "//mediapipe/framework:packet_generator_proto",
], ],
) )

Some files were not shown because too many files have changed in this diff Show More