Project import generated by Copybara.
GitOrigin-RevId: ea8d45731f5a052f79745e35bfd8240d6ac568d2
This commit is contained in:
parent
38be2ec58f
commit
39309bedba
|
@ -102,6 +102,8 @@ run code search using
|
|||
|
||||
## Publications
|
||||
|
||||
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
|
||||
in Google AI Blog
|
||||
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)
|
||||
in Google AI Blog
|
||||
* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html)
|
||||
|
|
|
@ -364,9 +364,9 @@ http_archive(
|
|||
)
|
||||
|
||||
#Tensorflow repo should always go after the other external dependencies.
|
||||
# 2020-12-07
|
||||
_TENSORFLOW_GIT_COMMIT = "f556709f4df005ad57fd24d5eaa0d9380128d3ba"
|
||||
_TENSORFLOW_SHA256= "9e157d4723921b48a974f645f70d07c8fd3c363569a0ac6ee85fec114d6459ea"
|
||||
# 2020-12-09
|
||||
_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a"
|
||||
_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
urls = [
|
||||
|
|
|
@ -94,7 +94,8 @@ for app in ${apps}; do
|
|||
else
|
||||
graph_name="${target_name}/${target_name}"
|
||||
fi
|
||||
if [[ ${target_name} == "iris_tracking" ||
|
||||
if [[ ${target_name} == "holistic_tracking" ||
|
||||
${target_name} == "iris_tracking" ||
|
||||
${target_name} == "pose_tracking" ||
|
||||
${target_name} == "upper_body_pose_tracking" ]]; then
|
||||
graph_suffix="cpu"
|
||||
|
|
|
@ -31,16 +31,16 @@ install --user six`.
|
|||
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
to install Bazel 3.4 or higher.
|
||||
|
||||
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu only, Bazel needs
|
||||
For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs
|
||||
to be built from source:
|
||||
|
||||
```bash
|
||||
# For Bazel 3.4.0
|
||||
mkdir $HOME/bazel-3.4.0
|
||||
cd $HOME/bazel-3.4.0
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/3.4.0/bazel-3.4.0-dist.zip
|
||||
# For Bazel 3.4.1
|
||||
mkdir $HOME/bazel-3.4.1
|
||||
cd $HOME/bazel-3.4.1
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip
|
||||
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
|
||||
unzip bazel-3.4.0-dist.zip
|
||||
unzip bazel-3.4.1-dist.zip
|
||||
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
|
||||
sudo cp output/bazel /usr/local/bin/
|
||||
```
|
||||
|
@ -338,14 +338,7 @@ build issues.
|
|||
|
||||
2. Install Bazel.
|
||||
|
||||
Option 1. Use package manager tool to install Bazel
|
||||
|
||||
```bash
|
||||
$ brew install bazel
|
||||
# Run 'bazel version' to check version of bazel
|
||||
```
|
||||
|
||||
Option 2. Follow the official
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
|
||||
to install Bazel 3.4 or higher.
|
||||
|
||||
|
@ -604,14 +597,14 @@ cameras. Alternatively, you use a video file as input.
|
|||
|
||||
```bash
|
||||
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
|
||||
https://storage.googleapis.com/bazel/3.4.0/release/bazel-3.4.0-installer-linux-x86_64.sh && \
|
||||
sudo mkdir -p /usr/local/bazel/3.4.0 && \
|
||||
chmod 755 bazel-3.4.0-installer-linux-x86_64.sh && \
|
||||
sudo ./bazel-3.4.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.0 && \
|
||||
source /usr/local/bazel/3.4.0/lib/bazel/bin/bazel-complete.bash
|
||||
https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \
|
||||
sudo mkdir -p /usr/local/bazel/3.4.1 && \
|
||||
chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \
|
||||
sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \
|
||||
source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash
|
||||
|
||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.0/lib/bazel/bin/bazel version && \
|
||||
alias bazel='/usr/local/bazel/3.4.0/lib/bazel/bin/bazel'
|
||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \
|
||||
alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel'
|
||||
```
|
||||
|
||||
6. Checkout MediaPipe repository.
|
||||
|
|
|
@ -70,6 +70,11 @@ Python package from source. Otherwise, we strongly encourage our users to simply
|
|||
run `pip install mediapipe` to use the ready-to-use solutions, more convenient
|
||||
and much faster.
|
||||
|
||||
MediaPipe PyPI currently doesn't provide aarch64 Python wheel
|
||||
files. For building and using MediaPipe Python on aarch64 Linux systems such as
|
||||
Nvidia Jetson and Raspberry Pi, please read
|
||||
[here](https://github.com/jiuqiant/mediapipe-python-aarch64).
|
||||
|
||||
1. Make sure that Bazel and OpenCV are correctly installed and configured for
|
||||
MediaPipe. Please see [Installation](./install.md) for how to setup Bazel
|
||||
and OpenCV for MediaPipe on Linux and macOS.
|
||||
|
@ -82,12 +87,18 @@ and much faster.
|
|||
$ sudo apt install python3-dev
|
||||
$ sudo apt install python3-venv
|
||||
$ sudo apt install -y protobuf-compiler
|
||||
|
||||
# If you need to build opencv from source.
|
||||
$ sudo apt install cmake
|
||||
```
|
||||
|
||||
macOS:
|
||||
|
||||
```bash
|
||||
$ brew install protobuf
|
||||
|
||||
# If you need to build opencv from source.
|
||||
$ brew install cmake
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
@ -118,3 +129,10 @@ and much faster.
|
|||
(mp_env)mediapipe$ python3 setup.py gen_protos
|
||||
(mp_env)mediapipe$ python3 setup.py install --link-opencv
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
(mp_env)mediapipe$ python3 setup.py gen_protos
|
||||
(mp_env)mediapipe$ python3 setup.py bdist_wheel
|
||||
```
|
||||
|
|
|
@ -102,6 +102,8 @@ run code search using
|
|||
|
||||
## Publications
|
||||
|
||||
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
|
||||
in Google AI Blog
|
||||
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)
|
||||
in Google AI Blog
|
||||
* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html)
|
||||
|
|
|
@ -405,7 +405,7 @@ on how to build MediaPipe examples.
|
|||
## Resources
|
||||
|
||||
* Google AI Blog:
|
||||
[MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction on Device](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
|
||||
[MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
|
||||
* [Models and model cards](./models.md#holistic)
|
||||
|
||||
[Colab]:https://mediapipe.page.link/holistic_py_colab
|
||||
|
|
|
@ -2,33 +2,33 @@
|
|||
"additionalFilePaths" : [
|
||||
"/BUILD",
|
||||
"mediapipe/BUILD",
|
||||
"mediapipe/objc/BUILD",
|
||||
"mediapipe/framework/BUILD",
|
||||
"mediapipe/gpu/BUILD",
|
||||
"mediapipe/objc/testing/app/BUILD",
|
||||
"mediapipe/examples/ios/common/BUILD",
|
||||
"mediapipe/examples/ios/helloworld/BUILD",
|
||||
"mediapipe/examples/ios/facedetectioncpu/BUILD",
|
||||
"mediapipe/examples/ios/facedetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/faceeffect/BUILD",
|
||||
"mediapipe/examples/ios/facemeshgpu/BUILD",
|
||||
"mediapipe/examples/ios/handdetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/handtrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/helloworld/BUILD",
|
||||
"mediapipe/examples/ios/holistictrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/iristrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectioncpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/posetrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD"
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD",
|
||||
"mediapipe/framework/BUILD",
|
||||
"mediapipe/gpu/BUILD",
|
||||
"mediapipe/objc/BUILD",
|
||||
"mediapipe/objc/testing/app/BUILD"
|
||||
],
|
||||
"buildTargets" : [
|
||||
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
|
||||
"//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp",
|
||||
"//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/faceeffect:FaceEffectApp",
|
||||
"//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp",
|
||||
"//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
|
||||
"//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp",
|
||||
|
@ -91,13 +91,13 @@
|
|||
"mediapipe/examples/ios",
|
||||
"mediapipe/examples/ios/common",
|
||||
"mediapipe/examples/ios/common/Base.lproj",
|
||||
"mediapipe/examples/ios/helloworld",
|
||||
"mediapipe/examples/ios/facedetectioncpu",
|
||||
"mediapipe/examples/ios/facedetectiongpu",
|
||||
"mediapipe/examples/ios/faceeffect",
|
||||
"mediapipe/examples/ios/faceeffect/Base.lproj",
|
||||
"mediapipe/examples/ios/handdetectiongpu",
|
||||
"mediapipe/examples/ios/handtrackinggpu",
|
||||
"mediapipe/examples/ios/helloworld",
|
||||
"mediapipe/examples/ios/holistictrackinggpu",
|
||||
"mediapipe/examples/ios/iristrackinggpu",
|
||||
"mediapipe/examples/ios/objectdetectioncpu",
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
"packages" : [
|
||||
"",
|
||||
"mediapipe",
|
||||
"mediapipe/objc",
|
||||
"mediapipe/examples/ios",
|
||||
"mediapipe/examples/ios/facedetectioncpu",
|
||||
"mediapipe/examples/ios/facedetectiongpu",
|
||||
|
@ -22,7 +21,8 @@
|
|||
"mediapipe/examples/ios/objectdetectioncpu",
|
||||
"mediapipe/examples/ios/objectdetectiongpu",
|
||||
"mediapipe/examples/ios/posetrackinggpu",
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu"
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu",
|
||||
"mediapipe/objc"
|
||||
],
|
||||
"projectName" : "Mediapipe",
|
||||
"workspaceRoot" : "../.."
|
||||
|
|
|
@ -146,6 +146,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
|
@ -286,6 +287,7 @@ cc_library(
|
|||
deps = [
|
||||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -393,6 +395,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -406,7 +409,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
|
@ -422,7 +425,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
|
@ -438,6 +441,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
|
||||
],
|
||||
|
@ -606,6 +610,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
|
@ -958,6 +963,7 @@ cc_library(
|
|||
deps = [
|
||||
":sequence_shift_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -1007,6 +1013,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -1042,6 +1049,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
|
|
|
@ -12,11 +12,13 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Attach the header from a stream or side input to another stream.
|
||||
//
|
||||
|
@ -42,49 +44,40 @@ namespace mediapipe {
|
|||
// output_stream: "audio_with_header"
|
||||
// }
|
||||
//
|
||||
class AddHeaderCalculator : public CalculatorBase {
|
||||
class AddHeaderCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
bool has_side_input = false;
|
||||
bool has_header_stream = false;
|
||||
if (cc->InputSidePackets().HasTag("HEADER")) {
|
||||
cc->InputSidePackets().Tag("HEADER").SetAny();
|
||||
has_side_input = true;
|
||||
}
|
||||
if (cc->Inputs().HasTag("HEADER")) {
|
||||
cc->Inputs().Tag("HEADER").SetNone();
|
||||
has_header_stream = true;
|
||||
}
|
||||
if (has_side_input == has_header_stream) {
|
||||
static constexpr Input<NoneType>::Optional kHeader{"HEADER"};
|
||||
static constexpr SideInput<AnyType>::Optional kHeaderSide{"HEADER"};
|
||||
static constexpr Input<AnyType> kData{"DATA"};
|
||||
static constexpr Output<SameType<kData>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut);
|
||||
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"Header must be provided via exactly one of side input and input "
|
||||
"stream");
|
||||
}
|
||||
cc->Inputs().Tag("DATA").SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA"));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
Packet header;
|
||||
if (cc->InputSidePackets().HasTag("HEADER")) {
|
||||
header = cc->InputSidePackets().Tag("HEADER");
|
||||
}
|
||||
if (cc->Inputs().HasTag("HEADER")) {
|
||||
header = cc->Inputs().Tag("HEADER").Header();
|
||||
}
|
||||
const PacketBase& header =
|
||||
kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc);
|
||||
if (!header.IsEmpty()) {
|
||||
cc->Outputs().Index(0).SetHeader(header);
|
||||
kOut(cc).SetHeader(header);
|
||||
}
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value());
|
||||
kOut(cc).Send(kData(cc).packet());
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(AddHeaderCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
|
@ -23,27 +24,24 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Concatenates several NormalizedLandmarkList protos following stream index
|
||||
// order. This class assumes that every input stream contains a
|
||||
// NormalizedLandmarkList proto object.
|
||||
class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase {
|
||||
class ConcatenateNormalizedLandmarkListCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() != 0);
|
||||
RET_CHECK(cc->Outputs().NumEntries() == 1);
|
||||
static constexpr Input<NormalizedLandmarkList>::Multiple kIn{""};
|
||||
static constexpr Output<NormalizedLandmarkList> kOut{""};
|
||||
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).Set<NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Set<NormalizedLandmarkList>();
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GE(kIn(cc).Count(), 1);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
only_emit_if_all_present_ =
|
||||
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
|
||||
.only_emit_if_all_present();
|
||||
|
@ -52,32 +50,29 @@ class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase {
|
|||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (only_emit_if_all_present_) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) return mediapipe::OkStatus();
|
||||
for (int i = 0; i < kIn(cc).Count(); ++i) {
|
||||
if (kIn(cc)[i].IsEmpty()) return mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
NormalizedLandmarkList output;
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) continue;
|
||||
const NormalizedLandmarkList& input =
|
||||
cc->Inputs().Index(i).Get<NormalizedLandmarkList>();
|
||||
for (int i = 0; i < kIn(cc).Count(); ++i) {
|
||||
if (kIn(cc)[i].IsEmpty()) continue;
|
||||
const NormalizedLandmarkList& input = *kIn(cc)[i];
|
||||
for (int j = 0; j < input.landmark_size(); ++j) {
|
||||
const NormalizedLandmark& input_landmark = input.landmark(j);
|
||||
*output.add_landmark() = input_landmark;
|
||||
*output.add_landmark() = input.landmark(j);
|
||||
}
|
||||
}
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<NormalizedLandmarkList>(output).At(cc->InputTimestamp()));
|
||||
kOut(cc).Send(std::move(output));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
bool only_emit_if_all_present_;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator);
|
||||
|
||||
REGISTER_CALCULATOR(ConcatenateNormalizedLandmarkListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
|
|
|
@ -15,10 +15,12 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Given two input streams (A, B), output a single stream containing a pair<A,
|
||||
// B>.
|
||||
|
@ -30,32 +32,27 @@ namespace mediapipe {
|
|||
// input_stream: "packet_b"
|
||||
// output_stream: "output_pair_a_b"
|
||||
// }
|
||||
class MakePairCalculator : public CalculatorBase {
|
||||
class MakePairCalculator : public Node {
|
||||
public:
|
||||
MakePairCalculator() {}
|
||||
~MakePairCalculator() override {}
|
||||
static constexpr Input<AnyType>::Multiple kIn{""};
|
||||
// Note that currently api2::Packet is a different type from mediapipe::Packet
|
||||
static constexpr Output<std::pair<mediapipe::Packet, mediapipe::Packet>>
|
||||
kPair{""};
|
||||
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Inputs().Index(1).SetAny();
|
||||
cc->Outputs().Index(0).Set<std::pair<Packet, Packet>>();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kPair);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_EQ(kIn(cc).Count(), 2);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
cc->Outputs().Index(0).Add(
|
||||
new std::pair<Packet, Packet>(cc->Inputs().Index(0).Value(),
|
||||
cc->Inputs().Index(1).Value()),
|
||||
cc->InputTimestamp());
|
||||
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MakePairCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(MakePairCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -13,11 +13,13 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
// Perform a (left) matrix multiply. Meaning (output = A * input)
|
||||
// where A is the matrix which is provided as an input side packet.
|
||||
//
|
||||
|
@ -28,39 +30,22 @@ namespace mediapipe {
|
|||
// output_stream: "multiplied_samples"
|
||||
// input_side_packet: "multiplication_matrix"
|
||||
// }
|
||||
class MatrixMultiplyCalculator : public CalculatorBase {
|
||||
class MatrixMultiplyCalculator : public Node {
|
||||
public:
|
||||
MatrixMultiplyCalculator() {}
|
||||
~MatrixMultiplyCalculator() override {}
|
||||
static constexpr Input<Matrix> kIn{""};
|
||||
static constexpr Output<Matrix> kOut{""};
|
||||
static constexpr SideInput<Matrix> kSide{""};
|
||||
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(MatrixMultiplyCalculator);
|
||||
|
||||
// static
|
||||
mediapipe::Status MatrixMultiplyCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>();
|
||||
cc->Outputs().Index(0).Set<Matrix>();
|
||||
cc->InputSidePackets().Index(0).Set<Matrix>();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) {
|
||||
// The output is at the same timestamp as the input.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator);
|
||||
|
||||
mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
|
||||
Matrix* multiplied = new Matrix();
|
||||
*multiplied = cc->InputSidePackets().Index(0).Get<Matrix>() *
|
||||
cc->Inputs().Index(0).Get<Matrix>();
|
||||
cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp());
|
||||
kOut(cc).Send(*kSide(cc) * *kIn(cc));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -13,11 +13,13 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Subtract input matrix from the side input matrix and vice versa. The matrices
|
||||
// must have the same dimension.
|
||||
|
@ -41,83 +43,40 @@ namespace mediapipe {
|
|||
// input_side_packet: "MINUEND:side_matrix"
|
||||
// output_stream: "output_matrix"
|
||||
// }
|
||||
class MatrixSubtractCalculator : public CalculatorBase {
|
||||
class MatrixSubtractCalculator : public Node {
|
||||
public:
|
||||
MatrixSubtractCalculator() {}
|
||||
~MatrixSubtractCalculator() override {}
|
||||
static constexpr Input<Matrix>::SideFallback kMinuend{"MINUEND"};
|
||||
static constexpr Input<Matrix>::SideFallback kSubtrahend{"SUBTRAHEND"};
|
||||
static constexpr Output<Matrix> kOut{""};
|
||||
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut);
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
bool subtract_from_input_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(MatrixSubtractCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator);
|
||||
|
||||
// static
|
||||
mediapipe::Status MatrixSubtractCalculator::GetContract(
|
||||
mediapipe::Status MatrixSubtractCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
if (cc->Inputs().NumEntries() != 1 ||
|
||||
cc->InputSidePackets().NumEntries() != 1) {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"MatrixSubtractCalculator only accepts exactly one input stream and "
|
||||
"one "
|
||||
"input side packet");
|
||||
}
|
||||
if (cc->Inputs().HasTag("MINUEND") &&
|
||||
cc->InputSidePackets().HasTag("SUBTRAHEND")) {
|
||||
cc->Inputs().Tag("MINUEND").Set<Matrix>();
|
||||
cc->InputSidePackets().Tag("SUBTRAHEND").Set<Matrix>();
|
||||
} else if (cc->Inputs().HasTag("SUBTRAHEND") &&
|
||||
cc->InputSidePackets().HasTag("MINUEND")) {
|
||||
cc->Inputs().Tag("SUBTRAHEND").Set<Matrix>();
|
||||
cc->InputSidePackets().Tag("MINUEND").Set<Matrix>();
|
||||
} else {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"Must specify exactly one minuend and one subtrahend.");
|
||||
}
|
||||
cc->Outputs().Index(0).Set<Matrix>();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) {
|
||||
// The output is at the same timestamp as the input.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
if (cc->Inputs().HasTag("MINUEND")) {
|
||||
subtract_from_input_ = true;
|
||||
}
|
||||
// TODO: the next restriction could be relaxed.
|
||||
RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream())
|
||||
<< "MatrixSubtractCalculator only accepts exactly one input stream and "
|
||||
"one input side packet";
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
|
||||
Matrix* subtracted = new Matrix();
|
||||
if (subtract_from_input_) {
|
||||
const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get<Matrix>();
|
||||
const Matrix& side_input_matrix =
|
||||
cc->InputSidePackets().Tag("SUBTRAHEND").Get<Matrix>();
|
||||
if (input_matrix.rows() != side_input_matrix.rows() ||
|
||||
input_matrix.cols() != side_input_matrix.cols()) {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"Input matrix and the input side matrix must have the same "
|
||||
"dimension.");
|
||||
}
|
||||
*subtracted = input_matrix - side_input_matrix;
|
||||
} else {
|
||||
const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get<Matrix>();
|
||||
const Matrix& side_input_matrix =
|
||||
cc->InputSidePackets().Tag("MINUEND").Get<Matrix>();
|
||||
if (input_matrix.rows() != side_input_matrix.rows() ||
|
||||
input_matrix.cols() != side_input_matrix.cols()) {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"Input matrix and the input side matrix must have the same "
|
||||
"dimension.");
|
||||
}
|
||||
*subtracted = side_input_matrix - input_matrix;
|
||||
const Matrix& minuend = *kMinuend(cc);
|
||||
const Matrix& subtrahend = *kSubtrahend(cc);
|
||||
if (minuend.rows() != subtrahend.rows() ||
|
||||
minuend.cols() != subtrahend.cols()) {
|
||||
return mediapipe::InvalidArgumentError(
|
||||
"Minuend and subtrahend must have the same dimensions.");
|
||||
}
|
||||
cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp());
|
||||
kOut(cc).Send(minuend - subtrahend);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -89,9 +89,8 @@ TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
|
|||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
auto status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr("specify exactly one minuend and one subtrahend."));
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
|
@ -30,6 +31,7 @@
|
|||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A calculator that converts a Matrix M to a vector containing all the
|
||||
// entries of M in column-major order.
|
||||
|
@ -40,33 +42,20 @@ namespace mediapipe {
|
|||
// input_stream: "input_matrix"
|
||||
// output_stream: "column_major_vector"
|
||||
// }
|
||||
class MatrixToVectorCalculator : public CalculatorBase {
|
||||
class MatrixToVectorCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input Packet containing a Matrix.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<std::vector<float>>(
|
||||
// Output Packet containing a vector, one for each input Packet.
|
||||
);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
static constexpr Input<Matrix> kIn{""};
|
||||
static constexpr Output<std::vector<float>> kOut{""};
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
// Outputs a packet containing a vector for each input packet.
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(MatrixToVectorCalculator);
|
||||
|
||||
mediapipe::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) {
|
||||
// Inform the framework that we don't alter timestamps.
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator);
|
||||
|
||||
mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
|
||||
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
|
||||
const Matrix& input = *kIn(cc);
|
||||
auto output = absl::make_unique<std::vector<float>>();
|
||||
|
||||
// The following lines work to convert the Matrix to a vector because Matrix
|
||||
|
@ -76,8 +65,9 @@ mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
|
|||
Eigen::Map<Matrix>(output->data(), input.rows(), input.cols());
|
||||
output_as_matrix = input;
|
||||
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
kOut(cc).Send(std::move(output));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -12,11 +12,13 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// This calculator takes a set of input streams and combines them into a single
|
||||
// output stream. The packets from different streams do not need to contain the
|
||||
|
@ -41,40 +43,31 @@ namespace mediapipe {
|
|||
// output_stream: "merged_shot_infos"
|
||||
// }
|
||||
//
|
||||
class MergeCalculator : public CalculatorBase {
|
||||
class MergeCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GT(cc->Inputs().NumEntries(), 0)
|
||||
<< "Needs at least one input stream";
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
if (cc->Inputs().NumEntries() == 1) {
|
||||
static constexpr Input<AnyType>::Multiple kIn{""};
|
||||
static constexpr Output<AnyType> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream";
|
||||
if (kIn(cc).Count() == 1) {
|
||||
LOG(WARNING)
|
||||
<< "MergeCalculator expects multiple input streams to merge but is "
|
||||
"receiving only one. Make sure the calculator is configured "
|
||||
"correctly or consider removing this calculator to reduce "
|
||||
"unnecessary overhead.";
|
||||
}
|
||||
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
}
|
||||
cc->Outputs().Index(0).SetAny();
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
// Output the packet from the first input stream with a packet ready at this
|
||||
// timestamp.
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (!cc->Inputs().Index(i).IsEmpty()) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value());
|
||||
for (int i = 0; i < kIn(cc).Count(); ++i) {
|
||||
if (!kIn(cc)[i].IsEmpty()) {
|
||||
kOut(cc).Send(kIn(cc)[i].packet());
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
|
@ -86,6 +79,7 @@ class MergeCalculator : public CalculatorBase {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MergeCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(MergeCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -12,97 +12,45 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kSelectTag[] = "SELECT";
|
||||
constexpr char kInputTag[] = "INPUT";
|
||||
} // namespace
|
||||
namespace api2 {
|
||||
|
||||
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
|
||||
// using the integer value (0, 1, ...) in the packet on the kSelectTag input
|
||||
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
|
||||
// stream, and passes the packet on the selected input stream to the "OUTPUT"
|
||||
// output stream.
|
||||
// The kSelectTag input can also be passed in as an input side packet, instead
|
||||
// of as an input stream. Either of input stream or input side packet must be
|
||||
// specified but not both.
|
||||
//
|
||||
// Note that this calculator defaults to use MuxInputStreamHandler, which is
|
||||
// required for this calculator. However, it can be overridden to work with
|
||||
// other InputStreamHandlers. Check out the unit tests on for an example usage
|
||||
// with DefaultInputStreamHandler.
|
||||
class MuxCalculator : public CalculatorBase {
|
||||
// TODO: why would you need to use DefaultISH? Perhaps b/167596925?
|
||||
class MuxCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status CheckAndInitAllowDisallowInputs(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^
|
||||
cc->InputSidePackets().HasTag(kSelectTag));
|
||||
if (cc->Inputs().HasTag(kSelectTag)) {
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||
} else {
|
||||
cc->InputSidePackets().Tag(kSelectTag).Set<int>();
|
||||
}
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
static constexpr Input<int>::SideFallback kSelect{"SELECT"};
|
||||
// TODO: this currently sets them all to Any independently, instead
|
||||
// of the first being Any and the others being SameAs.
|
||||
static constexpr Input<AnyType>::Multiple kIn{"INPUT"};
|
||||
static constexpr Output<SameType<kIn>> kOut{"OUTPUT"};
|
||||
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
|
||||
CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag);
|
||||
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
|
||||
data_input0->SetAny();
|
||||
++data_input_id;
|
||||
for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) {
|
||||
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
|
||||
}
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
|
||||
|
||||
cc->SetInputStreamHandler("MuxInputStreamHandler");
|
||||
MediaPipeOptions options;
|
||||
cc->SetInputStreamHandlerOptions(options);
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
use_side_packet_select_ = false;
|
||||
if (cc->InputSidePackets().HasTag(kSelectTag)) {
|
||||
use_side_packet_select_ = true;
|
||||
selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get<int>();
|
||||
} else {
|
||||
select_input_ = cc->Inputs().GetId(kSelectTag, 0);
|
||||
}
|
||||
data_input_base_ = cc->Inputs().GetId(kInputTag, 0);
|
||||
num_data_inputs_ = cc->Inputs().NumEntries(kInputTag);
|
||||
output_ = cc->Outputs().GetId("OUTPUT", 0);
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut,
|
||||
StreamHandler("MuxInputStreamHandler"));
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
int select = use_side_packet_select_
|
||||
? selected_index_
|
||||
: cc->Inputs().Get(select_input_).Get<int>();
|
||||
RET_CHECK(0 <= select && select < num_data_inputs_);
|
||||
if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) {
|
||||
cc->Outputs().Get(output_).AddPacket(
|
||||
cc->Inputs().Get(data_input_base_ + select).Value());
|
||||
int select = *kSelect(cc);
|
||||
RET_CHECK(0 <= select && select < kIn(cc).Count());
|
||||
if (!kIn(cc)[select].IsEmpty()) {
|
||||
kOut(cc).Send(kIn(cc)[select].packet());
|
||||
}
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
CollectionItemId select_input_;
|
||||
CollectionItemId data_input_base_;
|
||||
int num_data_inputs_ = 0;
|
||||
CollectionItemId output_;
|
||||
bool use_side_packet_select_;
|
||||
int selected_index_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MuxCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(MuxCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
|
||||
#include <deque>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// PreviousLoopbackCalculator is useful when a graph needs to process an input
|
||||
// together with some previous output.
|
||||
|
@ -51,15 +53,19 @@ namespace mediapipe {
|
|||
// input_stream: "PREV_TRACK:prev_output"
|
||||
// output_stream: "TRACK:output"
|
||||
// }
|
||||
class PreviousLoopbackCalculator : public CalculatorBase {
|
||||
class PreviousLoopbackCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Get("MAIN", 0).SetAny();
|
||||
cc->Inputs().Get("LOOP", 0).SetAny();
|
||||
cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0)));
|
||||
// TODO: an optional PREV_TIMESTAMP output could be added to
|
||||
// carry the original timestamp of the packet on PREV_LOOP.
|
||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||
static constexpr Input<AnyType> kMain{"MAIN"};
|
||||
static constexpr Input<AnyType> kLoop{"LOOP"};
|
||||
static constexpr Output<SameType<kLoop>> kPrevLoop{"PREV_LOOP"};
|
||||
// TODO: an optional PREV_TIMESTAMP output could be added to
|
||||
// carry the original timestamp of the packet on PREV_LOOP.
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
|
||||
StreamHandler("ImmediateInputStreamHandler"),
|
||||
TimestampChange::Arbitrary());
|
||||
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
// Process() function is invoked in response to MAIN/LOOP stream timestamp
|
||||
// bound updates.
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
|
@ -67,12 +73,7 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
main_id_ = cc->Inputs().GetId("MAIN", 0);
|
||||
loop_id_ = cc->Inputs().GetId("LOOP", 0);
|
||||
prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
|
||||
cc->Outputs()
|
||||
.Get(prev_loop_id_)
|
||||
.SetHeader(cc->Inputs().Get(loop_id_).Header());
|
||||
kPrevLoop(cc).SetHeader(kLoop(cc).Header());
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -82,48 +83,47 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
// packets within the same stream. Calculator tracks and operates on such
|
||||
// packets.
|
||||
|
||||
const Packet& main_packet = cc->Inputs().Get(main_id_).Value();
|
||||
if (prev_main_ts_ < main_packet.Timestamp()) {
|
||||
const PacketBase& main_packet = kMain(cc).packet();
|
||||
if (prev_main_ts_ < main_packet.timestamp()) {
|
||||
Timestamp loop_timestamp;
|
||||
if (!main_packet.IsEmpty()) {
|
||||
loop_timestamp = prev_non_empty_main_ts_;
|
||||
prev_non_empty_main_ts_ = main_packet.Timestamp();
|
||||
prev_non_empty_main_ts_ = main_packet.timestamp();
|
||||
} else {
|
||||
// Calculator advances PREV_LOOP timestamp bound in response to empty
|
||||
// MAIN packet, hence not caring about corresponding loop packet.
|
||||
loop_timestamp = Timestamp::Unset();
|
||||
}
|
||||
main_packet_specs_.push_back({main_packet.Timestamp(), loop_timestamp});
|
||||
prev_main_ts_ = main_packet.Timestamp();
|
||||
main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp});
|
||||
prev_main_ts_ = main_packet.timestamp();
|
||||
}
|
||||
|
||||
const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value();
|
||||
if (prev_loop_ts_ < loop_packet.Timestamp()) {
|
||||
const PacketBase& loop_packet = kLoop(cc).packet();
|
||||
if (prev_loop_ts_ < loop_packet.timestamp()) {
|
||||
loop_packets_.push_back(loop_packet);
|
||||
prev_loop_ts_ = loop_packet.Timestamp();
|
||||
prev_loop_ts_ = loop_packet.timestamp();
|
||||
}
|
||||
|
||||
auto& prev_loop = cc->Outputs().Get(prev_loop_id_);
|
||||
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
|
||||
// The earliest MAIN packet.
|
||||
const MainPacketSpec& main_spec = main_packet_specs_.front();
|
||||
// The earliest LOOP packet.
|
||||
const Packet& loop_candidate = loop_packets_.front();
|
||||
const PacketBase& loop_candidate = loop_packets_.front();
|
||||
// Match LOOP and MAIN packets.
|
||||
if (main_spec.loop_timestamp < loop_candidate.Timestamp()) {
|
||||
if (main_spec.loop_timestamp < loop_candidate.timestamp()) {
|
||||
// No LOOP packet can match the MAIN packet under review.
|
||||
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
main_packet_specs_.pop_front();
|
||||
} else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) {
|
||||
} else if (main_spec.loop_timestamp > loop_candidate.timestamp()) {
|
||||
// No MAIN packet can match the LOOP packet under review.
|
||||
loop_packets_.pop_front();
|
||||
} else {
|
||||
// Exact match found.
|
||||
if (loop_candidate.IsEmpty()) {
|
||||
// However, LOOP packet is empty.
|
||||
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
} else {
|
||||
prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp));
|
||||
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
|
||||
}
|
||||
loop_packets_.pop_front();
|
||||
main_packet_specs_.pop_front();
|
||||
|
@ -135,7 +135,7 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
// b) Empty MAIN packet has been received with Timestamp::Max() indicating
|
||||
// MAIN is done.
|
||||
if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) {
|
||||
prev_loop.Close();
|
||||
kPrevLoop(cc).Close();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,10 +150,6 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
Timestamp loop_timestamp;
|
||||
};
|
||||
|
||||
CollectionItemId main_id_;
|
||||
CollectionItemId loop_id_;
|
||||
CollectionItemId prev_loop_id_;
|
||||
|
||||
// Contains specs for MAIN packets which only can be:
|
||||
// - non-empty packets
|
||||
// - empty packets indicating timestamp bound updates
|
||||
|
@ -169,12 +165,13 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
// - empty packets indicating timestamp bound updates
|
||||
//
|
||||
// Sorted according to packet timestamps.
|
||||
std::deque<Packet> loop_packets_;
|
||||
std::deque<PacketBase> loop_packets_;
|
||||
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
|
||||
// allow addition of the very first empty packet (which doesn't indicate
|
||||
// timestamp bound change necessarily).
|
||||
Timestamp prev_loop_ts_ = Timestamp::Unset();
|
||||
};
|
||||
REGISTER_CALCULATOR(PreviousLoopbackCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
#include <deque>
|
||||
|
||||
#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A Calculator that shifts the timestamps of packets along a stream. Packets on
|
||||
// the input stream are output with a timestamp of the packet given by packet
|
||||
|
@ -28,24 +30,19 @@ namespace mediapipe {
|
|||
// of -1, the first packet on the stream will be dropped, the second will be
|
||||
// output with the timestamp of the first, the third with the timestamp of the
|
||||
// second, and so on.
|
||||
class SequenceShiftCalculator : public CalculatorBase {
|
||||
class SequenceShiftCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) {
|
||||
cc->InputSidePackets().Tag(kPacketOffsetTag).Set<int>();
|
||||
}
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
static constexpr Input<AnyType> kIn{""};
|
||||
static constexpr SideInput<int>::Optional kOffset{"PACKET_OFFSET"};
|
||||
static constexpr Output<SameType<kIn>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOffset, kOut, TimestampChange::Arbitrary());
|
||||
|
||||
// Reads from options to set cache_size_ and packet_offset_.
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
static constexpr const char* kPacketOffsetTag = "PACKET_OFFSET";
|
||||
|
||||
// A positive offset means we want a packet to be output with the timestamp of
|
||||
// a later packet. Stores packets waiting for their output timestamps and
|
||||
// outputs a single packet when the cache fills.
|
||||
|
@ -58,7 +55,7 @@ class SequenceShiftCalculator : public CalculatorBase {
|
|||
|
||||
// Storage for packets waiting to be output when packet_offset > 0. When cache
|
||||
// is full, oldest packet is output with current timestamp.
|
||||
std::deque<Packet> packet_cache_;
|
||||
std::deque<PacketBase> packet_cache_;
|
||||
|
||||
// Storage for previous timestamps used when packet_offset < 0. When cache is
|
||||
// full, oldest timestamp is used for current packet.
|
||||
|
@ -70,14 +67,11 @@ class SequenceShiftCalculator : public CalculatorBase {
|
|||
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
|
||||
int cache_size_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SequenceShiftCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator);
|
||||
|
||||
mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
|
||||
packet_offset_ =
|
||||
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset();
|
||||
if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) {
|
||||
packet_offset_ = cc->InputSidePackets().Tag(kPacketOffsetTag).Get<int>();
|
||||
}
|
||||
packet_offset_ = kOffset(cc).GetOr(
|
||||
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset());
|
||||
cache_size_ = abs(packet_offset_);
|
||||
// An offset of zero is a no-op, but someone might still request it.
|
||||
if (packet_offset_ == 0) {
|
||||
|
@ -92,7 +86,7 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
|
|||
} else if (packet_offset_ < 0) {
|
||||
ProcessNegativeOffset(cc);
|
||||
} else {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
kOut(cc).Send(kIn(cc).packet());
|
||||
}
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -100,23 +94,22 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
|
|||
void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
|
||||
if (packet_cache_.size() >= cache_size_) {
|
||||
// Ready to output oldest packet with current timestamp.
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
packet_cache_.front().At(cc->InputTimestamp()));
|
||||
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
|
||||
packet_cache_.pop_front();
|
||||
}
|
||||
// Store current packet for later output.
|
||||
packet_cache_.push_back(cc->Inputs().Index(0).Value());
|
||||
packet_cache_.push_back(kIn(cc).packet());
|
||||
}
|
||||
|
||||
void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) {
|
||||
if (timestamp_cache_.size() >= cache_size_) {
|
||||
// Ready to output current packet with oldest timestamp.
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
cc->Inputs().Index(0).Value().At(timestamp_cache_.front()));
|
||||
kOut(cc).Send(kIn(cc).packet().At(timestamp_cache_.front()));
|
||||
timestamp_cache_.pop_front();
|
||||
}
|
||||
// Store current timestamp for use by a future packet.
|
||||
timestamp_cache_.push_back(cc->InputTimestamp());
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -61,7 +61,7 @@ class StringToIntCalculatorTemplate : public CalculatorBase {
|
|||
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
|
||||
REGISTER_CALCULATOR(StringToIntCalculator);
|
||||
|
||||
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
|
||||
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
||||
|
|
|
@ -59,6 +59,7 @@ cc_library(
|
|||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
|
@ -234,6 +235,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
|
@ -286,6 +288,7 @@ cc_library(
|
|||
deps = [
|
||||
":tensors_to_landmarks_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -317,6 +320,7 @@ cc_library(
|
|||
deps = [
|
||||
":tensors_to_floats_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
],
|
||||
|
@ -355,6 +359,7 @@ cc_library(
|
|||
":tensors_to_classification_calculator_cc_proto",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:location",
|
||||
|
@ -421,6 +426,7 @@ cc_library(
|
|||
":image_to_tensor_converter",
|
||||
":image_to_tensor_converter_opencv",
|
||||
":image_to_tensor_utils",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
@ -45,16 +46,15 @@
|
|||
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace {
|
||||
constexpr char kInputCpu[] = "IMAGE";
|
||||
constexpr char kInputGpu[] = "IMAGE_GPU";
|
||||
constexpr char kOutputMatrix[] = "MATRIX";
|
||||
constexpr char kOutput[] = "TENSORS";
|
||||
constexpr char kInputNormRect[] = "NORM_RECT";
|
||||
constexpr char kOutputLetterboxPadding[] = "LETTERBOX_PADDING";
|
||||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
|
||||
using GpuBuffer = AnyType;
|
||||
#else
|
||||
using GpuBuffer = mediapipe::GpuBuffer;
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// Converts image into Tensor, possibly with cropping, resizing and
|
||||
// normalization, according to specified inputs and options.
|
||||
|
@ -110,9 +110,21 @@ namespace mediapipe {
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class ImageToTensorCalculator : public CalculatorBase {
|
||||
class ImageToTensorCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
static constexpr Input<mediapipe::ImageFrame>::Optional kInCpu{"IMAGE"};
|
||||
static constexpr Input<GpuBuffer>::Optional kInGpu{"IMAGE_GPU"};
|
||||
static constexpr Input<mediapipe::NormalizedRect>::Optional kInNormRect{
|
||||
"NORM_RECT"};
|
||||
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
||||
static constexpr Output<std::array<float, 4>>::Optional kOutLetterboxPadding{
|
||||
"LETTERBOX_PADDING"};
|
||||
static constexpr Output<std::array<float, 16>>::Optional kOutMatrix{"MATRIX"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInCpu, kInGpu, kInNormRect, kOutTensors,
|
||||
kOutLetterboxPadding, kOutMatrix);
|
||||
|
||||
static ::mediapipe::Status UpdateContract(CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
|
||||
|
@ -126,24 +138,10 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
RET_CHECK_GT(options.output_tensor_height(), 0)
|
||||
<< "Valid output tensor height is required.";
|
||||
|
||||
if (cc->Inputs().HasTag(kInputNormRect)) {
|
||||
cc->Inputs().Tag(kInputNormRect).Set<mediapipe::NormalizedRect>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kOutputLetterboxPadding)) {
|
||||
cc->Outputs().Tag(kOutputLetterboxPadding).Set<std::array<float, 4>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kOutputMatrix)) {
|
||||
cc->Outputs().Tag(kOutputMatrix).Set<std::array<float, 16>>();
|
||||
}
|
||||
RET_CHECK(kInCpu(cc).IsConnected() ^ kInGpu(cc).IsConnected())
|
||||
<< "One and only one of CPU or GPU input is expected.";
|
||||
|
||||
const bool has_cpu_input = cc->Inputs().HasTag(kInputCpu);
|
||||
const bool has_gpu_input = cc->Inputs().HasTag(kInputGpu);
|
||||
RET_CHECK_EQ((has_cpu_input ? 1 : 0) + (has_gpu_input ? 1 : 0), 1)
|
||||
<< "Either CPU or GPU input is expected, not both.";
|
||||
|
||||
if (has_cpu_input) {
|
||||
cc->Inputs().Tag(kInputCpu).Set<mediapipe::ImageFrame>();
|
||||
} else if (has_gpu_input) {
|
||||
if (kInGpu(cc).IsConnected()) {
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
return mediapipe::UnimplementedError("GPU processing is disabled");
|
||||
#else
|
||||
|
@ -153,25 +151,20 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
#else
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
cc->Inputs().Tag(kInputGpu).Set<mediapipe::GpuBuffer>();
|
||||
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
cc->Outputs().Tag(kOutput).Set<std::vector<Tensor>>();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) {
|
||||
// Makes sure outputs' next timestamp bound update is handled automatically
|
||||
// by the framework.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
output_width_ = options_.output_tensor_width();
|
||||
output_height_ = options_.output_tensor_height();
|
||||
range_min_ = options_.output_tensor_float_range().min();
|
||||
range_max_ = options_.output_tensor_float_range().max();
|
||||
|
||||
if (cc->Inputs().HasTag(kInputCpu)) {
|
||||
if (kInCpu(cc).IsConnected()) {
|
||||
ASSIGN_OR_RETURN(converter_, CreateOpenCvConverter(cc, GetBorderMode()));
|
||||
} else {
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -196,21 +189,20 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) {
|
||||
const InputStreamShard& input = cc->Inputs().Tag(
|
||||
cc->Inputs().HasTag(kInputCpu) ? kInputCpu : kInputGpu);
|
||||
if (input.IsEmpty()) {
|
||||
const PacketBase& image_packet =
|
||||
kInCpu(cc).IsConnected() ? kInCpu(cc).packet() : kInGpu(cc).packet();
|
||||
if (image_packet.IsEmpty()) {
|
||||
// Timestamp bound update happens automatically. (See Open().)
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
absl::optional<mediapipe::NormalizedRect> norm_rect;
|
||||
if (cc->Inputs().HasTag(kInputNormRect)) {
|
||||
if (cc->Inputs().Tag(kInputNormRect).IsEmpty()) {
|
||||
if (kInNormRect(cc).IsConnected()) {
|
||||
if (kInNormRect(cc).IsEmpty()) {
|
||||
// Timestamp bound update happens automatically. (See Open().)
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
norm_rect =
|
||||
cc->Inputs().Tag(kInputNormRect).Get<mediapipe::NormalizedRect>();
|
||||
norm_rect = *kInNormRect(cc);
|
||||
if (norm_rect->width() == 0 && norm_rect->height() == 0) {
|
||||
// WORKAROUND: some existing graphs may use sentinel rects {width=0,
|
||||
// height=0, ...} quite often and calculator has to handle them
|
||||
|
@ -223,27 +215,20 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
}
|
||||
}
|
||||
|
||||
const Packet& image_packet = input.Value();
|
||||
const Size& size = converter_->GetImageSize(image_packet);
|
||||
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
|
||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||
options_.output_tensor_height(),
|
||||
options_.keep_aspect_ratio(), &roi));
|
||||
if (cc->Outputs().HasTag(kOutputLetterboxPadding)) {
|
||||
cc->Outputs()
|
||||
.Tag(kOutputLetterboxPadding)
|
||||
.AddPacket(MakePacket<std::array<float, 4>>(padding).At(
|
||||
cc->InputTimestamp()));
|
||||
if (kOutLetterboxPadding(cc).IsConnected()) {
|
||||
kOutLetterboxPadding(cc).Send(padding);
|
||||
}
|
||||
if (cc->Outputs().HasTag(kOutputMatrix)) {
|
||||
if (kOutMatrix(cc).IsConnected()) {
|
||||
std::array<float, 16> matrix;
|
||||
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
|
||||
/*flip_horizontaly=*/false,
|
||||
&matrix);
|
||||
cc->Outputs()
|
||||
.Tag(kOutputMatrix)
|
||||
.AddPacket(MakePacket<std::array<float, 16>>(std::move(matrix))
|
||||
.At(cc->InputTimestamp()));
|
||||
kOutMatrix(cc).Send(std::move(matrix));
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
|
@ -251,11 +236,9 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
converter_->Convert(image_packet, roi, {output_width_, output_height_},
|
||||
range_min_, range_max_));
|
||||
|
||||
std::vector<Tensor> result;
|
||||
result.push_back(std::move(tensor));
|
||||
cc->Outputs().Tag(kOutput).AddPacket(
|
||||
MakePacket<std::vector<Tensor>>(std::move(result))
|
||||
.At(cc->InputTimestamp()));
|
||||
auto result = std::make_unique<std::vector<Tensor>>();
|
||||
result->push_back(std::move(tensor));
|
||||
kOutTensors(cc).Send(std::move(result));
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -286,6 +269,7 @@ class ImageToTensorCalculator : public CalculatorBase {
|
|||
float range_max_ = 1.0f;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(ImageToTensorCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -88,7 +89,6 @@ bool ShouldUseGpu(const mediapipe::InferenceCalculatorOptions& options) {
|
|||
}
|
||||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
} // namespace
|
||||
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
#include "edgetpu.h"
|
||||
|
@ -112,7 +112,10 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
|
|||
}
|
||||
#endif // MEDIAPIPE_EDGE_TPU
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
namespace {
|
||||
|
@ -224,12 +227,19 @@ int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) {
|
|||
// Tensors are assumed to be ordered correctly (sequentially added to model).
|
||||
// Input tensors are assumed to be of the correct size and already normalized.
|
||||
|
||||
class InferenceCalculator : public CalculatorBase {
|
||||
class InferenceCalculator : public Node {
|
||||
public:
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
|
||||
kOutTensors);
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
@ -239,11 +249,12 @@ class InferenceCalculator : public CalculatorBase {
|
|||
mediapipe::Status ReadKernelsFromFile();
|
||||
mediapipe::Status WriteKernelsToFile();
|
||||
mediapipe::Status LoadModel(CalculatorContext* cc);
|
||||
mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
|
||||
mediapipe::StatusOr<mediapipe::Packet> GetModelAsPacket(
|
||||
const CalculatorContext& cc);
|
||||
mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
||||
mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
|
||||
Packet model_packet_;
|
||||
mediapipe::Packet model_packet_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
|
||||
|
@ -277,28 +288,13 @@ class InferenceCalculator : public CalculatorBase {
|
|||
std::string cached_kernel_filename_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(InferenceCalculator);
|
||||
|
||||
mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
|
||||
RET_CHECK(cc->Outputs().HasTag(kTensorsTag));
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
|
||||
MEDIAPIPE_REGISTER_NODE(InferenceCalculator);
|
||||
|
||||
mediapipe::Status InferenceCalculator::UpdateContract(CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
RET_CHECK(!options.model_path().empty() ^
|
||||
cc->InputSidePackets().HasTag("MODEL"))
|
||||
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
|
||||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("MODEL")) {
|
||||
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||
}
|
||||
|
||||
if (ShouldUseGpu(options)) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
|
@ -310,8 +306,6 @@ mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) {
|
|||
}
|
||||
|
||||
mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE || MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
if (ShouldUseGpu(options)) {
|
||||
|
@ -361,11 +355,10 @@ mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
||||
if (kInTensors(cc).IsEmpty()) {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
@ -509,9 +502,7 @@ mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) {
|
|||
output_tensors->back().bytes());
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -575,12 +566,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner(
|
|||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
|
@ -629,12 +617,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner(
|
|||
mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
interpreter_ =
|
||||
|
@ -659,7 +644,7 @@ mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) {
|
|||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::StatusOr<Packet> InferenceCalculator::GetModelAsPacket(
|
||||
mediapipe::StatusOr<mediapipe::Packet> InferenceCalculator::GetModelAsPacket(
|
||||
const CalculatorContext& cc) {
|
||||
const auto& options = cc.Options<mediapipe::InferenceCalculatorOptions>();
|
||||
if (!options.model_path().empty()) {
|
||||
|
@ -845,4 +830,5 @@ mediapipe::Status InferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
|||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
@ -32,6 +33,7 @@
|
|||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Convert result tensors from classification models into MediaPipe
|
||||
// classifications.
|
||||
|
@ -57,9 +59,12 @@ namespace mediapipe {
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorsToClassificationCalculator : public CalculatorBase {
|
||||
class TensorsToClassificationCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr Output<ClassificationList> kOutClassificationList{
|
||||
"CLASSIFICATIONS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutClassificationList);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
@ -71,28 +76,10 @@ class TensorsToClassificationCalculator : public CalculatorBase {
|
|||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorsToClassificationCalculator);
|
||||
|
||||
mediapipe::Status TensorsToClassificationCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS")) {
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
}
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator);
|
||||
|
||||
mediapipe::Status TensorsToClassificationCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ =
|
||||
cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>();
|
||||
|
||||
|
@ -118,9 +105,7 @@ mediapipe::Status TensorsToClassificationCalculator::Open(
|
|||
|
||||
mediapipe::Status TensorsToClassificationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
|
||||
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
||||
int num_classes = input_tensors[0].shape().num_elements();
|
||||
|
@ -182,10 +167,7 @@ mediapipe::Status TensorsToClassificationCalculator::Process(
|
|||
raw_classification_list->DeleteSubrange(
|
||||
top_k_, raw_classification_list->size() - top_k_);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("CLASSIFICATIONS")
|
||||
.Add(classification_list.release(), cc->InputTimestamp());
|
||||
|
||||
kOutClassificationList(cc).Send(std::move(classification_list));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -194,4 +176,5 @@ mediapipe::Status TensorsToClassificationCalculator::Close(
|
|||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
@ -47,9 +48,6 @@
|
|||
namespace {
|
||||
constexpr int kNumInputTensorsWithAnchors = 3;
|
||||
constexpr int kNumCoordsPerBox = 4;
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kAnchorsTag[] = "ANCHORS";
|
||||
|
||||
bool CanUseGpu() {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || MEDIAPIPE_METAL_ENABLED
|
||||
|
@ -63,6 +61,7 @@ bool CanUseGpu() {
|
|||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -128,9 +127,14 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorsToDetectionsCalculator : public CalculatorBase {
|
||||
class TensorsToDetectionsCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr SideInput<std::vector<Anchor>>::Optional kInAnchors{
|
||||
"ANCHORS"};
|
||||
static constexpr Output<std::vector<Detection>> kOutDetections{"DETECTIONS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections);
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
@ -161,7 +165,6 @@ class TensorsToDetectionsCalculator : public CalculatorBase {
|
|||
|
||||
::mediapipe::TensorsToDetectionsCalculatorOptions options_;
|
||||
std::vector<Anchor> anchors_;
|
||||
bool side_packet_anchors_{};
|
||||
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
|
@ -179,22 +182,10 @@ class TensorsToDetectionsCalculator : public CalculatorBase {
|
|||
bool gpu_input_ = false;
|
||||
bool anchors_init_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorsToDetectionsCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
||||
|
||||
mediapipe::Status TensorsToDetectionsCalculator::GetContract(
|
||||
mediapipe::Status TensorsToDetectionsCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag(kDetectionsTag));
|
||||
cc->Outputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
|
||||
|
||||
if (cc->InputSidePackets().UsesTags()) {
|
||||
if (cc->InputSidePackets().HasTag(kAnchorsTag)) {
|
||||
cc->InputSidePackets().Tag(kAnchorsTag).Set<std::vector<Anchor>>();
|
||||
}
|
||||
}
|
||||
|
||||
if (CanUseGpu()) {
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
|
@ -207,8 +198,6 @@ mediapipe::Status TensorsToDetectionsCalculator::GetContract(
|
|||
}
|
||||
|
||||
mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
side_packet_anchors_ = cc->InputSidePackets().HasTag(kAnchorsTag);
|
||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (CanUseGpu()) {
|
||||
|
@ -226,18 +215,12 @@ mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
mediapipe::Status TensorsToDetectionsCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||
|
||||
bool gpu_processing = false;
|
||||
if (CanUseGpu()) {
|
||||
// Use GPU processing only if at least one input tensor is already on GPU
|
||||
// (to avoid CPU->GPU overhead).
|
||||
for (const auto& tensor :
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>()) {
|
||||
for (const auto& tensor : *kInTensors(cc)) {
|
||||
if (tensor.ready_on_gpu()) {
|
||||
gpu_processing = true;
|
||||
break;
|
||||
|
@ -251,18 +234,13 @@ mediapipe::Status TensorsToDetectionsCalculator::Process(
|
|||
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
||||
}
|
||||
|
||||
// Output
|
||||
cc->Outputs()
|
||||
.Tag(kDetectionsTag)
|
||||
.Add(output_detections.release(), cc->InputTimestamp());
|
||||
|
||||
kOutDetections(cc).Send(std::move(output_detections));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
|
||||
if (input_tensors.size() == 2 ||
|
||||
input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
|
@ -294,10 +272,8 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
auto anchor_view = anchor_tensor->GetCpuReadView();
|
||||
auto raw_anchors = anchor_view.buffer<float>();
|
||||
ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
|
||||
} else if (side_packet_anchors_) {
|
||||
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
|
||||
anchors_ =
|
||||
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
|
||||
} else if (!kInAnchors(cc).IsEmpty()) {
|
||||
anchors_ = *kInAnchors(cc);
|
||||
} else {
|
||||
return mediapipe::UnavailableError("No anchor data available.");
|
||||
}
|
||||
|
@ -391,8 +367,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
|
||||
mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK_GE(input_tensors.size(), 2);
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
|
||||
|
@ -400,21 +375,20 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
&output_detections]()
|
||||
-> mediapipe::Status {
|
||||
if (!anchors_init_) {
|
||||
if (side_packet_anchors_) {
|
||||
CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty());
|
||||
const auto& anchors =
|
||||
cc->InputSidePackets().Tag(kAnchorsTag).Get<std::vector<Anchor>>();
|
||||
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
|
||||
auto raw_anchors = anchors_view.buffer<float>();
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors);
|
||||
} else {
|
||||
CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
auto read_view = input_tensors[2].GetOpenGlBufferReadView();
|
||||
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
|
||||
auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView();
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
input_tensors[2].bytes());
|
||||
} else if (!kInAnchors(cc).IsEmpty()) {
|
||||
const auto& anchors = *kInAnchors(cc);
|
||||
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
|
||||
auto raw_anchors = anchors_view.buffer<float>();
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors);
|
||||
} else {
|
||||
return mediapipe::UnavailableError("No anchor data available.");
|
||||
}
|
||||
anchors_init_ = true;
|
||||
}
|
||||
|
@ -464,14 +438,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
#elif MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
if (!anchors_init_) {
|
||||
if (side_packet_anchors_) {
|
||||
CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty());
|
||||
const auto& anchors =
|
||||
cc->InputSidePackets().Tag(kAnchorsTag).Get<std::vector<Anchor>>();
|
||||
auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView();
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_,
|
||||
raw_anchors_view.buffer<float>());
|
||||
} else {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||
auto command_buffer = [gpu_helper_ commandBuffer];
|
||||
auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer);
|
||||
|
@ -486,6 +453,13 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
size:input_tensors[2].bytes()];
|
||||
[blit_command endEncoding];
|
||||
[command_buffer commit];
|
||||
} else if (!kInAnchors(cc).IsEmpty()) {
|
||||
const auto& anchors = *kInAnchors(cc);
|
||||
auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView();
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_,
|
||||
raw_anchors_view.buffer<float>());
|
||||
} else {
|
||||
return mediapipe::UnavailableError("No anchor data available.");
|
||||
}
|
||||
anchors_init_ = true;
|
||||
}
|
||||
|
@ -1157,4 +1131,5 @@ kernel void scoreKernel(
|
|||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -43,47 +44,39 @@ inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); }
|
|||
// input_stream: "TENSORS:tensors"
|
||||
// output_stream: "FLOATS:floats"
|
||||
// }
|
||||
class TensorsToFloatsCalculator : public CalculatorBase {
|
||||
namespace api2 {
|
||||
class TensorsToFloatsCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr Output<float>::Optional kOutFloat{"FLOAT"};
|
||||
static constexpr Output<std::vector<float>>::Optional kOutFloats{"FLOATS"};
|
||||
MEDIAPIPE_NODE_INTERFACE(TensorsToFloatsCalculator, kInTensors, kOutFloat,
|
||||
kOutFloats);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc);
|
||||
mediapipe::Status Open(CalculatorContext* cc) final;
|
||||
mediapipe::Status Process(CalculatorContext* cc) final;
|
||||
|
||||
private:
|
||||
::mediapipe::TensorsToFloatsCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorsToFloatsCalculator);
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToFloatsCalculator);
|
||||
|
||||
mediapipe::Status TensorsToFloatsCalculator::GetContract(
|
||||
mediapipe::Status TensorsToFloatsCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
|
||||
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
|
||||
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
cc->Outputs().Tag("FLOAT").Set<float>();
|
||||
}
|
||||
|
||||
// Only exactly a single output allowed.
|
||||
RET_CHECK(kOutFloat(cc).IsConnected() ^ kOutFloats(cc).IsConnected());
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
options_ = cc->Options<::mediapipe::TensorsToFloatsCalculatorOptions>();
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
||||
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
// TODO: Add option to specify which tensor to take from.
|
||||
auto view = input_tensors[0].GetCpuReadView();
|
||||
auto raw_floats = view.buffer<float>();
|
||||
|
@ -100,18 +93,15 @@ mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
|||
break;
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
// TODO: Could add an index in the option to specifiy returning one
|
||||
// value of a float array.
|
||||
if (kOutFloat(cc).IsConnected()) {
|
||||
// TODO: Could add an index in the option to specifiy returning
|
||||
// one value of a float array.
|
||||
RET_CHECK_EQ(num_values, 1);
|
||||
cc->Outputs().Tag("FLOAT").AddPacket(
|
||||
MakePacket<float>(output_floats->at(0)).At(cc->InputTimestamp()));
|
||||
kOutFloat(cc).Send(output_floats->at(0));
|
||||
} else {
|
||||
kOutFloats(cc).Send(std::move(output_floats));
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
|
||||
cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -13,12 +13,14 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -85,9 +87,18 @@ float ApplyActivation(
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorsToLandmarksCalculator : public CalculatorBase {
|
||||
class TensorsToLandmarksCalculator : public Node {
|
||||
public:
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr Input<bool>::SideFallback::Optional kFlipHorizontally{
|
||||
"FLIP_HORIZONTALLY"};
|
||||
static constexpr Input<bool>::SideFallback::Optional kFlipVertically{
|
||||
"FLIP_VERTICALLY"};
|
||||
static constexpr Output<LandmarkList>::Optional kOutLandmarkList{"LANDMARKS"};
|
||||
static constexpr Output<NormalizedLandmarkList>::Optional
|
||||
kOutNormalizedLandmarkList{"NORM_LANDMARKS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kFlipHorizontally, kFlipVertically,
|
||||
kOutLandmarkList, kOutNormalizedLandmarkList);
|
||||
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
@ -95,100 +106,39 @@ class TensorsToLandmarksCalculator : public CalculatorBase {
|
|||
private:
|
||||
mediapipe::Status LoadOptions(CalculatorContext* cc);
|
||||
int num_landmarks_ = 0;
|
||||
bool flip_vertically_ = false;
|
||||
bool flip_horizontally_ = false;
|
||||
|
||||
::mediapipe::TensorsToLandmarksCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorsToLandmarksCalculator);
|
||||
|
||||
mediapipe::Status TensorsToLandmarksCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS")) {
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("FLIP_HORIZONTALLY")) {
|
||||
cc->Inputs().Tag("FLIP_HORIZONTALLY").Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("FLIP_VERTICALLY")) {
|
||||
cc->Inputs().Tag("FLIP_VERTICALLY").Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY")) {
|
||||
cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("FLIP_VERTICALLY")) {
|
||||
cc->InputSidePackets().Tag("FLIP_VERTICALLY").Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("LANDMARKS")) {
|
||||
cc->Outputs().Tag("LANDMARKS").Set<LandmarkList>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
cc->Outputs().Tag("NORM_LANDMARKS").Set<NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToLandmarksCalculator);
|
||||
|
||||
mediapipe::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
if (kOutNormalizedLandmarkList(cc).IsConnected()) {
|
||||
RET_CHECK(options_.has_input_image_height() &&
|
||||
options_.has_input_image_width())
|
||||
<< "Must provide input with/height for getting normalized landmarks.";
|
||||
}
|
||||
if (cc->Outputs().HasTag("LANDMARKS") &&
|
||||
(options_.flip_vertically() || options_.flip_horizontally() ||
|
||||
cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY") ||
|
||||
cc->InputSidePackets().HasTag("FLIP_VERTICALLY"))) {
|
||||
if (kOutLandmarkList(cc).IsConnected() &&
|
||||
(options_.flip_horizontally() || options_.flip_vertically() ||
|
||||
kFlipHorizontally(cc).IsConnected() ||
|
||||
kFlipVertically(cc).IsConnected())) {
|
||||
RET_CHECK(options_.has_input_image_height() &&
|
||||
options_.has_input_image_width())
|
||||
<< "Must provide input with/height for using flip_vertically option "
|
||||
"when outputing landmarks in absolute coordinates.";
|
||||
<< "Must provide input with/height for using flipping when outputing "
|
||||
"landmarks in absolute coordinates.";
|
||||
}
|
||||
|
||||
flip_horizontally_ =
|
||||
cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY")
|
||||
? cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Get<bool>()
|
||||
: options_.flip_horizontally();
|
||||
|
||||
flip_vertically_ =
|
||||
cc->InputSidePackets().HasTag("FLIP_VERTICALLY")
|
||||
? cc->InputSidePackets().Tag("FLIP_VERTICALLY").Get<bool>()
|
||||
: options_.flip_vertically();
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
|
||||
// Override values if specified so.
|
||||
if (cc->Inputs().HasTag("FLIP_HORIZONTALLY") &&
|
||||
!cc->Inputs().Tag("FLIP_HORIZONTALLY").IsEmpty()) {
|
||||
flip_horizontally_ = cc->Inputs().Tag("FLIP_HORIZONTALLY").Get<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("FLIP_VERTICALLY") &&
|
||||
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
|
||||
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
|
||||
if (kInTensors(cc).IsEmpty()) {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
bool flip_horizontally =
|
||||
kFlipHorizontally(cc).GetOr(options_.flip_horizontally());
|
||||
bool flip_vertically = kFlipVertically(cc).GetOr(options_.flip_vertically());
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
|
||||
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
int num_values = input_tensors[0].shape().num_elements();
|
||||
const int num_dimensions = num_values / num_landmarks_;
|
||||
CHECK_GT(num_dimensions, 0);
|
||||
|
@ -202,13 +152,13 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
|
|||
const int offset = ld * num_dimensions;
|
||||
Landmark* landmark = output_landmarks.add_landmark();
|
||||
|
||||
if (flip_horizontally_) {
|
||||
if (flip_horizontally) {
|
||||
landmark->set_x(options_.input_image_width() - raw_landmarks[offset]);
|
||||
} else {
|
||||
landmark->set_x(raw_landmarks[offset]);
|
||||
}
|
||||
if (num_dimensions > 1) {
|
||||
if (flip_vertically_) {
|
||||
if (flip_vertically) {
|
||||
landmark->set_y(options_.input_image_height() -
|
||||
raw_landmarks[offset + 1]);
|
||||
} else {
|
||||
|
@ -229,7 +179,7 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
// Output normalized landmarks if required.
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
if (kOutNormalizedLandmarkList(cc).IsConnected()) {
|
||||
NormalizedLandmarkList output_norm_landmarks;
|
||||
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
|
||||
const Landmark& landmark = output_landmarks.landmark(i);
|
||||
|
@ -246,18 +196,12 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
|
|||
norm_landmark->set_presence(landmark.presence());
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("NORM_LANDMARKS")
|
||||
.AddPacket(MakePacket<NormalizedLandmarkList>(output_norm_landmarks)
|
||||
.At(cc->InputTimestamp()));
|
||||
kOutNormalizedLandmarkList(cc).Send(std::move(output_norm_landmarks));
|
||||
}
|
||||
|
||||
// Output absolute landmarks.
|
||||
if (cc->Outputs().HasTag("LANDMARKS")) {
|
||||
cc->Outputs()
|
||||
.Tag("LANDMARKS")
|
||||
.AddPacket(MakePacket<LandmarkList>(output_landmarks)
|
||||
.At(cc->InputTimestamp()));
|
||||
if (kOutLandmarkList(cc).IsConnected()) {
|
||||
kOutLandmarkList(cc).Send(std::move(output_landmarks));
|
||||
}
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
|
@ -272,4 +216,5 @@ mediapipe::Status TensorsToLandmarksCalculator::LoadOptions(
|
|||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -396,6 +396,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util/sequence:media_sequence",
|
||||
"//mediapipe/util/sequence:media_sequence_util",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -841,6 +842,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/util/sequence:media_sequence",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||
|
@ -57,7 +58,7 @@ namespace mpms = mediapipe::mediasequence;
|
|||
// bounding boxes from vector<Detections>, and streams with the
|
||||
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
|
||||
// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints
|
||||
// from unordered_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
|
||||
// from flat_hash_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
|
||||
// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of
|
||||
// each stream, which allows for multiple image streams to be included. However,
|
||||
// the default names are suppored by more tools.
|
||||
|
@ -131,8 +132,8 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
}
|
||||
cc->Inputs()
|
||||
.Tag(tag)
|
||||
.Set<std::unordered_map<std::string,
|
||||
std::vector<std::pair<float, float>>>>();
|
||||
.Set<absl::flat_hash_map<std::string,
|
||||
std::vector<std::pair<float, float>>>>();
|
||||
}
|
||||
if (absl::StartsWith(tag, kBBoxTag)) {
|
||||
std::string key = "";
|
||||
|
@ -348,7 +349,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
const auto& keypoints =
|
||||
cc->Inputs()
|
||||
.Tag(tag)
|
||||
.Get<std::unordered_map<
|
||||
.Get<absl::flat_hash_map<
|
||||
std::string, std::vector<std::pair<float, float>>>>();
|
||||
for (const auto& pair : keypoints) {
|
||||
std::string prefix = mpms::merge_prefix(key, pair.first);
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||
|
@ -537,8 +538,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
|
|||
std::string test_video_id = "test_video_id";
|
||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::pair<float, float>>> points =
|
||||
{{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
||||
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
|
||||
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
||||
runner_->MutableInputs()
|
||||
->Tag("KEYPOINTS_TEST")
|
||||
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
|
||||
|
|
|
@ -450,7 +450,9 @@ mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
#else
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
if (gpu_inference_) {
|
||||
// Metal delegate supports external command encoder only if all input and
|
||||
// output buffers are on GPU.
|
||||
if (gpu_inference_ && gpu_input_ && gpu_output_) {
|
||||
RET_CHECK(
|
||||
TFLGpuDelegateSetCommandEncoder(delegate_.get(), compute_encoder));
|
||||
}
|
||||
|
|
|
@ -934,11 +934,22 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "local_file_contents_calculator_proto",
|
||||
srcs = ["local_file_contents_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_file_contents_calculator",
|
||||
srcs = ["local_file_contents_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":local_file_contents_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/calculators/util/local_file_contents_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
@ -78,6 +79,8 @@ class LocalFileContentsCalculator : public CalculatorBase {
|
|||
mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
CollectionItemId input_id = cc->InputSidePackets().BeginId(kFilePathTag);
|
||||
CollectionItemId output_id = cc->OutputSidePackets().BeginId(kContentsTag);
|
||||
auto options = cc->Options<mediapipe::LocalFileContentsCalculatorOptions>();
|
||||
|
||||
// Number of inputs and outpus is the same according to the contract.
|
||||
for (; input_id != cc->InputSidePackets().EndId(kFilePathTag);
|
||||
++input_id, ++output_id) {
|
||||
|
@ -86,7 +89,8 @@ class LocalFileContentsCalculator : public CalculatorBase {
|
|||
ASSIGN_OR_RETURN(file_path, PathToResourceAsFile(file_path));
|
||||
|
||||
std::string contents;
|
||||
MP_RETURN_IF_ERROR(GetResourceContents(file_path, &contents));
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetResourceContents(file_path, &contents, options.read_as_binary()));
|
||||
cc->OutputSidePackets().Get(output_id).Set(
|
||||
MakePacket<std::string>(std::move(contents)));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message LocalFileContentsCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional LocalFileContentsCalculatorOptions ext = 346849340;
|
||||
}
|
||||
|
||||
// If true, set the file open mode to 'rb'. Otherwise, set the mode to 'r'.
|
||||
optional bool read_as_binary = 1 [default = true];
|
||||
}
|
|
@ -56,8 +56,10 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
|
|||
} catch (NameNotFoundException e) {
|
||||
Log.e(TAG, "Cannot find application info: " + e);
|
||||
}
|
||||
|
||||
// Get allowed object category.
|
||||
String categoryName = applicationInfo.metaData.getString("categoryName");
|
||||
// Get maximum allowed number of objects.
|
||||
int maxNumObjects = applicationInfo.metaData.getInt("maxNumObjects");
|
||||
float[] modelScale = parseFloatArrayFromString(
|
||||
applicationInfo.metaData.getString("modelScale"));
|
||||
float[] modelTransform = parseFloatArrayFromString(
|
||||
|
@ -70,6 +72,7 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
|
|||
inputSidePackets.put("obj_texture", packetCreator.createRgbaImageFrame(objTexture));
|
||||
inputSidePackets.put("box_texture", packetCreator.createRgbaImageFrame(boxTexture));
|
||||
inputSidePackets.put("allowed_labels", packetCreator.createString(categoryName));
|
||||
inputSidePackets.put("max_num_objects", packetCreator.createInt32(maxNumObjects));
|
||||
inputSidePackets.put("model_scale", packetCreator.createFloat32Array(modelScale));
|
||||
inputSidePackets.put("model_transformation", packetCreator.createFloat32Array(modelTransform));
|
||||
processor.setInputSidePackets(inputSidePackets);
|
||||
|
@ -118,8 +121,8 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
|
|||
} catch (RuntimeException e) {
|
||||
Log.e(
|
||||
TAG,
|
||||
"MediaPipeException encountered adding packets to width and height"
|
||||
+ " input streams.");
|
||||
"MediaPipeException encountered adding packets to input_width and input_height"
|
||||
+ " input streams.", e);
|
||||
}
|
||||
widthPacket.release();
|
||||
heightPacket.release();
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
<application>
|
||||
<meta-data android:name="categoryName" android:value="Camera"/>
|
||||
<meta-data android:name="maxNumObjects" android:value="5"/>
|
||||
<meta-data android:name="modelScale" android:value="250, 250, 250"/>
|
||||
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, 0.0,
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
<application>
|
||||
<meta-data android:name="categoryName" android:value="Chair"/>
|
||||
<meta-data android:name="maxNumObjects" android:value="5"/>
|
||||
<meta-data android:name="modelScale" android:value="0.1, 0.05, 0.1"/>
|
||||
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 1.0, 0.0, -10.0,
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
<application>
|
||||
<meta-data android:name="categoryName" android:value="Coffee cup,Mug"/>
|
||||
<meta-data android:name="maxNumObjects" android:value="5"/>
|
||||
<meta-data android:name="modelScale" android:value="500, 500, 500"/>
|
||||
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, -0.001,
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
<application>
|
||||
<meta-data android:name="categoryName" android:value="Footwear"/>
|
||||
<meta-data android:name="maxNumObjects" android:value="5"/>
|
||||
<meta-data android:name="modelScale" android:value="0.25, 0.25, 0.12"/>
|
||||
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, 0.0,
|
||||
|
|
|
@ -52,8 +52,9 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position,
|
|||
}
|
||||
}
|
||||
|
||||
double delta_degs = (Median(raw_positions_at_time_) - current_position_px_) /
|
||||
pixels_per_degree_;
|
||||
int filtered_position = Median(raw_positions_at_time_);
|
||||
double delta_degs =
|
||||
(filtered_position - current_position_px_) / pixels_per_degree_;
|
||||
|
||||
// If the motion is smaller than the min_motion_to_reframe and camera is
|
||||
// stationary, don't use the update.
|
||||
|
@ -68,14 +69,14 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position,
|
|||
} else if (delta_degs > 0) {
|
||||
// Apply new position, less the reframe window size.
|
||||
target_position_px_ =
|
||||
position - pixels_per_degree_ * options_.reframe_window();
|
||||
filtered_position - pixels_per_degree_ * options_.reframe_window();
|
||||
delta_degs =
|
||||
(target_position_px_ - current_position_px_) / pixels_per_degree_;
|
||||
motion_state_ = true;
|
||||
} else {
|
||||
// Apply new position, plus the reframe window size.
|
||||
target_position_px_ =
|
||||
position + pixels_per_degree_ * options_.reframe_window();
|
||||
filtered_position + pixels_per_degree_ * options_.reframe_window();
|
||||
delta_degs =
|
||||
(target_position_px_ - current_position_px_) / pixels_per_degree_;
|
||||
motion_state_ = true;
|
||||
|
|
|
@ -351,7 +351,6 @@ cc_library(
|
|||
":calculator_base",
|
||||
":calculator_context",
|
||||
":calculator_context_manager",
|
||||
":calculator_registry_util",
|
||||
":calculator_state",
|
||||
":counter_factory",
|
||||
":input_side_packet_handler",
|
||||
|
@ -405,27 +404,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "calculator_registry_util",
|
||||
srcs = ["calculator_registry_util.cc"],
|
||||
hdrs = ["calculator_registry_util.h"],
|
||||
visibility = [
|
||||
":mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_context",
|
||||
":calculator_state",
|
||||
":collection",
|
||||
":collection_item_id",
|
||||
":packet_set",
|
||||
":timestamp",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "calculator_runner",
|
||||
testonly = 1,
|
||||
|
@ -1128,7 +1106,6 @@ cc_library(
|
|||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_contract",
|
||||
":calculator_registry_util",
|
||||
":legacy_calculator_support",
|
||||
":packet",
|
||||
":packet_generator",
|
||||
|
|
231
mediapipe/framework/api2/BUILD
Normal file
231
mediapipe/framework/api2/BUILD
Normal file
|
@ -0,0 +1,231 @@
|
|||
package(
|
||||
default_visibility = [":preview_users"],
|
||||
features = ["-use_header_modules"],
|
||||
)
|
||||
|
||||
# API2 is in preview mode. Internal clients are welcome and encouraged to try
|
||||
# it out, but be aware that there may be more changes before release. Please
|
||||
# add your package to this list and reach out to the MediaPipe team (use
|
||||
# camillol@ as the CL reviewer).
|
||||
package_group(
|
||||
name = "preview_users",
|
||||
packages = [
|
||||
"//mediapipe/...",
|
||||
"//video/content_analysis/...",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "const_str",
|
||||
hdrs = ["const_str.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "builder",
|
||||
hdrs = ["builder.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":contract",
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "builder_test",
|
||||
srcs = ["builder_test.cc"],
|
||||
deps = [
|
||||
":builder",
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
":tag",
|
||||
":test_contracts",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "contract",
|
||||
hdrs = ["contract.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":packet",
|
||||
":port",
|
||||
":tag",
|
||||
":tuple",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:output_side_packet",
|
||||
"//mediapipe/framework/port:logging",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "contract_test",
|
||||
srcs = ["contract_test.cc"],
|
||||
deps = [
|
||||
":contract",
|
||||
":port",
|
||||
":tag",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node",
|
||||
srcs = ["node.cc"],
|
||||
hdrs = ["node.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":contract",
|
||||
":packet",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/deps:no_destructor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_contracts",
|
||||
testonly = 1,
|
||||
hdrs = ["test_contracts.h"],
|
||||
deps = [
|
||||
":node",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "node_test",
|
||||
srcs = ["node_test.cc"],
|
||||
deps = [
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
":test_contracts",
|
||||
":tuple",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "packet",
|
||||
srcs = ["packet.cc"],
|
||||
hdrs = ["packet.h"],
|
||||
deps = [
|
||||
":tuple",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:logging",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "packet_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"packet_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":packet",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "port",
|
||||
hdrs = ["port.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":packet",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:output_side_packet",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "port_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"port_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "subgraph_test",
|
||||
srcs = ["subgraph_test.cc"],
|
||||
deps = [
|
||||
":builder",
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
":test_contracts",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:subgraph_expansion",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tag",
|
||||
hdrs = ["tag.h"],
|
||||
deps = [":const_str"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tag_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"tag_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tag",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tuple",
|
||||
hdrs = ["tuple.h"],
|
||||
deps = ["@com_google_absl//absl/meta:type_traits"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tuple_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"tuple_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tuple",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
111
mediapipe/framework/api2/README.md
Normal file
111
mediapipe/framework/api2/README.md
Normal file
|
@ -0,0 +1,111 @@
|
|||
# Experimental new APIs
|
||||
|
||||
This directory defines new APIs for MediaPipe:
|
||||
|
||||
- Node API, an update to the Calculator API for defining MediaPipe components.
|
||||
- Builder API, for assembling CalculatorGraphConfigs with C++, as an alternative
|
||||
to using the proto API directly.
|
||||
|
||||
The code is working, and the new APIs interoperate fully with the existing
|
||||
framework code. They are considered a work in progress, but are being released
|
||||
now so we can begin adopting them in our calculators.
|
||||
|
||||
Developers are welcome to try out these APIs as early adopters, but should
|
||||
expect breaking changes. The placement of this code under the `mediapipe::api2`
|
||||
namespace is not final.
|
||||
|
||||
## Node API
|
||||
|
||||
This API can be used to define calculators. It is designed to be more type-safe
|
||||
and less verbose than the original API.
|
||||
|
||||
Input/output ports (streams and side packets) can now be declared as typed
|
||||
constants, instead of using plain strings for access.
|
||||
|
||||
For example, instead of
|
||||
|
||||
```
|
||||
constexpr char kSelectTag[] = "SELECT";
|
||||
if (cc->Inputs().HasTag(kSelectTag)) {
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||
}
|
||||
```
|
||||
|
||||
you can write
|
||||
|
||||
```
|
||||
static constexpr Input<int>::Optional kSelect{"SELECT"};
|
||||
```
|
||||
|
||||
Instead of setting up the contract procedurally in `GetContract`, add ports to
|
||||
the contract declaratively, as follows:
|
||||
|
||||
```
|
||||
MEDIAPIPE_NODE_CONTRACT(kInput, kOutput);
|
||||
```
|
||||
|
||||
To access an input in Process, instead of
|
||||
|
||||
```
|
||||
int select = cc->Inputs().Tag(kSelectTag).Get<int>();
|
||||
```
|
||||
|
||||
write
|
||||
|
||||
```
|
||||
int select = kSelectTag(cc).Get(); // alternative: *kSelectTag(cc)
|
||||
```
|
||||
|
||||
Sets of multiple ports can be declared with `::Multiple`. Note, also, that a tag
|
||||
string must always be provided when declaring a port; use `""` for untagged
|
||||
ports. For example:
|
||||
|
||||
|
||||
```
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
}
|
||||
```
|
||||
|
||||
becomes
|
||||
|
||||
```
|
||||
static constexpr Input<AnyType>::Multiple kIn{""};
|
||||
```
|
||||
|
||||
For output ports, the payload can be passed directly to the `Send` method. For
|
||||
example, instead of
|
||||
|
||||
```
|
||||
cc->Outputs().Index(0).Add(
|
||||
new std::pair<Packet, Packet>(cc->Inputs().Index(0).Value(),
|
||||
cc->Inputs().Index(1).Value()),
|
||||
cc->InputTimestamp());
|
||||
```
|
||||
|
||||
you can write
|
||||
|
||||
```
|
||||
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
|
||||
```
|
||||
|
||||
The input timestamp is propagated to the outputs by default. If your calculator
|
||||
wants to alter timestamps, it must add a `TimestampChange` entry to its contract
|
||||
declaration. For example:
|
||||
|
||||
```
|
||||
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
|
||||
StreamHandler("ImmediateInputStreamHandler"),
|
||||
TimestampChange::Arbitrary());
|
||||
```
|
||||
|
||||
Several calculators in
|
||||
[`calculators/core`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core) and
|
||||
[`calculators/tensor`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/tensor)
|
||||
have been updated to use this API. Reference them for more examples.
|
||||
|
||||
More complete documentation will be provided in the future.
|
||||
|
||||
## Builder API
|
||||
|
||||
Documentation will be provided in the future.
|
576
mediapipe/framework/api2/builder.h
Normal file
576
mediapipe/framework/api2/builder.h
Normal file
|
@ -0,0 +1,576 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/contract.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace builder {
|
||||
|
||||
template <typename T>
|
||||
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
|
||||
auto& vec = *vecp;
|
||||
if (vec.size() <= index) {
|
||||
vec.resize(index + 1);
|
||||
}
|
||||
if (vec[index] == nullptr) {
|
||||
vec[index] = absl::make_unique<T>();
|
||||
}
|
||||
return *vec[index];
|
||||
}
|
||||
|
||||
struct TagIndexLocation {
|
||||
const std::string& tag;
|
||||
std::size_t index;
|
||||
std::size_t count;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TagIndexMap {
|
||||
public:
|
||||
std::vector<std::unique_ptr<T>>& operator[](const std::string& tag) {
|
||||
return map_[tag];
|
||||
}
|
||||
|
||||
void Visit(std::function<void(const TagIndexLocation&, const T&)> fun) const {
|
||||
for (const auto& tagged : map_) {
|
||||
TagIndexLocation loc{tagged.first, 0, tagged.second.size()};
|
||||
for (const auto& item : tagged.second) {
|
||||
fun(loc, *item);
|
||||
++loc.index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(std::function<void(const TagIndexLocation&, T*)> fun) {
|
||||
for (auto& tagged : map_) {
|
||||
TagIndexLocation loc{tagged.first, 0, tagged.second.size()};
|
||||
for (auto& item : tagged.second) {
|
||||
fun(loc, item.get());
|
||||
++loc.index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: entries are held by a unique_ptr to ensure pointers remain valid.
|
||||
// Should use absl::flat_hash_map but ordering keys for now.
|
||||
std::map<std::string, std::vector<std::unique_ptr<T>>> map_;
|
||||
};
|
||||
|
||||
// These structs are used internally to store information about the endpoints
|
||||
// of a connection.
|
||||
struct SourceBase;
|
||||
struct DestinationBase {
|
||||
SourceBase* source = nullptr;
|
||||
};
|
||||
struct SourceBase {
|
||||
std::vector<DestinationBase*> dests_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
// Following existing GraphConfig usage, we allow using a multiport as a single
|
||||
// port as well. This is necessary for generic nodes, since we have no
|
||||
// information about which ports are meant to be multiports or not, but it is
|
||||
// also convenient with typed nodes.
|
||||
template <typename Single>
|
||||
class MultiPort : public Single {
|
||||
public:
|
||||
using Base = typename Single::Base;
|
||||
|
||||
explicit MultiPort(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: Single(vec), vec_(*vec) {}
|
||||
|
||||
Single operator[](int index) {
|
||||
CHECK_GE(index, 0);
|
||||
return Single{&GetWithAutoGrow(&vec_, index)};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
||||
// These classes wrap references to the underlying source/destination
|
||||
// endpoints, adding type information and the user-visible API.
|
||||
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
|
||||
class DestinationImpl {
|
||||
public:
|
||||
using Base = DestinationBase;
|
||||
|
||||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||
DestinationBase& base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class DestinationImpl<true, IsSide, T>
|
||||
: public MultiPort<DestinationImpl<false, IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<DestinationImpl<false, IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
|
||||
class SourceImpl {
|
||||
public:
|
||||
using Base = SourceBase;
|
||||
|
||||
// Src is used as the return type of fluent methods below. Since these are
|
||||
// single-port methods, it is desirable to always decay to a reference to the
|
||||
// single-port superclass, even if they are called on a multiport.
|
||||
using Src = SourceImpl<false, IsSide, T>;
|
||||
template <typename U>
|
||||
using Dst = DestinationImpl<false, IsSide, U>;
|
||||
|
||||
// clang-format off
|
||||
template <typename U>
|
||||
struct AllowConnection : public std::integral_constant<bool,
|
||||
std::is_same<T, U>{} || std::is_same<T, internal::Generic>{} ||
|
||||
std::is_same<U, internal::Generic>{}> {};
|
||||
// clang-format on
|
||||
|
||||
explicit SourceImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: SourceImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit SourceImpl(SourceBase* base) : base_(*base) {}
|
||||
|
||||
template <typename U,
|
||||
typename std::enable_if<AllowConnection<U>{}, int>::type = 0>
|
||||
Src& AddTarget(const Dst<U>& dest) {
|
||||
CHECK(dest.base_.source == nullptr);
|
||||
dest.base_.source = &base_;
|
||||
base_.dests_.emplace_back(&dest.base_);
|
||||
return *this;
|
||||
}
|
||||
Src& SetName(std::string name) {
|
||||
base_.name_ = std::move(name);
|
||||
return *this;
|
||||
}
|
||||
template <typename U>
|
||||
Src& operator>>(const Dst<U>& dest) {
|
||||
return AddTarget(dest);
|
||||
}
|
||||
|
||||
private:
|
||||
SourceBase& base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class SourceImpl<true, IsSide, T>
|
||||
: public MultiPort<SourceImpl<false, IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<SourceImpl<false, IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
// A source and a destination correspond to an output/input stream on a node,
|
||||
// and a side source and side destination correspond to an output/input side
|
||||
// packet.
|
||||
// For graph inputs/outputs, however, the inputs are sources, and the outputs
|
||||
// are destinations. This is because graph ports are connected "from inside"
|
||||
// when building the graph.
|
||||
template <bool AllowMultiple = false, typename T = internal::Generic>
|
||||
using Source = SourceImpl<AllowMultiple, false, T>;
|
||||
template <bool AllowMultiple = false, typename T = internal::Generic>
|
||||
using SideSource = SourceImpl<AllowMultiple, true, T>;
|
||||
template <bool AllowMultiple = false, typename T = internal::Generic>
|
||||
using Destination = DestinationImpl<AllowMultiple, false, T>;
|
||||
template <bool AllowMultiple = false, typename T = internal::Generic>
|
||||
using SideDestination = DestinationImpl<AllowMultiple, true, T>;
|
||||
|
||||
class NodeBase {
|
||||
public:
|
||||
// TODO: right now access to an indexed port is made directly by
|
||||
// specifying both a tag and an index. It would be better to represent this
|
||||
// as a two-step lookup, first getting a multi-port, and then accessing one
|
||||
// of its entries by index. However, for nodes without visible contracts we
|
||||
// can't know whether a tag is indexable or not, so we would need the
|
||||
// multi-port to also be usable as a port directly (representing index 0).
|
||||
Source<true> Out(const std::string& tag) {
|
||||
return Source<true>(&out_streams_[tag]);
|
||||
}
|
||||
|
||||
Destination<true> In(const std::string& tag) {
|
||||
return Destination<true>(&in_streams_[tag]);
|
||||
}
|
||||
|
||||
SideSource<true> SideOut(const std::string& tag) {
|
||||
return SideSource<true>(&out_sides_[tag]);
|
||||
}
|
||||
|
||||
SideDestination<true> SideIn(const std::string& tag) {
|
||||
return SideDestination<true>(&in_sides_[tag]);
|
||||
}
|
||||
|
||||
// Convenience methods for accessing purely index-based ports.
|
||||
Source<false> Out(int index) { return Out("")[index]; }
|
||||
|
||||
Destination<false> In(int index) { return In("")[index]; }
|
||||
|
||||
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
|
||||
|
||||
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
|
||||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
}
|
||||
|
||||
protected:
|
||||
NodeBase(std::string type) : type_(std::move(type)) {}
|
||||
|
||||
std::string type_;
|
||||
TagIndexMap<DestinationBase> in_streams_;
|
||||
TagIndexMap<SourceBase> out_streams_;
|
||||
TagIndexMap<DestinationBase> in_sides_;
|
||||
TagIndexMap<SourceBase> out_sides_;
|
||||
CalculatorOptions options_;
|
||||
// ideally we'd just check if any extensions are set on options_
|
||||
bool options_used_ = false;
|
||||
friend class Graph;
|
||||
};
|
||||
|
||||
template <class Calc = internal::Generic>
|
||||
class Node;
|
||||
#if __cplusplus >= 201703L
|
||||
// Deduction guide to silence -Wctad-maybe-unsupported.
|
||||
explicit Node()->Node<internal::Generic>;
|
||||
#endif // C++17
|
||||
|
||||
template <>
|
||||
class Node<internal::Generic> : public NodeBase {
|
||||
public:
|
||||
Node(std::string type) : NodeBase(std::move(type)) {}
|
||||
};
|
||||
|
||||
using GenericNode = Node<internal::Generic>;
|
||||
|
||||
template <template <bool, class> class BP, class Port, class TagIndexMapT>
|
||||
auto MakeBuilderPort(const Port& port, TagIndexMapT& streams) {
|
||||
return BP<Port::kMultiple, typename Port::PayloadT>(&streams[port.Tag()]);
|
||||
}
|
||||
|
||||
template <class Calc>
|
||||
class Node : public NodeBase {
|
||||
public:
|
||||
Node() : NodeBase(Calc::kCalculatorName) {}
|
||||
// Overrides the built-in calculator type std::string with the provided
|
||||
// argument. Can be used to create nodes from pure interfaces.
|
||||
// TODO: only use this for pure interfaces
|
||||
Node(const std::string& type_override) : NodeBase(type_override) {}
|
||||
|
||||
// These methods only allow access to ports declared in the contract.
|
||||
// The argument must be a tag object created with the MPP_TAG macro.
|
||||
// These objects encode the tag in their type, which allows us to return
|
||||
// a result with the appropriate payload type depending on the tag.
|
||||
template <class Tag>
|
||||
auto Out(Tag tag) {
|
||||
constexpr auto& port = Calc::Contract::TaggedOutputs::get(tag);
|
||||
return MakeBuilderPort<Source>(port, out_streams_);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
auto In(Tag tag) {
|
||||
constexpr auto& port = Calc::Contract::TaggedInputs::get(tag);
|
||||
return MakeBuilderPort<Destination>(port, in_streams_);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
auto SideOut(Tag tag) {
|
||||
constexpr auto& port = Calc::Contract::TaggedSideOutputs::get(tag);
|
||||
return MakeBuilderPort<SideSource>(port, out_sides_);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
auto SideIn(Tag tag) {
|
||||
constexpr auto& port = Calc::Contract::TaggedSideInputs::get(tag);
|
||||
return MakeBuilderPort<SideDestination>(port, in_sides_);
|
||||
}
|
||||
|
||||
// We could allow using the non-checked versions with typed nodes too, but
|
||||
// we don't.
|
||||
// using NodeBase::Out;
|
||||
// using NodeBase::In;
|
||||
// using NodeBase::SideOut;
|
||||
// using NodeBase::SideIn;
|
||||
};
|
||||
|
||||
// For legacy PacketGenerators.
|
||||
class PacketGenerator {
|
||||
public:
|
||||
PacketGenerator(std::string type) : type_(std::move(type)) {}
|
||||
|
||||
SideSource<true> SideOut(const std::string& tag) {
|
||||
return SideSource<true>(&out_sides_[tag]);
|
||||
}
|
||||
|
||||
SideDestination<true> SideIn(const std::string& tag) {
|
||||
return SideDestination<true>(&in_sides_[tag]);
|
||||
}
|
||||
|
||||
// Convenience methods for accessing purely index-based ports.
|
||||
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
|
||||
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
|
||||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string type_;
|
||||
TagIndexMap<DestinationBase> in_sides_;
|
||||
TagIndexMap<SourceBase> out_sides_;
|
||||
mediapipe::PacketGeneratorOptions options_;
|
||||
// ideally we'd just check if any extensions are set on options_
|
||||
bool options_used_ = false;
|
||||
friend class Graph;
|
||||
};
|
||||
|
||||
class Graph {
|
||||
public:
|
||||
void SetType(std::string type) { type_ = std::move(type); }
|
||||
|
||||
// Creates a node of a specific type. Should be used for calculators whose
|
||||
// contract is available.
|
||||
template <class Calc>
|
||||
Node<Calc>& AddNode() {
|
||||
auto node = std::make_unique<Node<Calc>>();
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// Creates a node of a specific type. Should be used for pure interfaces,
|
||||
// which do not have a built-in type std::string.
|
||||
template <class Calc>
|
||||
Node<Calc>& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<Node<Calc>>(type);
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// Creates a generic node, with no compile-time checking of inputs and
|
||||
// outputs. This can be used for calculators whose contract is not visible.
|
||||
GenericNode& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<GenericNode>(type);
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// For legacy PacketGenerators.
|
||||
PacketGenerator& AddPacketGenerator(const std::string& type) {
|
||||
auto node = std::make_unique<PacketGenerator>(type);
|
||||
auto node_p = node.get();
|
||||
packet_gens_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// Graph ports, non-typed.
|
||||
Source<true> In(const std::string& graph_input) {
|
||||
return graph_boundary_.Out(graph_input);
|
||||
}
|
||||
|
||||
Destination<true> Out(const std::string& graph_output) {
|
||||
return graph_boundary_.In(graph_output);
|
||||
}
|
||||
|
||||
SideSource<true> SideIn(const std::string& graph_input) {
|
||||
return graph_boundary_.SideOut(graph_input);
|
||||
}
|
||||
|
||||
SideDestination<true> SideOut(const std::string& graph_output) {
|
||||
return graph_boundary_.SideIn(graph_output);
|
||||
}
|
||||
|
||||
// Convenience methods for accessing purely index-based ports.
|
||||
Source<false> In(int index) { return In("")[0]; }
|
||||
|
||||
Destination<false> Out(int index) { return Out("")[0]; }
|
||||
|
||||
SideSource<false> SideIn(int index) { return SideIn("")[0]; }
|
||||
|
||||
SideDestination<false> SideOut(int index) { return SideOut("")[0]; }
|
||||
|
||||
// Graph ports, typed.
|
||||
// TODO: make graph_boundary_ a typed node!
|
||||
template <class PortT, class Payload = typename PortT::PayloadT,
|
||||
class Src = Source<PortT::kMultiple, Payload>>
|
||||
Src In(const PortT& graph_input) {
|
||||
return Src(&graph_boundary_.out_streams_[graph_input.Tag()]);
|
||||
}
|
||||
|
||||
template <class PortT, class Payload = typename PortT::PayloadT,
|
||||
class Dst = Destination<PortT::kMultiple, Payload>>
|
||||
Dst Out(const PortT& graph_output) {
|
||||
return Dst(&graph_boundary_.in_streams_[graph_output.Tag()]);
|
||||
}
|
||||
|
||||
template <class PortT, class Payload = typename PortT::PayloadT,
|
||||
class Src = SideSource<PortT::kMultiple, Payload>>
|
||||
Src SideIn(const PortT& graph_input) {
|
||||
return Src(&graph_boundary_.out_sides_[graph_input.Tag()]);
|
||||
}
|
||||
|
||||
template <class PortT, class Payload = typename PortT::PayloadT,
|
||||
class Dst = SideDestination<PortT::kMultiple, Payload>>
|
||||
Dst SideOut(const PortT& graph_output) {
|
||||
return Dst(&graph_boundary_.in_sides_[graph_output.Tag()]);
|
||||
}
|
||||
|
||||
// Returns the graph config. This can be used to instantiate and run the
|
||||
// graph.
|
||||
CalculatorGraphConfig GetConfig() {
|
||||
CalculatorGraphConfig config;
|
||||
if (!type_.empty()) {
|
||||
config.set_type(type_);
|
||||
}
|
||||
FixUnnamedConnections();
|
||||
CHECK_OK(UpdateBoundaryConfig(&config));
|
||||
for (const std::unique_ptr<NodeBase>& node : nodes_) {
|
||||
auto* out_node = config.add_node();
|
||||
CHECK_OK(UpdateNodeConfig(*node, out_node));
|
||||
}
|
||||
for (const std::unique_ptr<PacketGenerator>& node : packet_gens_) {
|
||||
auto* out_node = config.add_packet_generator();
|
||||
CHECK_OK(UpdateNodeConfig(*node, out_node));
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private:
|
||||
void FixUnnamedConnections(NodeBase* node, int* unnamed_count) {
|
||||
node->out_streams_.Visit([&](const TagIndexLocation&, SourceBase* source) {
|
||||
if (source->name_.empty()) {
|
||||
source->name_ = absl::StrCat("__stream_", (*unnamed_count)++);
|
||||
}
|
||||
});
|
||||
node->out_sides_.Visit([&](const TagIndexLocation&, SourceBase* source) {
|
||||
if (source->name_.empty()) {
|
||||
source->name_ = absl::StrCat("__side_packet_", (*unnamed_count)++);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void FixUnnamedConnections() {
|
||||
int unnamed_count = 0;
|
||||
FixUnnamedConnections(&graph_boundary_, &unnamed_count);
|
||||
for (std::unique_ptr<NodeBase>& node : nodes_) {
|
||||
FixUnnamedConnections(node.get(), &unnamed_count);
|
||||
}
|
||||
for (std::unique_ptr<PacketGenerator>& node : packet_gens_) {
|
||||
node->out_sides_.Visit([&](const TagIndexLocation&, SourceBase* source) {
|
||||
if (source->name_.empty()) {
|
||||
source->name_ = absl::StrCat("__side_packet_", unnamed_count++);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::string TaggedName(const TagIndexLocation& loc, const std::string& name) {
|
||||
if (loc.tag.empty()) {
|
||||
// ParseTagIndexName does not allow using explicit indices without tags,
|
||||
// while ParseTagIndex does. There is no explanation for this discrepancy
|
||||
// in the CLs that introduced them (cl/143209019, cl/156499931).
|
||||
// TODO: decide whether we should just allow it.
|
||||
return name;
|
||||
} else {
|
||||
if (loc.count <= 1) {
|
||||
return absl::StrCat(loc.tag, ":", name);
|
||||
} else {
|
||||
return absl::StrCat(loc.tag, ":", loc.index, ":", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mediapipe::Status UpdateNodeConfig(const NodeBase& node,
|
||||
CalculatorGraphConfig::Node* config) {
|
||||
config->set_calculator(node.type_);
|
||||
node.in_streams_.Visit(
|
||||
[&](const TagIndexLocation& loc, const DestinationBase& endpoint) {
|
||||
CHECK(endpoint.source != nullptr);
|
||||
config->add_input_stream(TaggedName(loc, endpoint.source->name_));
|
||||
});
|
||||
node.out_streams_.Visit(
|
||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_output_stream(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
node.in_sides_.Visit([&](const TagIndexLocation& loc,
|
||||
const DestinationBase& endpoint) {
|
||||
CHECK(endpoint.source != nullptr);
|
||||
config->add_input_side_packet(TaggedName(loc, endpoint.source->name_));
|
||||
});
|
||||
node.out_sides_.Visit(
|
||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
if (node.options_used_) {
|
||||
*config->mutable_options() = node.options_;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
mediapipe::Status UpdateNodeConfig(const PacketGenerator& node,
|
||||
PacketGeneratorConfig* config) {
|
||||
config->set_packet_generator(node.type_);
|
||||
node.in_sides_.Visit([&](const TagIndexLocation& loc,
|
||||
const DestinationBase& endpoint) {
|
||||
CHECK(endpoint.source != nullptr);
|
||||
config->add_input_side_packet(TaggedName(loc, endpoint.source->name_));
|
||||
});
|
||||
node.out_sides_.Visit(
|
||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
if (node.options_used_) {
|
||||
*config->mutable_options() = node.options_;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// For special boundary node.
|
||||
mediapipe::Status UpdateBoundaryConfig(CalculatorGraphConfig* config) {
|
||||
graph_boundary_.in_streams_.Visit(
|
||||
[&](const TagIndexLocation& loc, const DestinationBase& endpoint) {
|
||||
CHECK(endpoint.source != nullptr);
|
||||
config->add_output_stream(TaggedName(loc, endpoint.source->name_));
|
||||
});
|
||||
graph_boundary_.out_streams_.Visit(
|
||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_input_stream(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
graph_boundary_.in_sides_.Visit([&](const TagIndexLocation& loc,
|
||||
const DestinationBase& endpoint) {
|
||||
CHECK(endpoint.source != nullptr);
|
||||
config->add_output_side_packet(TaggedName(loc, endpoint.source->name_));
|
||||
});
|
||||
graph_boundary_.out_sides_.Visit(
|
||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_input_side_packet(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string type_;
|
||||
std::vector<std::unique_ptr<NodeBase>> nodes_;
|
||||
std::vector<std::unique_ptr<PacketGenerator>> packet_gens_;
|
||||
// Special node representing graph inputs and outputs.
|
||||
NodeBase graph_boundary_{"__GRAPH__"};
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
190
mediapipe/framework/api2/builder_test.cc
Normal file
190
mediapipe/framework/api2/builder_test.cc
Normal file
|
@ -0,0 +1,190 @@
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/api2/tag.h"
|
||||
#include "mediapipe/framework/api2/test_contracts.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/message_matchers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace test {
|
||||
|
||||
TEST(BuilderTest, BuildGraph) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE");
|
||||
foo.Out("OUT") >> bar.In("IN");
|
||||
bar.Out("OUT").SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:__stream_0"
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
template <class FooT>
|
||||
void BuildGraphTypedTest() {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode<FooT>();
|
||||
auto& bar = graph.AddNode<Bar>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS"));
|
||||
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
|
||||
bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "IN:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "$0"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "BIAS:side"
|
||||
output_stream: "OUT:__stream_0"
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)",
|
||||
FooT::kCalculatorName));
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<Foo>(); }
|
||||
|
||||
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<Foo2>(); }
|
||||
|
||||
TEST(BuilderTest, FanOut) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
auto& adder = graph.AddNode("FloatAdder");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
foo.Out("OUT") >> adder.In("IN")[0];
|
||||
foo.Out("OUT") >> adder.In("IN")[1];
|
||||
adder.Out("OUT").SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:base"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
output_stream: "OUT:__stream_0"
|
||||
}
|
||||
node {
|
||||
calculator: "FloatAdder"
|
||||
input_stream: "IN:0:__stream_0"
|
||||
input_stream: "IN:1:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, TypedMultiple) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode<Foo>();
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
|
||||
adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:base"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
output_stream: "OUT:__stream_0"
|
||||
}
|
||||
node {
|
||||
calculator: "FloatAdder"
|
||||
input_stream: "IN:0:__stream_0"
|
||||
input_stream: "IN:1:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, PacketGenerator) {
|
||||
builder::Graph graph;
|
||||
auto& generator = graph.AddPacketGenerator("FloatGenerator");
|
||||
graph.SideIn("IN") >> generator.SideIn("IN");
|
||||
generator.SideOut("OUT") >> graph.SideOut("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_side_packet: "IN:__side_packet_0"
|
||||
output_side_packet: "OUT:__side_packet_1"
|
||||
packet_generator {
|
||||
packet_generator: "FloatGenerator"
|
||||
input_side_packet: "IN:__side_packet_0"
|
||||
output_side_packet: "OUT:__side_packet_1"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, EmptyTag) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In("A").SetName("a") >> foo.In("")[0];
|
||||
graph.In("C").SetName("c") >> foo.In("")[2];
|
||||
graph.In("B").SetName("b") >> foo.In("")[1];
|
||||
foo.Out("")[0].SetName("x") >> graph.Out("ONE");
|
||||
foo.Out("")[1].SetName("y") >> graph.Out("TWO");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "A:a"
|
||||
input_stream: "B:b"
|
||||
input_stream: "C:c"
|
||||
output_stream: "ONE:x"
|
||||
output_stream: "TWO:y"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "a"
|
||||
input_stream: "b"
|
||||
input_stream: "c"
|
||||
output_stream: "x"
|
||||
output_stream: "y"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
43
mediapipe/framework/api2/const_str.h
Normal file
43
mediapipe/framework/api2/const_str.h
Normal file
|
@ -0,0 +1,43 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// This class stores a constant std::string that can be inspected at compile
|
||||
// time in constexpr code.
|
||||
class const_str {
|
||||
public:
|
||||
constexpr const_str(std::size_t size, const char* data)
|
||||
: len_(size - 1), data_(data) {}
|
||||
|
||||
template <std::size_t N>
|
||||
explicit constexpr const_str(const char (&str)[N]) : const_str(N, str) {}
|
||||
|
||||
constexpr std::size_t len() const { return len_; }
|
||||
constexpr const char* data() const { return data_; }
|
||||
|
||||
constexpr bool operator==(const const_str& other) const {
|
||||
return len_ == other.len_ && equal(len_, data_, other.data_);
|
||||
}
|
||||
|
||||
constexpr char operator[](const std::size_t idx) const {
|
||||
return idx <= len_ ? data_[idx] : '\0';
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr bool equal(std::size_t len, const char* const p,
|
||||
const char* const q) {
|
||||
return len == 0 || (*p == *q && equal(len - 1, p + 1, q + 1));
|
||||
}
|
||||
|
||||
const std::size_t len_;
|
||||
const char* const data_;
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_
|
387
mediapipe/framework/api2/contract.h
Normal file
387
mediapipe/framework/api2/contract.h
Normal file
|
@ -0,0 +1,387 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/api2/tag.h"
|
||||
#include "mediapipe/framework/api2/tuple.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/output_side_packet.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
class StreamHandler {
|
||||
public:
|
||||
template <std::size_t N>
|
||||
explicit constexpr StreamHandler(const char (&name)[N]) : name_(N, name) {}
|
||||
|
||||
const const_str& name() { return name_; }
|
||||
|
||||
mediapipe::Status AddToContract(CalculatorContract* cc) const {
|
||||
cc->SetInputStreamHandler(name_.data());
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
const const_str name_;
|
||||
};
|
||||
|
||||
class TimestampChange {
|
||||
public:
|
||||
// Note: we don't use TimestampDiff as an argument because it's not constexpr.
|
||||
static constexpr TimestampChange Offset(int64_t offset) {
|
||||
return TimestampChange(offset);
|
||||
}
|
||||
|
||||
static constexpr TimestampChange Arbitrary() {
|
||||
// Same value as used for Timestamp::Unset.
|
||||
return TimestampChange(kUnset);
|
||||
}
|
||||
|
||||
mediapipe::Status AddToContract(CalculatorContract* cc) const {
|
||||
if (offset_ != kUnset) cc->SetTimestampOffset(offset_);
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
constexpr TimestampChange(int64_t offset) : offset_(offset) {}
|
||||
static constexpr int64_t kUnset = std::numeric_limits<int64_t>::min();
|
||||
int64_t offset_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <class Base>
|
||||
struct IsSubclass {
|
||||
template <class T>
|
||||
using pred = std::is_base_of<Base, std::decay_t<T>>;
|
||||
};
|
||||
|
||||
template <class T, class = void>
|
||||
struct HasProcessMethod : std::false_type {};
|
||||
|
||||
template <class T>
|
||||
struct HasProcessMethod<
|
||||
T, std::void_t<decltype(mediapipe::Status(
|
||||
std::declval<std::decay_t<T>>().Process(
|
||||
std::declval<mediapipe::CalculatorContext*>())))>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class = void>
|
||||
struct HasNestedItems : std::false_type {};
|
||||
|
||||
template <class T>
|
||||
struct HasNestedItems<
|
||||
T, std::void_t<decltype(std::declval<std::decay_t<T>>().nested_items())>>
|
||||
: std::true_type {};
|
||||
|
||||
// Helper to construct a tuple of Tag types (see tag.h) from a tuple of ports.
|
||||
template <class TupleRef>
|
||||
struct TagTuple {
|
||||
template <std::size_t J>
|
||||
struct S {
|
||||
const const_str tag{std::get<J>(TupleRef::get()).tag_};
|
||||
};
|
||||
|
||||
template <std::size_t... I>
|
||||
static constexpr auto Make(std::index_sequence<I...> indices) {
|
||||
return std::make_tuple(mediapipe::api2::internal::tag_build(S<I>{})...);
|
||||
}
|
||||
|
||||
static constexpr auto Make() {
|
||||
using TupleT = decltype(TupleRef::get());
|
||||
return Make(internal::tuple_index_sequence<TupleT>());
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to access a tuple of ports by static tag. Attempts to look up a
|
||||
// missing tag will not compile.
|
||||
template <class TupleRef>
|
||||
struct TaggedAccess {
|
||||
// This is not functionally necessary (we could do the tag search directly
|
||||
// on the port tuple), but it gives a more readable error message when the
|
||||
// static_assert below fails.
|
||||
static constexpr auto kTagTuple = TagTuple<TupleRef>::Make();
|
||||
|
||||
template <class Tag>
|
||||
static constexpr auto& get(Tag tag) {
|
||||
constexpr auto i =
|
||||
internal::tuple_find([tag](auto x) { return x == tag; }, kTagTuple);
|
||||
static_assert(i < std::tuple_size_v<decltype(kTagTuple)>, "tag not found");
|
||||
return std::get<i>(TupleRef::get());
|
||||
}
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
constexpr auto ExtractNestedItems(std::tuple<T...> tuple) {
|
||||
return internal::flatten_tuple(internal::map_tuple(
|
||||
[](auto&& item) {
|
||||
if constexpr (HasNestedItems<decltype(item)>{}) {
|
||||
return std::tuple_cat(std::make_tuple(item), item.nested_items());
|
||||
} else {
|
||||
return std::make_tuple(item);
|
||||
}
|
||||
},
|
||||
tuple));
|
||||
}
|
||||
|
||||
// Internal contract type. Takes a list of ports or other contract items.
|
||||
template <typename... T>
|
||||
class Contract {
|
||||
public:
|
||||
constexpr Contract(std::tuple<T...> tuple) : items(tuple) {}
|
||||
constexpr Contract(T&&... args)
|
||||
: Contract(std::tuple<T...>{std::move(args)...}) {}
|
||||
|
||||
mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) const {
|
||||
std::vector<mediapipe::Status> statuses;
|
||||
auto store_status = [&statuses](mediapipe::Status status) {
|
||||
if (!status.ok()) statuses.push_back(std::move(status));
|
||||
};
|
||||
internal::tuple_for_each(
|
||||
[cc, &store_status](auto&& item) {
|
||||
store_status(item.AddToContract(cc));
|
||||
},
|
||||
all_items);
|
||||
|
||||
if (timestamp_change_count() == 0) {
|
||||
// Default to SetOffset(0);
|
||||
store_status(TimestampChange::Offset(0).AddToContract(cc));
|
||||
}
|
||||
|
||||
if (statuses.empty()) return {};
|
||||
if (statuses.size() == 1) return statuses[0];
|
||||
return tool::CombinedStatus("Multiple errors", statuses);
|
||||
}
|
||||
|
||||
std::tuple<T...> items;
|
||||
|
||||
// TODO: when forwarding nested items (e.g. ports), check for conflicts.
|
||||
decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)};
|
||||
|
||||
constexpr auto inputs() const {
|
||||
return internal::filter_tuple<IsSubclass<InputBase>::pred>(all_items);
|
||||
}
|
||||
constexpr auto outputs() const {
|
||||
return internal::filter_tuple<IsSubclass<OutputBase>::pred>(all_items);
|
||||
}
|
||||
constexpr auto side_inputs() const {
|
||||
return internal::filter_tuple<IsSubclass<SideInputBase>::pred>(all_items);
|
||||
}
|
||||
constexpr auto side_outputs() const {
|
||||
return internal::filter_tuple<IsSubclass<SideOutputBase>::pred>(all_items);
|
||||
}
|
||||
|
||||
constexpr auto timestamp_change_count() const {
|
||||
return internal::filtered_tuple_indices<IsSubclass<TimestampChange>::pred>(
|
||||
all_items)
|
||||
.size();
|
||||
}
|
||||
|
||||
constexpr auto process_items() const {
|
||||
return internal::filter_tuple<HasProcessMethod>(all_items);
|
||||
}
|
||||
};
|
||||
|
||||
// Helpers to construct a Contract.
|
||||
template <typename... T>
|
||||
constexpr auto MakeContract(T&&... args) {
|
||||
return Contract<T...>(std::forward<T>(args)...);
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
constexpr auto MakeContract(const std::tuple<T...>& tuple) {
|
||||
return Contract<T...>(tuple);
|
||||
}
|
||||
|
||||
// Helper for accessing the ports of a Contract by static tags.
|
||||
template <typename C2T, const C2T& c2>
|
||||
class TaggedContract {
|
||||
public:
|
||||
constexpr TaggedContract() = default;
|
||||
|
||||
static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
|
||||
return c2.GetContract(cc);
|
||||
}
|
||||
|
||||
template <class Tuple, Tuple (C2T::*member)() const>
|
||||
struct GetMember {
|
||||
static constexpr const auto get() { return (c2.*member)(); }
|
||||
};
|
||||
|
||||
using TaggedInputs =
|
||||
TaggedAccess<GetMember<decltype(c2.inputs()), &C2T::inputs>>;
|
||||
using TaggedOutputs =
|
||||
TaggedAccess<GetMember<decltype(c2.outputs()), &C2T::outputs>>;
|
||||
using TaggedSideInputs =
|
||||
TaggedAccess<GetMember<decltype(c2.side_inputs()), &C2T::side_inputs>>;
|
||||
using TaggedSideOutputs =
|
||||
TaggedAccess<GetMember<decltype(c2.side_outputs()), &C2T::side_outputs>>;
|
||||
};
|
||||
|
||||
// Support for function-based Process.
|
||||
|
||||
template <class T>
|
||||
struct IsInputPort
|
||||
: std::bool_constant<std::is_base_of<InputBase, std::decay_t<T>>{} ||
|
||||
std::is_base_of<SideInputBase, std::decay_t<T>>{}> {};
|
||||
|
||||
template <class T>
|
||||
struct IsOutputPort
|
||||
: std::bool_constant<std::is_base_of<OutputBase, std::decay_t<T>>{} ||
|
||||
std::is_base_of<SideOutputBase, std::decay_t<T>>{}> {};
|
||||
|
||||
// Helper class that converts a port specification into a function argument.
|
||||
template <class P>
|
||||
class PortArg {
|
||||
public:
|
||||
PortArg(CalculatorContext* cc, const P& port) : cc_(cc), port_(port) {}
|
||||
|
||||
using PayloadT = typename P::PayloadT;
|
||||
|
||||
operator const PayloadT&() { return port_(cc_).Get(); }
|
||||
|
||||
operator Packet<typename P::value_t>() { return port_(cc_); }
|
||||
|
||||
operator PacketBase() { return port_(cc_).packet(); }
|
||||
|
||||
private:
|
||||
CalculatorContext* cc_;
|
||||
const P& port_;
|
||||
};
|
||||
|
||||
template <class P>
|
||||
auto MakePortArg(CalculatorContext* cc, const P& port) {
|
||||
return PortArg<P>(cc, port);
|
||||
}
|
||||
|
||||
// Helper class that takes a function result and sends it into outputs.
|
||||
template <class... P>
|
||||
class OutputSender {
|
||||
public:
|
||||
OutputSender(P&&... args) : outputs_(args...) {}
|
||||
OutputSender(std::tuple<P...>&& args) : outputs_(args) {}
|
||||
|
||||
template <class R, std::enable_if_t<sizeof...(P) == 1, int> = 0>
|
||||
mediapipe::Status operator()(CalculatorContext* cc,
|
||||
mediapipe::StatusOr<R>&& result) {
|
||||
if (result.ok()) {
|
||||
return this(cc, result.ValueOrDie());
|
||||
} else {
|
||||
return result.status();
|
||||
}
|
||||
}
|
||||
|
||||
template <class R, std::enable_if_t<sizeof...(P) == 1, int> = 0>
|
||||
mediapipe::Status operator()(CalculatorContext* cc, R&& result) {
|
||||
std::get<0>(outputs_)(cc).Send(std::forward<R>(result));
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class... R>
|
||||
mediapipe::Status operator()(CalculatorContext* cc,
|
||||
mediapipe::StatusOr<std::tuple<R...>>&& result) {
|
||||
if (result.ok()) {
|
||||
return this(cc, result.ValueOrDie());
|
||||
} else {
|
||||
return result.status();
|
||||
}
|
||||
}
|
||||
|
||||
template <class... R>
|
||||
mediapipe::Status operator()(CalculatorContext* cc,
|
||||
std::tuple<R...>&& result) {
|
||||
static_assert(sizeof...(P) == sizeof...(R), "");
|
||||
internal::tuple_for_each(
|
||||
[cc, &result](const auto& port, auto i_const) {
|
||||
constexpr std::size_t i = decltype(i_const)::value;
|
||||
port(cc).Send(std::get<i>(result));
|
||||
},
|
||||
outputs_);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::tuple<P...> outputs_;
|
||||
};
|
||||
|
||||
template <class... P>
|
||||
auto MakeOutputSender(P&&... args) {
|
||||
return OutputSender<P...>(std::forward<P>(args)...);
|
||||
}
|
||||
|
||||
template <class... P>
|
||||
auto MakeOutputSender(std::tuple<P...>&& args) {
|
||||
return OutputSender<P...>(std::forward<std::tuple<P...>>(args));
|
||||
}
|
||||
|
||||
// Contract item that specifies that certain I/O ports are handled by invoking
|
||||
// a specific function.
|
||||
template <class F, class... P>
|
||||
class FunCaller {
|
||||
public:
|
||||
constexpr FunCaller(F&& f, P&&... args) : f_(f), args_(args...) {}
|
||||
|
||||
auto operator()(CalculatorContext* cc) const {
|
||||
auto output_sender = MakeOutputSender(outputs());
|
||||
// tuple_apply gives better error messages than std::apply if the argument
|
||||
// types don't match.
|
||||
return output_sender(
|
||||
cc, internal::tuple_apply(f_, internal::map_tuple(
|
||||
[cc](const auto& port) {
|
||||
return MakePortArg(cc, port);
|
||||
},
|
||||
inputs())));
|
||||
}
|
||||
|
||||
auto inputs() const { return internal::filter_tuple<IsInputPort>(args_); }
|
||||
auto outputs() const { return internal::filter_tuple<IsOutputPort>(args_); }
|
||||
|
||||
mediapipe::Status AddToContract(CalculatorContract* cc) const { return {}; }
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) const { return (*this)(cc); }
|
||||
|
||||
constexpr std::tuple<P...> nested_items() const { return args_; }
|
||||
|
||||
F f_;
|
||||
std::tuple<P...> args_;
|
||||
};
|
||||
|
||||
// Helper function to invoke function callers in Process.
|
||||
|
||||
// TODO: implement multiple callers for syncsets.
|
||||
template <class... T>
|
||||
mediapipe::Status ProcessFnCallers(CalculatorContext* cc,
|
||||
std::tuple<T...> callers);
|
||||
|
||||
inline mediapipe::Status ProcessFnCallers(CalculatorContext* cc, std::tuple<>) {
|
||||
return mediapipe::InternalError("Process unimplemented");
|
||||
}
|
||||
|
||||
template <class T>
|
||||
mediapipe::Status ProcessFnCallers(CalculatorContext* cc,
|
||||
std::tuple<T> callers) {
|
||||
return std::get<0>(callers).Process(cc);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Function used to add a process function to a calculator contract.
|
||||
template <class F, class... P>
|
||||
constexpr auto ProcessFn(F&& f, P&&... args) {
|
||||
return internal::FunCaller<F, P...>(std::forward<F>(f),
|
||||
std::forward<P>(args)...);
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_
|
73
mediapipe/framework/api2/contract_test.cc
Normal file
73
mediapipe/framework/api2/contract_test.cc
Normal file
|
@ -0,0 +1,73 @@
|
|||
#include "mediapipe/framework/api2/contract.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
struct ProcessItem {
|
||||
mediapipe::Status Process(CalculatorContext* cc) { return {}; }
|
||||
};
|
||||
|
||||
struct ItemWithNested {
|
||||
constexpr auto nested_items() { return std::make_tuple(Input<char>{"FWD"}); }
|
||||
};
|
||||
|
||||
static constexpr auto kTestContract = internal::MakeContract(
|
||||
Input<int>{"BASE"}, Input<float>::Optional{"SCALE"}, Output<float>{"OUT"},
|
||||
SideInput<float>::Optional{"BIAS"}, SideOutput<char>{"SIDE"},
|
||||
ProcessItem{});
|
||||
|
||||
static_assert(std::tuple_size_v<decltype(kTestContract.inputs())> == 2, "");
|
||||
static_assert(std::tuple_size_v<decltype(kTestContract.outputs())> == 1, "");
|
||||
static_assert(std::tuple_size_v<decltype(kTestContract.side_inputs())> == 1,
|
||||
"");
|
||||
static_assert(std::tuple_size_v<decltype(kTestContract.side_outputs())> == 1,
|
||||
"");
|
||||
|
||||
static_assert(internal::HasProcessMethod<ProcessItem>{}, "");
|
||||
static_assert(!internal::HasProcessMethod<Input<int>>{}, "");
|
||||
|
||||
static_assert(std::tuple_size_v<decltype(kTestContract.process_items())> == 1,
|
||||
"");
|
||||
|
||||
static constexpr auto kExtractNested1 = internal::ExtractNestedItems(
|
||||
std::make_tuple(Input<int>{"BASE"}, Input<float>::Optional{"SCALE"},
|
||||
Output<float>{"OUT"}));
|
||||
|
||||
static_assert(std::tuple_size_v<decltype(kExtractNested1)> == 3, "");
|
||||
|
||||
static constexpr auto kExtractNested2 = internal::ExtractNestedItems(
|
||||
std::make_tuple(Input<int>{"BASE"}, Input<float>::Optional{"SCALE"},
|
||||
Output<float>{"OUT"}, ItemWithNested{}));
|
||||
static_assert(std::tuple_size_v<decltype(kExtractNested2)> == 5, "");
|
||||
|
||||
using TaggedTestContract =
|
||||
internal::TaggedContract<decltype(kTestContract), kTestContract>;
|
||||
|
||||
static constexpr auto kBASE = MPP_TAG("BASE");
|
||||
static constexpr auto kSCALE = MPP_TAG("SCALE");
|
||||
static constexpr auto kBIAS = MPP_TAG("BIAS");
|
||||
static constexpr auto kOUT = MPP_TAG("OUT");
|
||||
static constexpr auto kSIDE = MPP_TAG("SIDE");
|
||||
|
||||
static_assert(TaggedTestContract::TaggedInputs::get(kBASE).tag_ == kBASE.kStr,
|
||||
"");
|
||||
static_assert(TaggedTestContract::TaggedInputs::get(kSCALE).tag_ == kSCALE.kStr,
|
||||
"");
|
||||
static_assert(TaggedTestContract::TaggedOutputs::get(kOUT).tag_ == kOUT.kStr,
|
||||
"");
|
||||
static_assert(TaggedTestContract::TaggedSideInputs::get(kBIAS).tag_ ==
|
||||
kBIAS.kStr,
|
||||
"");
|
||||
static_assert(TaggedTestContract::TaggedSideOutputs::get(kSIDE).tag_ ==
|
||||
kSIDE.kStr,
|
||||
"");
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
9
mediapipe/framework/api2/node.cc
Normal file
9
mediapipe/framework/api2/node.cc
Normal file
|
@ -0,0 +1,9 @@
|
|||
#include "mediapipe/framework/api2/node.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
Node::~Node() {}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
248
mediapipe/framework/api2/node.h
Normal file
248
mediapipe/framework/api2/node.h
Normal file
|
@ -0,0 +1,248 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_NODE_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_NODE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/contract.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/deps/no_destructor.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
class NodeIntf {};
|
||||
|
||||
class Node : public CalculatorBase {
|
||||
public:
|
||||
virtual ~Node();
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <class T>
|
||||
class CalculatorBaseFactoryFor<
|
||||
T,
|
||||
typename std::enable_if<std::is_base_of<mediapipe::api2::Node, T>{}>::type>
|
||||
: public CalculatorBaseFactory {
|
||||
public:
|
||||
mediapipe::Status GetContract(CalculatorContract* cc) final {
|
||||
auto status = T::Contract::GetContract(cc);
|
||||
if (status.ok()) {
|
||||
status = UpdateContract<T>(cc);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorBase> CreateCalculator(
|
||||
CalculatorContext* calculator_context) final {
|
||||
return absl::make_unique<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename U>
|
||||
auto UpdateContract(CalculatorContract* cc)
|
||||
-> decltype(U::UpdateContract(cc)) {
|
||||
return U::UpdateContract(cc);
|
||||
}
|
||||
template <typename U>
|
||||
mediapipe::Status UpdateContract(...) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
namespace api2 {
|
||||
namespace internal {
|
||||
|
||||
// Defining a member of this type causes P to be ODR-used, which forces its
|
||||
// instantiation if it's a static member of a template.
|
||||
// Previously we depended on the pointer's value to determine whether the size
|
||||
// of a character array is 0 or 1, forcing it to be instantiated so the
|
||||
// compiler can determine the object's layout. But using it as a template
|
||||
// argument is more compact.
|
||||
template <auto* P>
|
||||
struct ForceStaticInstantiation {
|
||||
#ifdef _MSC_VER
|
||||
// Just having it as the template argument does not count as a use for
|
||||
// MSVC.
|
||||
static constexpr bool Use() { return P != nullptr; }
|
||||
char force_static[Use()];
|
||||
#endif // _MSC_VER
|
||||
};
|
||||
|
||||
// Helper template for forcing the definition of a static registration token.
|
||||
template <typename T>
|
||||
struct NodeRegistrationStatic {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
|
||||
static mediapipe::RegistrationToken Make() {
|
||||
return mediapipe::CalculatorBaseRegistry::Register(
|
||||
T::kCalculatorName,
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>);
|
||||
}
|
||||
|
||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
||||
};
|
||||
|
||||
// Static members of template classes can be defined in the header.
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
|
||||
|
||||
template <typename T>
|
||||
struct SubgraphRegistrationImpl {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
|
||||
static mediapipe::RegistrationToken Make() {
|
||||
return mediapipe::SubgraphRegistry::Register(T::kCalculatorName,
|
||||
absl::make_unique<T>);
|
||||
}
|
||||
|
||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
SubgraphRegistrationImpl<T>::registration(
|
||||
SubgraphRegistrationImpl<T>::Make());
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// By passing the Impl parameter, registration is done automatically. No need
|
||||
// to use MEDIAPIPE_NODE_IMPLEMENTATION.
|
||||
// For backward compatibility, Impl can be omitted; use
|
||||
// MEDIAPIPE_NODE_IMPLEMENTATION with this.
|
||||
// TODO: migrate and remove.
|
||||
template <class Impl = void>
|
||||
class RegisteredNode;
|
||||
|
||||
template <class Impl>
|
||||
class RegisteredNode : public Node {
|
||||
private:
|
||||
// The member below triggers instantiation of the registration static.
|
||||
// Note that the constructor of calculator subclasses is only invoked through
|
||||
// the registration token, and so we cannot simply use the static in the
|
||||
// constructor.
|
||||
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
|
||||
};
|
||||
|
||||
// No-op version for backwards compatibility.
|
||||
template <>
|
||||
class RegisteredNode<void> : public Node {};
|
||||
|
||||
template <class Impl>
|
||||
struct FunctionNode : public RegisteredNode<Impl> {
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return internal::ProcessFnCallers(cc, Impl::kContract.process_items());
|
||||
}
|
||||
};
|
||||
|
||||
template <class Intf, class Impl = void>
|
||||
class NodeImpl : public RegisteredNode<Impl>, public Intf {
|
||||
protected:
|
||||
// These methods allow accessing a node's ports by tag. This can be useful in
|
||||
// a few cases, e.g. if the port is not available as a named constant.
|
||||
// They parallel the corresponding methods on builder nodes.
|
||||
template <class Tag>
|
||||
static constexpr auto Out(Tag t) {
|
||||
return Intf::Contract::TaggedOutputs::get(t);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
static constexpr auto In(Tag t) {
|
||||
return Intf::Contract::TaggedInputs::get(t);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
static constexpr auto SideOut(Tag t) {
|
||||
return Intf::Contract::TaggedSideOutputs::get(t);
|
||||
}
|
||||
|
||||
template <class Tag>
|
||||
static constexpr auto SideIn(Tag t) {
|
||||
return Intf::Contract::TaggedSideInputs::get(t);
|
||||
}
|
||||
|
||||
// Convenience.
|
||||
template <class Tag, class CC>
|
||||
static auto Out(Tag t, CC cc) {
|
||||
return Out(t)(cc);
|
||||
}
|
||||
template <class Tag, class CC>
|
||||
static auto In(Tag t, CC cc) {
|
||||
return In(t)(cc);
|
||||
}
|
||||
template <class Tag, class CC>
|
||||
static auto SideOut(Tag t, CC cc) {
|
||||
return SideOut(t)(cc);
|
||||
}
|
||||
template <class Tag, class CC>
|
||||
static auto SideIn(Tag t, CC cc) {
|
||||
return SideIn(t)(cc);
|
||||
}
|
||||
};
|
||||
|
||||
// This macro is used to define the contract, without also giving the
|
||||
// node a type name. It can be used directly in pure interfaces.
|
||||
#define MEDIAPIPE_NODE_CONTRACT(...) \
|
||||
static constexpr auto kContract = \
|
||||
mediapipe::api2::internal::MakeContract(__VA_ARGS__); \
|
||||
using Contract = \
|
||||
typename mediapipe::api2::internal::TaggedContract<decltype(kContract), \
|
||||
kContract>;
|
||||
|
||||
// This macro is used to define the contract and the type name of a node.
|
||||
// This saves the name of the calculator, making it available to the
|
||||
// implementation too, and to the registration macro for it. The reason is
|
||||
// that the name must be available with the contract (so that it can be used
|
||||
// to build a graph config, for instance); however, it is the implementation
|
||||
// that needs to be registered.
|
||||
// TODO: rename to MEDIAPIPE_NODE_DECLARATION?
|
||||
// TODO: more detailed explanation.
|
||||
#define MEDIAPIPE_NODE_INTERFACE(name, ...) \
|
||||
static constexpr char kCalculatorName[] = #name; \
|
||||
MEDIAPIPE_NODE_CONTRACT(__VA_ARGS__)
|
||||
|
||||
// TODO: verify that the subgraph config fully implements the
|
||||
// declared interface.
|
||||
template <class Intf, class Impl>
|
||||
class SubgraphImpl : public Subgraph, public Intf {
|
||||
private:
|
||||
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_;
|
||||
};
|
||||
|
||||
// This macro is used to register a calculator that does not use automatic
|
||||
// registration. Deprecated.
|
||||
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
|
||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
||||
REGISTRY_STATIC_VAR(calculator_registration, __LINE__)( \
|
||||
mediapipe::CalculatorBaseRegistry::Register( \
|
||||
Impl::kCalculatorName, \
|
||||
absl::make_unique< \
|
||||
mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
|
||||
|
||||
// This macro is used to register a non-split-contract calculator. Deprecated.
|
||||
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
|
||||
|
||||
// This macro is used to define a subgraph that does not use automatic
|
||||
// registration. Deprecated.
|
||||
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
|
||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
||||
REGISTRY_STATIC_VAR(subgraph_registration, \
|
||||
__LINE__)(mediapipe::SubgraphRegistry::Register( \
|
||||
Impl::kCalculatorName, absl::make_unique<Impl>))
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_NODE_H_
|
527
mediapipe/framework/api2/node_test.cc
Normal file
527
mediapipe/framework/api2/node_test.cc
Normal file
|
@ -0,0 +1,527 @@
|
|||
#include "mediapipe/framework/api2/node.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/api2/test_contracts.h"
|
||||
#include "mediapipe/framework/api2/tuple.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace test {
|
||||
|
||||
using testing::ElementsAre;
|
||||
|
||||
// Returns the packet values for a vector of Packets.
|
||||
template <typename T>
|
||||
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {
|
||||
std::vector<T> result;
|
||||
for (const auto& packet : packets) {
|
||||
result.push_back(packet.Get<T>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class FooImpl : public NodeImpl<Foo, FooImpl> {
|
||||
public:
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
float bias = kBias(cc).GetOr(0.0);
|
||||
float scale = kScale(cc).GetOr(1.0);
|
||||
kOut(cc).Send(*kBase(cc) * scale + bias);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class Foo3 : public FunctionNode<Foo3> {
|
||||
public:
|
||||
static constexpr Input<int> kBase{"BASE"};
|
||||
static constexpr Input<float>::Optional kScale{"SCALE"};
|
||||
static constexpr Output<float> kOut{"OUT"};
|
||||
static constexpr SideInput<float>::Optional kBias{"BIAS"};
|
||||
|
||||
static float foo(int base, Packet<float> bias, Packet<float> scale) {
|
||||
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
|
||||
}
|
||||
|
||||
// TODO: add support for methods.
|
||||
MEDIAPIPE_NODE_INTERFACE(Foo3, ProcessFn(&foo, kBase, kBias, kScale, kOut));
|
||||
};
|
||||
|
||||
class Foo4 : public FunctionNode<Foo4> {
|
||||
public:
|
||||
static float foo(int base, Packet<float> bias, Packet<float> scale) {
|
||||
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
|
||||
}
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(Foo4, ProcessFn(&foo, Input<int>{"BASE"},
|
||||
SideInput<float>::Optional{"BIAS"},
|
||||
Input<float>::Optional{"SCALE"},
|
||||
Output<float>{"OUT"}));
|
||||
};
|
||||
|
||||
class Foo5 : public FunctionNode<Foo5> {
|
||||
public:
|
||||
MEDIAPIPE_NODE_INTERFACE(
|
||||
Foo5, ProcessFn(
|
||||
[](int base, Packet<float> bias, Packet<float> scale) {
|
||||
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
|
||||
},
|
||||
Input<int>{"BASE"}, SideInput<float>::Optional{"BIAS"},
|
||||
Input<float>::Optional{"SCALE"}, Output<float>{"OUT"}));
|
||||
};
|
||||
|
||||
class Foo2Impl : public NodeImpl<Foo2, Foo2Impl> {
|
||||
public:
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
float bias = SideIn(MPP_TAG("BIAS"), cc).GetOr(0.0);
|
||||
float scale = In(MPP_TAG("SCALE"), cc).GetOr(1.0);
|
||||
Out(MPP_TAG("OUT"), cc).Send(*In(MPP_TAG("BASE"), cc) * scale + bias);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class BarImpl : public NodeImpl<Bar, BarImpl> {
|
||||
public:
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
Packet p = kIn(cc);
|
||||
kOut(cc).Send(p);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class BazImpl : public NodeImpl<Baz> {
|
||||
public:
|
||||
static mediapipe::Status UpdateContract(CalculatorContract* cc) { return {}; }
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
for (int i = 0; i < kData(cc).Count(); ++i) {
|
||||
kDataOut(cc)[i].Send(kData(cc)[i]);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_NODE_IMPLEMENTATION(BazImpl);
|
||||
|
||||
class IntForwarderImpl : public NodeImpl<IntForwarder, IntForwarderImpl> {
|
||||
public:
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
kOut(cc).Send(*kIn(cc));
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
class ToFloatImpl : public NodeImpl<ToFloat, ToFloatImpl> {
|
||||
public:
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
kIn(cc).Visit([cc](auto x) { kOut(cc).Send(x); });
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
TEST(NodeTest, GetContract) {
|
||||
// In the old API, contracts are defined "backwards"; first you fill it in
|
||||
// with what you have in the graph, then you let the calculator fill it in
|
||||
// with what it expects, and then you see if they match.
|
||||
const CalculatorGraphConfig::Node node_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_stream: "SCALE:scale"
|
||||
output_stream: "OUT:out"
|
||||
)");
|
||||
mediapipe::CalculatorContract contract;
|
||||
MP_EXPECT_OK(contract.Initialize(node_config));
|
||||
MP_EXPECT_OK(Foo::Contract::GetContract(&contract));
|
||||
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Inputs()));
|
||||
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Outputs()));
|
||||
}
|
||||
|
||||
TEST(NodeTest, GetContractMulti) {
|
||||
const CalculatorGraphConfig::Node node_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "Baz"
|
||||
input_stream: "DATA:0:b"
|
||||
input_stream: "DATA:1:c"
|
||||
output_stream: "DATA:0:d"
|
||||
output_stream: "DATA:1:e"
|
||||
)");
|
||||
mediapipe::CalculatorContract contract;
|
||||
MP_EXPECT_OK(contract.Initialize(node_config));
|
||||
MP_EXPECT_OK(Baz::Contract::GetContract(&contract));
|
||||
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Inputs()));
|
||||
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Outputs()));
|
||||
}
|
||||
|
||||
TEST(NodeTest, CreateByName) {
|
||||
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByName("Foo"));
|
||||
}
|
||||
|
||||
void RunFooCalculatorInGraph(const std::string& foo_name) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "base"
|
||||
input_stream: "scale"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "$0"
|
||||
input_stream: "BASE:base"
|
||||
input_stream: "SCALE:scale"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)",
|
||||
foo_name));
|
||||
std::vector<mediapipe::Packet> out_packets;
|
||||
tool::AddVectorSink("out", &config, &out_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"base", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"scale", mediapipe::MakePacket<float>(2.0).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(20.0));
|
||||
}
|
||||
|
||||
TEST(NodeTest, RunInGraph) { RunFooCalculatorInGraph("Foo"); }
|
||||
|
||||
TEST(NodeTest, RunInGraph3) { RunFooCalculatorInGraph("Foo3"); }
|
||||
|
||||
TEST(NodeTest, RunInGraph4) { RunFooCalculatorInGraph("Foo4"); }
|
||||
|
||||
TEST(NodeTest, RunInGraph5) { RunFooCalculatorInGraph("Foo5"); }
|
||||
|
||||
TEST(NodeTest, OptionalStream) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "base"
|
||||
input_side_packet: "bias"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "BIAS:bias"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<mediapipe::Packet> out_packets;
|
||||
tool::AddVectorSink("out", &config, &out_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({{"bias", mediapipe::MakePacket<float>(30.0)}}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"base", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(40.0));
|
||||
}
|
||||
|
||||
TEST(NodeTest, DynamicTypes) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:in"
|
||||
output_stream: "OUT:bar"
|
||||
}
|
||||
node {
|
||||
calculator: "IntForwarder"
|
||||
input_stream: "IN:bar"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<mediapipe::Packet> out_packets;
|
||||
tool::AddVectorSink("out", &config, &out_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(PacketValues<int>(out_packets), testing::ElementsAre(10));
|
||||
}
|
||||
|
||||
TEST(NodeTest, MultiPort) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in0"
|
||||
input_stream: "in1"
|
||||
output_stream: "out0"
|
||||
output_stream: "out1"
|
||||
node {
|
||||
calculator: "Baz"
|
||||
input_stream: "DATA:0:in0"
|
||||
input_stream: "DATA:1:in1"
|
||||
output_stream: "DATA:0:baz0"
|
||||
output_stream: "DATA:1:baz1"
|
||||
}
|
||||
node {
|
||||
calculator: "IntForwarder"
|
||||
input_stream: "IN:baz0"
|
||||
output_stream: "OUT:out0"
|
||||
}
|
||||
node {
|
||||
calculator: "IntForwarder"
|
||||
input_stream: "IN:baz1"
|
||||
output_stream: "OUT:out1"
|
||||
}
|
||||
)");
|
||||
std::vector<mediapipe::Packet> out0_packets;
|
||||
std::vector<mediapipe::Packet> out1_packets;
|
||||
tool::AddVectorSink("out0", &config, &out0_packets);
|
||||
tool::AddVectorSink("out1", &config, &out1_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in0", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in1", mediapipe::MakePacket<int>(5).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in0", mediapipe::MakePacket<int>(15).At(Timestamp(2))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in1", mediapipe::MakePacket<int>(7).At(Timestamp(2))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
std::vector<int> out0_values;
|
||||
std::vector<int> out1_values;
|
||||
for (auto& packet : out0_packets) {
|
||||
out0_values.push_back(packet.Get<int>());
|
||||
}
|
||||
for (auto& packet : out1_packets) {
|
||||
out1_values.push_back(packet.Get<int>());
|
||||
}
|
||||
EXPECT_EQ(out0_values, (std::vector<int>{10, 15}));
|
||||
EXPECT_EQ(out1_values, (std::vector<int>{5, 7}));
|
||||
}
|
||||
|
||||
struct SideFallback : public Node {
|
||||
static constexpr Input<int> kIn{"IN"};
|
||||
static constexpr Input<int>::SideFallback kFactor{"FACTOR"};
|
||||
static constexpr Output<int> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kFactor, kOut);
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
kOut(cc).Send(kIn(cc).Get() * kFactor(cc).Get());
|
||||
return {};
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(SideFallback);
|
||||
|
||||
TEST(NodeTest, SideFallbackWithStream) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
input_stream: "factor"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "SideFallback"
|
||||
input_stream: "IN:in"
|
||||
input_stream: "FACTOR:factor"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<int> outputs;
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(
|
||||
graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) {
|
||||
outputs.push_back(p.Get<int>());
|
||||
return mediapipe::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", mediapipe::MakePacket<int>(10).At(Timestamp(0))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"factor", mediapipe::MakePacket<int>(2).At(Timestamp(0))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_EQ(outputs, std::vector<int>{20});
|
||||
}
|
||||
|
||||
TEST(NodeTest, SideFallbackWithSide) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
input_side_packet: "factor"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "SideFallback"
|
||||
input_stream: "IN:in"
|
||||
input_side_packet: "FACTOR:factor"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<int> outputs;
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(
|
||||
graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) {
|
||||
outputs.push_back(p.Get<int>());
|
||||
return mediapipe::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph.StartRun({{"factor", mediapipe::MakePacket<int>(2)}}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", mediapipe::MakePacket<int>(10).At(Timestamp(0))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_EQ(outputs, std::vector<int>{20});
|
||||
}
|
||||
|
||||
TEST(NodeTest, SideFallbackWithNone) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "SideFallback"
|
||||
input_stream: "IN:in"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<int> outputs;
|
||||
mediapipe::CalculatorGraph graph;
|
||||
auto status = graph.Initialize(config, {});
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
|
||||
}
|
||||
|
||||
TEST(NodeTest, SideFallbackWithBoth) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
input_stream: "factor"
|
||||
input_side_packet: "factor_side"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "SideFallback"
|
||||
input_stream: "IN:in"
|
||||
input_stream: "FACTOR:factor"
|
||||
input_side_packet: "FACTOR:factor_side"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<int> outputs;
|
||||
mediapipe::CalculatorGraph graph;
|
||||
auto status = graph.Initialize(config, {});
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
|
||||
}
|
||||
|
||||
TEST(NodeTest, OneOf) {
|
||||
CalculatorGraphConfig config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "in"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "ToFloat"
|
||||
input_stream: "IN:in"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)");
|
||||
std::vector<mediapipe::Packet> out_packets;
|
||||
tool::AddVectorSink("out", &config, &out_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", mediapipe::MakePacket<float>(5.0).At(Timestamp(2))));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(10, 5.0));
|
||||
}
|
||||
|
||||
struct DropEvenTimestamps : public Node {
|
||||
static constexpr Input<AnyType> kIn{"IN"};
|
||||
static constexpr Output<SameType<kIn>> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (cc->InputTimestamp().Value() % 2) {
|
||||
kOut(cc).Send(kIn(cc));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(DropEvenTimestamps);
|
||||
|
||||
struct ListIntPackets : public Node {
|
||||
static constexpr Input<int>::Multiple kIn{"INT"};
|
||||
static constexpr Output<std::string> kOut{"STR"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
std::string result = absl::StrCat(cc->InputTimestamp().DebugString(), ":");
|
||||
for (int i = 0; i < kIn(cc).Count(); ++i) {
|
||||
if (kIn(cc)[i].IsEmpty()) {
|
||||
absl::StrAppend(&result, " empty");
|
||||
} else {
|
||||
absl::StrAppend(&result, " ", *kIn(cc)[i]);
|
||||
}
|
||||
}
|
||||
kOut(cc).Send(std::move(result));
|
||||
return {};
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ListIntPackets);
|
||||
|
||||
TEST(NodeTest, DefaultTimestampChange0) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "a"
|
||||
input_stream: "b"
|
||||
output_stream: "out"
|
||||
node {
|
||||
calculator: "DropEvenTimestamps"
|
||||
input_stream: "IN:a"
|
||||
output_stream: "OUT:a2"
|
||||
}
|
||||
node {
|
||||
calculator: "IntForwarder"
|
||||
input_stream: "IN:a2"
|
||||
output_stream: "OUT:a3"
|
||||
}
|
||||
node {
|
||||
calculator: "ListIntPackets"
|
||||
input_stream: "INT:0:a3"
|
||||
input_stream: "INT:1:b"
|
||||
output_stream: "STR:out"
|
||||
}
|
||||
)");
|
||||
std::vector<mediapipe::Packet> out_packets;
|
||||
tool::AddVectorSink("out", &config, &out_packets);
|
||||
mediapipe::CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(config, {}));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"a", mediapipe::MakePacket<int>(10).At(Timestamp(2))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"b", mediapipe::MakePacket<int>(10).At(Timestamp(2))));
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
// The packet sent to a should have been dropped, but the timestamp bound
|
||||
// should be forwarded by IntForwarder, and ListIntPackets should have run.
|
||||
EXPECT_THAT(PacketValues<std::string>(out_packets),
|
||||
testing::ElementsAre("2: empty 10"));
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
15
mediapipe/framework/api2/packet.cc
Normal file
15
mediapipe/framework/api2/packet.cc
Normal file
|
@ -0,0 +1,15 @@
|
|||
#include "mediapipe/framework/api2/packet.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
PacketBase FromOldPacket(const mediapipe::Packet& op) {
|
||||
return PacketBase(packet_internal::GetHolderShared(op)).At(op.Timestamp());
|
||||
}
|
||||
|
||||
mediapipe::Packet ToOldPacket(const PacketBase& p) {
|
||||
return mediapipe::packet_internal::Create(p.payload_, p.timestamp_);
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
353
mediapipe/framework/api2/packet.h
Normal file
353
mediapipe/framework/api2/packet.h
Normal file
|
@ -0,0 +1,353 @@
|
|||
// This file defines a typed Packet type. It fully interoperates with the older
|
||||
// mediapipe::Packet; creating an api::Packet<T> that refers to an existing
|
||||
// Packet (or vice versa) is cheap, just like copying a Packet. Ownership of
|
||||
// the payload is shared. Consider this as a typed view into the same data.
|
||||
//
|
||||
// Conversion is currently done explicitly with the FromOldPacket and
|
||||
// ToOldPacket functions, but calculator code does not need to concern itself
|
||||
// with it.
|
||||
|
||||
#ifndef MEDIAPIPE_FRAMEWORK_API2_PACKET_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_PACKET_H_
|
||||
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mediapipe/framework/api2/tuple.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using Timestamp = mediapipe::Timestamp;
|
||||
using HolderBase = mediapipe::packet_internal::HolderBase;
|
||||
|
||||
template <typename T>
|
||||
class Packet;
|
||||
|
||||
// Type-erased packet.
|
||||
class PacketBase {
|
||||
public:
|
||||
// Empty.
|
||||
PacketBase() = default;
|
||||
// Copy.
|
||||
PacketBase(const PacketBase&) = default;
|
||||
PacketBase& operator=(const PacketBase&) = default;
|
||||
// Move.
|
||||
PacketBase(PacketBase&&) = default;
|
||||
PacketBase& operator=(PacketBase&&) = default;
|
||||
|
||||
// Get timestamp.
|
||||
Timestamp timestamp() const { return timestamp_; }
|
||||
// The original API has a Timestamp method, but it shadows the Timestamp
|
||||
// type within this class, which is annoying.
|
||||
// Timestamp Timestamp() const { return timestamp_; }
|
||||
|
||||
PacketBase At(Timestamp timestamp) const&;
|
||||
PacketBase At(Timestamp timestamp) &&;
|
||||
|
||||
bool IsEmpty() const { return payload_ == nullptr; }
|
||||
|
||||
template <typename T>
|
||||
Packet<T> As() const;
|
||||
|
||||
// Returns the reference to the object of type T if it contains
|
||||
// one, crashes otherwise.
|
||||
template <typename T>
|
||||
const T& Get() const;
|
||||
|
||||
// Conversion to old Packet type.
|
||||
operator mediapipe::Packet() const { return ToOldPacket(*this); }
|
||||
|
||||
protected:
|
||||
explicit PacketBase(std::shared_ptr<HolderBase> payload)
|
||||
: payload_(std::move(payload)) {}
|
||||
|
||||
std::shared_ptr<HolderBase> payload_;
|
||||
Timestamp timestamp_;
|
||||
|
||||
template <typename T>
|
||||
friend PacketBase PacketBaseAdopting(const T* ptr);
|
||||
friend PacketBase FromOldPacket(const mediapipe::Packet& op);
|
||||
friend mediapipe::Packet ToOldPacket(const PacketBase& p);
|
||||
};
|
||||
|
||||
PacketBase FromOldPacket(const mediapipe::Packet& op);
|
||||
mediapipe::Packet ToOldPacket(const PacketBase& p);
|
||||
|
||||
template <typename T>
|
||||
inline const T& PacketBase::Get() const {
|
||||
CHECK(payload_);
|
||||
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
|
||||
CHECK(typed_payload) << absl::StrCat(
|
||||
"The Packet stores \"", payload_->DebugTypeName(), "\", but \"",
|
||||
MediaPipeTypeStringOrDemangled<T>(), "\" was requested.");
|
||||
return typed_payload->data();
|
||||
}
|
||||
|
||||
// This is used to indicate that the packet could be holding one of a set of
|
||||
// types, e.g. Packet<OneOf<A, B>>.
|
||||
//
|
||||
// A Packet<OneOf<T...>> has an interface similar to std::variant<T...>.
|
||||
// However, we cannot use std::variant directly, since it requires that the
|
||||
// contained object be stored in place within the variant.
|
||||
// Suppose we have a stream that accepts an Image or an ImageFrame, and it
|
||||
// receives a Packet<ImageFrame>. To present it as a
|
||||
// std::variant<Image, ImageFrame> we would have to move the ImageFrame into
|
||||
// the variant (or copy it), but that is not compatible with Packet's existing
|
||||
// ownership model.
|
||||
// We could have Get() return a std::variant<std::reference_wrapper<Image>,
|
||||
// std::reference_wrapper<ImageFrame>>, but that would just make user code more
|
||||
// convoluted.
|
||||
//
|
||||
// TODO: should we just use Packet<T...>?
|
||||
template <class... T>
|
||||
struct OneOf {};
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <class T>
|
||||
inline void CheckCompatibleType(const HolderBase& holder, internal::Wrap<T>) {
|
||||
const packet_internal::Holder<T>* typed_payload = holder.As<T>();
|
||||
CHECK(typed_payload) << absl::StrCat(
|
||||
"The Packet stores \"", holder.DebugTypeName(), "\", but \"",
|
||||
MediaPipeTypeStringOrDemangled<T>(), "\" was requested.");
|
||||
// CHECK(payload_->has_type<T>());
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
inline void CheckCompatibleType(const HolderBase& holder,
|
||||
internal::Wrap<OneOf<T...>>) {
|
||||
bool compatible = (holder.As<T>() || ...);
|
||||
CHECK(compatible)
|
||||
<< "The Packet stores \"" << holder.DebugTypeName() << "\", but one of "
|
||||
<< absl::StrJoin(
|
||||
{absl::StrCat("\"", MediaPipeTypeStringOrDemangled<T>(), "\"")...},
|
||||
", ")
|
||||
<< " was requested.";
|
||||
}
|
||||
|
||||
struct Generic {
|
||||
Generic() = delete;
|
||||
};
|
||||
|
||||
}; // namespace internal
|
||||
|
||||
template <typename T>
|
||||
inline Packet<T> PacketBase::As() const {
|
||||
if (!payload_) return Packet<T>().At(timestamp_);
|
||||
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
|
||||
internal::CheckCompatibleType(*payload_, internal::Wrap<T>{});
|
||||
return Packet<T>(payload_).At(timestamp_);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Packet<internal::Generic> PacketBase::As<internal::Generic>() const;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
class Packet;
|
||||
#if __cplusplus >= 201703L
|
||||
// Deduction guide to silence -Wctad-maybe-unsupported.
|
||||
explicit Packet()->Packet<internal::Generic>;
|
||||
#endif // C++17
|
||||
|
||||
template <>
|
||||
class Packet<internal::Generic> : public PacketBase {
|
||||
public:
|
||||
Packet() = default;
|
||||
|
||||
Packet<internal::Generic> At(Timestamp timestamp) const&;
|
||||
Packet<internal::Generic> At(Timestamp timestamp) &&;
|
||||
|
||||
protected:
|
||||
explicit Packet(std::shared_ptr<HolderBase> payload)
|
||||
: PacketBase(std::move(payload)) {}
|
||||
|
||||
friend PacketBase;
|
||||
};
|
||||
|
||||
// Having Packet<T> subclass Packet<Generic> will require hiding some methods
|
||||
// like As. May be better not to subclass, and allow implicit conversion
|
||||
// instead.
|
||||
template <typename T>
|
||||
class Packet : public Packet<internal::Generic> {
|
||||
public:
|
||||
Packet() = default;
|
||||
|
||||
Packet<T> At(Timestamp timestamp) const&;
|
||||
Packet<T> At(Timestamp timestamp) &&;
|
||||
|
||||
const T& Get() const {
|
||||
CHECK(payload_);
|
||||
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
|
||||
CHECK(typed_payload);
|
||||
return typed_payload->data();
|
||||
}
|
||||
const T& operator*() const { return Get(); }
|
||||
|
||||
template <typename U>
|
||||
T GetOr(U&& v) const {
|
||||
return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this;
|
||||
}
|
||||
|
||||
private:
|
||||
explicit Packet(std::shared_ptr<HolderBase> payload)
|
||||
: Packet<internal::Generic>(std::move(payload)) {}
|
||||
|
||||
friend PacketBase;
|
||||
template <typename U, typename... Args>
|
||||
friend Packet<U> MakePacket(Args&&... args);
|
||||
template <typename U>
|
||||
friend Packet<U> PacketAdopting(const U* ptr);
|
||||
template <typename U>
|
||||
friend Packet<U> PacketAdopting(std::unique_ptr<U> ptr);
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
template <class... F>
|
||||
struct Overload : F... {
|
||||
using F::operator()...;
|
||||
};
|
||||
template <class... F>
|
||||
explicit Overload(F...) -> Overload<F...>;
|
||||
|
||||
template <class T, class... U>
|
||||
struct First {
|
||||
using type = T;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
template <class... T>
|
||||
class Packet<OneOf<T...>> : public PacketBase {
|
||||
public:
|
||||
Packet() = default;
|
||||
|
||||
template <class U>
|
||||
using AllowedType = std::enable_if_t<(std::is_same_v<U, T> || ...)>;
|
||||
|
||||
template <class U, class = AllowedType<U>>
|
||||
Packet(const Packet<U>& p) : PacketBase(p) {}
|
||||
template <class U, class = AllowedType<U>>
|
||||
Packet<OneOf<T...>>& operator=(const Packet<U>& p) {
|
||||
PacketBase::operator=(p);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class U, class = AllowedType<U>>
|
||||
Packet(Packet<U>&& p) : PacketBase(std::move(p)) {}
|
||||
template <class U, class = AllowedType<U>>
|
||||
Packet<OneOf<T...>>& operator=(Packet<U>&& p) {
|
||||
PacketBase::operator=(std::move(p));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Packet<OneOf<T...>> At(Timestamp timestamp) const& {
|
||||
return Packet<OneOf<T...>>(*this).At(timestamp);
|
||||
}
|
||||
Packet<OneOf<T...>> At(Timestamp timestamp) && {
|
||||
timestamp_ = timestamp;
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
template <class U, class = AllowedType<U>>
|
||||
const U& Get() const {
|
||||
CHECK(payload_);
|
||||
packet_internal::Holder<U>* typed_payload = payload_->As<U>();
|
||||
CHECK(typed_payload);
|
||||
return typed_payload->data();
|
||||
}
|
||||
|
||||
template <class U, class = AllowedType<U>>
|
||||
bool Has() const {
|
||||
return payload_ && payload_->As<U>();
|
||||
}
|
||||
|
||||
template <class... F>
|
||||
auto Visit(const F&... args) const {
|
||||
CHECK(payload_);
|
||||
auto f = internal::Overload{args...};
|
||||
using FirstT = typename internal::First<T...>::type;
|
||||
using ResultType = absl::result_of_t<decltype(f)(const FirstT&)>;
|
||||
static_assert(
|
||||
(std::is_same_v<ResultType, absl::result_of_t<decltype(f)(const T&)>> &&
|
||||
...),
|
||||
"All visitor overloads must have the same return type");
|
||||
return Invoke<decltype(f), T...>(f);
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Packet(std::shared_ptr<HolderBase> payload)
|
||||
: PacketBase(std::move(payload)) {}
|
||||
|
||||
friend PacketBase;
|
||||
|
||||
private:
|
||||
template <class F, class U>
|
||||
auto Invoke(const F& f) const {
|
||||
return f(Get<U>());
|
||||
}
|
||||
|
||||
template <class F, class U, class V, class... W>
|
||||
auto Invoke(const F& f) const {
|
||||
return Has<U>() ? f(Get<U>()) : Invoke<F, V, W...>(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
inline Packet<internal::Generic> PacketBase::As<internal::Generic>() const {
|
||||
if (!payload_) return Packet<internal::Generic>().At(timestamp_);
|
||||
return Packet<internal::Generic>(payload_).At(timestamp_);
|
||||
}
|
||||
|
||||
inline PacketBase PacketBase::At(Timestamp timestamp) const& {
|
||||
return PacketBase(*this).At(timestamp);
|
||||
}
|
||||
|
||||
inline PacketBase PacketBase::At(Timestamp timestamp) && {
|
||||
timestamp_ = timestamp;
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Packet<T> Packet<T>::At(Timestamp timestamp) const& {
|
||||
return Packet<T>(*this).At(timestamp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Packet<T> Packet<T>::At(Timestamp timestamp) && {
|
||||
timestamp_ = timestamp;
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
inline Packet<internal::Generic> Packet<internal::Generic>::At(
|
||||
Timestamp timestamp) const& {
|
||||
return Packet<internal::Generic>(*this).At(timestamp);
|
||||
}
|
||||
|
||||
inline Packet<internal::Generic> Packet<internal::Generic>::At(
|
||||
Timestamp timestamp) && {
|
||||
timestamp_ = timestamp;
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
Packet<T> MakePacket(Args&&... args) {
|
||||
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(
|
||||
new T(std::forward<Args>(args)...)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Packet<T> PacketAdopting(const T* ptr) {
|
||||
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(ptr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Packet<T> PacketAdopting(std::unique_ptr<T> ptr) {
|
||||
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(ptr.release()));
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_PACKET_H_
|
16
mediapipe/framework/api2/packet_nc.cc
Normal file
16
mediapipe/framework/api2/packet_nc.cc
Normal file
|
@ -0,0 +1,16 @@
|
|||
#include "mediapipe/framework/api2/packet.h"
|
||||
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
#if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE)
|
||||
void AssignWrongPacketType() { Packet<int> p = MakePacket<float>(1.0); }
|
||||
#elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC)
|
||||
void AssignWrongPacketType() {
|
||||
Packet<> p = MakePacket<float>(1.0);
|
||||
Packet<int> p2 = p;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
}; // namespace api2
|
195
mediapipe/framework/api2/packet_test.cc
Normal file
195
mediapipe/framework/api2/packet_test.cc
Normal file
|
@ -0,0 +1,195 @@
|
|||
#include "mediapipe/framework/api2/packet.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
class LiveCheck {
|
||||
public:
|
||||
explicit LiveCheck(bool* alive) : alive_(*alive) { alive_ = true; }
|
||||
~LiveCheck() { alive_ = false; }
|
||||
|
||||
private:
|
||||
bool& alive_;
|
||||
};
|
||||
|
||||
TEST(PacketTest, PacketBaseDefault) {
|
||||
PacketBase p;
|
||||
EXPECT_TRUE(p.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketBaseNonEmpty) {
|
||||
PacketBase p = PacketAdopting(new int(5));
|
||||
EXPECT_FALSE(p.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketBaseRefCount) {
|
||||
bool alive = false;
|
||||
PacketBase p = PacketAdopting(new LiveCheck(&alive));
|
||||
EXPECT_TRUE(alive);
|
||||
PacketBase p2 = p;
|
||||
p = {};
|
||||
EXPECT_TRUE(alive);
|
||||
p2 = {};
|
||||
EXPECT_FALSE(alive);
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketBaseSame) {
|
||||
int* ip = new int(5);
|
||||
PacketBase p = PacketAdopting(ip);
|
||||
PacketBase p2 = p;
|
||||
EXPECT_EQ(&p2.Get<int>(), ip);
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketNonEmpty) {
|
||||
Packet<int> p = MakePacket<int>(5);
|
||||
EXPECT_FALSE(p.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, Get) {
|
||||
Packet<int> p = MakePacket<int>(5);
|
||||
EXPECT_EQ(*p, 5);
|
||||
EXPECT_EQ(p.Get(), 5);
|
||||
}
|
||||
|
||||
TEST(PacketTest, GetOr) {
|
||||
Packet<int> p_0 = MakePacket<int>(0);
|
||||
Packet<int> p_5 = MakePacket<int>(5);
|
||||
Packet<int> p_empty;
|
||||
EXPECT_EQ(p_0.GetOr(1), 0);
|
||||
EXPECT_EQ(p_5.GetOr(1), 5);
|
||||
EXPECT_EQ(p_empty.GetOr(1), 1);
|
||||
}
|
||||
|
||||
// This show how GetOr can be used with a lambda that is only called if the "or"
|
||||
// case is needed. Can be useful when generating the fallback value is
|
||||
// expensive.
|
||||
// We could also add an overload to GetOr for types which are not convertible to
|
||||
// T, but are callable and return T.
|
||||
// TODO: consider adding it to make things easier.
|
||||
template <typename F>
|
||||
struct Lazy {
|
||||
F f;
|
||||
using ValueT = decltype(f());
|
||||
Lazy(F fun) : f(fun) {}
|
||||
operator ValueT() const { return f(); }
|
||||
};
|
||||
template <typename F>
|
||||
Lazy(F f) -> Lazy<F>;
|
||||
|
||||
TEST(PacketTest, GetOrLazy) {
|
||||
int expensive_call_count = 0;
|
||||
auto expensive_string_generation = [&expensive_call_count] {
|
||||
++expensive_call_count;
|
||||
return "an expensive fallback";
|
||||
};
|
||||
|
||||
auto p_hello = MakePacket<std::string>("hello");
|
||||
Packet<std::string> p_empty;
|
||||
|
||||
EXPECT_EQ(p_hello.GetOr(Lazy(expensive_string_generation)), "hello");
|
||||
EXPECT_EQ(expensive_call_count, 0);
|
||||
EXPECT_EQ(p_empty.GetOr(Lazy(expensive_string_generation)),
|
||||
"an expensive fallback");
|
||||
EXPECT_EQ(expensive_call_count, 1);
|
||||
}
|
||||
|
||||
TEST(PacketTest, OneOf) {
|
||||
Packet<OneOf<std::string, int>> p = MakePacket<std::string>("hi");
|
||||
EXPECT_TRUE(p.Has<std::string>());
|
||||
EXPECT_FALSE(p.Has<int>());
|
||||
EXPECT_EQ(p.Get<std::string>(), "hi");
|
||||
std::string out =
|
||||
p.Visit([](std::string s) { return absl::StrCat("string: ", s); },
|
||||
[](int i) { return absl::StrCat("int: ", i); });
|
||||
EXPECT_EQ(out, "string: hi");
|
||||
|
||||
p = MakePacket<int>(2);
|
||||
EXPECT_FALSE(p.Has<std::string>());
|
||||
EXPECT_TRUE(p.Has<int>());
|
||||
EXPECT_EQ(p.Get<int>(), 2);
|
||||
out = p.Visit([](std::string s) { return absl::StrCat("string: ", s); },
|
||||
[](int i) { return absl::StrCat("int: ", i); });
|
||||
EXPECT_EQ(out, "int: 2");
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketRefCount) {
|
||||
bool alive = false;
|
||||
auto p = MakePacket<LiveCheck>(&alive);
|
||||
EXPECT_TRUE(alive);
|
||||
auto p2 = p;
|
||||
p = {};
|
||||
EXPECT_TRUE(alive);
|
||||
p2 = {};
|
||||
EXPECT_FALSE(alive);
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketTimestamp) {
|
||||
auto p = MakePacket<int>(5);
|
||||
EXPECT_EQ(p.timestamp(), Timestamp::Unset());
|
||||
auto p2 = p.At(Timestamp(1));
|
||||
EXPECT_EQ(p.timestamp(), Timestamp::Unset());
|
||||
EXPECT_EQ(p2.timestamp(), Timestamp(1));
|
||||
auto p3 = std::move(p2).At(Timestamp(3));
|
||||
EXPECT_EQ(p3.timestamp(), Timestamp(3));
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketFromGeneric) {
|
||||
Packet<> pb = PacketAdopting(new int(5));
|
||||
Packet<int> p = pb.As<int>();
|
||||
EXPECT_EQ(p.Get(), 5);
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketAdopting) {
|
||||
Packet<float> p = PacketAdopting(new float(1.0));
|
||||
EXPECT_FALSE(p.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketGeneric) {
|
||||
// With C++17, Packet<> could be written simply as Packet.
|
||||
Packet<> p = PacketAdopting(new float(1.0));
|
||||
EXPECT_FALSE(p.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketGenericTimestamp) {
|
||||
Packet<> p = MakePacket<int>(5);
|
||||
EXPECT_EQ(p.timestamp(), mediapipe::Timestamp::Unset());
|
||||
auto p2 = p.At(Timestamp(1));
|
||||
EXPECT_EQ(p.timestamp(), mediapipe::Timestamp::Unset());
|
||||
EXPECT_EQ(p2.timestamp(), Timestamp(1));
|
||||
auto p3 = std::move(p2).At(Timestamp(3));
|
||||
EXPECT_EQ(p3.timestamp(), Timestamp(3));
|
||||
}
|
||||
|
||||
TEST(PacketTest, FromOldPacket) {
|
||||
mediapipe::Packet op = mediapipe::MakePacket<int>(7);
|
||||
Packet<int> p = FromOldPacket(op).As<int>();
|
||||
EXPECT_EQ(p.Get(), 7);
|
||||
}
|
||||
|
||||
TEST(PacketTest, ToOldPacket) {
|
||||
auto p = MakePacket<int>(7);
|
||||
mediapipe::Packet op = ToOldPacket(p);
|
||||
EXPECT_EQ(op.Get<int>(), 7);
|
||||
}
|
||||
|
||||
TEST(PacketTest, OldRefCounting) {
|
||||
bool alive = false;
|
||||
PacketBase p = PacketAdopting(new LiveCheck(&alive));
|
||||
EXPECT_TRUE(alive);
|
||||
mediapipe::Packet op = ToOldPacket(p);
|
||||
p = {};
|
||||
EXPECT_TRUE(alive);
|
||||
PacketBase p2 = FromOldPacket(op);
|
||||
op = {};
|
||||
EXPECT_TRUE(alive);
|
||||
p2 = {};
|
||||
EXPECT_FALSE(alive);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
628
mediapipe/framework/api2/port.h
Normal file
628
mediapipe/framework/api2/port.h
Normal file
|
@ -0,0 +1,628 @@
|
|||
// This file defines an API to define a node's ports in a concise, type-safe
|
||||
// way. Example usage in a node:
|
||||
//
|
||||
// static constexpr Input<int> kBase("IN");
|
||||
// static constexpr Output<float> kOut("OUT");
|
||||
// static constexpr SideInput<float>::Optional kDelta("DELTA");
|
||||
// static constexpr SideOutput<float> kForward("FORWARD");
|
||||
//
|
||||
// Pass a CalculatorContext to a port to access the inputs or outputs in the
|
||||
// context. For example:
|
||||
//
|
||||
// kBase(cc) yields an InputShardAccess<int>
|
||||
// kOut(cc) yields an OutputShardAccess<float>
|
||||
// kDelta(cc) yields an InputSidePacketAccess<float>
|
||||
// kForward(cc) yields an OutputSidePacketAccess<float>
|
||||
|
||||
#ifndef MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/output_side_packet.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// typeid is not constexpr, but a pointer to this is.
|
||||
template <typename T>
|
||||
size_t get_type_hash() {
|
||||
return typeid(T).hash_code();
|
||||
}
|
||||
|
||||
using type_id_fptr = size_t (*)();
|
||||
|
||||
// This is a base class for various types of port. It is not meant to be used
|
||||
// directly by node code.
|
||||
class PortBase {
|
||||
public:
|
||||
constexpr PortBase(std::size_t tag_size, const char* tag,
|
||||
type_id_fptr get_type_id, bool optional, bool multiple)
|
||||
: tag_(tag_size, tag),
|
||||
optional_(optional),
|
||||
multiple_(multiple),
|
||||
type_id_getter_(get_type_id) {}
|
||||
|
||||
bool IsOptional() const { return optional_; }
|
||||
bool IsMultiple() const { return multiple_; }
|
||||
const char* Tag() const { return tag_.data(); }
|
||||
|
||||
size_t type_id() const { return type_id_getter_(); }
|
||||
|
||||
const const_str tag_;
|
||||
const bool optional_;
|
||||
const bool multiple_;
|
||||
|
||||
protected:
|
||||
type_id_fptr type_id_getter_;
|
||||
};
|
||||
|
||||
// These four base classes are used to distinguish between ports of different
|
||||
// kinds. They are not meant to be used directly by node code.
|
||||
class InputBase : public PortBase {
|
||||
using PortBase::PortBase;
|
||||
};
|
||||
class OutputBase : public PortBase {
|
||||
using PortBase::PortBase;
|
||||
};
|
||||
class SideInputBase : public PortBase {
|
||||
using PortBase::PortBase;
|
||||
};
|
||||
class SideOutputBase : public PortBase {
|
||||
using PortBase::PortBase;
|
||||
};
|
||||
|
||||
struct NoneType {
|
||||
private:
|
||||
NoneType() = delete;
|
||||
};
|
||||
|
||||
struct DynamicType {};
|
||||
|
||||
struct AnyType : public DynamicType {};
|
||||
|
||||
template <auto& P>
|
||||
class SameType : public DynamicType {
|
||||
public:
|
||||
static constexpr const decltype(P)& kPort = P;
|
||||
};
|
||||
|
||||
class PacketTypeAccess;
|
||||
class PacketTypeAccessFallback;
|
||||
template <typename T>
|
||||
class InputShardAccess;
|
||||
template <typename T>
|
||||
class OutputShardAccess;
|
||||
template <typename T>
|
||||
class InputSidePacketAccess;
|
||||
template <typename T>
|
||||
class OutputSidePacketAccess;
|
||||
template <typename T>
|
||||
class InputShardOrSideAccess;
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Forward declaration for AddToContract friend.
|
||||
template <typename...>
|
||||
class Contract;
|
||||
|
||||
template <class CC>
|
||||
auto GetCollection(CC* cc, const InputBase& port) -> decltype(cc->Inputs()) {
|
||||
return cc->Inputs();
|
||||
}
|
||||
|
||||
template <class CC>
|
||||
auto GetCollection(CC* cc, const SideInputBase& port)
|
||||
-> decltype(cc->InputSidePackets()) {
|
||||
return cc->InputSidePackets();
|
||||
}
|
||||
|
||||
template <class CC>
|
||||
auto GetCollection(CC* cc, const OutputBase& port) -> decltype(cc->Outputs()) {
|
||||
return cc->Outputs();
|
||||
}
|
||||
|
||||
template <class CC>
|
||||
auto GetCollection(CC* cc, const SideOutputBase& port)
|
||||
-> decltype(cc->OutputSidePackets()) {
|
||||
return cc->OutputSidePackets();
|
||||
}
|
||||
|
||||
template <class Collection>
|
||||
auto GetOrNull(Collection& collection, const std::string& tag, int index)
|
||||
-> decltype(&collection.Get(std::declval<CollectionItemId>())) {
|
||||
CollectionItemId id = collection.GetId(tag, index);
|
||||
return id.IsValid() ? &collection.Get(id) : nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct IsOneOf : std::false_type {};
|
||||
|
||||
template <class... T>
|
||||
struct IsOneOf<OneOf<T...>> : std::true_type {};
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
!std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{},
|
||||
int>::type = 0>
|
||||
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||
pt.Set<T>();
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{},
|
||||
int>::type = 0>
|
||||
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) {
|
||||
pt.SetAny();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void SetType<NoneType>(CalculatorContract* cc, PacketType& pt) {
|
||||
// This is used for header-only streams. Should it be removed?
|
||||
pt.SetNone();
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<IsOneOf<T>{}, int>::type = 0>
|
||||
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||
pt.SetAny();
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
InputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
|
||||
const InputStreamShard* stream) {
|
||||
return InputShardAccess<ValueT>(*cc, stream);
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
OutputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
|
||||
OutputStreamShard* stream) {
|
||||
return OutputShardAccess<ValueT>(*cc, stream);
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
InputSidePacketAccess<ValueT> SinglePortAccess(
|
||||
mediapipe::CalculatorContext* cc, const mediapipe::Packet* packet) {
|
||||
return InputSidePacketAccess<ValueT>(packet);
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
OutputSidePacketAccess<ValueT> SinglePortAccess(
|
||||
mediapipe::CalculatorContext* cc, OutputSidePacket* osp) {
|
||||
return OutputSidePacketAccess<ValueT>(osp);
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
InputShardOrSideAccess<ValueT> SinglePortAccess(
|
||||
mediapipe::CalculatorContext* cc, const InputStreamShard* stream,
|
||||
const mediapipe::Packet* packet) {
|
||||
return InputShardOrSideAccess<ValueT>(*cc, stream, packet);
|
||||
}
|
||||
|
||||
template <typename ValueT>
|
||||
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
|
||||
PacketType* pt);
|
||||
|
||||
template <typename ValueT>
|
||||
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
|
||||
PacketType* pt, bool is_stream);
|
||||
|
||||
template <typename ValueT, typename PortT, class CC>
|
||||
auto AccessPort(std::false_type, const PortT& port, CC* cc) {
|
||||
auto& collection = GetCollection(cc, port);
|
||||
return SinglePortAccess<ValueT>(
|
||||
cc, internal::GetOrNull(collection, port.Tag(), 0));
|
||||
}
|
||||
|
||||
template <typename ValueT, typename X, class CC>
|
||||
class MultiplePortAccess {
|
||||
public:
|
||||
MultiplePortAccess(CC* cc, X* first, int count)
|
||||
: cc_(cc), first_(first), count_(count) {}
|
||||
|
||||
// TODO: maybe this should be size(), like in a standard C++
|
||||
// container?
|
||||
int Count() { return count_; }
|
||||
auto operator[](int pos) {
|
||||
CHECK_GE(pos, 0);
|
||||
CHECK_LT(pos, count_);
|
||||
return SinglePortAccess<ValueT>(cc_, &first_[pos]);
|
||||
}
|
||||
|
||||
// TODO: add begin/end.
|
||||
|
||||
private:
|
||||
CC* cc_;
|
||||
X* first_;
|
||||
int count_;
|
||||
};
|
||||
|
||||
template <typename ValueT, typename PortT, class CC>
|
||||
auto AccessPort(std::true_type, const PortT& port, CC* cc) {
|
||||
auto& collection = GetCollection(cc, port);
|
||||
auto* first = internal::GetOrNull(collection, port.Tag(), 0);
|
||||
using EntryT = typename std::remove_pointer<decltype(first)>::type;
|
||||
return MultiplePortAccess<ValueT, EntryT, CC>(
|
||||
cc, first, collection.NumEntries(port.Tag()));
|
||||
}
|
||||
|
||||
template <class Base>
|
||||
struct SideBase;
|
||||
|
||||
template <>
|
||||
struct SideBase<InputBase> {
|
||||
using type = SideInputBase;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
|
||||
template <typename T, typename std::enable_if<
|
||||
!std::is_base_of<DynamicType, T>{}, int>::type = 0>
|
||||
auto ActualValueT(T) -> T;
|
||||
|
||||
auto ActualValueT(DynamicType) -> internal::Generic;
|
||||
|
||||
template <typename Base, typename ValueT, bool IsOptional = false,
|
||||
bool IsMultiple = false>
|
||||
class SideFallbackT;
|
||||
|
||||
// This template is used to define a port. Nodes should use it through one
|
||||
// of the aliases below (Input, Output, SideInput, SideOutput).
|
||||
template <typename Base, typename ValueT, bool IsOptionalV = false,
|
||||
bool IsMultipleV = false>
|
||||
class PortCommon : public Base {
|
||||
public:
|
||||
using value_t = ValueT;
|
||||
static constexpr bool kOptional = IsOptionalV;
|
||||
static constexpr bool kMultiple = IsMultipleV;
|
||||
|
||||
using Optional = PortCommon<Base, ValueT, true, IsMultipleV>;
|
||||
using Multiple = PortCommon<Base, ValueT, IsOptionalV, true>;
|
||||
using SideFallback = SideFallbackT<Base, ValueT, IsOptionalV, IsMultipleV>;
|
||||
|
||||
template <std::size_t N>
|
||||
explicit constexpr PortCommon(const char (&tag)[N])
|
||||
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV) {}
|
||||
|
||||
using PayloadT = decltype(ActualValueT(std::declval<ValueT>()));
|
||||
|
||||
auto operator()(CalculatorContext* cc) const {
|
||||
return internal::AccessPort<PayloadT>(
|
||||
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
|
||||
}
|
||||
|
||||
auto operator()(CalculatorContract* cc) const {
|
||||
return internal::AccessPort<PayloadT>(
|
||||
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
|
||||
}
|
||||
|
||||
private:
|
||||
mediapipe::Status AddToContract(CalculatorContract* cc) const {
|
||||
if (kMultiple) {
|
||||
AddMultiple(cc);
|
||||
} else {
|
||||
auto& pt = internal::GetCollection(cc, *this).Tag(this->Tag());
|
||||
internal::SetType<value_t>(cc, pt);
|
||||
if (kOptional) {
|
||||
pt.Optional();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void AddMultiple(CalculatorContract* cc) const {
|
||||
auto& collection = internal::GetCollection(cc, *this);
|
||||
int count = collection.NumEntries(this->Tag());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
internal::SetType<value_t>(cc, collection.Get(this->Tag(), i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename...>
|
||||
friend class internal::Contract;
|
||||
template <typename B, typename VT, bool, bool>
|
||||
friend class mediapipe::api2::SideFallbackT;
|
||||
};
|
||||
|
||||
// Use one of these templates to define a port in node code.
|
||||
template <typename T = internal::Generic>
|
||||
using Input = PortCommon<InputBase, T>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using Output = PortCommon<OutputBase, T>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using SideInput = PortCommon<SideInputBase, T>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using SideOutput = PortCommon<SideOutputBase, T>;
|
||||
|
||||
template <typename Base, typename ValueT, bool IsOptionalV, bool IsMultipleV>
|
||||
class SideFallbackT : public Base {
|
||||
public:
|
||||
using value_t = ValueT;
|
||||
static constexpr bool kOptional = IsOptionalV;
|
||||
static constexpr bool kMultiple = IsMultipleV;
|
||||
using Optional = SideFallbackT<Base, ValueT, true, IsMultipleV>;
|
||||
using PayloadT = decltype(ActualValueT(std::declval<ValueT>()));
|
||||
|
||||
const char* Tag() const { return stream_port.Tag(); }
|
||||
|
||||
auto operator()(CalculatorContract* cc) const {
|
||||
bool is_stream = true;
|
||||
auto& stream_collection = internal::GetCollection(cc, stream_port);
|
||||
auto* packet_type = internal::GetOrNull(stream_collection, Tag(), 0);
|
||||
if (packet_type == nullptr) {
|
||||
auto& side_collection = internal::GetCollection(cc, side_port);
|
||||
packet_type = internal::GetOrNull(side_collection, Tag(), 0);
|
||||
is_stream = false;
|
||||
}
|
||||
return internal::SinglePortAccess<PayloadT>(cc, packet_type, is_stream);
|
||||
}
|
||||
|
||||
auto operator()(CalculatorContext* cc) const {
|
||||
auto& stream_collection = internal::GetCollection(cc, stream_port);
|
||||
auto& side_collection = internal::GetCollection(cc, side_port);
|
||||
return internal::SinglePortAccess<PayloadT>(
|
||||
cc, internal::GetOrNull(stream_collection, Tag(), 0),
|
||||
internal::GetOrNull(side_collection, Tag(), 0));
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
explicit constexpr SideFallbackT(const char (&tag)[N])
|
||||
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV),
|
||||
stream_port(tag),
|
||||
side_port(tag) {}
|
||||
|
||||
protected:
|
||||
mediapipe::Status AddToContract(CalculatorContract* cc) const {
|
||||
stream_port.AddToContract(cc);
|
||||
side_port.AddToContract(cc);
|
||||
int connected_count =
|
||||
stream_port(cc).IsConnected() + side_port(cc).IsConnected();
|
||||
if (connected_count > 1)
|
||||
return mediapipe::InvalidArgumentError(absl::StrCat(
|
||||
Tag(),
|
||||
" can be connected as a stream or as a side packet, but not both"));
|
||||
if (!IsOptionalV && connected_count == 0)
|
||||
return mediapipe::InvalidArgumentError(
|
||||
absl::StrCat(Tag(), " must be connected"));
|
||||
return {};
|
||||
}
|
||||
|
||||
using StreamPort = PortCommon<Base, ValueT, true, IsMultipleV>;
|
||||
using SidePort = PortCommon<typename internal::SideBase<Base>::type, ValueT,
|
||||
true, IsMultipleV>;
|
||||
StreamPort stream_port;
|
||||
SidePort side_port;
|
||||
|
||||
template <typename...>
|
||||
friend class internal::Contract;
|
||||
};
|
||||
|
||||
// An OutputShardAccess is returned when accessing an output stream within a
|
||||
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
|
||||
// OutputStreamShard. Like that class, this class will not be usually named in
|
||||
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
|
||||
class OutputShardAccessBase {
|
||||
public:
|
||||
OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)
|
||||
: context_(cc), output_(output) {}
|
||||
|
||||
void SetNextTimestampBound(Timestamp timestamp) {
|
||||
if (output_) output_->SetNextTimestampBound(timestamp);
|
||||
}
|
||||
|
||||
bool IsClosed() { return output_ ? output_->IsClosed() : true; }
|
||||
void Close() {
|
||||
if (output_) output_->Close();
|
||||
}
|
||||
|
||||
bool IsConnected() { return output_ != nullptr; }
|
||||
|
||||
protected:
|
||||
const CalculatorContext& context_;
|
||||
OutputStreamShard* output_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class OutputShardAccess : public OutputShardAccessBase {
|
||||
public:
|
||||
void Send(Packet<T>&& packet) {
|
||||
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
|
||||
}
|
||||
|
||||
void Send(const Packet<T>& packet) {
|
||||
if (output_) output_->AddPacket(ToOldPacket(packet));
|
||||
}
|
||||
|
||||
void Send(const T& payload, Timestamp time) {
|
||||
Send(api2::MakePacket<T>(payload).At(time));
|
||||
}
|
||||
|
||||
void Send(const T& payload) { Send(payload, context_.InputTimestamp()); }
|
||||
|
||||
void Send(std::unique_ptr<T> payload, Timestamp time) {
|
||||
Send(api2::PacketAdopting(std::move(payload)).At(time));
|
||||
}
|
||||
|
||||
void Send(std::unique_ptr<T> payload) {
|
||||
Send(std::move(payload), context_.InputTimestamp());
|
||||
}
|
||||
|
||||
private:
|
||||
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
|
||||
: OutputShardAccessBase(cc, output) {}
|
||||
|
||||
friend OutputShardAccess<T> internal::SinglePortAccess<T>(
|
||||
mediapipe::CalculatorContext*, OutputStreamShard*);
|
||||
};
|
||||
|
||||
template <>
|
||||
class OutputShardAccess<internal::Generic> : public OutputShardAccessBase {
|
||||
public:
|
||||
void Send(PacketBase&& packet) {
|
||||
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
|
||||
}
|
||||
|
||||
void Send(const PacketBase& packet) {
|
||||
if (output_) output_->AddPacket(ToOldPacket(packet));
|
||||
}
|
||||
|
||||
void SetHeader(const PacketBase& header) {
|
||||
if (output_) output_->SetHeader(ToOldPacket(header));
|
||||
}
|
||||
|
||||
private:
|
||||
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
|
||||
: OutputShardAccessBase(cc, output) {}
|
||||
|
||||
friend OutputShardAccess<internal::Generic>
|
||||
internal::SinglePortAccess<internal::Generic>(mediapipe::CalculatorContext*,
|
||||
OutputStreamShard*);
|
||||
};
|
||||
|
||||
// Equivalent of OutputShardAccess, but for side packets.
|
||||
template <typename T>
|
||||
class OutputSidePacketAccess {
|
||||
public:
|
||||
void Set(Packet<T> packet) {
|
||||
if (output_) output_->Set(ToOldPacket(std::move(packet)));
|
||||
}
|
||||
|
||||
void Set(const T& payload) { Set(MakePacket<T>(payload)); }
|
||||
|
||||
private:
|
||||
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}
|
||||
OutputSidePacket* output_;
|
||||
|
||||
friend OutputSidePacketAccess<T> internal::SinglePortAccess<T>(
|
||||
mediapipe::CalculatorContext*, OutputSidePacket*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InputShardAccess : public Packet<T> {
|
||||
public:
|
||||
const PacketBase& packet() const& { return *this; }
|
||||
// Since InputShardAccess is currently created as a temporary, this avoids
|
||||
// easy mistakes with dangling references.
|
||||
PacketBase packet() const&& { return *this; }
|
||||
|
||||
bool IsDone() const { return stream_->IsDone(); }
|
||||
bool IsConnected() { return stream_ != nullptr; }
|
||||
|
||||
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
|
||||
|
||||
private:
|
||||
InputShardAccess(const CalculatorContext&, const InputStreamShard* stream)
|
||||
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
|
||||
: Packet<T>()),
|
||||
stream_(stream) {}
|
||||
const InputStreamShard* stream_;
|
||||
|
||||
friend InputShardAccess<T> internal::SinglePortAccess<T>(
|
||||
mediapipe::CalculatorContext*, const InputStreamShard*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InputSidePacketAccess : public Packet<T> {
|
||||
public:
|
||||
const PacketBase& packet() const& { return *this; }
|
||||
PacketBase packet() const&& { return *this; }
|
||||
|
||||
bool IsConnected() { return connected_; }
|
||||
|
||||
private:
|
||||
InputSidePacketAccess(const mediapipe::Packet* packet)
|
||||
: Packet<T>(packet ? FromOldPacket(*packet).template As<T>()
|
||||
: Packet<T>()),
|
||||
connected_(packet != nullptr) {}
|
||||
bool connected_;
|
||||
|
||||
friend InputSidePacketAccess<T> internal::SinglePortAccess<T>(
|
||||
mediapipe::CalculatorContext*, const mediapipe::Packet*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InputShardOrSideAccess : public Packet<T> {
|
||||
public:
|
||||
const PacketBase& packet() const& { return *this; }
|
||||
PacketBase packet() const&& { return *this; }
|
||||
|
||||
bool IsDone() const { return stream_->IsDone(); }
|
||||
bool IsConnected() { return connected_; }
|
||||
bool IsStream() { return stream_ != nullptr; }
|
||||
|
||||
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
|
||||
|
||||
private:
|
||||
InputShardOrSideAccess(const CalculatorContext&,
|
||||
const InputStreamShard* stream,
|
||||
const mediapipe::Packet* packet)
|
||||
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
|
||||
: packet ? FromOldPacket(*packet).template As<T>()
|
||||
: Packet<T>()),
|
||||
stream_(stream),
|
||||
connected_(stream_ != nullptr || packet != nullptr) {}
|
||||
const InputStreamShard* stream_;
|
||||
bool connected_;
|
||||
|
||||
friend InputShardOrSideAccess<T> internal::SinglePortAccess<T>(
|
||||
mediapipe::CalculatorContext*, const InputStreamShard*,
|
||||
const mediapipe::Packet*);
|
||||
};
|
||||
|
||||
class PacketTypeAccess {
|
||||
public:
|
||||
bool IsConnected() { return packet_type_ != nullptr; }
|
||||
|
||||
protected:
|
||||
PacketTypeAccess(PacketType* pt) : packet_type_(pt) {}
|
||||
PacketType* packet_type_;
|
||||
|
||||
template <typename T>
|
||||
friend PacketTypeAccess internal::SinglePortAccess(
|
||||
mediapipe::CalculatorContract*, PacketType*);
|
||||
};
|
||||
|
||||
class PacketTypeAccessFallback : public PacketTypeAccess {
|
||||
public:
|
||||
bool IsStream() { return is_stream_; }
|
||||
|
||||
private:
|
||||
PacketTypeAccessFallback(PacketType* pt, bool is_stream)
|
||||
: PacketTypeAccess(pt), is_stream_(is_stream) {}
|
||||
bool is_stream_;
|
||||
|
||||
template <typename T>
|
||||
friend PacketTypeAccessFallback internal::SinglePortAccess(
|
||||
mediapipe::CalculatorContract*, PacketType*, bool);
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
template <typename ValueT>
|
||||
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
|
||||
PacketType* pt) {
|
||||
return PacketTypeAccess(pt);
|
||||
}
|
||||
template <typename ValueT>
|
||||
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
|
||||
PacketType* pt, bool is_stream) {
|
||||
return PacketTypeAccessFallback(pt, is_stream);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|
26
mediapipe/framework/api2/port_test.cc
Normal file
26
mediapipe/framework/api2/port_test.cc
Normal file
|
@ -0,0 +1,26 @@
|
|||
#include "mediapipe/framework/api2/port.h"
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
TEST(PortTest, IntInput) {
|
||||
static constexpr auto port = Input<int>("FOO");
|
||||
EXPECT_EQ(port.type_id(), typeid(int).hash_code());
|
||||
}
|
||||
|
||||
TEST(PortTest, OptionalInput) {
|
||||
static constexpr auto port = Input<float>::Optional("BAR");
|
||||
EXPECT_TRUE(port.IsOptional());
|
||||
}
|
||||
|
||||
TEST(PortTest, Tag) {
|
||||
static constexpr auto port = Input<int>("FOO");
|
||||
EXPECT_EQ(std::string(port.Tag()), "FOO");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
157
mediapipe/framework/api2/subgraph_test.cc
Normal file
157
mediapipe/framework/api2/subgraph_test.cc
Normal file
|
@ -0,0 +1,157 @@
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/api2/test_contracts.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/message_matchers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/subgraph_expansion.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace test {
|
||||
|
||||
class FooBarImpl1 : public SubgraphImpl<FooBar1, FooBarImpl1> {
|
||||
public:
|
||||
mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
const SubgraphOptions& /*options*/) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
graph.In(kIn) >> foo.In("BASE");
|
||||
foo.Out("OUT") >> bar.In("IN");
|
||||
bar.Out("OUT") >> graph.Out(kOut);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
class FooBarImpl2 : public SubgraphImpl<FooBar2, FooBarImpl2> {
|
||||
public:
|
||||
mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
const SubgraphOptions& /*options*/) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode<Foo>();
|
||||
auto& bar = graph.AddNode<Bar>();
|
||||
graph.In(kIn) >> foo.In(MPP_TAG("BASE"));
|
||||
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
|
||||
bar.Out(MPP_TAG("OUT")) >> graph.Out(kOut);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(SubgraphTest, SubgraphConfig) {
|
||||
CalculatorGraphConfig subgraph = FooBarImpl1().GetConfig({}).ValueOrDie();
|
||||
const CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:__stream_2"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:__stream_0"
|
||||
output_stream: "OUT:__stream_1"
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_1"
|
||||
output_stream: "OUT:__stream_2"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(subgraph, EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
TEST(SubgraphTest, TypedSubgraphConfig) {
|
||||
CalculatorGraphConfig subgraph = FooBarImpl2().GetConfig({}).ValueOrDie();
|
||||
const CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:__stream_2"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:__stream_0"
|
||||
output_stream: "OUT:__stream_1"
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_1"
|
||||
output_stream: "OUT:__stream_2"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(subgraph, EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
TEST(SubgraphTest, ProtoApiConfig) {
|
||||
CalculatorGraphConfig graph;
|
||||
graph.add_input_stream("IN:__stream_0");
|
||||
graph.add_output_stream("OUT:__stream_2");
|
||||
auto* foo = graph.add_node();
|
||||
foo->set_calculator("Foo");
|
||||
foo->add_input_stream("BASE:__stream_0");
|
||||
foo->add_output_stream("OUT:__stream_1");
|
||||
auto* bar = graph.add_node();
|
||||
bar->set_calculator("Bar");
|
||||
bar->add_input_stream("IN:__stream_1");
|
||||
bar->add_output_stream("OUT:__stream_2");
|
||||
|
||||
const CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:__stream_2"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:__stream_0"
|
||||
output_stream: "OUT:__stream_1"
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_1"
|
||||
output_stream: "OUT:__stream_2"
|
||||
}
|
||||
)");
|
||||
EXPECT_THAT(graph, EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
TEST(SubgraphTest, ExpandSubgraphs) {
|
||||
CalculatorGraphConfig supergraph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
node {
|
||||
name: "simple_source"
|
||||
calculator: "SomeSourceCalculator"
|
||||
output_stream: "foo"
|
||||
}
|
||||
node {
|
||||
calculator: "FooBar"
|
||||
input_stream: "IN:foo"
|
||||
output_stream: "OUT:output"
|
||||
}
|
||||
)");
|
||||
const CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
node {
|
||||
name: "simple_source"
|
||||
calculator: "SomeSourceCalculator"
|
||||
output_stream: "foo"
|
||||
}
|
||||
node {
|
||||
name: "foobar__Foo"
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:foo"
|
||||
output_stream: "OUT:foobar____stream_1"
|
||||
}
|
||||
node {
|
||||
name: "foobar__Bar"
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:foobar____stream_1"
|
||||
output_stream: "OUT:output"
|
||||
}
|
||||
)");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
EXPECT_THAT(supergraph, EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
72
mediapipe/framework/api2/tag.h
Normal file
72
mediapipe/framework/api2/tag.h
Normal file
|
@ -0,0 +1,72 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_TAG_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_TAG_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// This template is used to define a separate type for each tag.
|
||||
// This makes it possible to obtain results of different types depending on
|
||||
// the tag. See MPP_TAG below for usage examples.
|
||||
template <char... C>
|
||||
struct Tag {
|
||||
static constexpr char const kChars[sizeof...(C) + 1] = {C..., '\0'};
|
||||
static constexpr const_str const kStr{kChars};
|
||||
static const std::string str() {
|
||||
return std::string(kStr.data(), kStr.len());
|
||||
}
|
||||
|
||||
template <char... Q>
|
||||
constexpr bool operator==(const Tag<Q...>& other) const {
|
||||
return kStr == other.kStr;
|
||||
}
|
||||
template <char... Q>
|
||||
constexpr bool operator!=(const Tag<Q...>& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
template <char... C>
|
||||
constexpr bool is_tag(Tag<C...>) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
constexpr bool is_tag(A) {
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename S, std::size_t... I>
|
||||
constexpr auto tag_build_impl(S, std::index_sequence<I...>)
|
||||
-> Tag<S().tag[I]...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
constexpr auto tag_build(S) {
|
||||
return tag_build_impl(S(), std::make_index_sequence<S().tag.len()>{});
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Use this to create typed tag objects.
|
||||
// For example:
|
||||
// auto kFOO = MPP_TAG(FOO);
|
||||
// auto kBAR = MPP_TAG(BAR);
|
||||
#define MPP_TAG(s) \
|
||||
([] { \
|
||||
struct S { \
|
||||
const const_str tag{s}; \
|
||||
}; \
|
||||
return ::mediapipe::api2::internal::tag_build(S()); \
|
||||
}())
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_TAG_H_
|
48
mediapipe/framework/api2/tag_test.cc
Normal file
48
mediapipe/framework/api2/tag_test.cc
Normal file
|
@ -0,0 +1,48 @@
|
|||
#include "mediapipe/framework/api2/tag.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
template <typename A, typename B>
|
||||
constexpr bool same_type(A, B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
constexpr bool same_type(A, A) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto kFOO = MPP_TAG("FOO");
|
||||
auto kFOO2 = MPP_TAG("FOO");
|
||||
auto kBAR = MPP_TAG("BAR");
|
||||
|
||||
TEST(TagTest, String) {
|
||||
EXPECT_EQ(kFOO.str(), "FOO");
|
||||
EXPECT_EQ(kBAR.str(), "BAR");
|
||||
}
|
||||
|
||||
// Separate invocations of MPP_TAG with the same std::string produce objects of
|
||||
// the same type.
|
||||
TEST(TagTest, SameType) { EXPECT_TRUE(same_type(kFOO, kFOO2)); }
|
||||
|
||||
// Different tags have different types.
|
||||
TEST(TagTest, DifferentType) { EXPECT_FALSE(same_type(kFOO, kBAR)); }
|
||||
|
||||
TEST(TagTest, Equal) {
|
||||
EXPECT_EQ(kFOO, kFOO2);
|
||||
EXPECT_NE(kFOO, kBAR);
|
||||
}
|
||||
|
||||
TEST(TagTest, IsTag) {
|
||||
EXPECT_TRUE(is_tag(kFOO));
|
||||
EXPECT_FALSE(is_tag("FOO"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
87
mediapipe/framework/api2/test_contracts.h
Normal file
87
mediapipe/framework/api2/test_contracts.h
Normal file
|
@ -0,0 +1,87 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace test {
|
||||
|
||||
struct Foo : public NodeIntf {
|
||||
static constexpr Input<int> kBase{"BASE"};
|
||||
static constexpr Input<float>::Optional kScale{"SCALE"};
|
||||
static constexpr Output<float> kOut{"OUT"};
|
||||
static constexpr SideInput<float>::Optional kBias{"BIAS"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(Foo, kBase, kScale, kOut, kBias);
|
||||
};
|
||||
|
||||
struct Foo2 : public NodeIntf {
|
||||
// clang-format off
|
||||
static constexpr auto kPorts = std::make_tuple(
|
||||
Input<int>{"BASE"},
|
||||
Input<float>::Optional{"SCALE"},
|
||||
Output<float>{"OUT"},
|
||||
SideInput<float>::Optional{"BIAS"}
|
||||
);
|
||||
// clang-format on
|
||||
MEDIAPIPE_NODE_INTERFACE(Foo2, kPorts);
|
||||
};
|
||||
|
||||
struct Bar : public NodeIntf {
|
||||
static constexpr Input<AnyType> kIn{"IN"};
|
||||
// Should all outputs be treated as optional by default?
|
||||
static constexpr Output<SameType<kIn>>::Optional kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(Bar, kIn, kOut);
|
||||
};
|
||||
|
||||
struct Baz : public NodeIntf {
|
||||
static constexpr Input<AnyType>::Multiple kData{"DATA"};
|
||||
// Should all outputs be treated as optional by default?
|
||||
static constexpr Output<SameType<kData>>::Multiple kDataOut{"DATA"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(Baz, kData, kDataOut);
|
||||
};
|
||||
|
||||
struct IntForwarder : public NodeIntf {
|
||||
static constexpr Input<int> kIn{"IN"};
|
||||
static constexpr Output<int> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(IntForwarder, kIn, kOut);
|
||||
};
|
||||
|
||||
struct FloatAdder : public NodeIntf {
|
||||
static constexpr Input<float>::Multiple kIn{"IN"};
|
||||
static constexpr Output<float> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(FloatAdder, kIn, kOut);
|
||||
};
|
||||
|
||||
struct ToFloat : public NodeIntf {
|
||||
static constexpr Input<OneOf<float, int>> kIn{"IN"};
|
||||
static constexpr Output<float> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(ToFloat, kIn, kOut);
|
||||
};
|
||||
|
||||
struct FooBar : public NodeIntf {
|
||||
static constexpr Input<int> kIn{"IN"};
|
||||
static constexpr Output<float> kOut{"OUT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
};
|
||||
|
||||
struct FooBar1 : public FooBar {
|
||||
static constexpr char kCalculatorName[] = "FooBar";
|
||||
};
|
||||
|
||||
struct FooBar2 : public FooBar {
|
||||
static constexpr char kCalculatorName[] = "FooBar2";
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_
|
187
mediapipe/framework/api2/tuple.h
Normal file
187
mediapipe/framework/api2/tuple.h
Normal file
|
@ -0,0 +1,187 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/meta/type_traits.h"
|
||||
|
||||
// This file contains utilities for working with constexpr tuples.
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace internal {
|
||||
|
||||
// Defines a std::index_sequence with indices for each item of the tuple.
|
||||
template <class Tuple>
|
||||
using tuple_index_sequence =
|
||||
std::make_index_sequence<std::tuple_size_v<std::decay_t<Tuple>>>;
|
||||
|
||||
// Concatenates two std::index_sequences.
|
||||
template <std::size_t... I, std::size_t... J>
|
||||
constexpr auto index_sequence_cat(std::index_sequence<I...>,
|
||||
std::index_sequence<J...>)
|
||||
-> std::index_sequence<I..., J...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <std::size_t... I, std::size_t... J, class... Tail>
|
||||
constexpr auto index_sequence_cat(std::index_sequence<I...>,
|
||||
std::index_sequence<J...>, Tail... tail) {
|
||||
return index_sequence_cat(std::index_sequence<I..., J...>(), tail...);
|
||||
}
|
||||
|
||||
template <template <typename...> class Pred, typename Tuple, std::size_t... I>
|
||||
constexpr auto filtered_tuple_indices_impl(Tuple&& t,
|
||||
std::index_sequence<I...>) {
|
||||
return index_sequence_cat(
|
||||
std::conditional_t<
|
||||
Pred<std::tuple_element_t<I, std::decay_t<Tuple>>>::value,
|
||||
std::index_sequence<I>, std::index_sequence<>>{}...);
|
||||
}
|
||||
|
||||
// Returns a std::index_sequence with the indices of the tuple items whose
|
||||
// type satisfied Pred.
|
||||
template <template <typename...> class Pred, typename Tuple>
|
||||
constexpr auto filtered_tuple_indices(Tuple&& tuple) {
|
||||
return filtered_tuple_indices_impl<Pred>(tuple,
|
||||
tuple_index_sequence<Tuple>());
|
||||
}
|
||||
|
||||
// Convenience type to pass any type as a value.
|
||||
template <typename T>
|
||||
struct Wrap {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class F, typename Tuple, std::size_t... I>
|
||||
constexpr auto filtered_tuple_indices_impl(Tuple&& t,
|
||||
std::index_sequence<I...>) {
|
||||
return index_sequence_cat(
|
||||
std::conditional_t<
|
||||
F{}(Wrap<std::tuple_element_t<I, std::decay_t<Tuple>>>{}),
|
||||
std::index_sequence<I>, std::index_sequence<>>{}...);
|
||||
}
|
||||
|
||||
// Returns a std::index_sequence with the indices of the tuple items for which
|
||||
// F{}(Wrap<item_type>) returns true.
|
||||
template <class F, typename Tuple>
|
||||
constexpr auto filtered_tuple_indices(Tuple&& tuple) {
|
||||
return filtered_tuple_indices_impl<F>(std::forward<Tuple>(tuple),
|
||||
tuple_index_sequence<Tuple>());
|
||||
}
|
||||
|
||||
// Returns a tuple of references to the tuple items with the specified indices.
|
||||
template <typename Tuple, std::size_t... I>
|
||||
constexpr auto select_tuple_indices(Tuple&& tuple, std::index_sequence<I...>) {
|
||||
return std::forward_as_tuple(std::get<I>(std::forward<Tuple>(tuple))...);
|
||||
}
|
||||
|
||||
// Returns a tuple of references to the tuple items whose types satisfy Pred.
|
||||
template <template <typename...> class Pred, typename Tuple>
|
||||
constexpr auto filter_tuple(Tuple&& t) {
|
||||
return select_tuple_indices(std::forward<Tuple>(t),
|
||||
filtered_tuple_indices<Pred>(t));
|
||||
}
|
||||
|
||||
// Returns a tuple of references to the tuple items for which
|
||||
// F{}(Wrap<item_type>) returns true.
|
||||
template <typename F, typename Tuple>
|
||||
constexpr auto filter_tuple(Tuple&& t) {
|
||||
return select_tuple_indices(
|
||||
std::forward<Tuple>(t),
|
||||
filtered_tuple_indices<F>(std::forward<Tuple>(t)));
|
||||
}
|
||||
|
||||
// TODO: ensure only one of these is enabled?
|
||||
template <class F, class T, class I>
|
||||
constexpr auto call_with_optional_index(F&& f, T&& t, I i)
|
||||
-> absl::void_t<decltype(f(std::forward<T>(t), i))> {
|
||||
return f(std::forward<T>(t), i);
|
||||
}
|
||||
|
||||
template <class F, class T, class I>
|
||||
constexpr auto call_with_optional_index(F&& f, T&& t, I i)
|
||||
-> absl::void_t<decltype(f(std::forward<T>(t)))> {
|
||||
return f(std::forward<T>(t));
|
||||
}
|
||||
|
||||
template <class F, class Tuple, std::size_t... I>
|
||||
constexpr void tuple_for_each_impl(F&& f, Tuple&& tuple,
|
||||
std::index_sequence<I...>) {
|
||||
int unpack[] = {
|
||||
(call_with_optional_index(std::forward<F>(f),
|
||||
std::get<I>(std::forward<Tuple>(tuple)),
|
||||
std::integral_constant<std::size_t, I>{}),
|
||||
0)...};
|
||||
(void)unpack;
|
||||
}
|
||||
|
||||
// Invokes f for each item in tuple.
|
||||
// If f takes one argument, it will be called as f(item).
|
||||
// If f takes two arguments, it will be called as
|
||||
// f(item, std::integral_constant<std::size_t, index>{}).
|
||||
template <class F, class Tuple>
|
||||
constexpr void tuple_for_each(F&& f, Tuple&& tuple) {
|
||||
return tuple_for_each_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
|
||||
tuple_index_sequence<Tuple>());
|
||||
}
|
||||
|
||||
template <class F, class Tuple, std::size_t... I>
|
||||
constexpr auto map_tuple_impl(F&& f, Tuple&& tuple, std::index_sequence<I...>) {
|
||||
return std::make_tuple(f(std::get<I>(std::forward<Tuple>(tuple)))...);
|
||||
}
|
||||
|
||||
// Returns a tuple where each item is the result of calling f on the
|
||||
// corresponding item of the provided tuple.
|
||||
template <class F, class Tuple>
|
||||
constexpr auto map_tuple(F&& f, Tuple&& tuple) {
|
||||
return map_tuple_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
|
||||
tuple_index_sequence<Tuple>());
|
||||
}
|
||||
|
||||
template <class F, class Tuple, std::size_t... I>
|
||||
constexpr auto tuple_apply_impl(F&& f, Tuple&& tuple,
|
||||
std::index_sequence<I...>) {
|
||||
return f(std::get<I>(std::forward<Tuple>(tuple))...);
|
||||
}
|
||||
|
||||
// Invokes f passing the tuple's items as arguments.
|
||||
template <class F, class Tuple>
|
||||
constexpr auto tuple_apply(F&& f, Tuple&& tuple) {
|
||||
return tuple_apply_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
|
||||
tuple_index_sequence<Tuple>());
|
||||
}
|
||||
|
||||
// Returns the index [0, tuple_size) of the first item for which f returns true,
|
||||
// or tuple_size if no such item is found.
|
||||
template <class F, class Tuple, std::size_t i = 0>
|
||||
constexpr std::enable_if_t<i == std::tuple_size_v<std::decay_t<Tuple>>,
|
||||
std::size_t>
|
||||
tuple_find(F&& f, Tuple&& tuple) {
|
||||
return i;
|
||||
}
|
||||
|
||||
template <class F, class Tuple, std::size_t i = 0>
|
||||
constexpr std::enable_if_t<i != std::tuple_size_v<std::decay_t<Tuple>>,
|
||||
std::size_t>
|
||||
tuple_find(F&& f, Tuple&& tuple) {
|
||||
if (f(std::get<i>(std::forward<Tuple>(tuple)))) {
|
||||
return i;
|
||||
}
|
||||
return tuple_find<F, Tuple, i + 1>(std::forward<F>(f),
|
||||
std::forward<Tuple>(tuple));
|
||||
}
|
||||
|
||||
template <class Tuple>
|
||||
constexpr auto flatten_tuple(Tuple&& tuple) {
|
||||
return tuple_apply([](auto&&... args) { return std::tuple_cat(args...); },
|
||||
tuple);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_
|
147
mediapipe/framework/api2/tuple_test.cc
Normal file
147
mediapipe/framework/api2/tuple_test.cc
Normal file
|
@ -0,0 +1,147 @@
|
|||
#include "mediapipe/framework/api2/tuple.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
template <typename A, typename B>
|
||||
constexpr bool same_type(A, B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
constexpr bool same_type(A, A) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <std::size_t... I>
|
||||
using iseq = std::index_sequence<I...>;
|
||||
|
||||
TEST(TupleTest, IndexSeq) {
|
||||
EXPECT_TRUE(
|
||||
same_type(iseq<0, 1, 2>(), index_sequence_cat(iseq<0, 1>(), iseq<2>())));
|
||||
EXPECT_TRUE(same_type(iseq<0, 1, 2>(),
|
||||
index_sequence_cat(iseq<0, 1>(), iseq<>(), iseq<2>())));
|
||||
}
|
||||
|
||||
TEST(TupleTest, FilteredIndices) {
|
||||
EXPECT_TRUE(same_type(
|
||||
filtered_tuple_indices<std::is_integral>(std::tuple<int, float, char>()),
|
||||
iseq<0, 2>()));
|
||||
}
|
||||
|
||||
TEST(TupleTest, SelectIndices) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
EXPECT_EQ((select_tuple_indices(t, iseq<0, 2>())),
|
||||
(std::make_tuple(5.0, "hi")));
|
||||
}
|
||||
|
||||
TEST(TupleTest, FilterTuple) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
EXPECT_EQ((filter_tuple<std::is_integral>(t)), (std::make_tuple(10)));
|
||||
}
|
||||
|
||||
TEST(TupleTest, FilterTupleRefs) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
auto tr = filter_tuple<std::is_integral>(t);
|
||||
int x;
|
||||
EXPECT_TRUE(same_type(tr, std::tuple<int&>{x}));
|
||||
EXPECT_FALSE(same_type(tr, std::tuple<int>{x}));
|
||||
auto tr_copy =
|
||||
std::apply([](auto&&... item) { return std::make_tuple(item...); },
|
||||
filter_tuple<std::is_integral>(t));
|
||||
EXPECT_TRUE(same_type(tr_copy, std::tuple<int>{x}));
|
||||
}
|
||||
|
||||
struct is_integral {
|
||||
template <class W>
|
||||
constexpr bool operator()(W&&) {
|
||||
return std::is_integral<typename W::type>{};
|
||||
}
|
||||
};
|
||||
|
||||
TEST(TupleTest, FilteredIndices2) {
|
||||
EXPECT_TRUE(same_type(
|
||||
filtered_tuple_indices<is_integral>(std::tuple<int, float, char>()),
|
||||
iseq<0, 2>()));
|
||||
}
|
||||
|
||||
// TEST(TupleTest, FilterTuple2) {
|
||||
// auto t = std::make_tuple(5.0, 10, "hi");
|
||||
// auto is_int = [](auto&& x) {
|
||||
// return std::is_integral_v<decltype(x)>;
|
||||
// };
|
||||
// EXPECT_EQ((filter_tuple(is_int, t)), (std::make_tuple(10)));
|
||||
// }
|
||||
|
||||
TEST(TupleTest, ForEach) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
std::vector<std::string> s;
|
||||
tuple_for_each([&s](auto&& item) { s.push_back(absl::StrCat(item)); }, t);
|
||||
EXPECT_EQ(s, (std::vector<std::string>{"5", "10", "hi"}));
|
||||
}
|
||||
|
||||
TEST(TupleTest, ForEachWithIndex) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
std::vector<std::string> s;
|
||||
tuple_for_each(
|
||||
[&s](auto&& item, std::size_t i) {
|
||||
s.push_back(absl::StrCat(i, ":", item));
|
||||
},
|
||||
t);
|
||||
EXPECT_EQ(s, (std::vector<std::string>{"0:5", "1:10", "2:hi"}));
|
||||
}
|
||||
|
||||
TEST(TupleTest, ForEachZip) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
auto u = std::make_tuple(2.0, 3, "lo");
|
||||
std::vector<std::string> s;
|
||||
tuple_for_each(
|
||||
[&s, &u](auto&& item, auto i_const) {
|
||||
constexpr std::size_t i = decltype(i_const)::value;
|
||||
s.push_back(absl::StrCat(i, ":", item, ",", std::get<i>(u)));
|
||||
},
|
||||
t);
|
||||
EXPECT_EQ(s, (std::vector<std::string>{"0:5,2", "1:10,3", "2:hi,lo"}));
|
||||
}
|
||||
|
||||
TEST(TupleTest, Apply) {
|
||||
auto t = std::make_tuple(5.0, 10, "hi");
|
||||
std::string s = tuple_apply(
|
||||
[](float f, int i, const char* s) { return absl::StrCat(f, i, s); }, t);
|
||||
EXPECT_EQ(s, "510hi");
|
||||
}
|
||||
|
||||
TEST(TupleTest, Map) {
|
||||
auto t = std::make_tuple(5.0, 10, 2L);
|
||||
auto t2 = map_tuple([](auto x) { return x * 2; }, t);
|
||||
EXPECT_EQ(t2, std::make_tuple(10.0, 20, 4L));
|
||||
}
|
||||
|
||||
TEST(TupleFind, Find) {
|
||||
auto t = std::make_tuple(5.0, 10, 2L);
|
||||
auto i = tuple_find([](auto x) { return x > 3; }, t);
|
||||
EXPECT_EQ(i, 0);
|
||||
}
|
||||
|
||||
TEST(TupleFind, Flatten) {
|
||||
auto t1 = std::make_tuple(5.0, 10);
|
||||
auto t2 = std::make_tuple(2L);
|
||||
auto t = std::make_tuple(t1, t2);
|
||||
auto tf = flatten_tuple(t);
|
||||
EXPECT_EQ(tf, std::make_tuple(5.0, 10, 2L));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
136
mediapipe/framework/api2/type_list.h
Normal file
136
mediapipe/framework/api2/type_list.h
Normal file
|
@ -0,0 +1,136 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace types {
|
||||
|
||||
// A list of types. This allows us to store a template parameter pack.
|
||||
template <typename... Args>
|
||||
struct List {};
|
||||
|
||||
// Concatenate two lists.
|
||||
template <typename... As, typename... Bs>
|
||||
auto concat(List<As...>, List<Bs...>) -> List<As..., Bs...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Filter a list using a predicate.
|
||||
template <template <typename> class Pred, typename... Args>
|
||||
auto filter(List<Args...>) -> List<Args...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <template <typename> class Pred, typename Head, typename... Tail>
|
||||
auto filter(List<Head, Tail...>) -> decltype(concat(
|
||||
typename std::conditional<Pred<Head>::value, List<Head>, List<>>::type{},
|
||||
filter<Pred>(List<Tail...>{}))) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename Pred>
|
||||
auto filter(Pred, List<>) -> List<> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename Pred, typename Head, typename... Tail>
|
||||
auto filter(Pred pred, List<Head, Tail...>) -> decltype(concat(
|
||||
typename std::conditional<pred(Head{}), List<Head>, List<>>::type{},
|
||||
filter(pred, List<Tail...>{}))) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Invoke a template using a list's types as parameters.
|
||||
template <template <typename...> class T, typename... Args>
|
||||
auto apply(List<Args...>) -> T<Args...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Wraps a single type. The wrapper can always be instantiated as a value,
|
||||
// even if T cannot.
|
||||
template <typename T>
|
||||
struct Wrap {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
// Find first match for a predicate.
|
||||
template <template <typename> class Pred, typename... Args>
|
||||
auto find(List<Args...>) -> Wrap<void> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <template <typename> class Pred, typename Head, typename... Tail>
|
||||
auto find(List<Head, Tail...>) ->
|
||||
typename std::conditional<Pred<Head>::value, Wrap<Head>,
|
||||
decltype(find<Pred>(List<Tail...>{}))>::type {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class Pred, typename... Args>
|
||||
auto find(Pred, List<Args...>) -> Wrap<void> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class Pred, typename Head, typename... Tail>
|
||||
auto find(Pred pred, List<Head, Tail...>) ->
|
||||
typename std::conditional<pred(Head{}), Wrap<Head>,
|
||||
decltype(find(pred, List<Tail...>{}))>::type {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Apply a function to each item in a list.
|
||||
template <template <typename> class Fun, typename... Items>
|
||||
auto map(List<Items...>) -> List<typename Fun<Items>::type...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get the list's head.
|
||||
template <typename... Args>
|
||||
constexpr auto head(List<Args...>) -> Wrap<void> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename H, typename... T>
|
||||
constexpr auto head(List<H, T...>) -> Wrap<H> {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get the list's length.
|
||||
template <typename... Args>
|
||||
constexpr std::size_t length(List<Args...>) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename H, typename... T>
|
||||
constexpr std::size_t length(List<H, T...>) {
|
||||
return length(List<T...>{}) + 1;
|
||||
}
|
||||
|
||||
// Add indices.
|
||||
template <std::size_t I, typename T>
|
||||
struct IndexedType {
|
||||
static constexpr std::size_t kIndex = I;
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename... Args, std::size_t... Is>
|
||||
auto enumerate_impl(List<Args...>, std::index_sequence<Is...>)
|
||||
-> List<IndexedType<Is, Args>...> {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
auto enumerate(List<Args...> a)
|
||||
-> decltype(enumerate_impl(a, std::index_sequence_for<Args...>{})) {
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace types
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_
|
101
mediapipe/framework/api2/type_list_test.cc
Normal file
101
mediapipe/framework/api2/type_list_test.cc
Normal file
|
@ -0,0 +1,101 @@
|
|||
#include "mediapipe/framework/api2/type_list.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace types {
|
||||
namespace {
|
||||
|
||||
template <typename A, typename B>
|
||||
constexpr bool same_type(A, B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
constexpr bool same_type(A, A) {
|
||||
return true;
|
||||
}
|
||||
|
||||
struct Foo {};
|
||||
struct Bar {};
|
||||
struct Baz {};
|
||||
|
||||
TEST(TypeListFTest, SameType) {
|
||||
EXPECT_FALSE(same_type(List<Foo>{}, List<>{}));
|
||||
EXPECT_TRUE(same_type(List<Foo>{}, List<Foo>{}));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Length) {
|
||||
EXPECT_EQ(length(List<float, int>{}), 2);
|
||||
EXPECT_EQ(length(List<>{}), 0);
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Head) {
|
||||
using Empty = List<>;
|
||||
using ListA = List<Foo, Bar>;
|
||||
EXPECT_TRUE(same_type(Wrap<Foo>{}, head(ListA{})));
|
||||
EXPECT_TRUE(same_type(Wrap<void>{}, head(Empty{})));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Concat) {
|
||||
using Empty = List<>;
|
||||
using ListA = List<Foo>;
|
||||
|
||||
EXPECT_TRUE(same_type(ListA{}, concat(ListA{}, Empty{})));
|
||||
EXPECT_TRUE(same_type(concat(ListA{}, Empty{}), ListA{}));
|
||||
|
||||
using ListB = List<Bar, Baz>;
|
||||
EXPECT_TRUE(same_type(concat(ListA{}, ListB{}), List<Foo, Bar, Baz>{}));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Filter) {
|
||||
EXPECT_TRUE(same_type(filter<std::is_integral>(List<>{}), List<>{}));
|
||||
EXPECT_TRUE(same_type(filter<std::is_integral>(List<int, float, char>{}),
|
||||
List<int, char>{}));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Filter2) {
|
||||
constexpr auto is_integral = [](auto x) {
|
||||
return std::is_integral<decltype(x)>{};
|
||||
};
|
||||
auto x = filter(is_integral, List<>{});
|
||||
EXPECT_TRUE(same_type(x, List<>{}));
|
||||
auto y = filter(is_integral, List<int, float, char>{});
|
||||
EXPECT_TRUE(same_type(y, List<int, char>{}));
|
||||
auto z = filter([](auto x) { return std::is_integral<decltype(x)>{}; },
|
||||
List<int, double>{});
|
||||
EXPECT_TRUE(same_type(z, List<int>{}));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Find) {
|
||||
EXPECT_TRUE(same_type(find<std::is_integral>(List<>{}), Wrap<void>{}));
|
||||
EXPECT_TRUE(
|
||||
same_type(find<std::is_integral>(List<float, int>{}), Wrap<int>()));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Find2) {
|
||||
constexpr auto is_integral = [](auto x) {
|
||||
return std::is_integral<decltype(x)>{};
|
||||
};
|
||||
EXPECT_TRUE(same_type(find(is_integral, List<>{}), Wrap<void>{}));
|
||||
EXPECT_TRUE(same_type(find(is_integral, List<float, int>{}), Wrap<int>()));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Map) {
|
||||
EXPECT_TRUE(
|
||||
same_type(map<std::remove_cv>(List<const int, const float, const char>{}),
|
||||
List<int, float, char>{}));
|
||||
}
|
||||
|
||||
TEST(TypeListFTest, Enumerate) {
|
||||
EXPECT_TRUE(same_type(enumerate(List<int, float, char>{}),
|
||||
List<IndexedType<0, int>, IndexedType<1, float>,
|
||||
IndexedType<2, char>>{}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace types
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/deps/registration.h"
|
||||
|
@ -168,22 +169,23 @@ class CalculatorBase {
|
|||
virtual Timestamp SourceProcessOrder(const CalculatorContext* cc) const;
|
||||
};
|
||||
|
||||
using CalculatorBaseRegistry =
|
||||
GlobalFactoryRegistry<std::unique_ptr<CalculatorBase>>;
|
||||
namespace api2 {
|
||||
class Node;
|
||||
} // namespace api2
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Gives access to the static functions within subclasses of CalculatorBase.
|
||||
// This adds functionality akin to virtual static functions.
|
||||
class StaticAccessToCalculatorBase {
|
||||
class CalculatorBaseFactory {
|
||||
public:
|
||||
virtual ~StaticAccessToCalculatorBase() {}
|
||||
virtual ~CalculatorBaseFactory() {}
|
||||
virtual mediapipe::Status GetContract(CalculatorContract* cc) = 0;
|
||||
virtual std::unique_ptr<CalculatorBase> CreateCalculator(
|
||||
CalculatorContext* calculator_context) = 0;
|
||||
virtual std::string ContractMethodName() { return "GetContract"; }
|
||||
};
|
||||
|
||||
using StaticAccessToCalculatorBaseRegistry =
|
||||
GlobalFactoryRegistry<std::unique_ptr<StaticAccessToCalculatorBase>>;
|
||||
|
||||
// Functions for checking that the calculator has the required GetContract.
|
||||
template <class T>
|
||||
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
|
||||
|
@ -197,14 +199,21 @@ constexpr bool CalculatorHasGetContract(...) {
|
|||
|
||||
// Provides access to the static functions within a specific subclass
|
||||
// of CalculatorBase.
|
||||
template <typename CalculatorBaseSubclass>
|
||||
class StaticAccessToCalculatorBaseTyped : public StaticAccessToCalculatorBase {
|
||||
template <class T, class Enable = void>
|
||||
class CalculatorBaseFactoryFor : public CalculatorBaseFactory {
|
||||
static_assert(std::is_base_of<mediapipe::CalculatorBase, T>::value,
|
||||
"Classes registered with REGISTER_CALCULATOR must be "
|
||||
"subclasses of mediapipe::CalculatorBase.");
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class CalculatorBaseFactoryFor<
|
||||
T,
|
||||
typename std::enable_if<std::is_base_of<mediapipe::CalculatorBase, T>{} &&
|
||||
!std::is_base_of<mediapipe::api2::Node, T>{}>::type>
|
||||
: public CalculatorBaseFactory {
|
||||
public:
|
||||
static_assert(
|
||||
std::is_base_of<mediapipe::CalculatorBase, CalculatorBaseSubclass>::value,
|
||||
"Classes registered with REGISTER_CALCULATOR must be "
|
||||
"subclasses of mediapipe::CalculatorBase.");
|
||||
static_assert(CalculatorHasGetContract<CalculatorBaseSubclass>(nullptr),
|
||||
static_assert(CalculatorHasGetContract<T>(nullptr),
|
||||
"GetContract() must be defined with the correct signature in "
|
||||
"every calculator.");
|
||||
|
||||
|
@ -213,12 +222,20 @@ class StaticAccessToCalculatorBaseTyped : public StaticAccessToCalculatorBase {
|
|||
mediapipe::Status GetContract(CalculatorContract* cc) final {
|
||||
// CalculatorBaseSubclass must implement this function, since it is not
|
||||
// implemented in the parent class.
|
||||
return CalculatorBaseSubclass::GetContract(cc);
|
||||
return T::GetContract(cc);
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorBase> CreateCalculator(
|
||||
CalculatorContext* calculator_context) final {
|
||||
return absl::make_unique<T>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using CalculatorBaseRegistry =
|
||||
GlobalFactoryRegistry<std::unique_ptr<internal::CalculatorBaseFactory>>;
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_
|
||||
|
|
|
@ -128,7 +128,6 @@ TEST(CalculatorTest, SourceProcessOrder) {
|
|||
CalculatorContext calculator_context(&calculator_state,
|
||||
tool::CreateTagMap({}).ValueOrDie(),
|
||||
output_stream_managers.TagMap());
|
||||
InputStreamShardSet& input_set = calculator_context.Inputs();
|
||||
OutputStreamShardSet& output_set = calculator_context.Outputs();
|
||||
output_set.Index(0).SetSpec(output_stream_managers.Index(0).Spec());
|
||||
output_set.Index(0).SetNextTimestampBound(Timestamp(10));
|
||||
|
@ -137,19 +136,6 @@ TEST(CalculatorTest, SourceProcessOrder) {
|
|||
CalculatorContextManager().PushInputTimestampToContext(
|
||||
&calculator_context, Timestamp::Unstarted());
|
||||
|
||||
InputStreamSet input_streams(input_set.TagMap());
|
||||
OutputStreamSet output_streams(output_set.TagMap());
|
||||
for (CollectionItemId id = input_streams.BeginId();
|
||||
id < input_streams.EndId(); ++id) {
|
||||
input_streams.Get(id) = &input_set.Get(id);
|
||||
}
|
||||
for (CollectionItemId id = output_streams.BeginId();
|
||||
id < output_streams.EndId(); ++id) {
|
||||
output_streams.Get(id) = &output_set.Get(id);
|
||||
}
|
||||
calculator_state.SetInputStreamSet(&input_streams);
|
||||
calculator_state.SetOutputStreamSet(&output_streams);
|
||||
|
||||
test_ns::DeadEndCalculator calculator;
|
||||
EXPECT_EQ(Timestamp(10), calculator.SourceProcessOrder(&calculator_context));
|
||||
output_set.Index(0).SetNextTimestampBound(Timestamp(100));
|
||||
|
@ -202,7 +188,8 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
|
|||
// Register a whitelisted calculator.
|
||||
CalculatorBaseRegistry::Register(
|
||||
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
|
||||
absl::make_unique<mediapipe::test_ns::whitelisted_ns::DeadCalculator>);
|
||||
absl::make_unique<internal::CalculatorBaseFactoryFor<
|
||||
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
|
||||
|
||||
// A whitelisted calculator can be found in its own namespace.
|
||||
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //
|
||||
|
|
|
@ -66,11 +66,26 @@ void CalculatorContext::SetOffset(TimestampDiff offset) {
|
|||
}
|
||||
|
||||
const InputStreamSet& CalculatorContext::InputStreams() const {
|
||||
return calculator_state_->InputStreams();
|
||||
if (!input_streams_) {
|
||||
input_streams_ = absl::make_unique<InputStreamSet>(inputs_.TagMap());
|
||||
for (CollectionItemId id = input_streams_->BeginId();
|
||||
id < input_streams_->EndId(); ++id) {
|
||||
input_streams_->Get(id) = const_cast<InputStreamShard*>(&inputs_.Get(id));
|
||||
}
|
||||
}
|
||||
return *input_streams_;
|
||||
}
|
||||
|
||||
const OutputStreamSet& CalculatorContext::OutputStreams() const {
|
||||
return calculator_state_->OutputStreams();
|
||||
if (!output_streams_) {
|
||||
output_streams_ = absl::make_unique<OutputStreamSet>(outputs_.TagMap());
|
||||
for (CollectionItemId id = output_streams_->BeginId();
|
||||
id < output_streams_->EndId(); ++id) {
|
||||
output_streams_->Get(id) =
|
||||
const_cast<OutputStreamShard*>(&outputs_.Get(id));
|
||||
}
|
||||
}
|
||||
return *output_streams_;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -163,6 +163,10 @@ class CalculatorContext {
|
|||
CalculatorState* calculator_state_;
|
||||
InputStreamShardSet inputs_;
|
||||
OutputStreamShardSet outputs_;
|
||||
// Created on-demand when needed by legacy APIs. No synchronization needed
|
||||
// because all possible callers are already serialized.
|
||||
mutable std::unique_ptr<InputStreamSet> input_streams_;
|
||||
mutable std::unique_ptr<OutputStreamSet> output_streams_;
|
||||
// The queue of timestamp values to Process() in this calculator context.
|
||||
std::queue<Timestamp> input_timestamps_;
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_registry_util.h"
|
||||
#include "mediapipe/framework/counter_factory.h"
|
||||
#include "mediapipe/framework/input_stream_manager.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
|
@ -374,15 +373,12 @@ mediapipe::Status CalculatorNode::PrepareForRun(
|
|||
MP_RETURN_IF_ERROR(calculator_context_manager_.PrepareForRun(std::bind(
|
||||
&CalculatorNode::ConnectShardsToStreams, this, std::placeholders::_1)));
|
||||
|
||||
auto calculator_statusor = CreateCalculator(
|
||||
input_stream_handler_->InputTagMap(),
|
||||
output_stream_handler_->OutputTagMap(), validated_graph_->Package(),
|
||||
calculator_state_.get(),
|
||||
ASSIGN_OR_RETURN(
|
||||
auto calculator_factory,
|
||||
CalculatorBaseRegistry::CreateByNameInNamespace(
|
||||
validated_graph_->Package(), calculator_state_->CalculatorType()));
|
||||
calculator_ = calculator_factory->CreateCalculator(
|
||||
calculator_context_manager_.GetDefaultCalculatorContext());
|
||||
if (!calculator_statusor.ok()) {
|
||||
return calculator_statusor.status();
|
||||
}
|
||||
calculator_ = std::move(calculator_statusor).ValueOrDie();
|
||||
|
||||
needs_to_close_ = false;
|
||||
|
||||
|
|
|
@ -19,14 +19,10 @@
|
|||
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
|
||||
#define REGISTER_CALCULATOR(name) \
|
||||
REGISTER_FACTORY_FUNCTION_QUALIFIED(mediapipe::CalculatorBaseRegistry, \
|
||||
calculator_registration, name, \
|
||||
absl::make_unique<name>); \
|
||||
REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||
mediapipe::internal::StaticAccessToCalculatorBaseRegistry, \
|
||||
access_registration, name, \
|
||||
absl::make_unique< \
|
||||
mediapipe::internal::StaticAccessToCalculatorBaseTyped<name>>)
|
||||
// Macro for registering calculators.
|
||||
#define REGISTER_CALCULATOR(name) \
|
||||
REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||
mediapipe::CalculatorBaseRegistry, calculator_registration, name, \
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<name>>)
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_H_
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_registry_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/collection.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/packet_set.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
bool IsLegacyCalculator(const std::string& package_name,
|
||||
const std::string& node_class) {
|
||||
return false;
|
||||
}
|
||||
|
||||
mediapipe::Status VerifyCalculatorWithContract(const std::string& package_name,
|
||||
const std::string& node_class,
|
||||
CalculatorContract* contract) {
|
||||
// A number of calculators use the non-CC methods on GlCalculatorHelper
|
||||
// even though they are CalculatorBase-based.
|
||||
ASSIGN_OR_RETURN(
|
||||
auto static_access_to_calculator_base,
|
||||
internal::StaticAccessToCalculatorBaseRegistry::CreateByNameInNamespace(
|
||||
package_name, node_class),
|
||||
_ << "Unable to find Calculator \"" << node_class << "\"");
|
||||
MP_RETURN_IF_ERROR(static_access_to_calculator_base->GetContract(contract))
|
||||
.SetPrepend()
|
||||
<< node_class << ": ";
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::StatusOr<std::unique_ptr<CalculatorBase>> CreateCalculator(
|
||||
const std::shared_ptr<tool::TagMap>& input_tag_map,
|
||||
const std::shared_ptr<tool::TagMap>& output_tag_map,
|
||||
const std::string& package_name, CalculatorState* calculator_state,
|
||||
CalculatorContext* calculator_context) {
|
||||
std::unique_ptr<CalculatorBase> calculator;
|
||||
ASSIGN_OR_RETURN(calculator,
|
||||
CalculatorBaseRegistry::CreateByNameInNamespace(
|
||||
package_name, calculator_state->CalculatorType()));
|
||||
return std::move(calculator);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -1,46 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_state.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/framework/tool/tag_map.h"
|
||||
|
||||
// Calculator registry util functions that supports both legacy Calculator API
|
||||
// and CalculatorBase.
|
||||
namespace mediapipe {
|
||||
|
||||
bool IsLegacyCalculator(const std::string& package_name,
|
||||
const std::string& node_class);
|
||||
|
||||
mediapipe::Status VerifyCalculatorWithContract(const std::string& package_name,
|
||||
const std::string& node_class,
|
||||
CalculatorContract* contract);
|
||||
|
||||
mediapipe::StatusOr<std::unique_ptr<CalculatorBase>> CreateCalculator(
|
||||
const std::shared_ptr<tool::TagMap>& input_tag_map,
|
||||
const std::shared_ptr<tool::TagMap>& output_tag_map,
|
||||
const std::string& package_name, CalculatorState* calculator_state,
|
||||
CalculatorContext* calculator_context);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_
|
|
@ -33,8 +33,6 @@ CalculatorState::CalculatorState(
|
|||
calculator_type_(calculator_type),
|
||||
node_config_(node_config),
|
||||
profiling_context_(profiling_context),
|
||||
input_streams_(nullptr),
|
||||
output_streams_(nullptr),
|
||||
counter_factory_(nullptr) {
|
||||
options_.Initialize(node_config);
|
||||
ResetBetweenRuns();
|
||||
|
@ -42,20 +40,8 @@ CalculatorState::CalculatorState(
|
|||
|
||||
CalculatorState::~CalculatorState() {}
|
||||
|
||||
void CalculatorState::SetInputStreamSet(InputStreamSet* input_stream_set) {
|
||||
CHECK(input_stream_set);
|
||||
input_streams_ = input_stream_set;
|
||||
}
|
||||
|
||||
void CalculatorState::SetOutputStreamSet(OutputStreamSet* output_stream_set) {
|
||||
CHECK(output_stream_set);
|
||||
output_streams_ = output_stream_set;
|
||||
}
|
||||
|
||||
void CalculatorState::ResetBetweenRuns() {
|
||||
input_side_packets_ = nullptr;
|
||||
input_streams_ = nullptr;
|
||||
output_streams_ = nullptr;
|
||||
counter_factory_ = nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,14 +52,6 @@ class CalculatorState {
|
|||
CalculatorState& operator=(const CalculatorState&) = delete;
|
||||
~CalculatorState();
|
||||
|
||||
// Sets the pointer to the InputStreamSet. The function is invoked by
|
||||
// CalculatorNode::PrepareForRun.
|
||||
void SetInputStreamSet(InputStreamSet* input_stream_set);
|
||||
|
||||
// Sets the pointer to the OutputStreamSet. The function is invoked by
|
||||
// CalculatorNode::PrepareForRun.
|
||||
void SetOutputStreamSet(OutputStreamSet* output_stream_set);
|
||||
|
||||
// Called before every call to Calculator::Open() (during the PrepareForRun
|
||||
// phase).
|
||||
void ResetBetweenRuns();
|
||||
|
@ -79,8 +71,6 @@ class CalculatorState {
|
|||
////////////////////////////////////////
|
||||
// Interface for Calculator.
|
||||
////////////////////////////////////////
|
||||
const InputStreamSet& InputStreams() const { return *input_streams_; }
|
||||
const OutputStreamSet& OutputStreams() const { return *output_streams_; }
|
||||
const PacketSet& InputSidePackets() const { return *input_side_packets_; }
|
||||
OutputSidePacketSet& OutputSidePackets() { return *output_side_packets_; }
|
||||
|
||||
|
@ -139,12 +129,6 @@ class CalculatorState {
|
|||
////////////////////////////////////////
|
||||
// Variables which ARE cleared by ResetBetweenRuns().
|
||||
////////////////////////////////////////
|
||||
// The InputStreamSet object is owned by the CalculatorNode.
|
||||
// CalculatorState obtains its pointer in CalculatorNode::PrepareForRun.
|
||||
InputStreamSet* input_streams_;
|
||||
// The OutputStreamSet object is owned by the CalculatorNode.
|
||||
// CalculatorState obtains its pointer in CalculatorNode::PrepareForRun.
|
||||
OutputStreamSet* output_streams_;
|
||||
// The set of input side packets set by CalculatorNode::PrepareForRun().
|
||||
// ResetBetweenRuns() clears this PacketSet pointer.
|
||||
const PacketSet* input_side_packets_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mediapipe {
|
||||
namespace file {
|
||||
mediapipe::Status GetContents(absl::string_view file_name, std::string* output,
|
||||
bool read_as_binary = false);
|
||||
bool read_as_binary = true);
|
||||
|
||||
mediapipe::Status SetContents(absl::string_view file_name,
|
||||
absl::string_view content);
|
||||
|
|
|
@ -454,12 +454,12 @@ cc_library(
|
|||
":status_util",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:legacy_calculator_support",
|
||||
"//mediapipe/framework:packet_generator",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:packet_set",
|
||||
"//mediapipe/framework:packet_type",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/legacy_calculator_support.h"
|
||||
#include "mediapipe/framework/packet_generator.h"
|
||||
#include "mediapipe/framework/packet_generator.pb.h"
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_registry_util.h"
|
||||
#include "mediapipe/framework/legacy_calculator_support.h"
|
||||
#include "mediapipe/framework/packet_generator.h"
|
||||
#include "mediapipe/framework/packet_generator.pb.h"
|
||||
|
@ -236,8 +235,14 @@ mediapipe::Status NodeTypeInfo::Initialize(
|
|||
}
|
||||
#endif
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContract> s(&contract_);
|
||||
MP_RETURN_IF_ERROR(VerifyCalculatorWithContract(validated_graph.Package(),
|
||||
node_class, &contract_));
|
||||
// A number of calculators use the non-CC methods on GlCalculatorHelper
|
||||
// even though they are CalculatorBase-based.
|
||||
ASSIGN_OR_RETURN(auto calculator_factory,
|
||||
CalculatorBaseRegistry::CreateByNameInNamespace(
|
||||
validated_graph.Package(), node_class),
|
||||
_ << "Unable to find Calculator \"" << node_class << "\"");
|
||||
MP_RETURN_IF_ERROR(calculator_factory->GetContract(&contract_)).SetPrepend()
|
||||
<< node_class << ": ";
|
||||
|
||||
// Validate result of FillExpectations or GetContract.
|
||||
std::vector<mediapipe::Status> statuses;
|
||||
|
@ -261,10 +266,7 @@ mediapipe::Status NodeTypeInfo::Initialize(
|
|||
}
|
||||
if (!statuses.empty()) {
|
||||
return tool::CombinedStatus(
|
||||
absl::StrCat(node_class,
|
||||
IsLegacyCalculator(validated_graph.Package(), node_class)
|
||||
? "::FillExpectations"
|
||||
: "::GetContract",
|
||||
absl::StrCat(node_class, "::", calculator_factory->ContractMethodName(),
|
||||
" failed to validate: "),
|
||||
statuses);
|
||||
}
|
||||
|
|
|
@ -55,6 +55,7 @@ cc_library(
|
|||
name = "desktop_cpu_calculators",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_model_calculator",
|
||||
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||
"//mediapipe/calculators/video:opencv_video_decoder_calculator",
|
||||
|
|
|
@ -635,12 +635,15 @@ void GlAnimationOverlayCalculator::LoadModelMatrices(
|
|||
|
||||
// Process model matrices, if any are being streamed in, and update our
|
||||
// list.
|
||||
current_model_matrices_.clear();
|
||||
if (has_model_matrix_stream_ &&
|
||||
!cc->Inputs().Tag("MODEL_MATRICES").IsEmpty()) {
|
||||
const TimedModelMatrixProtoList &model_matrices =
|
||||
cc->Inputs().Tag("MODEL_MATRICES").Get<TimedModelMatrixProtoList>();
|
||||
LoadModelMatrices(model_matrices, ¤t_model_matrices_);
|
||||
}
|
||||
|
||||
current_mask_model_matrices_.clear();
|
||||
if (has_mask_model_matrix_stream_ &&
|
||||
!cc->Inputs().Tag("MASK_MODEL_MATRICES").IsEmpty()) {
|
||||
const TimedModelMatrixProtoList &model_matrices =
|
||||
|
|
|
@ -8,6 +8,7 @@ input_side_packet: "LABELS_CSV:allowed_labels"
|
|||
input_side_packet: "MODEL_SCALE:model_scale"
|
||||
input_side_packet: "MODEL_TRANSFORMATION:model_transformation"
|
||||
input_side_packet: "TEXTURE:box_texture"
|
||||
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
|
||||
input_side_packet: "ANIMATION_ASSET:box_asset_name"
|
||||
input_side_packet: "MASK_TEXTURE:obj_texture"
|
||||
input_side_packet: "MASK_ASSET:obj_asset_name"
|
||||
|
@ -26,7 +27,7 @@ output_stream: "output_video"
|
|||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "FINISHED:lifted_objects"
|
||||
input_stream: "FINISHED:output_video"
|
||||
input_stream_info: {
|
||||
tag_index: "FINISHED"
|
||||
back_edge: true
|
||||
|
@ -52,6 +53,7 @@ node {
|
|||
calculator: "ObjectronGpuSubgraph"
|
||||
input_stream: "IMAGE_GPU:throttled_input_video_3x4"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
|
||||
output_stream: "FRAME_ANNOTATION:lifted_objects"
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,16 @@ input_side_packet: "FILE_PATH:0:box_landmark_model_path"
|
|||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
|
||||
|
||||
# Generates side packet with max number of objects to detect/track.
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:max_num_objects"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: {
|
||||
packet { int_value: 5 }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Decodes an input video file into images and a video header.
|
||||
node {
|
||||
|
@ -30,8 +40,9 @@ node {
|
|||
input_stream: "IMAGE:input_video"
|
||||
input_side_packet: "MODEL:box_landmark_model"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
output_stream: "LANDMARKS:box_landmarks"
|
||||
output_stream: "NORM_RECT:box_rect"
|
||||
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
|
||||
output_stream: "MULTI_LANDMARKS:box_landmarks"
|
||||
output_stream: "NORM_RECTS:box_rect"
|
||||
}
|
||||
|
||||
# Subgraph that renders annotations and overlays them on top of the input
|
||||
|
@ -39,8 +50,8 @@ node {
|
|||
node {
|
||||
calculator: "RendererSubgraph"
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "LANDMARKS:box_landmarks"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "MULTI_LANDMARKS:box_landmarks"
|
||||
input_stream: "NORM_RECTS:box_rect"
|
||||
output_stream: "IMAGE:output_video"
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,8 @@ mediapipe_simple_subgraph(
|
|||
register_as = "RendererSubgraph",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_render_data_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
|
||||
|
|
|
@ -3,15 +3,26 @@
|
|||
type: "RendererSubgraph"
|
||||
|
||||
input_stream: "IMAGE:input_image"
|
||||
input_stream: "LANDMARKS:landmarks"
|
||||
input_stream: "NORM_RECT:rect"
|
||||
input_stream: "MULTI_LANDMARKS:multi_landmarks"
|
||||
input_stream: "NORM_RECTS:multi_rect"
|
||||
output_stream: "IMAGE:output_image"
|
||||
|
||||
# Outputs each element of multi_landmarks at a fake timestamp for the rest
|
||||
# of the graph to process. At the end of the loop, outputs the BATCH_END
|
||||
# timestamp for downstream calculators to inform them that all elements in the
|
||||
# vector have been processed.
|
||||
node {
|
||||
calculator: "BeginLoopNormalizedLandmarkListVectorCalculator"
|
||||
input_stream: "ITERABLE:multi_landmarks"
|
||||
output_stream: "ITEM:single_landmarks"
|
||||
output_stream: "BATCH_END:landmark_timestamp"
|
||||
}
|
||||
|
||||
# Converts landmarks to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "LandmarksToRenderDataCalculator"
|
||||
input_stream: "NORM_LANDMARKS:landmarks"
|
||||
output_stream: "RENDER_DATA:landmark_render_data"
|
||||
input_stream: "NORM_LANDMARKS:single_landmarks"
|
||||
output_stream: "RENDER_DATA:single_landmark_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||
landmark_connections: [1, 2] # edge 1-2
|
||||
|
@ -33,11 +44,18 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "EndLoopRenderDataCalculator"
|
||||
input_stream: "ITEM:single_landmark_render_data"
|
||||
input_stream: "BATCH_END:landmark_timestamp"
|
||||
output_stream: "ITERABLE:multi_landmarks_render_data"
|
||||
}
|
||||
|
||||
# Converts normalized rects to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "RectToRenderDataCalculator"
|
||||
input_stream: "NORM_RECT:rect"
|
||||
output_stream: "RENDER_DATA:rect_render_data"
|
||||
input_stream: "NORM_RECTS:multi_rect"
|
||||
output_stream: "RENDER_DATA:multi_rect_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
|
||||
filled: false
|
||||
|
@ -51,7 +69,7 @@ node {
|
|||
node {
|
||||
calculator: "AnnotationOverlayCalculator"
|
||||
input_stream: "IMAGE:input_image"
|
||||
input_stream: "landmark_render_data"
|
||||
input_stream: "rect_render_data"
|
||||
input_stream: "VECTOR:multi_landmarks_render_data"
|
||||
input_stream: "multi_rect_render_data"
|
||||
output_stream: "IMAGE:output_image"
|
||||
}
|
||||
|
|
|
@ -430,7 +430,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
framesInUse--;
|
||||
int keep = max(framesToKeep - framesInUse, 0);
|
||||
while (framesAvailable.size() > keep) {
|
||||
teardownFrame(framesAvailable.remove());
|
||||
PoolTextureFrame textureFrameToRemove = framesAvailable.remove();
|
||||
handler.post(() -> teardownFrame(textureFrameToRemove));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ node {
|
|||
options: {
|
||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
||||
delegate { xnnpack {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,25 +56,19 @@ mediapipe_simple_subgraph(
|
|||
graph = "box_landmark_gpu.pbtxt",
|
||||
register_as = "BoxLandmarkSubgraph",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_converter_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_inference_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
|
||||
"//mediapipe/calculators/util:landmark_projection_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:thresholding_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -83,24 +77,19 @@ mediapipe_simple_subgraph(
|
|||
graph = "box_landmark_cpu.pbtxt",
|
||||
register_as = "BoxLandmarkSubgraph",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_converter_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_inference_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
|
||||
"//mediapipe/calculators/util:landmark_projection_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:thresholding_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -115,9 +104,7 @@ mediapipe_simple_subgraph(
|
|||
"//mediapipe/calculators/tflite:tflite_inference_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:filter_detection_calculator",
|
||||
],
|
||||
)
|
||||
|
@ -133,9 +120,7 @@ mediapipe_simple_subgraph(
|
|||
"//mediapipe/calculators/tflite:tflite_inference_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:filter_detection_calculator",
|
||||
],
|
||||
)
|
||||
|
@ -147,9 +132,19 @@ mediapipe_simple_subgraph(
|
|||
deps = [
|
||||
":box_landmark_cpu",
|
||||
":object_detection_oid_v4_cpu",
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:merge_calculator",
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/util:association_norm_rect_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -160,9 +155,18 @@ mediapipe_simple_subgraph(
|
|||
deps = [
|
||||
":box_landmark_gpu",
|
||||
":object_detection_oid_v4_gpu",
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:merge_calculator",
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/util:association_norm_rect_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -2,82 +2,75 @@
|
|||
|
||||
type: "BoxLandmarkSubgraph"
|
||||
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "IMAGE:image"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_side_packet: "MODEL:model"
|
||||
output_stream: "LANDMARKS:box_landmarks_filtered"
|
||||
output_stream: "NORM_RECT:box_rect_for_next_frame"
|
||||
output_stream: "PRESENCE:box_presence"
|
||||
output_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
|
||||
# Crops the rectangle that contains a box from the input image.
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImageCroppingCalculator"
|
||||
input_stream: "IMAGE:input_video"
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Expands the rectangle that contain the box so that it's likely to cover the
|
||||
# entire box.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
output_stream: "IMAGE:box_image"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_scaled"
|
||||
options: {
|
||||
[mediapipe.ImageCroppingCalculatorOptions.ext] {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
square_long: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Crops, resizes, and converts the input video into tensor.
|
||||
# Preserves aspect ratio of the images.
|
||||
node {
|
||||
calculator: "ImageToTensorCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
input_stream: "NORM_RECT:box_rect_scaled"
|
||||
output_stream: "TENSORS:image_tensor"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
options {
|
||||
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||
output_tensor_width: 224
|
||||
output_tensor_height: 224
|
||||
keep_aspect_ratio: true
|
||||
output_tensor_float_range {
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
}
|
||||
gpu_origin: TOP_LEFT
|
||||
border_mode: BORDER_REPLICATE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Transforms the input image to a 224x224 image. To scale the input
|
||||
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
|
||||
# resulting in potential letterboxing in the transformed image.
|
||||
node: {
|
||||
calculator: "ImageTransformationCalculator"
|
||||
input_stream: "IMAGE:box_image"
|
||||
output_stream: "IMAGE:transformed_box_image"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
options: {
|
||||
[mediapipe.ImageTransformationCalculatorOptions.ext] {
|
||||
output_width: 224
|
||||
output_height: 224
|
||||
scale_mode: FIT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the transformed input image into an image tensor stored as a
|
||||
# TfLiteTensor.
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
input_stream: "IMAGE:transformed_box_image"
|
||||
output_stream: "TENSORS:image_tensor"
|
||||
options: {
|
||||
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
|
||||
zero_center: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generates a single side packet containing a TensorFlow Lite op resolver that
|
||||
# supports custom ops needed by the model used in this graph.
|
||||
node {
|
||||
calculator: "TfLiteCustomOpResolverCalculator"
|
||||
output_side_packet: "opresolver"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
calculator: "TfLiteInferenceCalculator"
|
||||
calculator: "InferenceCalculator"
|
||||
input_stream: "TENSORS:image_tensor"
|
||||
output_stream: "TENSORS:output_tensors"
|
||||
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
|
||||
input_side_packet: "MODEL:model"
|
||||
output_stream: "TENSORS:output_tensors"
|
||||
options: {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
use_gpu: false
|
||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||
delegate { xnnpack {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Splits a vector of tensors into multiple vectors.
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "output_tensors"
|
||||
output_stream: "landmark_tensors"
|
||||
output_stream: "box_flag_tensor"
|
||||
|
@ -92,7 +85,7 @@ node {
|
|||
# Converts the box-flag tensor into a float that represents the confidence
|
||||
# score of box presence.
|
||||
node {
|
||||
calculator: "TfLiteTensorsToFloatsCalculator"
|
||||
calculator: "TensorsToFloatsCalculator"
|
||||
input_stream: "TENSORS:box_flag_tensor"
|
||||
output_stream: "FLOAT:box_presence_score"
|
||||
}
|
||||
|
@ -105,19 +98,27 @@ node {
|
|||
output_stream: "FLAG:box_presence"
|
||||
options: {
|
||||
[mediapipe.ThresholdingCalculatorOptions.ext] {
|
||||
threshold: 0.1
|
||||
threshold: 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Drops landmarks tensors if box is not present.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "landmark_tensors"
|
||||
input_stream: "ALLOW:box_presence"
|
||||
output_stream: "gated_landmark_tensors"
|
||||
}
|
||||
|
||||
# Decodes the landmark tensors into a list of landmarks, where the landmark
|
||||
# coordinates are normalized by the size of the input image to the model.
|
||||
node {
|
||||
calculator: "TfLiteTensorsToLandmarksCalculator"
|
||||
input_stream: "TENSORS:landmark_tensors"
|
||||
calculator: "TensorsToLandmarksCalculator"
|
||||
input_stream: "TENSORS:gated_landmark_tensors"
|
||||
output_stream: "NORM_LANDMARKS:landmarks"
|
||||
options: {
|
||||
[mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
|
||||
[mediapipe.TensorsToLandmarksCalculatorOptions.ext] {
|
||||
num_landmarks: 9
|
||||
input_image_width: 224
|
||||
input_image_height: 224
|
||||
|
@ -141,66 +142,6 @@ node {
|
|||
node {
|
||||
calculator: "LandmarkProjectionCalculator"
|
||||
input_stream: "NORM_LANDMARKS:scaled_landmarks"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "NORM_RECT:box_rect_scaled"
|
||||
output_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Smooth predicted landmarks coordinates.
|
||||
node {
|
||||
calculator: "LandmarksSmoothingCalculator"
|
||||
input_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_FILTERED_LANDMARKS:box_landmarks_filtered"
|
||||
options: {
|
||||
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
|
||||
velocity_filter: {
|
||||
window_size: 10
|
||||
velocity_scale: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Convert box landmarks to frame annotation.
|
||||
node {
|
||||
calculator: "LandmarksToFrameAnnotationCalculator"
|
||||
input_stream: "LANDMARKS:box_landmarks_filtered"
|
||||
output_stream: "FRAME_ANNOTATION:box_annotation"
|
||||
}
|
||||
|
||||
# Lift the 2D landmarks to 3D using EPnP algorithm.
|
||||
node {
|
||||
calculator: "Lift2DFrameAnnotationTo3DCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:box_annotation"
|
||||
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_box"
|
||||
}
|
||||
|
||||
# Get rotated rectangle from lifted box.
|
||||
node {
|
||||
calculator: "FrameAnnotationToRectCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:lifted_box"
|
||||
output_stream: "NORM_RECT:rect_from_box"
|
||||
}
|
||||
|
||||
# Expands the box rectangle so that in the next video frame it's likely to
|
||||
# still contain the box even with some motion.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:rect_from_box"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_for_next_frame"
|
||||
options: {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
square_long: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,81 +2,75 @@
|
|||
|
||||
type: "BoxLandmarkSubgraph"
|
||||
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "IMAGE:image"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
output_stream: "FRAME_ANNOTATION:lifted_box"
|
||||
output_stream: "NORM_RECT:box_rect_for_next_frame"
|
||||
output_stream: "PRESENCE:box_presence"
|
||||
output_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
|
||||
# Crops the rectangle that contains a box from the input image.
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImageCroppingCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE_GPU:image"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Expands the rectangle that contain the box so that it's likely to cover the
|
||||
# entire box.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
output_stream: "IMAGE_GPU:box_image"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_scaled"
|
||||
options: {
|
||||
[mediapipe.ImageCroppingCalculatorOptions.ext] {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
square_long: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Crops, resizes, and converts the input video into tensor.
|
||||
# Preserves aspect ratio of the images.
|
||||
node {
|
||||
calculator: "ImageToTensorCalculator"
|
||||
input_stream: "IMAGE_GPU:image"
|
||||
input_stream: "NORM_RECT:box_rect_scaled"
|
||||
output_stream: "TENSORS:image_tensor"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
options {
|
||||
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||
output_tensor_width: 224
|
||||
output_tensor_height: 224
|
||||
keep_aspect_ratio: true
|
||||
output_tensor_float_range {
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
}
|
||||
gpu_origin: TOP_LEFT
|
||||
border_mode: BORDER_REPLICATE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Transforms the input image on GPU to a 224x224 image. To scale the input
|
||||
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
|
||||
# resulting in potential letterboxing in the transformed image.
|
||||
node: {
|
||||
calculator: "ImageTransformationCalculator"
|
||||
input_stream: "IMAGE_GPU:box_image"
|
||||
output_stream: "IMAGE_GPU:transformed_box_image"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
options: {
|
||||
[mediapipe.ImageTransformationCalculatorOptions.ext] {
|
||||
output_width: 224
|
||||
output_height: 224
|
||||
scale_mode: FIT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the transformed input image on GPU into an image tensor stored as a
|
||||
# TfLiteTensor.
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
input_stream: "IMAGE_GPU:transformed_box_image"
|
||||
output_stream: "TENSORS_GPU:image_tensor"
|
||||
options: {
|
||||
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
|
||||
zero_center: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generates a single side packet containing a TensorFlow Lite op resolver that
|
||||
# supports custom ops needed by the model used in this graph.
|
||||
node {
|
||||
calculator: "TfLiteCustomOpResolverCalculator"
|
||||
output_side_packet: "opresolver"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
calculator: "TfLiteInferenceCalculator"
|
||||
input_stream: "TENSORS_GPU:image_tensor"
|
||||
calculator: "InferenceCalculator"
|
||||
input_stream: "TENSORS:image_tensor"
|
||||
output_stream: "TENSORS:output_tensors"
|
||||
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
|
||||
options: {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||
model_path: "object_detection_3d.tflite"
|
||||
use_gpu: true
|
||||
delegate { gpu {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Splits a vector of tensors into multiple vectors.
|
||||
# Splits a vector of tensors to multiple vectors according to the ranges
|
||||
# specified in option.
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "output_tensors"
|
||||
output_stream: "landmark_tensors"
|
||||
output_stream: "box_flag_tensor"
|
||||
|
@ -91,7 +85,7 @@ node {
|
|||
# Converts the box-flag tensor into a float that represents the confidence
|
||||
# score of box presence.
|
||||
node {
|
||||
calculator: "TfLiteTensorsToFloatsCalculator"
|
||||
calculator: "TensorsToFloatsCalculator"
|
||||
input_stream: "TENSORS:box_flag_tensor"
|
||||
output_stream: "FLOAT:box_presence_score"
|
||||
}
|
||||
|
@ -109,14 +103,22 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Drops landmarks tensors if box is not present.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "landmark_tensors"
|
||||
input_stream: "ALLOW:box_presence"
|
||||
output_stream: "gated_landmark_tensors"
|
||||
}
|
||||
|
||||
# Decodes the landmark tensors into a list of landmarks, where the landmark
|
||||
# coordinates are normalized by the size of the input image to the model.
|
||||
node {
|
||||
calculator: "TfLiteTensorsToLandmarksCalculator"
|
||||
input_stream: "TENSORS:landmark_tensors"
|
||||
calculator: "TensorsToLandmarksCalculator"
|
||||
input_stream: "TENSORS:gated_landmark_tensors"
|
||||
output_stream: "NORM_LANDMARKS:landmarks"
|
||||
options: {
|
||||
[mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
|
||||
[mediapipe.TensorsToLandmarksCalculatorOptions.ext] {
|
||||
num_landmarks: 9
|
||||
input_image_width: 224
|
||||
input_image_height: 224
|
||||
|
@ -140,66 +142,6 @@ node {
|
|||
node {
|
||||
calculator: "LandmarkProjectionCalculator"
|
||||
input_stream: "NORM_LANDMARKS:scaled_landmarks"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "NORM_RECT:box_rect_scaled"
|
||||
output_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Smooth predicted landmarks coordinates.
|
||||
node {
|
||||
calculator: "LandmarksSmoothingCalculator"
|
||||
input_stream: "NORM_LANDMARKS:box_landmarks"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_FILTERED_LANDMARKS:box_landmarks_filtered"
|
||||
options: {
|
||||
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
|
||||
velocity_filter: {
|
||||
window_size: 10
|
||||
velocity_scale: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Convert box landmarks to frame annotation.
|
||||
node {
|
||||
calculator: "LandmarksToFrameAnnotationCalculator"
|
||||
input_stream: "LANDMARKS:box_landmarks_filtered"
|
||||
output_stream: "FRAME_ANNOTATION:box_annotation"
|
||||
}
|
||||
|
||||
# Lift the 2D landmarks to 3D using EPnP algorithm.
|
||||
node {
|
||||
calculator: "Lift2DFrameAnnotationTo3DCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:box_annotation"
|
||||
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_box"
|
||||
}
|
||||
|
||||
# Get rotated rectangle from lifted box.
|
||||
node {
|
||||
calculator: "FrameAnnotationToRectCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:lifted_box"
|
||||
output_stream: "NORM_RECT:rect_from_box"
|
||||
}
|
||||
|
||||
# Expands the box rectangle so that in the next video frame it's likely to
|
||||
# still contain the box even with some motion.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:rect_from_box"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_for_next_frame"
|
||||
options: {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
square_long: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "absl/memory/memory.h"
|
||||
|
@ -32,7 +33,7 @@ using Eigen::Vector3f;
|
|||
namespace {
|
||||
|
||||
constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION";
|
||||
constexpr char kOutputNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kOutputNormRectsTag[] = "NORM_RECTS";
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -47,14 +48,14 @@ class FrameAnnotationToRectCalculator : public CalculatorBase {
|
|||
TOP_VIEW_OFF,
|
||||
};
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
void AnnotationToRect(const FrameAnnotation& annotation,
|
||||
NormalizedRect* rect);
|
||||
float RotationAngleFromAnnotation(const FrameAnnotation& annotation);
|
||||
void AddAnnotationToRect(const ObjectAnnotation& annotation,
|
||||
std::vector<NormalizedRect>* rect);
|
||||
float RotationAngleFromAnnotation(const ObjectAnnotation& annotation);
|
||||
|
||||
float RotationAngleFromPose(const Matrix3fRM& rotation,
|
||||
const Vector3f& translation, const Vector3f& vec);
|
||||
|
@ -64,17 +65,7 @@ class FrameAnnotationToRectCalculator : public CalculatorBase {
|
|||
};
|
||||
REGISTER_CALCULATOR(FrameAnnotationToRectCalculator);
|
||||
|
||||
::mediapipe::Status FrameAnnotationToRectCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
status_ = TOP_VIEW_OFF;
|
||||
const auto& options = cc->Options<FrameAnnotationToRectCalculatorOptions>();
|
||||
off_threshold_ = options.off_threshold();
|
||||
on_threshold_ = options.on_threshold();
|
||||
RET_CHECK(off_threshold_ <= on_threshold_);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status FrameAnnotationToRectCalculator::GetContract(
|
||||
mediapipe::Status FrameAnnotationToRectCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
@ -83,57 +74,69 @@ REGISTER_CALCULATOR(FrameAnnotationToRectCalculator);
|
|||
cc->Inputs().Tag(kInputFrameAnnotationTag).Set<FrameAnnotation>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kOutputNormRectTag)) {
|
||||
cc->Outputs().Tag(kOutputNormRectTag).Set<NormalizedRect>();
|
||||
if (cc->Outputs().HasTag(kOutputNormRectsTag)) {
|
||||
cc->Outputs().Tag(kOutputNormRectsTag).Set<std::vector<NormalizedRect>>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status FrameAnnotationToRectCalculator::Process(
|
||||
mediapipe::Status FrameAnnotationToRectCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
status_ = TOP_VIEW_OFF;
|
||||
const auto& options = cc->Options<FrameAnnotationToRectCalculatorOptions>();
|
||||
off_threshold_ = options.off_threshold();
|
||||
on_threshold_ = options.on_threshold();
|
||||
RET_CHECK(off_threshold_ <= on_threshold_);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status FrameAnnotationToRectCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kInputFrameAnnotationTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
auto output_rects = absl::make_unique<std::vector<NormalizedRect>>();
|
||||
const auto& frame_annotation =
|
||||
cc->Inputs().Tag(kInputFrameAnnotationTag).Get<FrameAnnotation>();
|
||||
for (const auto& object_annotation : frame_annotation.annotations()) {
|
||||
AddAnnotationToRect(object_annotation, output_rects.get());
|
||||
}
|
||||
auto output_rect = absl::make_unique<NormalizedRect>();
|
||||
AnnotationToRect(
|
||||
cc->Inputs().Tag(kInputFrameAnnotationTag).Get<FrameAnnotation>(),
|
||||
output_rect.get());
|
||||
|
||||
// Output
|
||||
// Output.
|
||||
cc->Outputs()
|
||||
.Tag(kOutputNormRectTag)
|
||||
.Add(output_rect.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
.Tag(kOutputNormRectsTag)
|
||||
.Add(output_rects.release(), cc->InputTimestamp());
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void FrameAnnotationToRectCalculator::AnnotationToRect(
|
||||
const FrameAnnotation& annotation, NormalizedRect* rect) {
|
||||
void FrameAnnotationToRectCalculator::AddAnnotationToRect(
|
||||
const ObjectAnnotation& annotation, std::vector<NormalizedRect>* rects) {
|
||||
float x_min = std::numeric_limits<float>::max();
|
||||
float x_max = std::numeric_limits<float>::min();
|
||||
float y_min = std::numeric_limits<float>::max();
|
||||
float y_max = std::numeric_limits<float>::min();
|
||||
const auto& object = annotation.annotations(0);
|
||||
for (const auto& keypoint : object.keypoints()) {
|
||||
for (const auto& keypoint : annotation.keypoints()) {
|
||||
const auto& point_2d = keypoint.point_2d();
|
||||
x_min = std::min(x_min, point_2d.x());
|
||||
x_max = std::max(x_max, point_2d.x());
|
||||
y_min = std::min(y_min, point_2d.y());
|
||||
y_max = std::max(y_max, point_2d.y());
|
||||
}
|
||||
rect->set_x_center((x_min + x_max) / 2);
|
||||
rect->set_y_center((y_min + y_max) / 2);
|
||||
rect->set_width(x_max - x_min);
|
||||
rect->set_height(y_max - y_min);
|
||||
rect->set_rotation(RotationAngleFromAnnotation(annotation));
|
||||
NormalizedRect new_rect;
|
||||
new_rect.set_x_center((x_min + x_max) / 2);
|
||||
new_rect.set_y_center((y_min + y_max) / 2);
|
||||
new_rect.set_width(x_max - x_min);
|
||||
new_rect.set_height(y_max - y_min);
|
||||
new_rect.set_rotation(RotationAngleFromAnnotation(annotation));
|
||||
rects->push_back(new_rect);
|
||||
}
|
||||
|
||||
float FrameAnnotationToRectCalculator::RotationAngleFromAnnotation(
|
||||
const FrameAnnotation& annotation) {
|
||||
const auto& object = annotation.annotations(0);
|
||||
const ObjectAnnotation& annotation) {
|
||||
Box box("category");
|
||||
std::vector<Vector3f> vertices_3d;
|
||||
std::vector<Vector2f> vertices_2d;
|
||||
for (const auto& keypoint : object.keypoints()) {
|
||||
for (const auto& keypoint : annotation.keypoints()) {
|
||||
const auto& point_3d = keypoint.point_3d();
|
||||
const auto& point_2d = keypoint.point_2d();
|
||||
vertices_3d.emplace_back(
|
||||
|
|
|
@ -23,6 +23,7 @@ namespace mediapipe {
|
|||
namespace {
|
||||
|
||||
constexpr char kInputLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kInputMultiLandmarksTag[] = "MULTI_LANDMARKS";
|
||||
constexpr char kOutputFrameAnnotationTag[] = "FRAME_ANNOTATION";
|
||||
|
||||
} // namespace
|
||||
|
@ -30,12 +31,17 @@ constexpr char kOutputFrameAnnotationTag[] = "FRAME_ANNOTATION";
|
|||
// A calculator that converts NormalizedLandmarkList to FrameAnnotation proto.
|
||||
class LandmarksToFrameAnnotationCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
void AddLandmarksToFrameAnnotation(const NormalizedLandmarkList& landmarks,
|
||||
FrameAnnotation* frame_annotation);
|
||||
};
|
||||
REGISTER_CALCULATOR(LandmarksToFrameAnnotationCalculator);
|
||||
|
||||
::mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract(
|
||||
mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
@ -43,34 +49,65 @@ REGISTER_CALCULATOR(LandmarksToFrameAnnotationCalculator);
|
|||
if (cc->Inputs().HasTag(kInputLandmarksTag)) {
|
||||
cc->Inputs().Tag(kInputLandmarksTag).Set<NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kInputMultiLandmarksTag)) {
|
||||
cc->Inputs()
|
||||
.Tag(kInputMultiLandmarksTag)
|
||||
.Set<std::vector<NormalizedLandmarkList>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kOutputFrameAnnotationTag)) {
|
||||
cc->Outputs().Tag(kOutputFrameAnnotationTag).Set<FrameAnnotation>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status LandmarksToFrameAnnotationCalculator::Process(
|
||||
mediapipe::Status LandmarksToFrameAnnotationCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status LandmarksToFrameAnnotationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
auto frame_annotation = absl::make_unique<FrameAnnotation>();
|
||||
auto* box_annotation = frame_annotation->add_annotations();
|
||||
|
||||
const auto& landmarks =
|
||||
cc->Inputs().Tag(kInputLandmarksTag).Get<NormalizedLandmarkList>();
|
||||
RET_CHECK_GT(landmarks.landmark_size(), 0)
|
||||
<< "Input landmark vector is empty.";
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
auto* point2d = box_annotation->add_keypoints()->mutable_point_2d();
|
||||
point2d->set_x(landmarks.landmark(i).x());
|
||||
point2d->set_y(landmarks.landmark(i).y());
|
||||
// Handle the case when input has only one NormalizedLandmarkList.
|
||||
if (cc->Inputs().HasTag(kInputLandmarksTag) &&
|
||||
!cc->Inputs().Tag(kInputLandmarksTag).IsEmpty()) {
|
||||
const auto& landmarks =
|
||||
cc->Inputs().Tag(kInputMultiLandmarksTag).Get<NormalizedLandmarkList>();
|
||||
AddLandmarksToFrameAnnotation(landmarks, frame_annotation.get());
|
||||
}
|
||||
|
||||
// Handle the case when input has muliple NormalizedLandmarkList.
|
||||
if (cc->Inputs().HasTag(kInputMultiLandmarksTag) &&
|
||||
!cc->Inputs().Tag(kInputMultiLandmarksTag).IsEmpty()) {
|
||||
const auto& landmarks_list =
|
||||
cc->Inputs()
|
||||
.Tag(kInputMultiLandmarksTag)
|
||||
.Get<std::vector<NormalizedLandmarkList>>();
|
||||
for (const auto& landmarks : landmarks_list) {
|
||||
AddLandmarksToFrameAnnotation(landmarks, frame_annotation.get());
|
||||
}
|
||||
}
|
||||
|
||||
// Output
|
||||
if (cc->Outputs().HasTag(kOutputFrameAnnotationTag)) {
|
||||
cc->Outputs()
|
||||
.Tag(kOutputFrameAnnotationTag)
|
||||
.Add(frame_annotation.release(), cc->InputTimestamp());
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void LandmarksToFrameAnnotationCalculator::AddLandmarksToFrameAnnotation(
|
||||
const NormalizedLandmarkList& landmarks,
|
||||
FrameAnnotation* frame_annotation) {
|
||||
auto* new_annotation = frame_annotation->add_annotations();
|
||||
for (const auto& landmark : landmarks.landmark()) {
|
||||
auto* point2d = new_annotation->add_keypoints()->mutable_point_2d();
|
||||
point2d->set_x(landmark.x());
|
||||
point2d->set_y(landmark.y());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -55,16 +55,16 @@ namespace mediapipe {
|
|||
// }
|
||||
class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
static mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::Status ProcessCPU(CalculatorContext* cc,
|
||||
FrameAnnotation* output_objects);
|
||||
::mediapipe::Status LoadOptions(CalculatorContext* cc);
|
||||
mediapipe::Status ProcessCPU(CalculatorContext* cc,
|
||||
FrameAnnotation* output_objects);
|
||||
mediapipe::Status LoadOptions(CalculatorContext* cc);
|
||||
|
||||
// Increment and assign object ID for each detected object.
|
||||
// In a single MediaPipe session, the IDs are unique.
|
||||
|
@ -73,23 +73,24 @@ class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase {
|
|||
void AssignObjectIdAndTimestamp(int64 timestamp_us,
|
||||
FrameAnnotation* annotation);
|
||||
std::unique_ptr<Decoder> decoder_;
|
||||
::mediapipe::Lift2DFrameAnnotationTo3DCalculatorOptions options_;
|
||||
Lift2DFrameAnnotationTo3DCalculatorOptions options_;
|
||||
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> projection_matrix_;
|
||||
};
|
||||
REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::GetContract(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kInputStreamTag));
|
||||
RET_CHECK(cc->Outputs().HasTag(kOutputStreamTag));
|
||||
cc->Inputs().Tag(kInputStreamTag).Set<FrameAnnotation>();
|
||||
cc->Outputs().Tag(kOutputStreamTag).Set<FrameAnnotation>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Open(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
// clang-format off
|
||||
projection_matrix_ <<
|
||||
|
@ -101,13 +102,13 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
|
|||
|
||||
decoder_ = absl::make_unique<Decoder>(
|
||||
BeliefDecoderConfig(options_.decoder_config()));
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kInputStreamTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
auto output_objects = absl::make_unique<FrameAnnotation>();
|
||||
|
@ -121,10 +122,10 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
|
|||
.Add(output_objects.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU(
|
||||
CalculatorContext* cc, FrameAnnotation* output_objects) {
|
||||
const auto& input_frame_annotations =
|
||||
cc->Inputs().Tag(kInputStreamTag).Get<FrameAnnotation>();
|
||||
|
@ -140,21 +141,20 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
|
|||
AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(),
|
||||
output_objects);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Close(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions(
|
||||
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions(
|
||||
CalculatorContext* cc) {
|
||||
// Get calculator options specified in the graph.
|
||||
options_ =
|
||||
cc->Options<::mediapipe::Lift2DFrameAnnotationTo3DCalculatorOptions>();
|
||||
options_ = cc->Options<Lift2DFrameAnnotationTo3DCalculatorOptions>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void Lift2DFrameAnnotationTo3DCalculator::AssignObjectIdAndTimestamp(
|
||||
|
|
|
@ -4,9 +4,9 @@ type: "ObjectDetectionOidV4Subgraph"
|
|||
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
output_stream: "NORM_RECT:box_rect_from_object_detections"
|
||||
output_stream: "DETECTIONS:detections"
|
||||
|
||||
# Transforms the input image on GPU to a 300x300 image. To scale the image, by
|
||||
# Transforms the input image on CPU to a 300x300 image. To scale the image, by
|
||||
# default it uses the STRETCH scale mode that maps the entire input image to the
|
||||
# entire transformed image. As a result, image aspect ratio may be changed and
|
||||
# objects in the image may be deformed (stretched or squeezed), but the object
|
||||
|
@ -23,7 +23,7 @@ node: {
|
|||
}
|
||||
}
|
||||
|
||||
# Converts the transformed input image on GPU into an image tensor stored as a
|
||||
# Converts the transformed input image on CPU into an image tensor stored as a
|
||||
# TfLiteTensor.
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
|
@ -31,7 +31,7 @@ node {
|
|||
output_stream: "TENSORS:image_tensor"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
||||
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
|
@ -82,7 +82,7 @@ node {
|
|||
calculator: "TfLiteTensorsToDetectionsCalculator"
|
||||
input_stream: "TENSORS:detection_tensors"
|
||||
input_side_packet: "ANCHORS:anchors"
|
||||
output_stream: "DETECTIONS:detections"
|
||||
output_stream: "DETECTIONS:all_detections"
|
||||
options: {
|
||||
[mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] {
|
||||
num_classes: 195
|
||||
|
@ -95,22 +95,7 @@ node {
|
|||
y_scale: 10.0
|
||||
h_scale: 5.0
|
||||
w_scale: 5.0
|
||||
min_score_thresh: 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Performs non-max suppression to remove excessive detections.
|
||||
node {
|
||||
calculator: "NonMaxSuppressionCalculator"
|
||||
input_stream: "detections"
|
||||
output_stream: "suppressed_detections"
|
||||
options: {
|
||||
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
|
||||
min_suppression_threshold: 0.4
|
||||
max_num_detections: 1
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
return_empty_detections: true
|
||||
min_score_thresh: 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -119,7 +104,7 @@ node {
|
|||
# provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "DetectionLabelIdToTextCalculator"
|
||||
input_stream: "suppressed_detections"
|
||||
input_stream: "all_detections"
|
||||
output_stream: "labeled_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||
|
@ -128,50 +113,26 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Filters the detections to only those with valid scores
|
||||
# for the specified allowed labels.
|
||||
node {
|
||||
calculator: "FilterDetectionCalculator"
|
||||
input_stream: "DETECTIONS:labeled_detections"
|
||||
output_stream: "DETECTIONS:filtered_detections"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
}
|
||||
|
||||
# Performs non-max suppression to remove excessive detections.
|
||||
node {
|
||||
calculator: "NonMaxSuppressionCalculator"
|
||||
input_stream: "filtered_detections"
|
||||
output_stream: "detections"
|
||||
options: {
|
||||
[mediapipe.FilterDetectionCalculatorOptions.ext]: {
|
||||
min_score: 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Converts results of box detection into a rectangle (normalized by image size)
|
||||
# that encloses the box.
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:filtered_detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_RECT:box_rect"
|
||||
options: {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
output_zero_rect_for_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Expands the rectangle that contains the box so that it's likely to cover the
|
||||
# entire box.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_from_object_detections"
|
||||
options: {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
|
||||
min_suppression_threshold: 0.5
|
||||
max_num_detections: 100
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
return_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ type: "ObjectDetectionOidV4Subgraph"
|
|||
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
output_stream: "NORM_RECT:box_rect_from_object_detections"
|
||||
output_stream: "DETECTIONS:detections"
|
||||
|
||||
# Transforms the input image on GPU to a 300x300 image. To scale the image, by
|
||||
# default it uses the STRETCH scale mode that maps the entire input image to the
|
||||
|
@ -40,7 +40,7 @@ node {
|
|||
output_stream: "TENSORS_GPU:detection_tensors"
|
||||
options: {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite"
|
||||
model_path: "object_detection_ssd_mobilenetv2_oidv4_fp16.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ node {
|
|||
calculator: "TfLiteTensorsToDetectionsCalculator"
|
||||
input_stream: "TENSORS_GPU:detection_tensors"
|
||||
input_side_packet: "ANCHORS:anchors"
|
||||
output_stream: "DETECTIONS:detections"
|
||||
output_stream: "DETECTIONS:all_detections"
|
||||
options: {
|
||||
[mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] {
|
||||
num_classes: 195
|
||||
|
@ -95,22 +95,7 @@ node {
|
|||
y_scale: 10.0
|
||||
h_scale: 5.0
|
||||
w_scale: 5.0
|
||||
min_score_thresh: 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Performs non-max suppression to remove excessive detections.
|
||||
node {
|
||||
calculator: "NonMaxSuppressionCalculator"
|
||||
input_stream: "detections"
|
||||
output_stream: "suppressed_detections"
|
||||
options: {
|
||||
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
|
||||
min_suppression_threshold: 0.4
|
||||
max_num_detections: 1
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
return_empty_detections: true
|
||||
min_score_thresh: 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -119,59 +104,35 @@ node {
|
|||
# provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "DetectionLabelIdToTextCalculator"
|
||||
input_stream: "suppressed_detections"
|
||||
input_stream: "all_detections"
|
||||
output_stream: "labeled_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||
label_map_path: "mediapipe/models/object_detection_oidv4_labelmap.pbtxt"
|
||||
label_map_path: "object_detection_oidv4_labelmap.pbtxt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Filters the detections to only those with valid scores
|
||||
# for the specified allowed labels.
|
||||
node {
|
||||
calculator: "FilterDetectionCalculator"
|
||||
input_stream: "DETECTIONS:labeled_detections"
|
||||
output_stream: "DETECTIONS:filtered_detections"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
}
|
||||
|
||||
# Performs non-max suppression to remove excessive detections.
|
||||
node {
|
||||
calculator: "NonMaxSuppressionCalculator"
|
||||
input_stream: "filtered_detections"
|
||||
output_stream: "detections"
|
||||
options: {
|
||||
[mediapipe.FilterDetectionCalculatorOptions.ext]: {
|
||||
min_score: 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Converts results of box detection into a rectangle (normalized by image size)
|
||||
# that encloses the box.
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:filtered_detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_RECT:box_rect"
|
||||
options: {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
output_zero_rect_for_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Expands the rectangle that contains the box so that it's likely to cover the
|
||||
# entire box.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "box_rect_from_object_detections"
|
||||
options: {
|
||||
[mediapipe.RectTransformationCalculatorOptions.ext] {
|
||||
scale_x: 1.5
|
||||
scale_y: 1.5
|
||||
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
|
||||
min_suppression_threshold: 0.5
|
||||
max_num_detections: 100
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
return_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ input_stream: "IMAGE:input_video"
|
|||
input_side_packet: "MODEL:box_landmark_model"
|
||||
# Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
# Max number of objects to detect/track. (int)
|
||||
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
|
||||
# Bounding box landmarks topology definition.
|
||||
# The numbers are indices in the box_landmarks list.
|
||||
#
|
||||
|
@ -22,36 +24,47 @@ input_side_packet: "LABELS_CSV:allowed_labels"
|
|||
# \+ \+
|
||||
# 2 + + + + + + + + 6
|
||||
#
|
||||
output_stream: "LANDMARKS:box_landmarks"
|
||||
# Crop rectangle derived from bounding box landmarks.
|
||||
output_stream: "NORM_RECT:box_rect"
|
||||
output_stream: "MULTI_LANDMARKS:multi_box_landmarks"
|
||||
# Crop rectangles derived from bounding box landmarks.
|
||||
output_stream: "NORM_RECTS:multi_box_rects"
|
||||
|
||||
|
||||
# Caches a box-presence decision fed back from boxLandmarkSubgraph, and upon
|
||||
# the arrival of the next input image sends out the cached decision with the
|
||||
# timestamp replaced by that of the input image, essentially generating a packet
|
||||
# that carries the previous box-presence decision. Note that upon the arrival
|
||||
# of the very first input image, an empty packet is sent out to jump start the
|
||||
# feedback loop.
|
||||
# Defines whether landmarks from the previous video frame should be used to help
|
||||
# predict landmarks on the current video frame.
|
||||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:input_video"
|
||||
input_stream: "LOOP:box_presence"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
name: "ConstantSidePacketCalculator"
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:use_prev_landmarks"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { bool_value: true }
|
||||
}
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_box_presence"
|
||||
}
|
||||
|
||||
# Drops the incoming image if boxLandmarkSubgraph was able to identify box
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "ALLOW:use_prev_landmarks"
|
||||
input_stream: "prev_box_rects_from_landmarks"
|
||||
output_stream: "gated_prev_box_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Determines if an input vector of NormalizedRect has a size greater than or
|
||||
# equal to the provided max_num_objects.
|
||||
node {
|
||||
calculator: "NormalizedRectVectorHasMinSizeCalculator"
|
||||
input_stream: "ITERABLE:gated_prev_box_rects_from_landmarks"
|
||||
input_side_packet: "max_num_objects"
|
||||
output_stream: "prev_has_enough_objects"
|
||||
}
|
||||
|
||||
# Drops the incoming image if BoxLandmarkSubgraph was able to identify box
|
||||
# presence in the previous image. Otherwise, passes the incoming image through
|
||||
# to trigger a new round of box detection in boxDetectionSubgraph.
|
||||
# to trigger a new round of box detection in ObjectDetectionOidV4Subgraph.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "DISALLOW:prev_box_presence"
|
||||
output_stream: "box_detection_input_video"
|
||||
input_stream: "DISALLOW:prev_has_enough_objects"
|
||||
output_stream: "detection_input_video"
|
||||
|
||||
options: {
|
||||
[mediapipe.GateCalculatorOptions.ext] {
|
||||
|
@ -60,23 +73,112 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Subgraph that detections boxs (see object_detection_oid_v4_cpu.pbtxt).
|
||||
# Subgraph that performs 2D object detection.
|
||||
node {
|
||||
calculator: "ObjectDetectionOidV4Subgraph"
|
||||
input_stream: "IMAGE:box_detection_input_video"
|
||||
input_stream: "IMAGE:detection_input_video"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
output_stream: "NORM_RECT:box_rect_from_object_detections"
|
||||
output_stream: "DETECTIONS:raw_detections"
|
||||
}
|
||||
|
||||
# Subgraph that localizes box landmarks (see box_landmark_gpu.pbtxt).
|
||||
# Makes sure there are no more detections than provided max_num_objects.
|
||||
node {
|
||||
calculator: "ClipDetectionVectorSizeCalculator"
|
||||
input_stream: "raw_detections"
|
||||
output_stream: "detections"
|
||||
input_side_packet: "max_num_objects"
|
||||
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Converts results of box detection into rectangles (normalized by image size)
|
||||
# that encloses the box.
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_RECTS:box_rects_from_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
output_zero_rect_for_empty_detections: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Performs association between NormalizedRect vector elements from previous
|
||||
# image and rects based on object detections from the current image. This
|
||||
# calculator ensures that the output box_rects vector doesn't contain
|
||||
# overlapping regions based on the specified min_similarity_threshold.
|
||||
node {
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "box_rects_from_detections"
|
||||
input_stream: "gated_prev_box_rects_from_landmarks"
|
||||
output_stream: "multi_box_rects"
|
||||
options: {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Outputs each element of box_rects at a fake timestamp for the rest of the
|
||||
# graph to process. Clones image and image size packets for each
|
||||
# single_box_rect at the fake timestamp. At the end of the loop, outputs the
|
||||
# BATCH_END timestamp for downstream calculators to inform them that all
|
||||
# elements in the vector have been processed.
|
||||
node {
|
||||
calculator: "BeginLoopNormalizedRectCalculator"
|
||||
input_stream: "ITERABLE:multi_box_rects"
|
||||
input_stream: "CLONE:input_video"
|
||||
output_stream: "ITEM:single_box_rect"
|
||||
output_stream: "CLONE:landmarks_input_video"
|
||||
output_stream: "BATCH_END:box_rects_timestamp"
|
||||
}
|
||||
|
||||
# Subgraph that localizes box landmarks.
|
||||
node {
|
||||
calculator: "BoxLandmarkSubgraph"
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
input_stream: "IMAGE:landmarks_input_video"
|
||||
input_side_packet: "MODEL:box_landmark_model"
|
||||
output_stream: "LANDMARKS:box_landmarks"
|
||||
output_stream: "NORM_RECT:box_rect_from_landmarks"
|
||||
output_stream: "PRESENCE:box_presence"
|
||||
input_stream: "NORM_RECT:single_box_rect"
|
||||
output_stream: "NORM_LANDMARKS:single_box_landmarks"
|
||||
}
|
||||
|
||||
# Collects a set of landmarks for each hand into a vector. Upon receiving the
|
||||
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
|
||||
# timestamp.
|
||||
node {
|
||||
calculator: "EndLoopNormalizedLandmarkListVectorCalculator"
|
||||
input_stream: "ITEM:single_box_landmarks"
|
||||
input_stream: "BATCH_END:box_rects_timestamp"
|
||||
output_stream: "ITERABLE:multi_box_landmarks"
|
||||
}
|
||||
|
||||
# Convert box landmarks to frame annotations.
|
||||
node {
|
||||
calculator: "LandmarksToFrameAnnotationCalculator"
|
||||
input_stream: "MULTI_LANDMARKS:multi_box_landmarks"
|
||||
output_stream: "FRAME_ANNOTATION:box_annotations"
|
||||
}
|
||||
|
||||
# Lift the 2D landmarks to 3D using EPnP algorithm.
|
||||
node {
|
||||
calculator: "Lift2DFrameAnnotationTo3DCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:box_annotations"
|
||||
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects"
|
||||
}
|
||||
|
||||
# Get rotated rectangle from lifted box.
|
||||
node {
|
||||
calculator: "FrameAnnotationToRectCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:lifted_objects"
|
||||
output_stream: "NORM_RECTS:box_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Caches a box rectangle fed back from boxLandmarkSubgraph, and upon the
|
||||
|
@ -88,25 +190,10 @@ node {
|
|||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:input_video"
|
||||
input_stream: "LOOP:box_rect_from_landmarks"
|
||||
input_stream: "LOOP:box_rects_from_landmarks"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_box_rect_from_landmarks"
|
||||
}
|
||||
|
||||
# Merges a stream of box rectangles generated by ObjectDetectionSubgraph and that
|
||||
# generated by BoxLandmarkSubgraph into a single output stream by selecting
|
||||
# between one of the two streams. The former is selected if the incoming packet
|
||||
# is not empty, i.e., box detection is performed on the current image by
|
||||
# BoxDetectionSubgraph (because BoxLandmarkSubgraph could not identify box
|
||||
# presence in the previous image). Otherwise, the latter is selected, which is
|
||||
# never empty because BoxLandmarkSubgraphs processes all images (that went
|
||||
# through FlowLimiterCaculator).
|
||||
node {
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "box_rect_from_object_detections"
|
||||
input_stream: "prev_box_rect_from_landmarks"
|
||||
output_stream: "box_rect"
|
||||
output_stream: "PREV_LOOP:prev_box_rects_from_landmarks"
|
||||
}
|
||||
|
|
|
@ -5,33 +5,48 @@
|
|||
input_stream: "IMAGE_GPU:input_video"
|
||||
# Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
# Max number of objects to detect/track. (int)
|
||||
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
|
||||
|
||||
# Collection of detected 3D objects, represented as a FrameAnnotation.
|
||||
output_stream: "FRAME_ANNOTATION:lifted_objects"
|
||||
|
||||
|
||||
# Caches a box-presence decision fed back from boxLandmarkSubgraph, and upon
|
||||
# the arrival of the next input image sends out the cached decision with the
|
||||
# timestamp replaced by that of the input image, essentially generating a packet
|
||||
# that carries the previous box-presence decision. Note that upon the arrival
|
||||
# of the very first input image, an empty packet is sent out to jump start the
|
||||
# feedback loop.
|
||||
# Defines whether landmarks from the previous video frame should be used to help
|
||||
# predict landmarks on the current video frame.
|
||||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:input_video"
|
||||
input_stream: "LOOP:box_presence"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
name: "ConstantSidePacketCalculator"
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:use_prev_landmarks"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { bool_value: true }
|
||||
}
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_box_presence"
|
||||
}
|
||||
|
||||
# Drops the incoming image if boxLandmarkSubgraph was able to identify box
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "ALLOW:use_prev_landmarks"
|
||||
input_stream: "prev_box_rects_from_landmarks"
|
||||
output_stream: "gated_prev_box_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Determines if an input vector of NormalizedRect has a size greater than or
|
||||
# equal to the provided max_num_objects.
|
||||
node {
|
||||
calculator: "NormalizedRectVectorHasMinSizeCalculator"
|
||||
input_stream: "ITERABLE:gated_prev_box_rects_from_landmarks"
|
||||
input_side_packet: "max_num_objects"
|
||||
output_stream: "prev_has_enough_objects"
|
||||
}
|
||||
|
||||
# Drops the incoming image if BoxLandmarkSubgraph was able to identify box
|
||||
# presence in the previous image. Otherwise, passes the incoming image through
|
||||
# to trigger a new round of box detection in boxDetectionSubgraph.
|
||||
# to trigger a new round of box detection in ObjectDetectionOidV4Subgraph.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "DISALLOW:prev_box_presence"
|
||||
input_stream: "DISALLOW:prev_has_enough_objects"
|
||||
output_stream: "detection_input_video"
|
||||
|
||||
options: {
|
||||
|
@ -46,17 +61,106 @@ node {
|
|||
calculator: "ObjectDetectionOidV4Subgraph"
|
||||
input_stream: "IMAGE_GPU:detection_input_video"
|
||||
input_side_packet: "LABELS_CSV:allowed_labels"
|
||||
output_stream: "NORM_RECT:box_rect_from_object_detections"
|
||||
output_stream: "DETECTIONS:raw_detections"
|
||||
}
|
||||
|
||||
# Makes sure there are no more detections than provided max_num_objects.
|
||||
node {
|
||||
calculator: "ClipDetectionVectorSizeCalculator"
|
||||
input_stream: "raw_detections"
|
||||
output_stream: "detections"
|
||||
input_side_packet: "max_num_objects"
|
||||
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Converts results of box detection into rectangles (normalized by image size)
|
||||
# that encloses the box.
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_RECTS:box_rects_from_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
output_zero_rect_for_empty_detections: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Performs association between NormalizedRect vector elements from previous
|
||||
# image and rects based on object detections from the current image. This
|
||||
# calculator ensures that the output box_rects vector doesn't contain
|
||||
# overlapping regions based on the specified min_similarity_threshold.
|
||||
node {
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "box_rects_from_detections"
|
||||
input_stream: "gated_prev_box_rects_from_landmarks"
|
||||
output_stream: "box_rects"
|
||||
options: {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Outputs each element of box_rects at a fake timestamp for the rest of the
|
||||
# graph to process. Clones image and image size packets for each
|
||||
# single_box_rect at the fake timestamp. At the end of the loop, outputs the
|
||||
# BATCH_END timestamp for downstream calculators to inform them that all
|
||||
# elements in the vector have been processed.
|
||||
node {
|
||||
calculator: "BeginLoopNormalizedRectCalculator"
|
||||
input_stream: "ITERABLE:box_rects"
|
||||
input_stream: "CLONE:input_video"
|
||||
output_stream: "ITEM:single_box_rect"
|
||||
output_stream: "CLONE:landmarks_input_video"
|
||||
output_stream: "BATCH_END:box_rects_timestamp"
|
||||
}
|
||||
|
||||
# Subgraph that localizes box landmarks.
|
||||
node {
|
||||
calculator: "BoxLandmarkSubgraph"
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "NORM_RECT:box_rect"
|
||||
output_stream: "FRAME_ANNOTATION:lifted_objects"
|
||||
output_stream: "NORM_RECT:box_rect_from_landmarks"
|
||||
output_stream: "PRESENCE:box_presence"
|
||||
input_stream: "IMAGE:landmarks_input_video"
|
||||
input_stream: "NORM_RECT:single_box_rect"
|
||||
output_stream: "NORM_LANDMARKS:single_box_landmarks"
|
||||
}
|
||||
|
||||
# Collects a set of landmarks for each hand into a vector. Upon receiving the
|
||||
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
|
||||
# timestamp.
|
||||
node {
|
||||
calculator: "EndLoopNormalizedLandmarkListVectorCalculator"
|
||||
input_stream: "ITEM:single_box_landmarks"
|
||||
input_stream: "BATCH_END:box_rects_timestamp"
|
||||
output_stream: "ITERABLE:multi_box_landmarks"
|
||||
}
|
||||
|
||||
# Convert box landmarks to frame annotations.
|
||||
node {
|
||||
calculator: "LandmarksToFrameAnnotationCalculator"
|
||||
input_stream: "MULTI_LANDMARKS:multi_box_landmarks"
|
||||
output_stream: "FRAME_ANNOTATION:box_annotations"
|
||||
}
|
||||
|
||||
# Lift the 2D landmarks to 3D using EPnP algorithm.
|
||||
node {
|
||||
calculator: "Lift2DFrameAnnotationTo3DCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:box_annotations"
|
||||
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects"
|
||||
}
|
||||
|
||||
# Get rotated rectangle from lifted box.
|
||||
node {
|
||||
calculator: "FrameAnnotationToRectCalculator"
|
||||
input_stream: "FRAME_ANNOTATION:lifted_objects"
|
||||
output_stream: "NORM_RECTS:box_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Caches a box rectangle fed back from boxLandmarkSubgraph, and upon the
|
||||
|
@ -68,25 +172,10 @@ node {
|
|||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:input_video"
|
||||
input_stream: "LOOP:box_rect_from_landmarks"
|
||||
input_stream: "LOOP:box_rects_from_landmarks"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_box_rect_from_landmarks"
|
||||
}
|
||||
|
||||
# Merges a stream of box rectangles generated by boxDetectionSubgraph and that
|
||||
# generated by boxLandmarkSubgraph into a single output stream by selecting
|
||||
# between one of the two streams. The former is selected if the incoming packet
|
||||
# is not empty, i.e., box detection is performed on the current image by
|
||||
# boxDetectionSubgraph (because boxLandmarkSubgraph could not identify box
|
||||
# presence in the previous image). Otherwise, the latter is selected, which is
|
||||
# never empty because boxLandmarkSubgraphs processes all images (that went
|
||||
# through FlowLimiterCaculator).
|
||||
node {
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "box_rect_from_object_detections"
|
||||
input_stream: "prev_box_rect_from_landmarks"
|
||||
output_stream: "box_rect"
|
||||
output_stream: "PREV_LOOP:prev_box_rects_from_landmarks"
|
||||
}
|
||||
|
|
|
@ -46,6 +46,11 @@ node {
|
|||
calculator: "LocalFileContentsCalculator"
|
||||
input_side_packet: "FILE_PATH:model_path"
|
||||
output_side_packet: "CONTENTS:model_blob"
|
||||
options: {
|
||||
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
|
||||
read_as_binary: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the input blob into a TF Lite model.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user