Project import generated by Copybara.

GitOrigin-RevId: 27c70b5fe62ab71189d358ca122ee4b19c817a8f
This commit is contained in:
MediaPipe Team 2021-07-23 17:09:32 -07:00 committed by chuoling
parent 374f5e2e7e
commit 50c92c6623
158 changed files with 4704 additions and 621 deletions

View File

@ -45,7 +45,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
@ -79,6 +79,13 @@ run code search using
## Publications
* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html)
in Google Developers Blog
* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html)
in Google Developers Blog
* [SignAll SDK: Sign language interface using MediaPipe is now available for
developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html)
in Google Developers Blog
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)

View File

@ -53,19 +53,12 @@ rules_foreign_cc_dependencies()
all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])"""
# GoogleTest/GoogleMock framework. Used by most unit-tests.
# Last updated 2020-06-30.
# Last updated 2021-07-02.
http_archive(
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"],
patches = [
# fix for https://github.com/google/googletest/issues/2817
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
],
patch_args = [
"-p1",
],
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"],
strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3",
sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049",
)
# Google Benchmark library.
@ -353,9 +346,9 @@ maven_install(
"com.google.android.material:material:aar:1.0.0-rc01",
"com.google.auto.value:auto-value:1.8.1",
"com.google.auto.value:auto-value-annotations:1.8.1",
"com.google.code.findbugs:jsr305:3.0.2",
"com.google.flogger:flogger-system-backend:0.3.1",
"com.google.flogger:flogger:0.3.1",
"com.google.code.findbugs:jsr305:latest.release",
"com.google.flogger:flogger-system-backend:latest.release",
"com.google.flogger:flogger:latest.release",
"com.google.guava:guava:27.0.1-android",
"com.google.guava:listenablefuture:1.0",
"junit:junit:4.12",

View File

@ -113,9 +113,9 @@ each project.
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.3.1'
implementation 'com.google.flogger:flogger-system-backend:0.3.1'
implementation 'com.google.code.findbugs:jsr305:3.0.2'
implementation 'com.google.flogger:flogger:latest.release'
implementation 'com.google.flogger:flogger-system-backend:latest.release'
implementation 'com.google.code.findbugs:jsr305:latest.release'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library

View File

@ -31,8 +31,8 @@ stream on an Android device.
## Setup
1. Install MediaPipe on your system, see [MediaPipe installation guide] for
details.
1. Install MediaPipe on your system, see
[MediaPipe installation guide](./install.md) for details.
2. Install Android Development SDK and Android NDK. See how to do so also in
[MediaPipe installation guide].
3. Enable [developer options] on your Android device.
@ -770,7 +770,6 @@ If you ran into any issues, please see the full code of the tutorial
[`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java
[`FrameLayout`]:https://developer.android.com/reference/android/widget/FrameLayout
[`FrameProcessor`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java
[MediaPipe installation guide]:./install.md
[`PermissionHelper`]: https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java
[`SurfaceHolder.Callback`]:https://developer.android.com/reference/android/view/SurfaceHolder.Callback.html
[`SurfaceView`]:https://developer.android.com/reference/android/view/SurfaceView

View File

@ -31,8 +31,8 @@ stream on an iOS device.
## Setup
1. Install MediaPipe on your system, see [MediaPipe installation guide] for
details.
1. Install MediaPipe on your system, see
[MediaPipe installation guide](./install.md) for details.
2. Setup your iOS device for development.
3. Setup [Bazel] on your system to build and deploy the iOS app.
@ -560,6 +560,5 @@ appropriate `BUILD` file dependencies for the edge detection graph.
[Bazel]:https://bazel.build/
[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt
[MediaPipe installation guide]:./install.md
[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common)
[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld)

View File

@ -22,6 +22,7 @@ Solution | NPM Package | Example
[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo]
[Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo]
[Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo]
[Objectron][Ob-pg] | [@mediapipe/objectron][Ob-npm] | [mediapipe.dev/demo/objectron][Ob-demo]
[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo]
[Selfie Segmentation][S-pg] | [@mediapipe/selfie_segmentation][S-npm] | [mediapipe.dev/demo/selfie_segmentation][S-demo]
@ -67,33 +68,24 @@ affecting your work, restrict your request to a `<minor>` number. e.g.,
[F-pg]: ../solutions/face_mesh#javascript-solution-api
[Fd-pg]: ../solutions/face_detection#javascript-solution-api
[H-pg]: ../solutions/hands#javascript-solution-api
[Ob-pg]: ../solutions/objectron#javascript-solution-api
[P-pg]: ../solutions/pose#javascript-solution-api
[S-pg]: ../solutions/selfie_segmentation#javascript-solution-api
[Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic
[F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh
[Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection
[H-npm]: https://www.npmjs.com/package/@mediapipe/hands
[Ob-npm]: https://www.npmjs.com/package/@mediapipe/objectron
[P-npm]: https://www.npmjs.com/package/@mediapipe/pose
[S-npm]: https://www.npmjs.com/package/@mediapipe/selfie_segmentation
[draw-npm]: https://www.npmjs.com/package/@mediapipe/drawing_utils
[cam-npm]: https://www.npmjs.com/package/@mediapipe/camera_utils
[ctrl-npm]: https://www.npmjs.com/package/@mediapipe/control_utils
[Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic
[F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh
[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection
[H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands
[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose
[P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/selfie_segmentation
[Ho-pen]: https://code.mediapipe.dev/codepen/holistic
[F-pen]: https://code.mediapipe.dev/codepen/face_mesh
[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection
[H-pen]: https://code.mediapipe.dev/codepen/hands
[P-pen]: https://code.mediapipe.dev/codepen/pose
[S-pen]: https://code.mediapipe.dev/codepen/selfie_segmentation
[Ho-demo]: https://mediapipe.dev/demo/holistic
[F-demo]: https://mediapipe.dev/demo/face_mesh
[Fd-demo]: https://mediapipe.dev/demo/face_detection
[H-demo]: https://mediapipe.dev/demo/hands
[Ob-demo]: https://mediapipe.dev/demo/objectron
[P-demo]: https://mediapipe.dev/demo/pose
[S-demo]: https://mediapipe.dev/demo/selfie_segmentation
[npm]: https://www.npmjs.com/package/@mediapipe

View File

@ -74,7 +74,7 @@ Mapping\[str, Packet\] | std::map<std::string, Packet> | create_st
np.ndarray<br>(cv.mat and PIL.Image) | mp::ImageFrame | create_image_frame(<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;format=ImageFormat.SRGB,<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;data=mat) | get_image_frame(packet)
np.ndarray | mp::Matrix | create_matrix(data) | get_matrix(packet)
Google Proto Message | Google Proto Message | create_proto(proto) | get_proto(packet)
List\[Proto\] | std::vector\<Proto\> | create_proto_vector(proto_list) | get_proto_list(packet)
List\[Proto\] | std::vector\<Proto\> | n/a | get_proto_list(packet)
It's not uncommon that users create custom C++ classes and and send those into
the graphs and calculators. To allow the custom classes to be used in Python

View File

@ -45,7 +45,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
@ -79,6 +79,13 @@ run code search using
## Publications
* [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html)
in Google Developers Blog
* [Prosthesis control via Mirru App using MediaPipe hand tracking](https://developers.googleblog.com/2021/05/control-your-mirru-prosthesis-with-mediapipe-hand-tracking.html)
in Google Developers Blog
* [SignAll SDK: Sign language interface using MediaPipe is now available for
developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html)
in Google Developers Blog
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)

View File

@ -224,29 +224,33 @@ where object detection simply runs on every image. Default to `0.99`.
#### model_name
Name of the model to use for predicting 3D bounding box landmarks. Currently supports
`{'Shoe', 'Chair', 'Cup', 'Camera'}`.
Name of the model to use for predicting 3D bounding box landmarks. Currently
supports `{'Shoe', 'Chair', 'Cup', 'Camera'}`. Default to `Shoe`.
#### focal_length
Camera focal length `(fx, fy)`, by default is defined in
[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in
[pixel space](#pixel-space), users should provide `image_size` = `(image_width,
image_height)` to enable conversions inside the API. For further details about
NDC and pixel space, please see [Coordinate Systems](#coordinate-systems).
By default, camera focal length defined in [NDC space](#ndc-space), i.e., `(fx,
fy)`. Default to `(1.0, 1.0)`. To specify focal length in
[pixel space](#pixel-space) instead, i.e., `(fx_pixel, fy_pixel)`, users should
provide [`image_size`](#image_size) = `(image_width, image_height)` to enable
conversions inside the API. For further details about NDC and pixel space,
please see [Coordinate Systems](#coordinate-systems).
#### principal_point
Camera principal point `(px, py)`, by default is defined in
[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in
[pixel space](#pixel-space), users should provide `image_size` = `(image_width,
image_height)` to enable conversions inside the API. For further details about
NDC and pixel space, please see [Coordinate Systems](#coordinate-systems).
By default, camera principal point defined in [NDC space](#ndc-space), i.e.,
`(px, py)`. Default to `(0.0, 0.0)`. To specify principal point in
[pixel space](#pixel-space), i.e.,`(px_pixel, py_pixel)`, users should provide
[`image_size`](#image_size) = `(image_width, image_height)` to enable
conversions inside the API. For further details about NDC and pixel space,
please see [Coordinate Systems](#coordinate-systems).
#### image_size
(**Optional**) size `(image_width, image_height)` of the input image, **ONLY**
needed when use `focal_length` and `principal_point` in pixel space.
**Specify only when [`focal_length`](#focal_length) and
[`principal_point`](#principal_point) are specified in pixel space.**
Size of the input image, i.e., `(image_width, image_height)`.
### Output
@ -356,6 +360,89 @@ with mp_objectron.Objectron(static_image_mode=False,
cap.release()
```
## JavaScript Solution API
Please first see general [introduction](../getting_started/javascript.md) on
MediaPipe in JavaScript, then learn more in the companion [web demo](#resources)
and the following usage example.
Supported configuration options:
* [staticImageMode](#static_image_mode)
* [maxNumObjects](#max_num_objects)
* [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence)
* [modelName](#model_name)
* [focalLength](#focal_length)
* [principalPoint](#principal_point)
* [imageSize](#image_size)
```html
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/camera_utils/camera_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/control_utils/control_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/drawing_utils/control_utils_3d.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/drawing_utils/drawing_utils.js" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/objectron/objectron.js" crossorigin="anonymous"></script>
</head>
<body>
<div class="container">
<video class="input_video"></video>
<canvas class="output_canvas" width="1280px" height="720px"></canvas>
</div>
</body>
</html>
```
```javascript
<script type="module">
const videoElement = document.getElementsByClassName('input_video')[0];
const canvasElement = document.getElementsByClassName('output_canvas')[0];
const canvasCtx = canvasElement.getContext('2d');
function onResults(results) {
canvasCtx.save();
canvasCtx.drawImage(
results.image, 0, 0, canvasElement.width, canvasElement.height);
if (!!results.objectDetections) {
for (const detectedObject of results.objectDetections) {
// Reformat keypoint information as landmarks, for easy drawing.
const landmarks: mpObjectron.Point2D[] =
detectedObject.keypoints.map(x => x.point2d);
// Draw bounding box.
drawingUtils.drawConnectors(canvasCtx, landmarks,
mpObjectron.BOX_CONNECTIONS, {color: '#FF0000'});
// Draw centroid.
drawingUtils.drawLandmarks(canvasCtx, [landmarks[0]], {color: '#FFFFFF'});
}
}
canvasCtx.restore();
}
const objectron = new Objectron({locateFile: (file) => {
return `https://cdn.jsdelivr.net/npm/@mediapipe/objectron/${file}`;
}});
objectron.setOptions({
modelName: 'Chair',
maxNumObjects: 3,
});
objectron.onResults(onResults);
const camera = new Camera(videoElement, {
onFrame: async () => {
await objectron.send({image: videoElement});
},
width: 1280,
height: 720
});
camera.start();
</script>
```
## Example Apps
Please first see general instructions for
@ -561,11 +648,15 @@ py = -py_pixel * 2.0 / image_height + 1.0
[Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html)
* Google AI Blog:
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021
* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the
Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in
CVPR 2021
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
Shape Supervision](https://arxiv.org/abs/2003.03522)
* Paper:
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth
Workshop on Computer Vision for AR/VR, CVPR 2020
* [Models and model cards](./models.md#objectron)
* [Web demo](https://code.mediapipe.dev/codepen/objectron)
* [Python Colab](https://mediapipe.page.link/objectron_py_colab)

View File

@ -96,6 +96,7 @@ Supported configuration options:
```python
import cv2
import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation

View File

@ -29,7 +29,7 @@ has_toc: false
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |

View File

@ -140,6 +140,16 @@ mediapipe_proto_library(
],
)
mediapipe_proto_library(
name = "graph_profile_calculator_proto",
srcs = ["graph_profile_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "add_header_calculator",
srcs = ["add_header_calculator.cc"],
@ -1200,3 +1210,45 @@ cc_test(
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "graph_profile_calculator",
srcs = ["graph_profile_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":graph_profile_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "graph_profile_calculator_test",
srcs = ["graph_profile_calculator_test.cc"],
deps = [
":graph_profile_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:test_calculators",
"//mediapipe/framework/deps:clock",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/tool:simulation_clock_executor",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,70 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/core/graph_profile_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// This calculator periodically copies the GraphProfile from
// mediapipe::GraphProfiler::CaptureProfile to the "PROFILE" output stream.
//
// Example config:
// node {
// calculator: "GraphProfileCalculator"
// output_stream: "FRAME:any_frame"
// output_stream: "PROFILE:graph_profile"
// }
//
class GraphProfileCalculator : public Node {
public:
static constexpr Input<AnyType>::Multiple kFrameIn{"FRAME"};
static constexpr Output<GraphProfile> kProfileOut{"PROFILE"};
MEDIAPIPE_NODE_CONTRACT(kFrameIn, kProfileOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
auto options = cc->Options<::mediapipe::GraphProfileCalculatorOptions>();
if (prev_profile_ts_ == Timestamp::Unset() ||
cc->InputTimestamp() - prev_profile_ts_ >= options.profile_interval()) {
prev_profile_ts_ = cc->InputTimestamp();
GraphProfile result;
MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile(&result));
kProfileOut(cc).Send(result);
}
return absl::OkStatus();
}
private:
Timestamp prev_profile_ts_;
};
MEDIAPIPE_REGISTER_NODE(GraphProfileCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message GraphProfileCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional GraphProfileCalculatorOptions ext = 367481815;
}
// The interval in microseconds between successive reported GraphProfiles.
optional int64 profile_interval = 1 [default = 1000000];
}

View File

@ -0,0 +1,207 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/port/threadpool.h"
#include "mediapipe/framework/tool/simulation_clock_executor.h"
// Tests for GraphProfileCalculator.
using testing::ElementsAre;
namespace mediapipe {
namespace {
using mediapipe::Clock;
// A Calculator with a fixed Process call latency.
class SleepCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
clock_->Sleep(absl::Milliseconds(5));
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
std::shared_ptr<::mediapipe::Clock> clock_ = nullptr;
};
REGISTER_CALCULATOR(SleepCalculator);
// Tests showing GraphProfileCalculator reporting GraphProfile output packets.
class GraphProfileCalculatorTest : public ::testing::Test {
protected:
void SetUpProfileGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: 'SleepCalculator'
input_side_packet: 'CLOCK:sync_clock'
input_stream: 'input_packets_0'
output_stream: 'output_packets_1'
}
node {
calculator: "GraphProfileCalculator"
options: {
[mediapipe.GraphProfileCalculatorOptions.ext]: {
profile_interval: 25000
}
}
input_stream: "FRAME:output_packets_1"
output_stream: "PROFILE:output_packets_0"
}
)",
&graph_config_));
}
static Packet PacketAt(int64 ts) {
return Adopt(new int64(999)).At(Timestamp(ts));
}
static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); }
static bool IsNone(const Packet& packet) {
return packet.Timestamp() == Timestamp::OneOverPostStream();
}
// Return the values of the timestamps of a vector of Packets.
static std::vector<int64> TimestampValues(
const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& p : packets) {
result.push_back(p.Timestamp().Value());
}
return result;
}
// Runs a CalculatorGraph with a series of packet sets.
// Returns a vector of packets from each graph output stream.
void RunGraph(const std::vector<std::vector<Packet>>& input_sets,
std::vector<Packet>* output_packets) {
// Register output packet observers.
tool::AddVectorSink("output_packets_0", &graph_config_, output_packets);
// Start running the graph.
std::shared_ptr<SimulationClockExecutor> executor(
new SimulationClockExecutor(3 /*num_threads*/));
CalculatorGraph graph;
MP_ASSERT_OK(graph.SetExecutor("", executor));
graph.profiler()->SetClock(executor->GetClock());
MP_ASSERT_OK(graph.Initialize(graph_config_));
executor->GetClock()->ThreadStart();
MP_ASSERT_OK(graph.StartRun({
{"sync_clock",
Adopt(new std::shared_ptr<::mediapipe::Clock>(executor->GetClock()))},
}));
// Send each packet to the graph in the specified order.
for (int t = 0; t < input_sets.size(); t++) {
const std::vector<Packet>& input_set = input_sets[t];
for (int i = 0; i < input_set.size(); i++) {
const Packet& packet = input_set[i];
if (!IsNone(packet)) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
absl::StrCat("input_packets_", i), packet));
}
executor->GetClock()->Sleep(absl::Milliseconds(10));
}
}
MP_ASSERT_OK(graph.CloseAllInputStreams());
executor->GetClock()->Sleep(absl::Milliseconds(100));
executor->GetClock()->ThreadFinish();
MP_ASSERT_OK(graph.WaitUntilDone());
}
CalculatorGraphConfig graph_config_;
};
TEST_F(GraphProfileCalculatorTest, GraphProfile) {
SetUpProfileGraph();
auto profiler_config = graph_config_.mutable_profiler_config();
profiler_config->set_enable_profiler(true);
profiler_config->set_trace_enabled(false);
profiler_config->set_trace_log_disabled(true);
profiler_config->set_enable_stream_latency(true);
profiler_config->set_calculator_filter(".*Calculator");
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000)}, //
{PacketAt(20000)}, //
{PacketAt(30000)}, //
{PacketAt(40000)},
};
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Validate the output packets.
EXPECT_THAT(TimestampValues(output_packets), //
ElementsAre(10000, 40000));
GraphProfile expected_profile =
mediapipe::ParseTextProtoOrDie<GraphProfile>(R"pb(
calculator_profiles {
name: "GraphProfileCalculator"
open_runtime: 0
process_runtime { total: 0 count: 3 }
process_input_latency { total: 15000 count: 3 }
process_output_latency { total: 15000 count: 3 }
input_stream_profiles {
name: "output_packets_1"
back_edge: false
latency { total: 0 count: 3 }
}
}
calculator_profiles {
name: "SleepCalculator"
open_runtime: 0
process_runtime { total: 15000 count: 3 }
process_input_latency { total: 0 count: 3 }
process_output_latency { total: 15000 count: 3 }
input_stream_profiles {
name: "input_packets_0"
back_edge: false
latency { total: 0 count: 3 }
}
})pb");
EXPECT_THAT(output_packets[1].Get<GraphProfile>(),
mediapipe::EqualsProto(expected_profile));
}
} // namespace
} // namespace mediapipe

View File

@ -240,7 +240,7 @@ absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) {
auto input_mat = mediapipe::formats::MatView(&input_frame);
// Only 1 or 3 channel images supported by OpenCV.
if ((input_mat.channels() == 1 || input_mat.channels() == 3)) {
if (!(input_mat.channels() == 1 || input_mat.channels() == 3)) {
return absl::InternalError(
"CPU filtering supports only 1 or 3 channel input images.");
}

View File

@ -36,7 +36,7 @@ using GpuBuffer = mediapipe::GpuBuffer;
// stored on the target storage (CPU vs GPU) specified in the calculator option.
//
// The clone shares ownership of the input pixel data on the existing storage.
// If the target storage is diffrent from the existing one, then the data is
// If the target storage is different from the existing one, then the data is
// further copied there.
//
// Example usage:

View File

@ -33,7 +33,7 @@ class InferenceCalculatorSelectorImpl
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const CalculatorGraphConfig::Node& subgraph_node) {
const auto& options =
Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>(
Subgraph::GetOptions<mediapipe::InferenceCalculatorOptions>(
subgraph_node);
std::vector<absl::string_view> impls;
const bool should_use_gpu =

View File

@ -99,8 +99,13 @@ class InferenceCalculator : public NodeIntf {
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr SideInput<std::string>::Optional kNnApiDelegateCacheDir{
"NNAPI_CACHE_DIR"};
static constexpr SideInput<std::string>::Optional kNnApiDelegateModelToken{
"NNAPI_MODEL_TOKEN"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
kOutTensors);
kOutTensors, kNnApiDelegateCacheDir,
kNnApiDelegateModelToken);
protected:
using TfLiteDelegatePtr =

View File

@ -67,9 +67,32 @@ message InferenceCalculatorOptions {
// Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set.
optional string cached_kernel_path = 2;
// Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage {
UNSPECIFIED = 0;
// InferenceRunner will be used only once. Therefore, it is important to
// minimize bootstrap time as well.
FAST_SINGLE_ANSWER = 1;
// Prefer maximizing the throughput. Same inference runner will be used
// repeatedly on different inputs.
SUSTAINED_SPEED = 2;
}
optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED];
}
// Android only.
message Nnapi {}
message Nnapi {
// Directory to store compilation cache. If unspecified, NNAPI will not
// try caching the compilation.
optional string cache_dir = 1;
// Unique token identifying the model. It is the caller's responsibility
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation.
optional string model_token = 2;
}
message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries
// to choose optimal number of threads depending on the device.)

View File

@ -181,9 +181,21 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
// Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1);
delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) {
// No need to free according to tflite::NnApiDelegate() documentation.
});
tflite::StatefulNnApiDelegate::Options options;
const auto& nnapi = calculator_opts.delegate().nnapi();
// Set up cache_dir and model_token for NNAPI compilation cache.
options.cache_dir =
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
if (!kNnApiDelegateCacheDir(cc).IsEmpty()) {
options.cache_dir = kNnApiDelegateCacheDir(cc).Get().c_str();
}
options.model_token =
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
if (!kNnApiDelegateModelToken(cc).IsEmpty()) {
options.model_token = kNnApiDelegateModelToken(cc).Get().c_str();
}
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
return absl::OkStatus();

View File

@ -18,6 +18,7 @@
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/util/tflite/config.h"
@ -65,6 +66,8 @@ class InferenceCalculatorGlImpl
bool allow_precision_loss_ = false;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
@ -96,6 +99,7 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
options.delegate().gpu().use_advanced_gpu_api();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
tflite_gpu_runner_api_ = options.delegate().gpu().api();
tflite_gpu_runner_usage_ = options.delegate().gpu().usage();
use_kernel_caching_ = use_advanced_gpu_api_ &&
options.delegate().gpu().has_cached_kernel_path();
use_gpu_delegate_ = !use_advanced_gpu_api_;
@ -253,9 +257,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
: tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
switch (tflite_gpu_runner_usage_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
tflite_gpu_runner_->ForceOpenGL();
break;
@ -264,10 +286,6 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
tflite_gpu_runner_->ForceOpenCL();
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
}
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));

View File

@ -864,6 +864,7 @@ cc_test(
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",

View File

@ -243,11 +243,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
}
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs()
.Tag(kSequenceExampleTag)
.SetNextTimestampBound(Timestamp::Max());
}
return absl::OkStatus();
}
@ -305,7 +300,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
cc->Outputs()
.Tag(kSequenceExampleTag)
.Add(sequence_.release(), Timestamp::PostStream());
.Add(sequence_.release(), options.output_as_zero_timestamp()
? Timestamp(0ll)
: Timestamp::PostStream());
}
sequence_.reset();

View File

@ -65,4 +65,7 @@ message PackMediaSequenceCalculatorOptions {
// If true, will return an error status if an output sequence would be too
// many bytes to serialize.
optional bool skip_large_sequences = 7 [default = true];
// If true/false, outputs the SequenceExample at timestamp 0/PostStream.
optional bool output_as_zero_timestamp = 8 [default = false];
}

View File

@ -29,6 +29,7 @@
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
@ -43,8 +44,9 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected:
void SetUpCalculator(const std::vector<std::string>& input_streams,
const tf::Features& features,
bool output_only_if_all_present,
bool replace_instead_of_append) {
const bool output_only_if_all_present,
const bool replace_instead_of_append,
const bool output_as_zero_timestamp = false) {
CalculatorGraphConfig::Node config;
config.set_calculator("PackMediaSequenceCalculator");
config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence");
@ -57,6 +59,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
*options->mutable_context_feature_map() = features;
options->set_output_only_if_all_present(output_only_if_all_present);
options->set_replace_data_instead_of_append(replace_instead_of_append);
options->set_output_as_zero_timestamp(output_as_zero_timestamp);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}
@ -194,6 +197,29 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
}
}
TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) {
SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST")
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(output_packets[0].Timestamp().Value(), 0ll);
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
SetUpCalculator(
{"FLOAT_CONTEXT_FEATURE_TEST:test", "FLOAT_CONTEXT_FEATURE_OTHER:test2"},

View File

@ -292,6 +292,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
bool allow_precision_loss_ = false;
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_;
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
bool use_kernel_caching_ = false;
std::string cached_kernel_filename_;
@ -377,6 +379,7 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
options.delegate().gpu().use_advanced_gpu_api();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
tflite_gpu_runner_api_ = options.delegate().gpu().api();
tflite_gpu_runner_usage_ = options.delegate().gpu().usage();
use_kernel_caching_ = use_advanced_gpu_api_ &&
options.delegate().gpu().has_cached_kernel_path();
@ -733,7 +736,23 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
: tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
switch (tflite_gpu_runner_usage_) {
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::
UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) {
case mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
@ -878,11 +897,15 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
// Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1);
delegate_ =
TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) {
// No need to free according to tflite::NnApiDelegate()
// documentation.
});
tflite::StatefulNnApiDelegate::Options options;
const auto& nnapi = calculator_opts.delegate().nnapi();
// Set up cache_dir and model_token for NNAPI compilation cache.
if (nnapi.has_cache_dir() && nnapi.has_model_token()) {
options.cache_dir = nnapi.cache_dir().c_str();
options.model_token = nnapi.model_token().c_str();
}
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
return absl::OkStatus();

View File

@ -67,9 +67,31 @@ message TfLiteInferenceCalculatorOptions {
// Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set.
optional string cached_kernel_path = 2;
// Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage {
UNSPECIFIED = 0;
// InferenceRunner will be used only once. Therefore, it is important to
// minimize bootstrap time as well.
FAST_SINGLE_ANSWER = 1;
// Prefer maximizing the throughput. Same inference runner will be used
// repeatedly on different inputs.
SUSTAINED_SPEED = 2;
}
optional InferenceUsage usage = 5 [default = SUSTAINED_SPEED];
}
// Android only.
message Nnapi {}
message Nnapi {
// Directory to store compilation cache. If unspecified, NNAPI will not
// try caching the compilation.
optional string cache_dir = 1;
// Unique token identifying the model. It is the caller's responsibility
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation.
optional string model_token = 2;
}
message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries
// to choose optimal number of threads depending on the device.)

View File

@ -0,0 +1,21 @@
# Copyright 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
filegroup(
name = "resource_files",
srcs = glob(["res/**"]),
visibility = ["//mediapipe/examples/android/solutions:__subpackages__"],
)

View File

@ -0,0 +1,24 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
repositories {
google()
mavenCentral()
}
dependencies {
classpath "com.android.tools.build:gradle:4.2.0"
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
}
allprojects {
repositories {
google()
mavenCentral()
}
}
task clean(type: Delete) {
delete rootProject.buildDir
}

View File

@ -0,0 +1,17 @@
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app"s APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true

View File

@ -0,0 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

185
mediapipe/examples/android/solutions/gradlew vendored Executable file
View File

@ -0,0 +1,185 @@
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"

View File

@ -0,0 +1,89 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -0,0 +1,50 @@
plugins {
id 'com.android.application'
}
android {
compileSdkVersion 30
buildToolsVersion "30.0.3"
defaultConfig {
applicationId "com.google.mediapipe.apps.hands"
minSdkVersion 21
targetSdkVersion 30
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'com.google.android.material:material:1.3.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
// MediaPipe hands solution API and solution-core.
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:latest.release'
implementation 'com.google.flogger:flogger-system-backend:latest.release'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}

View File

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -0,0 +1,31 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.examples.hands">
<uses-sdk
android:minSdkVersion="21"
android:targetSdkVersion="30" />
<!-- For loading images from gallery -->
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="MediaPipe Hands"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,40 @@
# Copyright 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
android_binary(
name = "hands",
srcs = glob(["**/*.java"]),
custom_package = "com.google.mediapipe.examples.hands",
manifest = "AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.examples.hands",
},
multidex = "native",
resource_files = ["//mediapipe/examples/android/solutions:resource_files"],
deps = [
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/solutioncore:camera_input",
"//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib",
"//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering",
"//mediapipe/java/com/google/mediapipe/solutions/hands",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"@maven//:androidx_concurrent_concurrent_futures",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,129 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.opengl.GLES20;
import android.opengl.Matrix;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutioncore.ResultGlBoundary;
import com.google.mediapipe.solutioncore.ResultGlRenderer;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.List;
/** A custom implementation of {@link ResultGlRenderer} to render MediaPope Hands results. */
public class HandsResultGlRenderer implements ResultGlRenderer<HandsResult> {
private static final String TAG = "HandsResultGlRenderer";
private static final float CONNECTION_THICKNESS = 20.0f;
private static final String VERTEX_SHADER =
"uniform mat4 uTransformMatrix;\n"
+ "attribute vec4 vPosition;\n"
+ "void main() {\n"
+ " gl_Position = uTransformMatrix * vPosition;\n"
+ "}";
private static final String FRAGMENT_SHADER =
"precision mediump float;\n"
+ "void main() {\n"
+ " gl_FragColor = vec4(0, 1, 0, 1);\n"
+ "}";
private int program;
private int positionHandle;
private int transformMatrixHandle;
private final float[] transformMatrix = new float[16];
private FloatBuffer vertexBuffer;
private int loadShader(int type, String shaderCode) {
int shader = GLES20.glCreateShader(type);
GLES20.glShaderSource(shader, shaderCode);
GLES20.glCompileShader(shader);
return shader;
}
@Override
public void setupRendering() {
program = GLES20.glCreateProgram();
int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER);
int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER);
GLES20.glAttachShader(program, vertexShader);
GLES20.glAttachShader(program, fragmentShader);
GLES20.glLinkProgram(program);
positionHandle = GLES20.glGetAttribLocation(program, "vPosition");
transformMatrixHandle = GLES20.glGetUniformLocation(program, "uTransformMatrix");
}
@Override
public void renderResult(HandsResult result, ResultGlBoundary boundary) {
if (result == null) {
return;
}
GLES20.glUseProgram(program);
// Sets the transform matrix to align the result rendering with the scaled output texture.
Matrix.setIdentityM(transformMatrix, 0);
Matrix.scaleM(
transformMatrix,
0,
2 / (boundary.right() - boundary.left()),
2 / (boundary.top() - boundary.bottom()),
1.0f);
GLES20.glUniformMatrix4fv(transformMatrixHandle, 1, false, transformMatrix, 0);
GLES20.glLineWidth(CONNECTION_THICKNESS);
int numHands = result.multiHandLandmarks().size();
for (int i = 0; i < numHands; ++i) {
drawLandmarks(result.multiHandLandmarks().get(i).getLandmarkList());
}
}
/**
* Calls this to delete the shader program.
*
* <p>This is only necessary if one wants to release the program while keeping the context around.
*/
public void release() {
GLES20.glDeleteProgram(program);
}
// TODO: Better hand landmark and hand connection drawing.
private void drawLandmarks(List<NormalizedLandmark> handLandmarkList) {
for (Hands.Connection c : Hands.HAND_CONNECTIONS) {
float[] vertex = new float[4];
NormalizedLandmark start = handLandmarkList.get(c.start());
vertex[0] = normalizedLandmarkValue(start.getX());
vertex[1] = normalizedLandmarkValue(start.getY());
NormalizedLandmark end = handLandmarkList.get(c.end());
vertex[2] = normalizedLandmarkValue(end.getX());
vertex[3] = normalizedLandmarkValue(end.getY());
vertexBuffer =
ByteBuffer.allocateDirect(vertex.length * 4)
.order(ByteOrder.nativeOrder())
.asFloatBuffer()
.put(vertex);
vertexBuffer.position(0);
GLES20.glEnableVertexAttribArray(positionHandle);
GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer);
GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2);
}
}
// Normalizes the value from the landmark value range:[0, 1] to the standard OpenGL coordinate
// value range: [-1, 1].
private float normalizedLandmarkValue(float value) {
return value * 2 - 1;
}
}

View File

@ -0,0 +1,95 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.widget.ImageView;
import com.google.mediapipe.formats.proto.LandmarkProto;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.util.List;
/** An ImageView implementation for displaying MediaPipe Hands results. */
public class HandsResultImageView extends ImageView {
private static final String TAG = "HandsResultImageView";
private static final int LANDMARK_COLOR = Color.RED;
private static final int LANDMARK_RADIUS = 15;
private static final int CONNECTION_COLOR = Color.GREEN;
private static final int CONNECTION_THICKNESS = 10;
public HandsResultImageView(Context context) {
super(context);
setScaleType(ImageView.ScaleType.FIT_CENTER);
}
/**
* Sets a {@link HandsResult} to render.
*
* @param result a {@link HandsResult} object that contains the solution outputs and the input
* {@link Bitmap}.
*/
public void setHandsResult(HandsResult result) {
if (result == null) {
return;
}
Bitmap bmInput = result.inputBitmap();
int width = bmInput.getWidth();
int height = bmInput.getHeight();
Bitmap bmOutput = Bitmap.createBitmap(width, height, bmInput.getConfig());
Canvas canvas = new Canvas(bmOutput);
canvas.drawBitmap(bmInput, new Matrix(), null);
int numHands = result.multiHandLandmarks().size();
for (int i = 0; i < numHands; ++i) {
drawLandmarksOnCanvas(
result.multiHandLandmarks().get(i).getLandmarkList(), canvas, width, height);
}
postInvalidate();
setImageBitmap(bmOutput);
}
// TODO: Better hand landmark and hand connection drawing.
private void drawLandmarksOnCanvas(
List<NormalizedLandmark> handLandmarkList, Canvas canvas, int width, int height) {
// Draw connections.
for (Hands.Connection c : Hands.HAND_CONNECTIONS) {
Paint connectionPaint = new Paint();
connectionPaint.setColor(CONNECTION_COLOR);
connectionPaint.setStrokeWidth(CONNECTION_THICKNESS);
NormalizedLandmark start = handLandmarkList.get(c.start());
NormalizedLandmark end = handLandmarkList.get(c.end());
canvas.drawLine(
start.getX() * width,
start.getY() * height,
end.getX() * width,
end.getY() * height,
connectionPaint);
}
Paint landmarkPaint = new Paint();
landmarkPaint.setColor(LANDMARK_COLOR);
// Draw landmarks.
for (LandmarkProto.NormalizedLandmark landmark : handLandmarkList) {
canvas.drawCircle(
landmark.getX() * width, landmark.getY() * height, LANDMARK_RADIUS, landmarkPaint);
}
}
}

View File

@ -0,0 +1,241 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.examples.hands;
import android.content.Intent;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.FrameLayout;
import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.solutioncore.CameraInput;
import com.google.mediapipe.solutioncore.SolutionGlSurfaceView;
import com.google.mediapipe.solutions.hands.HandLandmark;
import com.google.mediapipe.solutions.hands.Hands;
import com.google.mediapipe.solutions.hands.HandsOptions;
import com.google.mediapipe.solutions.hands.HandsResult;
import java.io.IOException;
/** Main activity of MediaPipe Hands app. */
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private Hands hands;
private int mode = HandsOptions.STATIC_IMAGE_MODE;
// Image demo UI and image loader components.
private Button loadImageButton;
private ActivityResultLauncher<Intent> imageGetter;
private HandsResultImageView imageView;
// Live camera demo UI and camera components.
private Button startCameraButton;
private CameraInput cameraInput;
private SolutionGlSurfaceView<HandsResult> glSurfaceView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
setupStaticImageDemoUiComponents();
setupLiveDemoUiComponents();
}
@Override
protected void onResume() {
super.onResume();
if (mode == HandsOptions.STREAMING_MODE) {
// Restarts the camera and the opengl surface rendering.
cameraInput = new CameraInput(this);
cameraInput.setCameraNewFrameListener(textureFrame -> hands.send(textureFrame));
glSurfaceView.post(this::startCamera);
glSurfaceView.setVisibility(View.VISIBLE);
}
}
@Override
protected void onPause() {
super.onPause();
if (mode == HandsOptions.STREAMING_MODE) {
stopLiveDemo();
}
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData());
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
if (bitmap != null) {
hands.send(bitmap);
}
}
}
});
loadImageButton = (Button) findViewById(R.id.button_load_picture);
loadImageButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (mode == HandsOptions.STREAMING_MODE) {
stopLiveDemo();
}
if (hands == null || mode != HandsOptions.STATIC_IMAGE_MODE) {
setupStaticImageModePipeline();
}
// Reads images from gallery.
Intent gallery =
new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI);
imageGetter.launch(gallery);
}
});
imageView = new HandsResultImageView(this);
}
/** The core MediaPipe Hands setup workflow for its static image mode. */
private void setupStaticImageModePipeline() {
// Initializes a new MediaPipe Hands instance in the static image mode.
mode = HandsOptions.STATIC_IMAGE_MODE;
if (hands != null) {
hands.close();
}
hands = new Hands(this, HandsOptions.builder().setMode(mode).build());
// Connects MediaPipe Hands to the user-defined HandsResultImageView.
hands.setResultListener(
handsResult -> {
logWristLandmark(handsResult, /*showPixelValues=*/ true);
runOnUiThread(() -> imageView.setHandsResult(handsResult));
});
hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe hands error:" + message));
// Updates the preview layout.
FrameLayout frameLayout = (FrameLayout) findViewById(R.id.preview_display_layout);
frameLayout.removeAllViewsInLayout();
imageView.setImageDrawable(null);
frameLayout.addView(imageView);
imageView.setVisibility(View.VISIBLE);
}
/** Sets up the UI components for the live demo with camera input. */
private void setupLiveDemoUiComponents() {
startCameraButton = (Button) findViewById(R.id.button_start_camera);
startCameraButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (hands == null || mode != HandsOptions.STREAMING_MODE) {
setupStreamingModePipeline();
}
}
});
}
/** The core MediaPipe Hands setup workflow for its streaming mode. */
private void setupStreamingModePipeline() {
// Initializes a new MediaPipe Hands instance in the streaming mode.
mode = HandsOptions.STREAMING_MODE;
if (hands != null) {
hands.close();
}
hands = new Hands(this, HandsOptions.builder().setMode(mode).build());
hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe hands error:" + message));
// Initializes a new CameraInput instance and connects it to MediaPipe Hands.
cameraInput = new CameraInput(this);
cameraInput.setCameraNewFrameListener(textureFrame -> hands.send(textureFrame));
// Initalizes a new Gl surface view with a user-defined HandsResultGlRenderer.
glSurfaceView =
new SolutionGlSurfaceView<>(this, hands.getGlContext(), hands.getGlMajorVersion());
glSurfaceView.setSolutionResultRenderer(new HandsResultGlRenderer());
glSurfaceView.setRenderInputImage(true);
hands.setResultListener(
handsResult -> {
logWristLandmark(handsResult, /*showPixelValues=*/ false);
glSurfaceView.setRenderData(handsResult);
glSurfaceView.requestRender();
});
// The runnable to start camera after the gl surface view is attached.
glSurfaceView.post(this::startCamera);
// Updates the preview layout.
FrameLayout frameLayout = (FrameLayout) findViewById(R.id.preview_display_layout);
imageView.setVisibility(View.GONE);
frameLayout.removeAllViewsInLayout();
frameLayout.addView(glSurfaceView);
glSurfaceView.setVisibility(View.VISIBLE);
frameLayout.requestLayout();
}
private void startCamera() {
cameraInput.start(
this,
hands.getGlContext(),
CameraInput.CameraFacing.FRONT,
glSurfaceView.getWidth(),
glSurfaceView.getHeight());
}
private void stopLiveDemo() {
if (cameraInput != null) {
cameraInput.stop();
}
if (glSurfaceView != null) {
glSurfaceView.setVisibility(View.GONE);
}
}
private void logWristLandmark(HandsResult result, boolean showPixelValues) {
NormalizedLandmark wristLandmark = Hands.getHandLandmark(result, 0, HandLandmark.WRIST);
// For Bitmaps, show the pixel values. For texture inputs, show the normoralized cooridanates.
if (showPixelValues) {
int width = result.inputBitmap().getWidth();
int height = result.inputBitmap().getHeight();
Log.i(
TAG,
"MediaPipe Hand wrist coordinates (pixel values): x= "
+ wristLandmark.getX() * width
+ " y="
+ wristLandmark.getY() * height);
} else {
Log.i(
TAG,
"MediaPipe Hand wrist normalized coordinates (value range: [0, 1]): x= "
+ wristLandmark.getX()
+ " y="
+ wristLandmark.getY());
}
}
}

View File

@ -0,0 +1 @@
../../../res

View File

@ -0,0 +1,34 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportHeight="108"
android:viewportWidth="108">
<path
android:fillType="evenOdd"
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
android:strokeColor="#00000000"
android:strokeWidth="1">
<aapt:attr name="android:fillColor">
<gradient
android:endX="78.5885"
android:endY="90.9159"
android:startX="48.7653"
android:startY="61.0927"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
android:strokeColor="#00000000"
android:strokeWidth="1" />
</vector>

View File

@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<vector
android:height="108dp"
android:width="108dp"
android:viewportHeight="108"
android:viewportWidth="108"
xmlns:android="http://schemas.android.com/apk/res/android">
<path android:fillColor="#26A69A"
android:pathData="M0,0h108v108h-108z"/>
<path android:fillColor="#00000000" android:pathData="M9,0L9,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,0L19,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,0L29,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,0L39,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,0L49,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,0L59,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,0L69,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,0L79,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M89,0L89,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M99,0L99,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,9L108,9"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,19L108,19"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,29L108,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,39L108,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,49L108,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,59L108,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,69L108,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,79L108,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,89L108,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,99L108,99"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,29L89,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,39L89,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,49L89,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,59L89,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,69L89,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,79L89,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,19L29,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,19L39,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,19L49,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,19L59,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,19L69,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,19L79,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
</vector>

View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:id="@+id/buttons"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:orientation="horizontal">
<Button
android:id="@+id/button_load_picture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Load Picture" />
<Button
android:id="@+id/button_start_camera"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Start Camera" />
</LinearLayout>
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/no_view"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:gravity="center"
android:text="Please press any button above to start" />
</FrameLayout>
</LinearLayout>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 959 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 900 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -0,0 +1,3 @@
<resources>
<string name="no_camera_access" translatable="false">Please grant camera permissions.</string>
</resources>

View File

@ -0,0 +1,11 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -0,0 +1,2 @@
rootProject.name = "mediapipe-solutions-examples"
include ':hands'

View File

@ -291,7 +291,6 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
CheckBorder(static_features, 1000, 1000, 495, 395);
}
#if 0
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
@ -727,8 +726,8 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeZoomingWithCache) {
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000,
1000, runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 2000000, 500, 500,
runner.get());
MP_ASSERT_OK(runner->Run());
@ -752,7 +751,6 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
CheckCropRect(500, 500, 916, 916, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
#endif
TEST(ContentZoomingCalculatorTest, MaxZoomValueOverride) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigF);
@ -781,7 +779,6 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValueOverride) {
runner->Outputs().Tag("CROP_RECT").packets);
}
#if 0
TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
@ -969,7 +966,6 @@ TEST(ContentZoomingCalculatorTest, ProvidesConstantFirstRect) {
EXPECT_EQ(first_rect.height(), rect.height());
}
}
#endif
} // namespace
} // namespace autoflip

View File

@ -71,7 +71,10 @@ message FaceBoxAdjusterCalculatorOptions {
optional int32 max_facesize_history_us = 9 [default = 300000000];
// Scale factor of face width to shift based on pan look angle.
optional float pan_position_shift_scale = 15 [default = 0.5];
optional float pan_position_shift_scale = 15
[default = 1.0, deprecated = true];
// Scale factor of face height to shift based on tilt look angle.
optional float tilt_position_shift_scale = 16 [default = 0.5];
optional float tilt_position_shift_scale = 16 [default = 0.25];
// Scale factor of face width to shift based on roll look angle.
optional float roll_position_shift_scale = 17 [default = 0.7];
}

View File

@ -29,6 +29,7 @@ import os
import plistlib
import re
import subprocess
from typing import Optional
import uuid
# This script is meant to be located in the MediaPipe iOS examples directory
@ -79,7 +80,7 @@ def configure_bundle_id_prefix(
return bundle_id_prefix
def get_app_id(profile_path) -> str:
def get_app_id(profile_path) -> Optional[str]:
try:
plist = subprocess.check_output(
["security", "cms", "-D", "-i", profile_path])

View File

@ -222,10 +222,10 @@ cc_library(
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto",
"//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/memory",
],
@ -299,7 +299,6 @@ cc_library(
":graph_service",
":graph_service_manager",
":input_stream_manager",
":input_stream_shard",
":output_side_packet_impl",
":output_stream",
":output_stream_manager",
@ -317,8 +316,6 @@ cc_library(
":timestamp",
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:packet_factory_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
@ -327,6 +324,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
@ -336,8 +334,8 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/tool:fill_packet_set",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate",
@ -345,10 +343,7 @@ cc_library(
"//mediapipe/gpu:graph_support",
"//mediapipe/util:cpu_util",
] + select({
"//conditions:default": [
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/gpu:gpu_service",
],
"//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"],
"//mediapipe/gpu:disable_gpu": [],
}),
)
@ -389,13 +384,11 @@ cc_library(
":input_side_packet_handler",
":input_stream_handler",
":input_stream_manager",
":input_stream_shard",
":legacy_calculator_support",
":mediapipe_profiling",
":output_side_packet_impl",
":output_stream_handler",
":output_stream_manager",
":output_stream_shard",
":packet",
":packet_set",
":packet_type",
@ -404,14 +397,12 @@ cc_library(
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/stream_handler:default_input_stream_handler",
"//mediapipe/framework/stream_handler:in_order_output_stream_handler",
"//mediapipe/framework/tool:name_util",
@ -421,6 +412,7 @@ cc_library(
"//mediapipe/gpu:graph_support",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
@ -1588,6 +1580,7 @@ cc_test(
":packet",
":packet_test_cc_proto",
":type_map",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",

View File

@ -212,6 +212,9 @@ message ProfilerConfig {
// False specifies an event for each calculator invocation.
// True specifies a separate event for each start and finish time.
bool trace_log_instant_events = 17;
// Limits calculator-profile histograms to a subset of calculators.
string calculator_filter = 18;
}
// Describes the topology and function of a MediaPipe Graph. The graph of

View File

@ -14,16 +14,40 @@
#include "mediapipe/framework/calculator_contract.h"
#include <memory>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/packet_generator_wrapper_calculator.pb.h"
#include "mediapipe/framework/tool/tag_map.h"
namespace mediapipe {
namespace {
CalculatorGraphConfig::Node MakePacketGeneratorWrapperConfig(
const PacketGeneratorConfig& node, const std::string& package) {
CalculatorGraphConfig::Node wrapper_node;
wrapper_node.set_calculator("PacketGeneratorWrapperCalculator");
*wrapper_node.mutable_input_side_packet() = node.input_side_packet();
*wrapper_node.mutable_output_side_packet() = node.output_side_packet();
auto* wrapper_options = wrapper_node.mutable_options()->MutableExtension(
mediapipe::PacketGeneratorWrapperCalculatorOptions::ext);
wrapper_options->set_packet_generator(node.packet_generator());
wrapper_options->set_package(package);
if (node.has_options()) {
*wrapper_options->mutable_options() = node.options();
}
return wrapper_node;
}
} // anonymous namespace
absl::Status CalculatorContract::Initialize(
const CalculatorGraphConfig::Node& node) {
std::vector<absl::Status> statuses;
@ -74,7 +98,8 @@ absl::Status CalculatorContract::Initialize(
return absl::OkStatus();
}
absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node) {
absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node,
const std::string& package) {
std::vector<absl::Status> statuses;
auto input_side_packet_statusor =
@ -101,6 +126,11 @@ absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node) {
return std::move(builder);
}
wrapper_config_ = std::make_unique<CalculatorGraphConfig::Node>(
MakePacketGeneratorWrapperConfig(node, package));
options_.Initialize(*wrapper_config_);
inputs_ = absl::make_unique<PacketTypeSet>(0);
outputs_ = absl::make_unique<PacketTypeSet>(0);
input_side_packets_ = absl::make_unique<PacketTypeSet>(
std::move(input_side_packet_statusor).value());
output_side_packets_ = absl::make_unique<PacketTypeSet>(

View File

@ -48,7 +48,8 @@ namespace mediapipe {
class CalculatorContract {
public:
absl::Status Initialize(const CalculatorGraphConfig::Node& node);
absl::Status Initialize(const PacketGeneratorConfig& node);
absl::Status Initialize(const PacketGeneratorConfig& node,
const std::string& package);
absl::Status Initialize(const StatusHandlerConfig& node);
void SetNodeName(const std::string& node_name) { node_name_ = node_name; }
@ -163,7 +164,14 @@ class CalculatorContract {
template <class T>
void GetNodeOptions(T* result) const;
// When creating a contract for a PacketGenerator, we define a configuration
// for a wrapper calculator, for use by CalculatorNode.
const CalculatorGraphConfig::Node& GetWrapperConfig() const {
return *wrapper_config_;
}
const CalculatorGraphConfig::Node* node_config_ = nullptr;
std::unique_ptr<CalculatorGraphConfig::Node> wrapper_config_;
tool::OptionsMap options_;
std::unique_ptr<PacketTypeSet> inputs_;
std::unique_ptr<PacketTypeSet> outputs_;
@ -175,6 +183,8 @@ class CalculatorContract {
std::map<std::string, GraphServiceRequest> service_requests_;
bool process_timestamps_ = false;
TimestampDiff timestamp_offset_ = TimestampDiff::Unset();
friend class CalculatorNode;
};
} // namespace mediapipe

View File

@ -80,7 +80,7 @@ TEST(CalculatorContractTest, PacketGenerator) {
output_side_packet: "content_fingerprint"
)pb");
CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node));
MP_EXPECT_OK(contract.Initialize(node, ""));
EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1);
EXPECT_EQ(contract.OutputSidePackets().NumEntries(), 4);
}

View File

@ -26,6 +26,7 @@
#include "absl/container/fixed_array.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
@ -84,9 +85,9 @@ void CalculatorGraph::ScheduleAllOpenableNodes() {
// node->ReadyForOpen() only before any node or graph input stream has
// propagated header packets or generated output side packets, either of
// which may cause a downstream node to be scheduled for OpenNode().
for (CalculatorNode& node : *nodes_) {
if (node.ReadyForOpen()) {
scheduler_.ScheduleNodeForOpen(&node);
for (auto& node : nodes_) {
if (node->ReadyForOpen()) {
scheduler_.ScheduleNodeForOpen(node.get());
}
}
}
@ -234,15 +235,15 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
std::vector<absl::Status> errors;
// Create and initialize all the nodes in the graph.
nodes_ = absl::make_unique<absl::FixedArray<CalculatorNode>>(
validated_graph_->CalculatorInfos().size());
for (int node_id = 0; node_id < validated_graph_->CalculatorInfos().size();
++node_id) {
// buffer_size_hint will be positive if one was specified in
// the graph proto.
int buffer_size_hint = 0;
const absl::Status result = (*nodes_)[node_id].Initialize(
validated_graph_.get(), node_id, input_stream_managers_.get(),
NodeTypeInfo::NodeRef node_ref(NodeTypeInfo::NodeType::CALCULATOR, node_id);
nodes_.push_back(absl::make_unique<CalculatorNode>());
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_);
if (buffer_size_hint > 0) {
@ -263,6 +264,38 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
return absl::OkStatus();
}
absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
const std::vector<int>& non_scheduled_generators) {
// Do not add wrapper nodes again if we are running the graph multiple times.
if (packet_generator_nodes_added_) return absl::OkStatus();
packet_generator_nodes_added_ = true;
// Use a local variable to avoid needing to lock errors_.
std::vector<absl::Status> errors;
for (int index : non_scheduled_generators) {
// This is never used by the packet generator wrapper.
int buffer_size_hint = 0;
NodeTypeInfo::NodeRef node_ref(NodeTypeInfo::NodeType::PACKET_GENERATOR,
index);
nodes_.push_back(absl::make_unique<CalculatorNode>());
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_);
if (!result.ok()) {
// Collect as many errors as we can before failing.
errors.push_back(result);
}
}
if (!errors.empty()) {
return tool::CombinedStatus(
"CalculatorGraph::InitializePacketGeneratorNodes failed: ", errors);
}
return absl::OkStatus();
}
absl::Status CalculatorGraph::InitializeProfiler() {
profiler_->Initialize(*validated_graph_);
return absl::OkStatus();
@ -528,8 +561,8 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
std::map<std::string, Packet> additional_side_packets;
bool update_sp = false;
bool uses_gpu = false;
for (const auto& node : *nodes_) {
if (node.UsesGpu()) {
for (const auto& node : nodes_) {
if (node->UsesGpu()) {
uses_gpu = true;
break;
}
@ -571,9 +604,9 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
}
// Set up executors.
for (auto& node : *nodes_) {
if (node.UsesGpu()) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(&node));
for (auto& node : nodes_) {
if (node->UsesGpu()) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
}
}
for (const auto& name_executor : gpu_resources->GetGpuExecutors()) {
@ -616,8 +649,10 @@ absl::Status CalculatorGraph::PrepareForRun(
}
current_run_side_packets_.clear();
std::vector<int> non_scheduled_generators;
absl::Status generator_status = packet_generator_graph_.RunGraphSetup(
*input_side_packets, &current_run_side_packets_);
*input_side_packets, &current_run_side_packets_,
&non_scheduled_generators);
CallStatusHandlers(GraphRunState::PRE_RUN, generator_status);
@ -650,6 +685,8 @@ absl::Status CalculatorGraph::PrepareForRun(
}
scheduler_.Reset();
MP_RETURN_IF_ERROR(InitializePacketGeneratorNodes(non_scheduled_generators));
{
absl::MutexLock lock(&full_input_streams_mutex_);
// Initialize a count per source node to store the number of input streams
@ -671,22 +708,22 @@ absl::Status CalculatorGraph::PrepareForRun(
output_side_packets_[index].PrepareForRun(
std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1));
}
for (CalculatorNode& node : *nodes_) {
for (auto& node : nodes_) {
InputStreamManager::QueueSizeCallback queue_size_callback =
std::bind(&CalculatorGraph::UpdateThrottledNodes, this,
std::placeholders::_1, std::placeholders::_2);
node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback);
scheduler_.AssignNodeToSchedulerQueue(&node);
node->SetQueueSizeCallbacks(queue_size_callback, queue_size_callback);
scheduler_.AssignNodeToSchedulerQueue(node.get());
// TODO: update calculator node to use GraphServiceManager
// instead of service packets?
const absl::Status result = node.PrepareForRun(
const absl::Status result = node->PrepareForRun(
current_run_side_packets_, service_manager_.ServicePackets(),
std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_,
&node),
node.get()),
std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_,
&node),
node.get()),
std::bind(&internal::Scheduler::ScheduleNodeIfNotThrottled, &scheduler_,
&node, std::placeholders::_1),
node.get(), std::placeholders::_1),
std::bind(&CalculatorGraph::RecordError, this, std::placeholders::_1),
counter_factory_.get());
if (!result.ok()) {
@ -714,8 +751,8 @@ absl::Status CalculatorGraph::PrepareForRun(
// Ensure that the latest value of max queue size is passed to all input
// streams.
for (auto& node : *nodes_) {
node.SetMaxInputStreamQueueSize(max_queue_size_);
for (auto& node : nodes_) {
node->SetMaxInputStreamQueueSize(max_queue_size_);
}
// Allow graph input streams to override the global max queue size.
@ -729,9 +766,9 @@ absl::Status CalculatorGraph::PrepareForRun(
(*stream)->SetMaxQueueSize(name_max.second);
}
for (CalculatorNode& node : *nodes_) {
if (node.IsSource()) {
scheduler_.AddUnopenedSourceNode(&node);
for (auto& node : nodes_) {
if (node->IsSource()) {
scheduler_.AddUnopenedSourceNode(node.get());
has_sources_ = true;
}
}
@ -1077,7 +1114,7 @@ void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream,
}
} else {
if (!is_throttled) {
CalculatorNode& node = (*nodes_)[node_id];
CalculatorNode& node = *nodes_[node_id];
// Add this node to the scheduler queue if possible.
if (node.Active() && !node.Closed()) {
nodes_to_schedule.emplace_back(&node);
@ -1244,8 +1281,8 @@ void CalculatorGraph::CleanupAfterRun(absl::Status* status) {
MEDIAPIPE_CHECK_OK(*status);
}
for (CalculatorNode& node : *nodes_) {
node.CleanupAfterRun(*status);
for (auto& node : nodes_) {
node->CleanupAfterRun(*status);
}
for (auto& graph_output_stream : graph_output_streams_) {

View File

@ -486,6 +486,8 @@ class CalculatorGraph {
absl::Status InitializeStreams();
absl::Status InitializeProfiler();
absl::Status InitializeCalculatorNodes();
absl::Status InitializePacketGeneratorNodes(
const std::vector<int>& non_scheduled_generators);
// Iterates through all nodes and schedules any that can be opened.
void ScheduleAllOpenableNodes();
@ -556,7 +558,8 @@ class CalculatorGraph {
std::unique_ptr<InputStreamManager[]> input_stream_managers_;
std::unique_ptr<OutputStreamManager[]> output_stream_managers_;
std::unique_ptr<OutputSidePacketImpl[]> output_side_packets_;
std::unique_ptr<absl::FixedArray<CalculatorNode>> nodes_;
std::vector<std::unique_ptr<CalculatorNode>> nodes_;
bool packet_generator_nodes_added_ = false;
// The graph output streams.
std::vector<std::shared_ptr<internal::GraphOutputStream>>

View File

@ -52,6 +52,25 @@ class OutputSidePacketInProcessCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator);
// Takes an input side packet and passes it as an output side packet.
class OutputSidePacketInOpenCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(
&cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
};
REGISTER_CALCULATOR(OutputSidePacketInOpenCalculator);
// Takes an input stream packet and counts the number of the packets it
// receives. Outputs the total number of packets as a side packet in Close.
class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase {
@ -802,5 +821,80 @@ TEST(CalculatorGraph, OutputSidePacketCached) {
}
}
TEST(CalculatorGraph, GeneratorAfterCalculatorOpen) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_side_packet: "offset"
node {
calculator: "OutputSidePacketInOpenCalculator"
input_side_packet: "offset"
output_side_packet: "offset1"
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'offset1'
output_side_packet: 'offset_out'
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "offset_out"
output_stream: "output"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({{"offset", MakePacket<TimestampDiff>(100)}}));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(100, output_packets[0].Get<TimestampDiff>().Value());
}
TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "offset"
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "offset"
output_side_packet: "offset"
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'offset'
output_side_packet: 'offset_out'
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "offset_out"
output_stream: "output"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return absl::OkStatus();
}));
// Run twice to verify that we don't duplicate wrapper nodes.
for (int run = 0; run < 2; ++run) {
output_packets.clear();
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(100).At(Timestamp(0))));
MP_ASSERT_OK(graph.CloseInputStream("offset"));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(100, output_packets[0].Get<TimestampDiff>().Value());
}
}
} // namespace
} // namespace mediapipe

View File

@ -1133,24 +1133,6 @@ class CheckInputTimestamp2SinkCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator);
// Takes an input stream packet and passes it (with timestamp removed) as an
// output side packet.
class OutputSidePacketInProcessCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(
cc->Inputs().Index(0).Value().At(Timestamp::Unset()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator);
// A calculator checks if either of two input streams contains a packet and
// sends the packet to the single output stream with the same timestamp.
class SimpleMuxCalculator : public CalculatorBase {

View File

@ -20,6 +20,7 @@
#include <utility>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
@ -119,79 +120,92 @@ Timestamp CalculatorNode::SourceProcessOrder(
}
absl::Status CalculatorNode::Initialize(
const ValidatedGraphConfig* validated_graph, int node_id,
const ValidatedGraphConfig* validated_graph, NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context) {
RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL";
node_id_ = node_id;
validated_graph_ = validated_graph;
profiling_context_ = profiling_context;
const CalculatorGraphConfig::Node& node_config =
validated_graph_->Config().node(node_id_);
name_ = tool::CanonicalNodeName(validated_graph_->Config(), node_id_);
max_in_flight_ = node_config.max_in_flight();
max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1;
if (!node_config.executor().empty()) {
executor_ = node_config.executor();
const CalculatorGraphConfig::Node* node_config;
if (node_ref.type == NodeTypeInfo::NodeType::CALCULATOR) {
node_config = &validated_graph_->Config().node(node_ref.index);
name_ = tool::CanonicalNodeName(validated_graph_->Config(), node_ref.index);
node_type_info_ = &validated_graph_->CalculatorInfos()[node_ref.index];
} else if (node_ref.type == NodeTypeInfo::NodeType::PACKET_GENERATOR) {
const PacketGeneratorConfig& pg_config =
validated_graph_->Config().packet_generator(node_ref.index);
name_ = absl::StrCat("__pg_", node_ref.index, "_",
pg_config.packet_generator());
node_type_info_ = &validated_graph_->GeneratorInfos()[node_ref.index];
node_config = &node_type_info_->Contract().GetWrapperConfig();
} else {
return absl::InvalidArgumentError(
"node_ref is not a calculator or packet generator");
}
source_layer_ = node_config.source_layer();
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
const CalculatorContract& contract = node_type_info.Contract();
max_in_flight_ = node_config->max_in_flight();
max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1;
if (!node_config->executor().empty()) {
executor_ = node_config->executor();
}
source_layer_ = node_config->source_layer();
const CalculatorContract& contract = node_type_info_->Contract();
uses_gpu_ =
node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
ContainsKey(node_type_info.Contract().ServiceRequests(), kGpuService.key);
node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
ContainsKey(node_type_info_->Contract().ServiceRequests(),
kGpuService.key);
// TODO Propagate types between calculators when SetAny is used.
MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
node_type_info.OutputSidePacketTypes(), output_side_packets));
node_type_info_->OutputSidePacketTypes(), output_side_packets));
MP_RETURN_IF_ERROR(InitializeInputSidePackets(output_side_packets));
MP_RETURN_IF_ERROR(InitializeOutputStreamHandler(
node_config.output_stream_handler(), node_type_info.OutputStreamTypes()));
MP_RETURN_IF_ERROR(
InitializeOutputStreamHandler(node_config->output_stream_handler(),
node_type_info_->OutputStreamTypes()));
MP_RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers));
calculator_state_ = absl::make_unique<CalculatorState>(
name_, node_id_, node_config.calculator(), node_config,
name_, node_ref.index, node_config->calculator(), *node_config,
profiling_context_);
// Inform the scheduler that this node has buffering behavior and that the
// maximum input queue size should be adjusted accordingly.
*buffer_size_hint = node_config.buffer_size_hint();
*buffer_size_hint = node_config->buffer_size_hint();
calculator_context_manager_.Initialize(
calculator_state_.get(), node_type_info.InputStreamTypes().TagMap(),
node_type_info.OutputStreamTypes().TagMap(),
calculator_state_.get(), node_type_info_->InputStreamTypes().TagMap(),
node_type_info_->OutputStreamTypes().TagMap(),
/*calculator_run_in_parallel=*/max_in_flight_ > 1);
// The graph specified InputStreamHandler takes priority.
const bool graph_specified =
node_config.input_stream_handler().has_input_stream_handler();
const bool calc_specified = !(node_type_info.GetInputStreamHandler().empty());
node_config->input_stream_handler().has_input_stream_handler();
const bool calc_specified =
!(node_type_info_->GetInputStreamHandler().empty());
// Only use calculator ISH if available, and if the graph ISH is not set.
InputStreamHandlerConfig handler_config;
const bool use_calc_specified = calc_specified && !graph_specified;
if (use_calc_specified) {
*(handler_config.mutable_input_stream_handler()) =
node_type_info.GetInputStreamHandler();
node_type_info_->GetInputStreamHandler();
*(handler_config.mutable_options()) =
node_type_info.GetInputStreamHandlerOptions();
node_type_info_->GetInputStreamHandlerOptions();
}
// Use calculator or graph specified InputStreamHandler, or the default ISH
// already set from graph.
MP_RETURN_IF_ERROR(InitializeInputStreamHandler(
use_calc_specified ? handler_config : node_config.input_stream_handler(),
node_type_info.InputStreamTypes()));
use_calc_specified ? handler_config : node_config->input_stream_handler(),
node_type_info_->InputStreamTypes()));
for (auto& stream : output_stream_handler_->OutputStreams()) {
stream->Spec()->offset_enabled =
@ -209,9 +223,7 @@ absl::Status CalculatorNode::InitializeOutputSidePackets(
OutputSidePacketImpl* output_side_packets) {
output_side_packets_ =
absl::make_unique<OutputSidePacketSet>(output_side_packet_types.TagMap());
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
int base_index = node_type_info.OutputSidePacketBaseIndex();
int base_index = node_type_info_->OutputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index);
for (CollectionItemId id = output_side_packets_->BeginId();
id < output_side_packets_->EndId(); ++id) {
@ -223,13 +235,11 @@ absl::Status CalculatorNode::InitializeOutputSidePackets(
absl::Status CalculatorNode::InitializeInputSidePackets(
OutputSidePacketImpl* output_side_packets) {
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
int base_index = node_type_info.InputSidePacketBaseIndex();
int base_index = node_type_info_->InputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index);
// Set all the mirrors.
for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId();
id < node_type_info.InputSidePacketTypes().EndId(); ++id) {
for (CollectionItemId id = node_type_info_->InputSidePacketTypes().BeginId();
id < node_type_info_->InputSidePacketTypes().EndId(); ++id) {
int output_side_packet_index =
validated_graph_->InputSidePacketInfos()[base_index + id.value()]
.upstream;
@ -252,11 +262,9 @@ absl::Status CalculatorNode::InitializeInputSidePackets(
absl::Status CalculatorNode::InitializeOutputStreams(
OutputStreamManager* output_stream_managers) {
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
RET_CHECK_LE(0, node_type_info.OutputStreamBaseIndex());
RET_CHECK_LE(0, node_type_info_->OutputStreamBaseIndex());
OutputStreamManager* current_output_stream_managers =
&output_stream_managers[node_type_info.OutputStreamBaseIndex()];
&output_stream_managers[node_type_info_->OutputStreamBaseIndex()];
return output_stream_handler_->InitializeOutputStreamManagers(
current_output_stream_managers);
}
@ -266,20 +274,18 @@ absl::Status CalculatorNode::InitializeInputStreams(
OutputStreamManager* output_stream_managers) {
RET_CHECK(input_stream_managers) << "input_stream_managers is NULL";
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
RET_CHECK_LE(0, node_type_info.InputStreamBaseIndex());
RET_CHECK_LE(0, node_type_info_->InputStreamBaseIndex());
InputStreamManager* current_input_stream_managers =
&input_stream_managers[node_type_info.InputStreamBaseIndex()];
&input_stream_managers[node_type_info_->InputStreamBaseIndex()];
MP_RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers(
current_input_stream_managers));
// Set all the mirrors.
for (CollectionItemId id = node_type_info.InputStreamTypes().BeginId();
id < node_type_info.InputStreamTypes().EndId(); ++id) {
for (CollectionItemId id = node_type_info_->InputStreamTypes().BeginId();
id < node_type_info_->InputStreamTypes().EndId(); ++id) {
int output_stream_index =
validated_graph_
->InputStreamInfos()[node_type_info.InputStreamBaseIndex() +
->InputStreamInfos()[node_type_info_->InputStreamBaseIndex() +
id.value()]
.upstream;
RET_CHECK_LE(0, output_stream_index);
@ -287,7 +293,7 @@ absl::Status CalculatorNode::InitializeInputStreams(
&output_stream_managers[output_stream_index];
VLOG(2) << "Adding mirror for input stream with id " << id.value()
<< " and flat index "
<< node_type_info.InputStreamBaseIndex() + id.value()
<< node_type_info_->InputStreamBaseIndex() + id.value()
<< " which will be connected to output stream with flat index "
<< output_stream_index;
origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id);
@ -391,10 +397,9 @@ absl::Status CalculatorNode::PrepareForRun(
std::move(schedule_callback), error_callback);
output_stream_handler_->PrepareForRun(error_callback);
const PacketTypeSet* packet_types =
&validated_graph_->CalculatorInfos()[node_id_].InputSidePacketTypes();
const auto& contract = node_type_info_->Contract();
input_side_packet_types_ = RemoveOmittedPacketTypes(
*packet_types, all_side_packets, validated_graph_);
contract.InputSidePackets(), all_side_packets, validated_graph_);
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(
input_side_packet_types_.get(), all_side_packets,
[this]() { CalculatorNode::InputSidePacketsReady(); },
@ -404,8 +409,6 @@ absl::Status CalculatorNode::PrepareForRun(
calculator_state_->SetOutputSidePackets(output_side_packets_.get());
calculator_state_->SetCounterFactory(counter_factory);
const auto& contract =
validated_graph_->CalculatorInfos()[node_id_].Contract();
for (const auto& svc_req : contract.ServiceRequests()) {
const auto& req = svc_req.second;
auto it = service_packets.find(req.Service().key);

View File

@ -70,7 +70,9 @@ class CalculatorNode {
CalculatorNode();
CalculatorNode(const CalculatorNode&) = delete;
CalculatorNode& operator=(const CalculatorNode&) = delete;
int Id() const { return node_id_; }
int Id() const {
return node_type_info_ ? node_type_info_->Node().index : -1;
}
// Returns a value according to which the scheduler queue determines the
// relative priority between runnable source nodes; a smaller value means
@ -106,7 +108,7 @@ class CalculatorNode {
// OutputSidePacketImpls corresponding to the output side packet indexes in
// validated_graph.
absl::Status Initialize(const ValidatedGraphConfig* validated_graph,
int node_id,
NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets,
@ -287,7 +289,6 @@ class CalculatorNode {
// Keeps data which a Calculator subclass needs access to.
std::unique_ptr<CalculatorState> calculator_state_;
int node_id_ = -1;
std::string name_; // Optional user-defined name
// Name of the executor which the node will execute on. If empty, the node
// will execute on the default executor.
@ -372,6 +373,8 @@ class CalculatorNode {
internal::SchedulerQueue* scheduler_queue_ = nullptr;
const ValidatedGraphConfig* validated_graph_ = nullptr;
const NodeTypeInfo* node_type_info_ = nullptr;
};
} // namespace mediapipe

View File

@ -158,11 +158,11 @@ class CalculatorNodeTest : public ::testing::Test {
input_side_packets_.emplace("input_a", Adopt(new int(42)));
input_side_packets_.emplace("input_b", Adopt(new int(42)));
node_.reset(new CalculatorNode());
node_ = absl::make_unique<CalculatorNode>();
MP_ASSERT_OK(node_->Initialize(
&validated_graph_, 2, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint_, graph_profiler_));
&validated_graph_, {NodeTypeInfo::NodeType::CALCULATOR, 2},
input_stream_managers_.get(), output_stream_managers_.get(),
output_side_packets_.get(), &buffer_size_hint_, graph_profiler_));
}
absl::Status PrepareNodeForRun() {

View File

@ -30,6 +30,14 @@ bzl_library(
visibility = ["//mediapipe/framework:__subpackages__"],
)
bzl_library(
name = "descriptor_set_bzl",
srcs = [
"descriptor_set.bzl",
],
visibility = ["//mediapipe/framework:__subpackages__"],
)
proto_library(
name = "proto_descriptor_proto",
srcs = ["proto_descriptor.proto"],
@ -281,7 +289,8 @@ cc_library(
# Use this library through "mediapipe/framework/port:gtest_main".
visibility = ["//mediapipe/framework/port:__pkg__"],
deps = [
":status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,139 @@
"""Outputs a FileDescriptorSet with all transitive dependencies.
Copied from tools/build_defs/proto/descriptor_set.bzl.
"""
TransitiveDescriptorInfo = provider(
"The transitive descriptors from a set of protos.",
fields = ["descriptors"],
)
DirectDescriptorInfo = provider(
"The direct descriptors from a set of protos.",
fields = ["descriptors"],
)
def calculate_transitive_descriptor_set(actions, deps, output):
"""Calculates the transitive dependencies of the deps.
Args:
actions: the actions (typically ctx.actions) used to run commands
deps: the deps to get the transitive dependencies of
output: the output file the data will be written to
Returns:
The same output file passed as the input arg, for convenience.
"""
# Join all proto descriptors in a single file.
transitive_descriptor_sets = depset(transitive = [
dep[ProtoInfo].transitive_descriptor_sets if ProtoInfo in dep else dep[TransitiveDescriptorInfo].descriptors
for dep in deps
])
args = actions.args()
args.use_param_file(param_file_arg = "--arg-file=%s")
args.add_all(transitive_descriptor_sets)
# Because `xargs` must take its arguments before the command to execute,
# we cannot simply put a reference to the argument list at the end, as in the
# case of param file spooling, since the entire argument list will get
# replaced by "--arg-file=bazel-out/..." which needs to be an `xargs`
# argument rather than a `cat` argument.
#
# We look to see if the first argument begins with a '--arg-file=' and
# selectively choose xargs vs. just supplying the arguments to `cat`.
actions.run_shell(
outputs = [output],
inputs = transitive_descriptor_sets,
progress_message = "Joining descriptors.",
command = ("if [[ \"$1\" =~ ^--arg-file=.* ]]; then xargs \"$1\" cat; " +
"else cat \"$@\"; fi >{output}".format(output = output.path)),
arguments = [args],
)
return output
def _transitive_descriptor_set_impl(ctx):
"""Combine descriptors for all transitive proto dependencies into one file.
Warning: Concatenating all of the descriptor files with a single `cat` command
could exceed system limits (1MB+). For example, a dependency on gwslog.proto
will trigger this edge case.
When writing new code, prefer to accept a list of descriptor files instead of
just one so that this limitation won't impact you.
"""
output = ctx.actions.declare_file(ctx.attr.name + "-transitive-descriptor-set.proto.bin")
calculate_transitive_descriptor_set(ctx.actions, ctx.attr.deps, output)
return DefaultInfo(
files = depset([output]),
runfiles = ctx.runfiles(files = [output]),
)
# transitive_descriptor_set outputs a single file containing a binary
# FileDescriptorSet with all transitive dependencies of the given proto
# dependencies.
#
# Example usage:
#
# transitive_descriptor_set(
# name = "my_descriptors",
# deps = [":my_proto"],
# )
transitive_descriptor_set = rule(
attrs = {
"deps": attr.label_list(providers = [[ProtoInfo], [TransitiveDescriptorInfo]]),
},
outputs = {
"out": "%{name}-transitive-descriptor-set.proto.bin",
},
implementation = _transitive_descriptor_set_impl,
)
def calculate_direct_descriptor_set(actions, deps, output):
"""Calculates the direct dependencies of the deps.
Args:
actions: the actions (typically ctx.actions) used to run commands
deps: the deps to get the direct dependencies of
output: the output file the data will be written to
Returns:
The same output file passed as the input arg, for convenience.
"""
descriptor_set = depset(
[dep[ProtoInfo].direct_descriptor_set for dep in deps if ProtoInfo in dep],
transitive = [dep[DirectDescriptorInfo].descriptors for dep in deps if ProtoInfo not in dep],
)
actions.run_shell(
outputs = [output],
inputs = descriptor_set,
progress_message = "Joining direct descriptors.",
command = ("cat %s > %s") % (
" ".join([d.path for d in descriptor_set.to_list()]),
output.path,
),
)
return output
def _direct_descriptor_set_impl(ctx):
calculate_direct_descriptor_set(ctx.actions, ctx.attr.deps, ctx.outputs.out)
# direct_descriptor_set outputs a single file containing a binary
# FileDescriptorSet with all direct, non transitive dependencies of
# the given proto dependencies.
#
# Example usage:
#
# direct_descriptor_set(
# name = "my_direct_descriptors",
# deps = [":my_proto"],
# )
direct_descriptor_set = rule(
attrs = {
"deps": attr.label_list(providers = [[ProtoInfo], [DirectDescriptorInfo]]),
},
outputs = {
"out": "%{name}-direct-descriptor-set.proto.bin",
},
implementation = _direct_descriptor_set_impl,
)

View File

@ -15,48 +15,74 @@
#ifndef MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_
#define MEDIAPIPE_DEPS_MESSAGE_MATCHERS_H_
#include <memory>
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
namespace mediapipe {
namespace internal {
bool EqualsMessage(const proto_ns::MessageLite& m_1,
const proto_ns::MessageLite& m_2) {
std::string s_1, s_2;
m_1.SerializeToString(&s_1);
m_2.SerializeToString(&s_2);
return s_1 == s_2;
}
} // namespace internal
template <typename MessageType>
class ProtoMatcher : public testing::MatcherInterface<MessageType> {
using MatchResultListener = testing::MatchResultListener;
class ProtoMatcher {
public:
explicit ProtoMatcher(const MessageType& message) : message_(message) {}
virtual bool MatchAndExplain(MessageType m, MatchResultListener*) const {
return internal::EqualsMessage(message_, m);
using is_gtest_matcher = void;
using MessageType = proto_ns::MessageLite;
explicit ProtoMatcher(const MessageType& message)
: message_(CloneMessage(message)) {}
bool MatchAndExplain(const MessageType& m,
testing::MatchResultListener*) const {
return EqualsMessage(*message_, m);
}
bool MatchAndExplain(const MessageType* m,
testing::MatchResultListener*) const {
return EqualsMessage(*message_, *m);
}
virtual void DescribeTo(::std::ostream* os) const {
#if defined(MEDIAPIPE_PROTO_LITE)
*os << "Protobuf messages have identical serializations.";
#else
*os << message_.DebugString();
#endif
void DescribeTo(std::ostream* os) const {
*os << "has the same serialization as " << ExpectedMessageDescription();
}
void DescribeNegationTo(std::ostream* os) const {
*os << "does not have the same serialization as "
<< ExpectedMessageDescription();
}
private:
const MessageType message_;
std::unique_ptr<MessageType> CloneMessage(const MessageType& message) {
std::unique_ptr<MessageType> clone(message.New());
clone->CheckTypeAndMergeFrom(message);
return clone;
}
bool EqualsMessage(const proto_ns::MessageLite& m_1,
const proto_ns::MessageLite& m_2) const {
std::string s_1, s_2;
m_1.SerializeToString(&s_1);
m_2.SerializeToString(&s_2);
return s_1 == s_2;
}
std::string ExpectedMessageDescription() const {
#if defined(MEDIAPIPE_PROTO_LITE)
return "the expected message";
#else
return message_->DebugString();
#endif
}
const std::shared_ptr<MessageType> message_;
};
template <typename MessageType>
inline testing::PolymorphicMatcher<ProtoMatcher<MessageType>> EqualsProto(
const MessageType& message) {
return testing::PolymorphicMatcher<ProtoMatcher<MessageType>>(
ProtoMatcher<MessageType>(message));
inline ProtoMatcher EqualsProto(const proto_ns::MessageLite& message) {
return ProtoMatcher(message);
}
// for Pointwise
MATCHER(EqualsProto, "") {
const auto& a = ::testing::get<0>(arg);
const auto& b = ::testing::get<1>(arg);
return ::testing::ExplainMatchResult(EqualsProto(b), a, result_listener);
}
} // namespace mediapipe

View File

@ -15,24 +15,102 @@
#ifndef MEDIAPIPE_DEPS_STATUS_MATCHERS_H_
#define MEDIAPIPE_DEPS_STATUS_MATCHERS_H_
#include "absl/status/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "mediapipe/framework/deps/status.h"
#include "mediapipe/framework/port/statusor.h"
namespace mediapipe {
inline const ::absl::Status& GetStatus(const ::absl::Status& status) {
return status;
}
template <typename T>
inline const ::absl::Status& GetStatus(const ::absl::StatusOr<T>& status) {
return status.status();
}
// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
// reference to StatusOr<T>.
template <typename StatusOrType>
class IsOkAndHoldsMatcherImpl
: public ::testing::MatcherInterface<StatusOrType> {
public:
typedef
typename std::remove_reference<StatusOrType>::type::value_type value_type;
template <typename InnerMatcher>
explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
: inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
std::forward<InnerMatcher>(inner_matcher))) {}
void DescribeTo(std::ostream* os) const override {
*os << "is OK and has a value that ";
inner_matcher_.DescribeTo(os);
}
void DescribeNegationTo(std::ostream* os) const override {
*os << "isn't OK or has a value that ";
inner_matcher_.DescribeNegationTo(os);
}
bool MatchAndExplain(
StatusOrType actual_value,
::testing::MatchResultListener* result_listener) const override {
if (!actual_value.ok()) {
*result_listener << "which has status " << actual_value.status();
return false;
}
::testing::StringMatchResultListener inner_listener;
const bool matches =
inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
const std::string inner_explanation = inner_listener.str();
if (!inner_explanation.empty()) {
*result_listener << "which contains value "
<< ::testing::PrintToString(*actual_value) << ", "
<< inner_explanation;
}
return matches;
}
private:
const ::testing::Matcher<const value_type&> inner_matcher_;
};
// Implements IsOkAndHolds(m) as a polymorphic matcher.
template <typename InnerMatcher>
class IsOkAndHoldsMatcher {
public:
explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
: inner_matcher_(std::move(inner_matcher)) {}
// Converts this polymorphic matcher to a monomorphic matcher of the
// given type. StatusOrType can be either StatusOr<T> or a
// reference to StatusOr<T>.
template <typename StatusOrType>
operator ::testing::Matcher<StatusOrType>() const { // NOLINT
return ::testing::Matcher<StatusOrType>(
new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
}
private:
const InnerMatcher inner_matcher_;
};
// Monomorphic implementation of matcher IsOk() for a given type T.
// T can be Status, StatusOr<>, or a reference to either of them.
template <typename T>
class MonoIsOkMatcherImpl : public testing::MatcherInterface<T> {
class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
public:
void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
void DescribeNegationTo(std::ostream* os) const override {
*os << "is not OK";
}
bool MatchAndExplain(T actual_value,
testing::MatchResultListener*) const override {
return actual_value.ok();
::testing::MatchResultListener*) const override {
return GetStatus(actual_value).ok();
}
};
@ -40,11 +118,20 @@ class MonoIsOkMatcherImpl : public testing::MatcherInterface<T> {
class IsOkMatcher {
public:
template <typename T>
operator testing::Matcher<T>() const { // NOLINT
return testing::Matcher<T>(new MonoIsOkMatcherImpl<T>());
operator ::testing::Matcher<T>() const { // NOLINT
return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>());
}
};
// Returns a gMock matcher that matches a StatusOr<> whose status is
// OK and whose value matches the inner matcher.
template <typename InnerMatcher>
IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds(
InnerMatcher&& inner_matcher) {
return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>(
std::forward<InnerMatcher>(inner_matcher));
}
// Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
inline IsOkMatcher IsOk() { return IsOkMatcher(); }

View File

@ -54,6 +54,7 @@ mediapipe_register_type(
types = [
"::mediapipe::Classification",
"::mediapipe::ClassificationList",
"::mediapipe::ClassificationListCollection",
"::std::vector<::mediapipe::Classification>",
"::std::vector<::mediapipe::ClassificationList>",
],
@ -262,8 +263,10 @@ mediapipe_register_type(
types = [
"::mediapipe::Landmark",
"::mediapipe::LandmarkList",
"::mediapipe::LandmarkListCollection",
"::mediapipe::NormalizedLandmark",
"::mediapipe::NormalizedLandmarkList",
"::mediapipe::NormalizedLandmarkListCollection",
"::std::vector<::mediapipe::Landmark>",
"::std::vector<::mediapipe::LandmarkList>",
"::std::vector<::mediapipe::NormalizedLandmark>",

View File

@ -39,3 +39,8 @@ message Classification {
message ClassificationList {
repeated Classification classification = 1;
}
// Group of ClassificationList protos.
message ClassificationListCollection {
repeated ClassificationList classification_list = 1;
}

View File

@ -47,6 +47,11 @@ message LandmarkList {
repeated Landmark landmark = 1;
}
// Group of LandmarkList protos.
message LandmarkListCollection {
repeated LandmarkList landmark_list = 1;
}
// A normalized version of above Landmark proto. All coordinates should be
// within [0, 1].
message NormalizedLandmark {
@ -61,3 +66,8 @@ message NormalizedLandmark {
message NormalizedLandmarkList {
repeated NormalizedLandmark landmark = 1;
}
// Group of NormalizedLandmarkList protos.
message NormalizedLandmarkListCollection {
repeated NormalizedLandmarkList landmark_list = 1;
}

View File

@ -428,37 +428,37 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
} else
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
// Transfer data from texture if not transferred from SSBO/MTLBuffer
// yet.
if (valid_ & kValidOpenGlTexture2d) {
gl_context_->Run([this]() {
const int padded_size =
texture_height_ * texture_width_ * 4 * element_size();
auto temp_buffer = absl::make_unique<uint8_t[]>(padded_size);
uint8_t* buffer = temp_buffer.get();
// Transfer data from texture if not transferred from SSBO/MTLBuffer
// yet.
if (valid_ & kValidOpenGlTexture2d) {
gl_context_->Run([this]() {
const int padded_size =
texture_height_ * texture_width_ * 4 * element_size();
auto temp_buffer = absl::make_unique<uint8_t[]>(padded_size);
uint8_t* buffer = temp_buffer.get();
glBindFramebuffer(GL_FRAMEBUFFER, frame_buffer_);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_TEXTURE_2D, opengl_texture2d_, 0);
glPixelStorei(GL_PACK_ALIGNMENT, 4);
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
buffer);
glBindFramebuffer(GL_FRAMEBUFFER, frame_buffer_);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_TEXTURE_2D, opengl_texture2d_, 0);
glPixelStorei(GL_PACK_ALIGNMENT, 4);
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
buffer);
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
const int actual_depth_size =
BhwcDepthFromShape(shape_) * element_size();
const int num_slices = (BhwcDepthFromShape(shape_) + 3) / 4;
const int padded_depth_size = num_slices * 4 * element_size();
const int num_elements = BhwcWidthFromShape(shape_) *
BhwcHeightFromShape(shape_) *
BhwcBatchFromShape(shape_);
for (int e = 0; e < num_elements; e++) {
std::memcpy(dest_buffer, buffer, actual_depth_size);
dest_buffer += actual_depth_size;
buffer += padded_depth_size;
}
});
}
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
const int actual_depth_size =
BhwcDepthFromShape(shape_) * element_size();
const int num_slices = (BhwcDepthFromShape(shape_) + 3) / 4;
const int padded_depth_size = num_slices * 4 * element_size();
const int num_elements = BhwcWidthFromShape(shape_) *
BhwcHeightFromShape(shape_) *
BhwcBatchFromShape(shape_);
for (int e = 0; e < num_elements; e++) {
std::memcpy(dest_buffer, buffer, actual_depth_size);
dest_buffer += actual_depth_size;
buffer += padded_depth_size;
}
});
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
valid_ |= kValidCpu;
}

View File

@ -127,7 +127,7 @@ const proto_ns::MessageLite& Packet::GetProtoMessageLite() const {
}
StatusOr<std::vector<const proto_ns::MessageLite*>>
Packet::GetVectorOfProtoMessageLitePtrs() {
Packet::GetVectorOfProtoMessageLitePtrs() const {
if (holder_ == nullptr) {
return absl::InternalError("Packet is empty.");
}

View File

@ -175,7 +175,7 @@ class Packet {
// Note: This function is meant to be used internally within the MediaPipe
// framework only.
StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLitePtrs();
GetVectorOfProtoMessageLitePtrs() const;
// Returns an error if the packet does not contain data of type T.
template <typename T>
@ -391,7 +391,7 @@ class HolderBase {
// underlying object is a vector of protocol buffer objects, otherwise,
// returns an error.
virtual StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLite() = 0;
GetVectorOfProtoMessageLite() const = 0;
private:
size_t type_id_;
@ -563,7 +563,7 @@ class Holder : public HolderBase {
// underlying object is a vector of protocol buffer objects, otherwise,
// returns an error.
StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLite() override {
GetVectorOfProtoMessageLite() const override {
return ConvertToVectorOfProtoMessageLitePtrs(ptr_, is_proto_vector<T>());
}

View File

@ -370,7 +370,8 @@ absl::Status PacketGeneratorGraph::Initialize(
absl::Status PacketGeneratorGraph::RunGraphSetup(
const std::map<std::string, Packet>& input_side_packets,
std::map<std::string, Packet>* output_side_packets) const {
std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators) const {
*output_side_packets = base_packets_;
for (const std::pair<const std::string, Packet>& item : input_side_packets) {
auto iter = output_side_packets->find(item.first);
@ -380,7 +381,9 @@ absl::Status PacketGeneratorGraph::RunGraphSetup(
}
output_side_packets->insert(iter, item);
}
std::vector<int> non_scheduled_generators;
std::vector<int> non_scheduled_generators_local;
if (!non_scheduled_generators)
non_scheduled_generators = &non_scheduled_generators_local;
MP_RETURN_IF_ERROR(
validated_graph_->CanAcceptSidePackets(input_side_packets));
@ -389,11 +392,7 @@ absl::Status PacketGeneratorGraph::RunGraphSetup(
MP_RETURN_IF_ERROR(
validated_graph_->ValidateRequiredSidePackets(*output_side_packets));
MP_RETURN_IF_ERROR(ExecuteGenerators(
output_side_packets, &non_scheduled_generators, /*initial=*/false));
RET_CHECK(non_scheduled_generators.empty())
<< "Some Generators were unrunnable (validation should have failed).\n"
"Generator indexes: "
<< absl::StrJoin(non_scheduled_generators, ", ");
output_side_packets, non_scheduled_generators, /*initial=*/false));
return absl::OkStatus();
}

View File

@ -76,7 +76,8 @@ class PacketGeneratorGraph {
// must now be runnable) to produce output_side_packets.
virtual absl::Status RunGraphSetup(
const std::map<std::string, Packet>& input_side_packets,
std::map<std::string, Packet>* output_side_packets) const;
std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators = nullptr) const;
// Get the base packets: the packets which are produced when Initialize
// is called.

View File

@ -21,6 +21,7 @@
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/packet_test.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
@ -214,6 +215,21 @@ TEST(PacketTest, ValidateAsProtoMessageLite) {
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
}
TEST(PacketTest, GetVectorOfProtos) {
std::vector<mediapipe::PacketTestProto> protos(2);
protos[0].add_x(123);
protos[1].add_x(456);
// Normally we'd move here, but we copy to use the protos for comparison.
const Packet packet =
MakePacket<std::vector<mediapipe::PacketTestProto>>(protos);
auto maybe_proto_ptrs = packet.GetVectorOfProtoMessageLitePtrs();
EXPECT_THAT(maybe_proto_ptrs,
IsOkAndHolds(testing::Pointwise(EqualsProto(), protos)));
const Packet wrong = MakePacket<int>(1);
EXPECT_THAT(wrong.GetVectorOfProtoMessageLitePtrs(), testing::Not(IsOk()));
}
TEST(PacketTest, SyncedPacket) {
Packet synced_packet = AdoptAsSyncedPacket(new int(100));
Packet value_packet =

View File

@ -4,6 +4,7 @@
""".bzl file for mediapipe open source build configs."""
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
def provided_args(**kwargs):
"""Returns the keyword arguments omitting None arguments."""
@ -47,6 +48,7 @@ def mediapipe_proto_library(
def_objc_proto = True,
def_java_proto = True,
def_jspb_proto = True,
def_options_lib = True,
portable_deps = None):
"""Defines the proto_library targets needed for all mediapipe platforms.
@ -67,6 +69,7 @@ def mediapipe_proto_library(
def_objc_proto: define the objc_proto_library target
def_java_proto: define the java_proto_library target
def_jspb_proto: define the jspb_proto_library target
def_options_lib: define the mediapipe_options_library target
"""
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps]
@ -116,6 +119,17 @@ def mediapipe_proto_library(
compatible_with = compatible_with,
))
if def_options_lib:
cc_deps = replace_deps(deps, "_proto", "_cc_proto")
mediapipe_options_library(**provided_args(
name = replace_suffix(name, "_proto", "_options_lib"),
proto_lib = name,
deps = cc_deps,
visibility = visibility,
testonly = testonly,
compatible_with = compatible_with,
))
def mediapipe_py_proto_library(
name,
srcs,

View File

@ -113,6 +113,7 @@ cc_library(
"//mediapipe/framework/port:advanced_proto_lite",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:re2",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util",

View File

@ -24,6 +24,7 @@
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/re2.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/profiler/profiler_resource_util.h"
@ -71,6 +72,8 @@ bool IsTracerEnabled(const ProfilerConfig& profiler_config) {
}
// Returns true if trace events are written to a log file.
// Note that for now, file output is only for graph-trace and not for
// calculator-profile.
bool IsTraceLogEnabled(const ProfilerConfig& profiler_config) {
return IsTracerEnabled(profiler_config) &&
!profiler_config.trace_log_disabled();
@ -117,6 +120,39 @@ PacketInfo* GetPacketInfo(PacketInfoMap* map, const PacketId& packet_id) {
} // namespace
// Builds GraphProfile records from profiler timing data.
class GraphProfiler::GraphProfileBuilder {
public:
GraphProfileBuilder(GraphProfiler* profiler)
: profiler_(profiler), calculator_regex_(".*") {
auto& filter = profiler_->profiler_config().calculator_filter();
calculator_regex_ = filter.empty() ? calculator_regex_ : RE2(filter);
}
bool ProfileIncluded(const CalculatorProfile& p) {
return RE2::FullMatch(p.name(), calculator_regex_);
}
private:
GraphProfiler* profiler_;
RE2 calculator_regex_;
};
GraphProfiler::GraphProfiler()
: is_initialized_(false),
is_profiling_(false),
calculator_profiles_(1000),
packets_info_(1000),
is_running_(false),
previous_log_end_time_(absl::InfinitePast()),
previous_log_index_(-1),
validated_graph_(nullptr) {
clock_ = std::shared_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
}
GraphProfiler::~GraphProfiler() {}
void GraphProfiler::Initialize(
const ValidatedGraphConfig& validated_graph_config) {
absl::WriterMutexLock lock(&profiler_mutex_);
@ -156,6 +192,7 @@ void GraphProfiler::Initialize(
CHECK(iter.second) << absl::Substitute(
"Calculator \"$0\" has already been added.", node_name);
}
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
is_initialized_ = true;
}
@ -554,15 +591,43 @@ class OstreamStream : public proto_ns::io::ZeroCopyOutputStream {
};
// Sets the canonical node name in each CalculatorGraphConfig::Node
// and also in GraphTrace.
// and also in the GraphTrace if present.
void AssignNodeNames(GraphProfile* profile) {
CalculatorGraphConfig* graph_config = profile->mutable_config();
GraphTrace* graph_trace = profile->mutable_graph_trace(0);
graph_trace->clear_calculator_name();
GraphTrace* graph_trace = profile->graph_trace_size() > 0
? profile->mutable_graph_trace(0)
: nullptr;
if (graph_trace) {
graph_trace->clear_calculator_name();
}
for (int i = 0; i < graph_config->node().size(); ++i) {
std::string node_name = CanonicalNodeName(*graph_config, i);
graph_config->mutable_node(i)->set_name(node_name);
graph_trace->add_calculator_name(node_name);
if (graph_trace) {
graph_trace->add_calculator_name(node_name);
}
}
}
// Clears fields containing their default values.
void CleanTimeHistogram(TimeHistogram* histogram) {
if (histogram->num_intervals() == 1) {
histogram->clear_num_intervals();
}
if (histogram->interval_size_usec() == 1000000) {
histogram->clear_interval_size_usec();
}
}
// Clears fields containing their default values.
void CleanCalculatorProfiles(GraphProfile* profile) {
for (CalculatorProfile& p : *profile->mutable_calculator_profiles()) {
CleanTimeHistogram(p.mutable_process_runtime());
CleanTimeHistogram(p.mutable_process_input_latency());
CleanTimeHistogram(p.mutable_process_output_latency());
for (StreamProfile& s : *p.mutable_input_stream_profiles()) {
CleanTimeHistogram(s.mutable_latency());
}
}
}
@ -588,11 +653,13 @@ absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) {
absl::Time end_time =
clock_->TimeNow() -
absl::Microseconds(profiler_config_.trace_log_margin_usec());
GraphTrace* trace = result->add_graph_trace();
if (!profiler_config_.trace_log_instant_events()) {
tracer()->GetTrace(previous_log_end_time_, end_time, trace);
} else {
tracer()->GetLog(previous_log_end_time_, end_time, trace);
if (tracer()) {
GraphTrace* trace = result->add_graph_trace();
if (!profiler_config_.trace_log_instant_events()) {
tracer()->GetTrace(previous_log_end_time_, end_time, trace);
} else {
tracer()->GetLog(previous_log_end_time_, end_time, trace);
}
}
previous_log_end_time_ = end_time;
@ -601,9 +668,12 @@ absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) {
std::vector<CalculatorProfile> profiles;
status.Update(GetCalculatorProfiles(&profiles));
for (CalculatorProfile& p : profiles) {
*result->mutable_calculator_profiles()->Add() = std::move(p);
if (profile_builder_->ProfileIncluded(p)) {
*result->mutable_calculator_profiles()->Add() = std::move(p);
}
}
this->Reset();
CleanCalculatorProfiles(result);
return status;
}

View File

@ -97,18 +97,8 @@ class GraphProfilerTestPeer;
// The client can overwrite this by calling SetClock().
class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
public:
GraphProfiler()
: is_initialized_(false),
is_profiling_(false),
calculator_profiles_(1000),
packets_info_(1000),
is_running_(false),
previous_log_end_time_(absl::InfinitePast()),
previous_log_index_(-1),
validated_graph_(nullptr) {
clock_ = std::shared_ptr<mediapipe::Clock>(
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
}
GraphProfiler();
~GraphProfiler();
// Not copyable or movable.
GraphProfiler(const GraphProfiler&) = delete;
@ -230,6 +220,8 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
int64 start_time_usec_;
};
const ProfilerConfig& profiler_config() { return profiler_config_; }
private:
// This can be used to add packet info for the input streams to the graph.
// It treats the stream defined by |stream_name| as a stream produced by a
@ -303,6 +295,7 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
// Helper method to get the clock time in microsecond.
int64 TimeNowUsec() { return ToUnixMicros(clock_->TimeNow()); }
private:
// The settings for this tracer.
ProfilerConfig profiler_config_;
@ -345,6 +338,10 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
// The configuration for the graph being profiled.
const ValidatedGraphConfig* validated_graph_;
// A private resource for creating GraphProfiles.
class GraphProfileBuilder;
std::unique_ptr<GraphProfileBuilder> profile_builder_;
// For testing.
friend GraphProfilerTestPeer;
};

View File

@ -1205,5 +1205,68 @@ TEST(GraphProfilerTest, ParallelReads) {
EXPECT_EQ(1001, out_1_packets.size());
}
// Returns the set of calculator names in a GraphProfile captured from
// CalculatorGraph initialized from a certain CalculatorGraphConfig.
std::set<std::string> GetCalculatorNames(const CalculatorGraphConfig& config) {
std::set<std::string> result;
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config));
GraphProfile profile;
MP_EXPECT_OK(graph.profiler()->CaptureProfile(&profile));
for (auto& p : profile.calculator_profiles()) {
result.insert(p.name());
}
return result;
}
TEST(GraphProfilerTest, CalculatorProfileFilter) {
CalculatorGraphConfig config;
QCHECK(proto2::TextFormat::ParseFromString(R"(
profiler_config {
enable_profiler: true
}
node {
calculator: "RangeCalculator"
input_side_packet: "range_step"
output_stream: "out"
output_stream: "sum"
output_stream: "mean"
}
node {
calculator: "PassThroughCalculator"
input_stream: "out"
input_stream: "sum"
input_stream: "mean"
output_stream: "out_1"
output_stream: "sum_1"
output_stream: "mean_1"
}
output_stream: "OUT:0:the_integers"
)",
&config));
std::set<std::string> expected_names;
expected_names = {"RangeCalculator", "PassThroughCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() =
"RangeCalculator";
expected_names = {"RangeCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() = "Range.*";
expected_names = {"RangeCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() =
".*Calculator";
expected_names = {"RangeCalculator", "PassThroughCalculator"};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
*config.mutable_profiler_config()->mutable_calculator_filter() = ".*Clock.*";
expected_names = {};
EXPECT_EQ(GetCalculatorNames(config), expected_names);
}
} // namespace
} // namespace mediapipe

Some files were not shown because too many files have changed in this diff Show More