Project import generated by Copybara.

GitOrigin-RevId: ea8d45731f5a052f79745e35bfd8240d6ac568d2
This commit is contained in:
MediaPipe Team 2020-12-15 23:29:11 -05:00 committed by chuoling
parent 38be2ec58f
commit 39309bedba
109 changed files with 5803 additions and 1500 deletions

View File

@ -102,6 +102,8 @@ run code search using
## Publications
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)
in Google AI Blog
* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html)

View File

@ -364,9 +364,9 @@ http_archive(
)
#Tensorflow repo should always go after the other external dependencies.
# 2020-12-07
_TENSORFLOW_GIT_COMMIT = "f556709f4df005ad57fd24d5eaa0d9380128d3ba"
_TENSORFLOW_SHA256= "9e157d4723921b48a974f645f70d07c8fd3c363569a0ac6ee85fec114d6459ea"
# 2020-12-09
_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a"
_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0"
http_archive(
name = "org_tensorflow",
urls = [

View File

@ -94,7 +94,8 @@ for app in ${apps}; do
else
graph_name="${target_name}/${target_name}"
fi
if [[ ${target_name} == "iris_tracking" ||
if [[ ${target_name} == "holistic_tracking" ||
${target_name} == "iris_tracking" ||
${target_name} == "pose_tracking" ||
${target_name} == "upper_body_pose_tracking" ]]; then
graph_suffix="cpu"

View File

@ -31,16 +31,16 @@ install --user six`.
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
to install Bazel 3.4 or higher.
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu only, Bazel needs
For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs
to be built from source:
```bash
# For Bazel 3.4.0
mkdir $HOME/bazel-3.4.0
cd $HOME/bazel-3.4.0
wget https://github.com/bazelbuild/bazel/releases/download/3.4.0/bazel-3.4.0-dist.zip
# For Bazel 3.4.1
mkdir $HOME/bazel-3.4.1
cd $HOME/bazel-3.4.1
wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
unzip bazel-3.4.0-dist.zip
unzip bazel-3.4.1-dist.zip
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
sudo cp output/bazel /usr/local/bin/
```
@ -338,14 +338,7 @@ build issues.
2. Install Bazel.
Option 1. Use package manager tool to install Bazel
```bash
$ brew install bazel
# Run 'bazel version' to check version of bazel
```
Option 2. Follow the official
Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
to install Bazel 3.4 or higher.
@ -604,14 +597,14 @@ cameras. Alternatively, you use a video file as input.
```bash
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
https://storage.googleapis.com/bazel/3.4.0/release/bazel-3.4.0-installer-linux-x86_64.sh && \
sudo mkdir -p /usr/local/bazel/3.4.0 && \
chmod 755 bazel-3.4.0-installer-linux-x86_64.sh && \
sudo ./bazel-3.4.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.0 && \
source /usr/local/bazel/3.4.0/lib/bazel/bin/bazel-complete.bash
https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \
sudo mkdir -p /usr/local/bazel/3.4.1 && \
chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \
sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \
source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.0/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/3.4.0/lib/bazel/bin/bazel'
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel'
```
6. Checkout MediaPipe repository.

View File

@ -70,6 +70,11 @@ Python package from source. Otherwise, we strongly encourage our users to simply
run `pip install mediapipe` to use the ready-to-use solutions, more convenient
and much faster.
MediaPipe PyPI currently doesn't provide aarch64 Python wheel
files. For building and using MediaPipe Python on aarch64 Linux systems such as
Nvidia Jetson and Raspberry Pi, please read
[here](https://github.com/jiuqiant/mediapipe-python-aarch64).
1. Make sure that Bazel and OpenCV are correctly installed and configured for
MediaPipe. Please see [Installation](./install.md) for how to setup Bazel
and OpenCV for MediaPipe on Linux and macOS.
@ -82,12 +87,18 @@ and much faster.
$ sudo apt install python3-dev
$ sudo apt install python3-venv
$ sudo apt install -y protobuf-compiler
# If you need to build opencv from source.
$ sudo apt install cmake
```
macOS:
```bash
$ brew install protobuf
# If you need to build opencv from source.
$ brew install cmake
```
Windows:
@ -118,3 +129,10 @@ and much faster.
(mp_env)mediapipe$ python3 setup.py gen_protos
(mp_env)mediapipe$ python3 setup.py install --link-opencv
```
or
```bash
(mp_env)mediapipe$ python3 setup.py gen_protos
(mp_env)mediapipe$ python3 setup.py bdist_wheel
```

View File

@ -102,6 +102,8 @@ run code search using
## Publications
* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
in Google AI Blog
* [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html)
in Google AI Blog
* [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html)

View File

@ -405,7 +405,7 @@ on how to build MediaPipe examples.
## Resources
* Google AI Blog:
[MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction on Device](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
[MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html)
* [Models and model cards](./models.md#holistic)
[Colab]:https://mediapipe.page.link/holistic_py_colab

View File

@ -2,33 +2,33 @@
"additionalFilePaths" : [
"/BUILD",
"mediapipe/BUILD",
"mediapipe/objc/BUILD",
"mediapipe/framework/BUILD",
"mediapipe/gpu/BUILD",
"mediapipe/objc/testing/app/BUILD",
"mediapipe/examples/ios/common/BUILD",
"mediapipe/examples/ios/helloworld/BUILD",
"mediapipe/examples/ios/facedetectioncpu/BUILD",
"mediapipe/examples/ios/facedetectiongpu/BUILD",
"mediapipe/examples/ios/faceeffect/BUILD",
"mediapipe/examples/ios/facemeshgpu/BUILD",
"mediapipe/examples/ios/handdetectiongpu/BUILD",
"mediapipe/examples/ios/handtrackinggpu/BUILD",
"mediapipe/examples/ios/helloworld/BUILD",
"mediapipe/examples/ios/holistictrackinggpu/BUILD",
"mediapipe/examples/ios/iristrackinggpu/BUILD",
"mediapipe/examples/ios/objectdetectioncpu/BUILD",
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
"mediapipe/examples/ios/posetrackinggpu/BUILD",
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD"
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD",
"mediapipe/framework/BUILD",
"mediapipe/gpu/BUILD",
"mediapipe/objc/BUILD",
"mediapipe/objc/testing/app/BUILD"
],
"buildTargets" : [
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
"//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp",
"//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp",
"//mediapipe/examples/ios/faceeffect:FaceEffectApp",
"//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp",
"//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp",
"//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp",
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
"//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp",
"//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp",
"//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp",
@ -91,13 +91,13 @@
"mediapipe/examples/ios",
"mediapipe/examples/ios/common",
"mediapipe/examples/ios/common/Base.lproj",
"mediapipe/examples/ios/helloworld",
"mediapipe/examples/ios/facedetectioncpu",
"mediapipe/examples/ios/facedetectiongpu",
"mediapipe/examples/ios/faceeffect",
"mediapipe/examples/ios/faceeffect/Base.lproj",
"mediapipe/examples/ios/handdetectiongpu",
"mediapipe/examples/ios/handtrackinggpu",
"mediapipe/examples/ios/helloworld",
"mediapipe/examples/ios/holistictrackinggpu",
"mediapipe/examples/ios/iristrackinggpu",
"mediapipe/examples/ios/objectdetectioncpu",

View File

@ -9,7 +9,6 @@
"packages" : [
"",
"mediapipe",
"mediapipe/objc",
"mediapipe/examples/ios",
"mediapipe/examples/ios/facedetectioncpu",
"mediapipe/examples/ios/facedetectiongpu",
@ -22,7 +21,8 @@
"mediapipe/examples/ios/objectdetectioncpu",
"mediapipe/examples/ios/objectdetectiongpu",
"mediapipe/examples/ios/posetrackinggpu",
"mediapipe/examples/ios/upperbodyposetrackinggpu"
"mediapipe/examples/ios/upperbodyposetrackinggpu",
"mediapipe/objc"
],
"projectName" : "Mediapipe",
"workspaceRoot" : "../.."

View File

@ -146,6 +146,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
],
@ -286,6 +287,7 @@ cc_library(
deps = [
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -393,6 +395,7 @@ cc_library(
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
@ -406,7 +409,7 @@ cc_library(
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
@ -422,7 +425,7 @@ cc_library(
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
@ -438,6 +441,7 @@ cc_library(
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
],
@ -606,6 +610,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
@ -958,6 +963,7 @@ cc_library(
deps = [
":sequence_shift_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
@ -1007,6 +1013,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -1042,6 +1049,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],

View File

@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
namespace api2 {
// Attach the header from a stream or side input to another stream.
//
@ -42,49 +44,40 @@ namespace mediapipe {
// output_stream: "audio_with_header"
// }
//
class AddHeaderCalculator : public CalculatorBase {
class AddHeaderCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
bool has_side_input = false;
bool has_header_stream = false;
if (cc->InputSidePackets().HasTag("HEADER")) {
cc->InputSidePackets().Tag("HEADER").SetAny();
has_side_input = true;
}
if (cc->Inputs().HasTag("HEADER")) {
cc->Inputs().Tag("HEADER").SetNone();
has_header_stream = true;
}
if (has_side_input == has_header_stream) {
static constexpr Input<NoneType>::Optional kHeader{"HEADER"};
static constexpr SideInput<AnyType>::Optional kHeaderSide{"HEADER"};
static constexpr Input<AnyType> kData{"DATA"};
static constexpr Output<SameType<kData>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut);
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) {
return mediapipe::InvalidArgumentError(
"Header must be provided via exactly one of side input and input "
"stream");
}
cc->Inputs().Tag("DATA").SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA"));
return mediapipe::OkStatus();
}
mediapipe::Status Open(CalculatorContext* cc) override {
Packet header;
if (cc->InputSidePackets().HasTag("HEADER")) {
header = cc->InputSidePackets().Tag("HEADER");
}
if (cc->Inputs().HasTag("HEADER")) {
header = cc->Inputs().Tag("HEADER").Header();
}
const PacketBase& header =
kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc);
if (!header.IsEmpty()) {
cc->Outputs().Index(0).SetHeader(header);
kOut(cc).SetHeader(header);
}
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
mediapipe::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value());
kOut(cc).Send(kData(cc).packet());
return mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(AddHeaderCalculator);
MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -16,6 +16,7 @@
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
@ -23,27 +24,24 @@
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Concatenates several NormalizedLandmarkList protos following stream index
// order. This class assumes that every input stream contains a
// NormalizedLandmarkList proto object.
class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase {
class ConcatenateNormalizedLandmarkListCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() != 0);
RET_CHECK(cc->Outputs().NumEntries() == 1);
static constexpr Input<NormalizedLandmarkList>::Multiple kIn{""};
static constexpr Output<NormalizedLandmarkList> kOut{""};
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<NormalizedLandmarkList>();
}
cc->Outputs().Index(0).Set<NormalizedLandmarkList>();
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GE(kIn(cc).Count(), 1);
return mediapipe::OkStatus();
}
mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
only_emit_if_all_present_ =
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
.only_emit_if_all_present();
@ -52,32 +50,29 @@ class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase {
mediapipe::Status Process(CalculatorContext* cc) override {
if (only_emit_if_all_present_) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) return mediapipe::OkStatus();
for (int i = 0; i < kIn(cc).Count(); ++i) {
if (kIn(cc)[i].IsEmpty()) return mediapipe::OkStatus();
}
}
NormalizedLandmarkList output;
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue;
const NormalizedLandmarkList& input =
cc->Inputs().Index(i).Get<NormalizedLandmarkList>();
for (int i = 0; i < kIn(cc).Count(); ++i) {
if (kIn(cc)[i].IsEmpty()) continue;
const NormalizedLandmarkList& input = *kIn(cc)[i];
for (int j = 0; j < input.landmark_size(); ++j) {
const NormalizedLandmark& input_landmark = input.landmark(j);
*output.add_landmark() = input_landmark;
*output.add_landmark() = input.landmark(j);
}
}
cc->Outputs().Index(0).AddPacket(
MakePacket<NormalizedLandmarkList>(output).At(cc->InputTimestamp()));
kOut(cc).Send(std::move(output));
return mediapipe::OkStatus();
}
private:
bool only_emit_if_all_present_;
};
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator);
REGISTER_CALCULATOR(ConcatenateNormalizedLandmarkListCalculator);
} // namespace api2
} // namespace mediapipe
// NOLINTNEXTLINE

View File

@ -15,10 +15,12 @@
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Given two input streams (A, B), output a single stream containing a pair<A,
// B>.
@ -30,32 +32,27 @@ namespace mediapipe {
// input_stream: "packet_b"
// output_stream: "output_pair_a_b"
// }
class MakePairCalculator : public CalculatorBase {
class MakePairCalculator : public Node {
public:
MakePairCalculator() {}
~MakePairCalculator() override {}
static constexpr Input<AnyType>::Multiple kIn{""};
// Note that currently api2::Packet is a different type from mediapipe::Packet
static constexpr Output<std::pair<mediapipe::Packet, mediapipe::Packet>>
kPair{""};
static mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Index(1).SetAny();
cc->Outputs().Index(0).Set<std::pair<Packet, Packet>>();
return mediapipe::OkStatus();
}
MEDIAPIPE_NODE_CONTRACT(kIn, kPair);
mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_EQ(kIn(cc).Count(), 2);
return mediapipe::OkStatus();
}
mediapipe::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(
new std::pair<Packet, Packet>(cc->Inputs().Index(0).Value(),
cc->Inputs().Index(1).Value()),
cc->InputTimestamp());
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
return mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MakePairCalculator);
MEDIAPIPE_REGISTER_NODE(MakePairCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -13,11 +13,13 @@
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Perform a (left) matrix multiply. Meaning (output = A * input)
// where A is the matrix which is provided as an input side packet.
//
@ -28,39 +30,22 @@ namespace mediapipe {
// output_stream: "multiplied_samples"
// input_side_packet: "multiplication_matrix"
// }
class MatrixMultiplyCalculator : public CalculatorBase {
class MatrixMultiplyCalculator : public Node {
public:
MatrixMultiplyCalculator() {}
~MatrixMultiplyCalculator() override {}
static constexpr Input<Matrix> kIn{""};
static constexpr Output<Matrix> kOut{""};
static constexpr SideInput<Matrix> kSide{""};
static mediapipe::Status GetContract(CalculatorContract* cc);
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(MatrixMultiplyCalculator);
// static
mediapipe::Status MatrixMultiplyCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>();
cc->Outputs().Index(0).Set<Matrix>();
cc->InputSidePackets().Index(0).Set<Matrix>();
return mediapipe::OkStatus();
}
mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) {
// The output is at the same timestamp as the input.
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator);
mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
Matrix* multiplied = new Matrix();
*multiplied = cc->InputSidePackets().Index(0).Get<Matrix>() *
cc->Inputs().Index(0).Get<Matrix>();
cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp());
kOut(cc).Send(*kSide(cc) * *kIn(cc));
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -13,11 +13,13 @@
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Subtract input matrix from the side input matrix and vice versa. The matrices
// must have the same dimension.
@ -41,83 +43,40 @@ namespace mediapipe {
// input_side_packet: "MINUEND:side_matrix"
// output_stream: "output_matrix"
// }
class MatrixSubtractCalculator : public CalculatorBase {
class MatrixSubtractCalculator : public Node {
public:
MatrixSubtractCalculator() {}
~MatrixSubtractCalculator() override {}
static constexpr Input<Matrix>::SideFallback kMinuend{"MINUEND"};
static constexpr Input<Matrix>::SideFallback kSubtrahend{"SUBTRAHEND"};
static constexpr Output<Matrix> kOut{""};
static mediapipe::Status GetContract(CalculatorContract* cc);
MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut);
static mediapipe::Status UpdateContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
private:
bool subtract_from_input_ = false;
};
REGISTER_CALCULATOR(MatrixSubtractCalculator);
MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator);
// static
mediapipe::Status MatrixSubtractCalculator::GetContract(
mediapipe::Status MatrixSubtractCalculator::UpdateContract(
CalculatorContract* cc) {
if (cc->Inputs().NumEntries() != 1 ||
cc->InputSidePackets().NumEntries() != 1) {
return mediapipe::InvalidArgumentError(
"MatrixSubtractCalculator only accepts exactly one input stream and "
"one "
"input side packet");
}
if (cc->Inputs().HasTag("MINUEND") &&
cc->InputSidePackets().HasTag("SUBTRAHEND")) {
cc->Inputs().Tag("MINUEND").Set<Matrix>();
cc->InputSidePackets().Tag("SUBTRAHEND").Set<Matrix>();
} else if (cc->Inputs().HasTag("SUBTRAHEND") &&
cc->InputSidePackets().HasTag("MINUEND")) {
cc->Inputs().Tag("SUBTRAHEND").Set<Matrix>();
cc->InputSidePackets().Tag("MINUEND").Set<Matrix>();
} else {
return mediapipe::InvalidArgumentError(
"Must specify exactly one minuend and one subtrahend.");
}
cc->Outputs().Index(0).Set<Matrix>();
return mediapipe::OkStatus();
}
mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) {
// The output is at the same timestamp as the input.
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("MINUEND")) {
subtract_from_input_ = true;
}
// TODO: the next restriction could be relaxed.
RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream())
<< "MatrixSubtractCalculator only accepts exactly one input stream and "
"one input side packet";
return mediapipe::OkStatus();
}
mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
Matrix* subtracted = new Matrix();
if (subtract_from_input_) {
const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get<Matrix>();
const Matrix& side_input_matrix =
cc->InputSidePackets().Tag("SUBTRAHEND").Get<Matrix>();
if (input_matrix.rows() != side_input_matrix.rows() ||
input_matrix.cols() != side_input_matrix.cols()) {
const Matrix& minuend = *kMinuend(cc);
const Matrix& subtrahend = *kSubtrahend(cc);
if (minuend.rows() != subtrahend.rows() ||
minuend.cols() != subtrahend.cols()) {
return mediapipe::InvalidArgumentError(
"Input matrix and the input side matrix must have the same "
"dimension.");
"Minuend and subtrahend must have the same dimensions.");
}
*subtracted = input_matrix - side_input_matrix;
} else {
const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get<Matrix>();
const Matrix& side_input_matrix =
cc->InputSidePackets().Tag("MINUEND").Get<Matrix>();
if (input_matrix.rows() != side_input_matrix.rows() ||
input_matrix.cols() != side_input_matrix.cols()) {
return mediapipe::InvalidArgumentError(
"Input matrix and the input side matrix must have the same "
"dimension.");
}
*subtracted = side_input_matrix - input_matrix;
}
cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp());
kOut(cc).Send(minuend - subtrahend);
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -89,9 +89,8 @@ TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
)");
CalculatorRunner runner(node_config);
auto status = runner.Run();
EXPECT_THAT(
status.message(),
testing::HasSubstr("specify exactly one minuend and one subtrahend."));
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
}
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {

View File

@ -21,6 +21,7 @@
#include "Eigen/Core"
#include "absl/memory/memory.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/integral_types.h"
@ -30,6 +31,7 @@
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace api2 {
// A calculator that converts a Matrix M to a vector containing all the
// entries of M in column-major order.
@ -40,33 +42,20 @@ namespace mediapipe {
// input_stream: "input_matrix"
// output_stream: "column_major_vector"
// }
class MatrixToVectorCalculator : public CalculatorBase {
class MatrixToVectorCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input Packet containing a Matrix.
);
cc->Outputs().Index(0).Set<std::vector<float>>(
// Output Packet containing a vector, one for each input Packet.
);
return mediapipe::OkStatus();
}
static constexpr Input<Matrix> kIn{""};
static constexpr Output<std::vector<float>> kOut{""};
mediapipe::Status Open(CalculatorContext* cc) override;
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
// Outputs a packet containing a vector for each input packet.
mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(MatrixToVectorCalculator);
mediapipe::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) {
// Inform the framework that we don't alter timestamps.
cc->SetOffset(mediapipe::TimestampDiff(0));
return mediapipe::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator);
mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
const Matrix& input = *kIn(cc);
auto output = absl::make_unique<std::vector<float>>();
// The following lines work to convert the Matrix to a vector because Matrix
@ -76,8 +65,9 @@ mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
Eigen::Map<Matrix>(output->data(), input.rows(), input.cols());
output_as_matrix = input;
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
kOut(cc).Send(std::move(output));
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// This calculator takes a set of input streams and combines them into a single
// output stream. The packets from different streams do not need to contain the
@ -41,40 +43,31 @@ namespace mediapipe {
// output_stream: "merged_shot_infos"
// }
//
class MergeCalculator : public CalculatorBase {
class MergeCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK_GT(cc->Inputs().NumEntries(), 0)
<< "Needs at least one input stream";
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
if (cc->Inputs().NumEntries() == 1) {
static constexpr Input<AnyType>::Multiple kIn{""};
static constexpr Output<AnyType> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream";
if (kIn(cc).Count() == 1) {
LOG(WARNING)
<< "MergeCalculator expects multiple input streams to merge but is "
"receiving only one. Make sure the calculator is configured "
"correctly or consider removing this calculator to reduce "
"unnecessary overhead.";
}
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
}
cc->Outputs().Index(0).SetAny();
return mediapipe::OkStatus();
}
mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
mediapipe::Status Process(CalculatorContext* cc) final {
// Output the packet from the first input stream with a packet ready at this
// timestamp.
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (!cc->Inputs().Index(i).IsEmpty()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value());
for (int i = 0; i < kIn(cc).Count(); ++i) {
if (!kIn(cc)[i].IsEmpty()) {
kOut(cc).Send(kIn(cc)[i].packet());
return mediapipe::OkStatus();
}
}
@ -86,6 +79,7 @@ class MergeCalculator : public CalculatorBase {
}
};
REGISTER_CALCULATOR(MergeCalculator);
MEDIAPIPE_REGISTER_NODE(MergeCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -12,97 +12,45 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace {
constexpr char kSelectTag[] = "SELECT";
constexpr char kInputTag[] = "INPUT";
} // namespace
namespace api2 {
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the kSelectTag input
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
// stream, and passes the packet on the selected input stream to the "OUTPUT"
// output stream.
// The kSelectTag input can also be passed in as an input side packet, instead
// of as an input stream. Either of input stream or input side packet must be
// specified but not both.
//
// Note that this calculator defaults to use MuxInputStreamHandler, which is
// required for this calculator. However, it can be overridden to work with
// other InputStreamHandlers. Check out the unit tests on for an example usage
// with DefaultInputStreamHandler.
class MuxCalculator : public CalculatorBase {
// TODO: why would you need to use DefaultISH? Perhaps b/167596925?
class MuxCalculator : public Node {
public:
static mediapipe::Status CheckAndInitAllowDisallowInputs(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^
cc->InputSidePackets().HasTag(kSelectTag));
if (cc->Inputs().HasTag(kSelectTag)) {
cc->Inputs().Tag(kSelectTag).Set<int>();
} else {
cc->InputSidePackets().Tag(kSelectTag).Set<int>();
}
return mediapipe::OkStatus();
}
static constexpr Input<int>::SideFallback kSelect{"SELECT"};
// TODO: this currently sets them all to Any independently, instead
// of the first being Any and the others being SameAs.
static constexpr Input<AnyType>::Multiple kIn{"INPUT"};
static constexpr Output<SameType<kIn>> kOut{"OUTPUT"};
static mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag);
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny();
++data_input_id;
for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
cc->SetInputStreamHandler("MuxInputStreamHandler");
MediaPipeOptions options;
cc->SetInputStreamHandlerOptions(options);
return mediapipe::OkStatus();
}
mediapipe::Status Open(CalculatorContext* cc) final {
use_side_packet_select_ = false;
if (cc->InputSidePackets().HasTag(kSelectTag)) {
use_side_packet_select_ = true;
selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get<int>();
} else {
select_input_ = cc->Inputs().GetId(kSelectTag, 0);
}
data_input_base_ = cc->Inputs().GetId(kInputTag, 0);
num_data_inputs_ = cc->Inputs().NumEntries(kInputTag);
output_ = cc->Outputs().GetId("OUTPUT", 0);
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut,
StreamHandler("MuxInputStreamHandler"));
mediapipe::Status Process(CalculatorContext* cc) final {
int select = use_side_packet_select_
? selected_index_
: cc->Inputs().Get(select_input_).Get<int>();
RET_CHECK(0 <= select && select < num_data_inputs_);
if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) {
cc->Outputs().Get(output_).AddPacket(
cc->Inputs().Get(data_input_base_ + select).Value());
int select = *kSelect(cc);
RET_CHECK(0 <= select && select < kIn(cc).Count());
if (!kIn(cc)[select].IsEmpty()) {
kOut(cc).Send(kIn(cc)[select].packet());
}
return mediapipe::OkStatus();
}
private:
CollectionItemId select_input_;
CollectionItemId data_input_base_;
int num_data_inputs_ = 0;
CollectionItemId output_;
bool use_side_packet_select_;
int selected_index_;
};
REGISTER_CALCULATOR(MuxCalculator);
MEDIAPIPE_REGISTER_NODE(MuxCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -14,12 +14,14 @@
#include <deque>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace api2 {
// PreviousLoopbackCalculator is useful when a graph needs to process an input
// together with some previous output.
@ -51,15 +53,19 @@ namespace mediapipe {
// input_stream: "PREV_TRACK:prev_output"
// output_stream: "TRACK:output"
// }
class PreviousLoopbackCalculator : public CalculatorBase {
class PreviousLoopbackCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Get("MAIN", 0).SetAny();
cc->Inputs().Get("LOOP", 0).SetAny();
cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0)));
static constexpr Input<AnyType> kMain{"MAIN"};
static constexpr Input<AnyType> kLoop{"LOOP"};
static constexpr Output<SameType<kLoop>> kPrevLoop{"PREV_LOOP"};
// TODO: an optional PREV_TIMESTAMP output could be added to
// carry the original timestamp of the packet on PREV_LOOP.
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
StreamHandler("ImmediateInputStreamHandler"),
TimestampChange::Arbitrary());
static mediapipe::Status UpdateContract(CalculatorContract* cc) {
// Process() function is invoked in response to MAIN/LOOP stream timestamp
// bound updates.
cc->SetProcessTimestampBounds(true);
@ -67,12 +73,7 @@ class PreviousLoopbackCalculator : public CalculatorBase {
}
mediapipe::Status Open(CalculatorContext* cc) final {
main_id_ = cc->Inputs().GetId("MAIN", 0);
loop_id_ = cc->Inputs().GetId("LOOP", 0);
prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
cc->Outputs()
.Get(prev_loop_id_)
.SetHeader(cc->Inputs().Get(loop_id_).Header());
kPrevLoop(cc).SetHeader(kLoop(cc).Header());
return mediapipe::OkStatus();
}
@ -82,48 +83,47 @@ class PreviousLoopbackCalculator : public CalculatorBase {
// packets within the same stream. Calculator tracks and operates on such
// packets.
const Packet& main_packet = cc->Inputs().Get(main_id_).Value();
if (prev_main_ts_ < main_packet.Timestamp()) {
const PacketBase& main_packet = kMain(cc).packet();
if (prev_main_ts_ < main_packet.timestamp()) {
Timestamp loop_timestamp;
if (!main_packet.IsEmpty()) {
loop_timestamp = prev_non_empty_main_ts_;
prev_non_empty_main_ts_ = main_packet.Timestamp();
prev_non_empty_main_ts_ = main_packet.timestamp();
} else {
// Calculator advances PREV_LOOP timestamp bound in response to empty
// MAIN packet, hence not caring about corresponding loop packet.
loop_timestamp = Timestamp::Unset();
}
main_packet_specs_.push_back({main_packet.Timestamp(), loop_timestamp});
prev_main_ts_ = main_packet.Timestamp();
main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp});
prev_main_ts_ = main_packet.timestamp();
}
const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value();
if (prev_loop_ts_ < loop_packet.Timestamp()) {
const PacketBase& loop_packet = kLoop(cc).packet();
if (prev_loop_ts_ < loop_packet.timestamp()) {
loop_packets_.push_back(loop_packet);
prev_loop_ts_ = loop_packet.Timestamp();
prev_loop_ts_ = loop_packet.timestamp();
}
auto& prev_loop = cc->Outputs().Get(prev_loop_id_);
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
// The earliest MAIN packet.
const MainPacketSpec& main_spec = main_packet_specs_.front();
// The earliest LOOP packet.
const Packet& loop_candidate = loop_packets_.front();
const PacketBase& loop_candidate = loop_packets_.front();
// Match LOOP and MAIN packets.
if (main_spec.loop_timestamp < loop_candidate.Timestamp()) {
if (main_spec.loop_timestamp < loop_candidate.timestamp()) {
// No LOOP packet can match the MAIN packet under review.
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
main_packet_specs_.pop_front();
} else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) {
} else if (main_spec.loop_timestamp > loop_candidate.timestamp()) {
// No MAIN packet can match the LOOP packet under review.
loop_packets_.pop_front();
} else {
// Exact match found.
if (loop_candidate.IsEmpty()) {
// However, LOOP packet is empty.
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
} else {
prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp));
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
}
loop_packets_.pop_front();
main_packet_specs_.pop_front();
@ -135,7 +135,7 @@ class PreviousLoopbackCalculator : public CalculatorBase {
// b) Empty MAIN packet has been received with Timestamp::Max() indicating
// MAIN is done.
if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) {
prev_loop.Close();
kPrevLoop(cc).Close();
}
}
@ -150,10 +150,6 @@ class PreviousLoopbackCalculator : public CalculatorBase {
Timestamp loop_timestamp;
};
CollectionItemId main_id_;
CollectionItemId loop_id_;
CollectionItemId prev_loop_id_;
// Contains specs for MAIN packets which only can be:
// - non-empty packets
// - empty packets indicating timestamp bound updates
@ -169,12 +165,13 @@ class PreviousLoopbackCalculator : public CalculatorBase {
// - empty packets indicating timestamp bound updates
//
// Sorted according to packet timestamps.
std::deque<Packet> loop_packets_;
std::deque<PacketBase> loop_packets_;
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
// allow addition of the very first empty packet (which doesn't indicate
// timestamp bound change necessarily).
Timestamp prev_loop_ts_ = Timestamp::Unset();
};
REGISTER_CALCULATOR(PreviousLoopbackCalculator);
MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -15,9 +15,11 @@
#include <deque>
#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
namespace api2 {
// A Calculator that shifts the timestamps of packets along a stream. Packets on
// the input stream are output with a timestamp of the packet given by packet
@ -28,24 +30,19 @@ namespace mediapipe {
// of -1, the first packet on the stream will be dropped, the second will be
// output with the timestamp of the first, the third with the timestamp of the
// second, and so on.
class SequenceShiftCalculator : public CalculatorBase {
class SequenceShiftCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) {
cc->InputSidePackets().Tag(kPacketOffsetTag).Set<int>();
}
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return mediapipe::OkStatus();
}
static constexpr Input<AnyType> kIn{""};
static constexpr SideInput<int>::Optional kOffset{"PACKET_OFFSET"};
static constexpr Output<SameType<kIn>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOffset, kOut, TimestampChange::Arbitrary());
// Reads from options to set cache_size_ and packet_offset_.
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
private:
static constexpr const char* kPacketOffsetTag = "PACKET_OFFSET";
// A positive offset means we want a packet to be output with the timestamp of
// a later packet. Stores packets waiting for their output timestamps and
// outputs a single packet when the cache fills.
@ -58,7 +55,7 @@ class SequenceShiftCalculator : public CalculatorBase {
// Storage for packets waiting to be output when packet_offset > 0. When cache
// is full, oldest packet is output with current timestamp.
std::deque<Packet> packet_cache_;
std::deque<PacketBase> packet_cache_;
// Storage for previous timestamps used when packet_offset < 0. When cache is
// full, oldest timestamp is used for current packet.
@ -70,14 +67,11 @@ class SequenceShiftCalculator : public CalculatorBase {
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
int cache_size_;
};
REGISTER_CALCULATOR(SequenceShiftCalculator);
MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator);
mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
packet_offset_ =
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset();
if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) {
packet_offset_ = cc->InputSidePackets().Tag(kPacketOffsetTag).Get<int>();
}
packet_offset_ = kOffset(cc).GetOr(
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset());
cache_size_ = abs(packet_offset_);
// An offset of zero is a no-op, but someone might still request it.
if (packet_offset_ == 0) {
@ -92,7 +86,7 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
} else if (packet_offset_ < 0) {
ProcessNegativeOffset(cc);
} else {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
kOut(cc).Send(kIn(cc).packet());
}
return mediapipe::OkStatus();
}
@ -100,23 +94,22 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
if (packet_cache_.size() >= cache_size_) {
// Ready to output oldest packet with current timestamp.
cc->Outputs().Index(0).AddPacket(
packet_cache_.front().At(cc->InputTimestamp()));
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
packet_cache_.pop_front();
}
// Store current packet for later output.
packet_cache_.push_back(cc->Inputs().Index(0).Value());
packet_cache_.push_back(kIn(cc).packet());
}
void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) {
if (timestamp_cache_.size() >= cache_size_) {
// Ready to output current packet with oldest timestamp.
cc->Outputs().Index(0).AddPacket(
cc->Inputs().Index(0).Value().At(timestamp_cache_.front()));
kOut(cc).Send(kIn(cc).packet().At(timestamp_cache_.front()));
timestamp_cache_.pop_front();
}
// Store current timestamp for use by a future packet.
timestamp_cache_.push_back(cc->InputTimestamp());
}
} // namespace api2
} // namespace mediapipe

View File

@ -61,7 +61,7 @@ class StringToIntCalculatorTemplate : public CalculatorBase {
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;

View File

@ -59,6 +59,7 @@ cc_library(
deps = [
":inference_calculator_cc_proto",
"@com_google_absl//absl/memory",
"//mediapipe/framework/api2:node",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/util/tflite:tflite_model_loader",
@ -234,6 +235,7 @@ cc_library(
"//mediapipe/framework/formats:detection_cc_proto",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//mediapipe/framework/api2:node",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:port",
"//mediapipe/framework/deps:file_path",
@ -286,6 +288,7 @@ cc_library(
deps = [
":tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
@ -317,6 +320,7 @@ cc_library(
deps = [
":tensors_to_floats_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
],
@ -355,6 +359,7 @@ cc_library(
":tensors_to_classification_calculator_cc_proto",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:location",
@ -421,6 +426,7 @@ cc_library(
":image_to_tensor_converter",
":image_to_tensor_converter_opencv",
":image_to_tensor_utils",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",

View File

@ -20,6 +20,7 @@
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -45,16 +46,15 @@
#endif // !MEDIAPIPE_DISABLE_GPU
namespace {
constexpr char kInputCpu[] = "IMAGE";
constexpr char kInputGpu[] = "IMAGE_GPU";
constexpr char kOutputMatrix[] = "MATRIX";
constexpr char kOutput[] = "TENSORS";
constexpr char kInputNormRect[] = "NORM_RECT";
constexpr char kOutputLetterboxPadding[] = "LETTERBOX_PADDING";
} // namespace
namespace mediapipe {
namespace api2 {
#if MEDIAPIPE_DISABLE_GPU
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
using GpuBuffer = AnyType;
#else
using GpuBuffer = mediapipe::GpuBuffer;
#endif // MEDIAPIPE_DISABLE_GPU
// Converts image into Tensor, possibly with cropping, resizing and
// normalization, according to specified inputs and options.
@ -110,9 +110,21 @@ namespace mediapipe {
// }
// }
// }
class ImageToTensorCalculator : public CalculatorBase {
class ImageToTensorCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc) {
static constexpr Input<mediapipe::ImageFrame>::Optional kInCpu{"IMAGE"};
static constexpr Input<GpuBuffer>::Optional kInGpu{"IMAGE_GPU"};
static constexpr Input<mediapipe::NormalizedRect>::Optional kInNormRect{
"NORM_RECT"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr Output<std::array<float, 4>>::Optional kOutLetterboxPadding{
"LETTERBOX_PADDING"};
static constexpr Output<std::array<float, 16>>::Optional kOutMatrix{"MATRIX"};
MEDIAPIPE_NODE_CONTRACT(kInCpu, kInGpu, kInNormRect, kOutTensors,
kOutLetterboxPadding, kOutMatrix);
static ::mediapipe::Status UpdateContract(CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
@ -126,24 +138,10 @@ class ImageToTensorCalculator : public CalculatorBase {
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
if (cc->Inputs().HasTag(kInputNormRect)) {
cc->Inputs().Tag(kInputNormRect).Set<mediapipe::NormalizedRect>();
}
if (cc->Outputs().HasTag(kOutputLetterboxPadding)) {
cc->Outputs().Tag(kOutputLetterboxPadding).Set<std::array<float, 4>>();
}
if (cc->Outputs().HasTag(kOutputMatrix)) {
cc->Outputs().Tag(kOutputMatrix).Set<std::array<float, 16>>();
}
RET_CHECK(kInCpu(cc).IsConnected() ^ kInGpu(cc).IsConnected())
<< "One and only one of CPU or GPU input is expected.";
const bool has_cpu_input = cc->Inputs().HasTag(kInputCpu);
const bool has_gpu_input = cc->Inputs().HasTag(kInputGpu);
RET_CHECK_EQ((has_cpu_input ? 1 : 0) + (has_gpu_input ? 1 : 0), 1)
<< "Either CPU or GPU input is expected, not both.";
if (has_cpu_input) {
cc->Inputs().Tag(kInputCpu).Set<mediapipe::ImageFrame>();
} else if (has_gpu_input) {
if (kInGpu(cc).IsConnected()) {
#if MEDIAPIPE_DISABLE_GPU
return mediapipe::UnimplementedError("GPU processing is disabled");
#else
@ -153,25 +151,20 @@ class ImageToTensorCalculator : public CalculatorBase {
#else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // MEDIAPIPE_METAL_ENABLED
cc->Inputs().Tag(kInputGpu).Set<mediapipe::GpuBuffer>();
#endif // MEDIAPIPE_DISABLE_GPU
}
cc->Outputs().Tag(kOutput).Set<std::vector<Tensor>>();
return mediapipe::OkStatus();
}
mediapipe::Status Open(CalculatorContext* cc) {
// Makes sure outputs' next timestamp bound update is handled automatically
// by the framework.
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width();
output_height_ = options_.output_tensor_height();
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
if (cc->Inputs().HasTag(kInputCpu)) {
if (kInCpu(cc).IsConnected()) {
ASSIGN_OR_RETURN(converter_, CreateOpenCvConverter(cc, GetBorderMode()));
} else {
#if MEDIAPIPE_DISABLE_GPU
@ -196,21 +189,20 @@ class ImageToTensorCalculator : public CalculatorBase {
}
mediapipe::Status Process(CalculatorContext* cc) {
const InputStreamShard& input = cc->Inputs().Tag(
cc->Inputs().HasTag(kInputCpu) ? kInputCpu : kInputGpu);
if (input.IsEmpty()) {
const PacketBase& image_packet =
kInCpu(cc).IsConnected() ? kInCpu(cc).packet() : kInGpu(cc).packet();
if (image_packet.IsEmpty()) {
// Timestamp bound update happens automatically. (See Open().)
return mediapipe::OkStatus();
}
absl::optional<mediapipe::NormalizedRect> norm_rect;
if (cc->Inputs().HasTag(kInputNormRect)) {
if (cc->Inputs().Tag(kInputNormRect).IsEmpty()) {
if (kInNormRect(cc).IsConnected()) {
if (kInNormRect(cc).IsEmpty()) {
// Timestamp bound update happens automatically. (See Open().)
return mediapipe::OkStatus();
}
norm_rect =
cc->Inputs().Tag(kInputNormRect).Get<mediapipe::NormalizedRect>();
norm_rect = *kInNormRect(cc);
if (norm_rect->width() == 0 && norm_rect->height() == 0) {
// WORKAROUND: some existing graphs may use sentinel rects {width=0,
// height=0, ...} quite often and calculator has to handle them
@ -223,27 +215,20 @@ class ImageToTensorCalculator : public CalculatorBase {
}
}
const Packet& image_packet = input.Value();
const Size& size = converter_->GetImageSize(image_packet);
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
options_.output_tensor_height(),
options_.keep_aspect_ratio(), &roi));
if (cc->Outputs().HasTag(kOutputLetterboxPadding)) {
cc->Outputs()
.Tag(kOutputLetterboxPadding)
.AddPacket(MakePacket<std::array<float, 4>>(padding).At(
cc->InputTimestamp()));
if (kOutLetterboxPadding(cc).IsConnected()) {
kOutLetterboxPadding(cc).Send(padding);
}
if (cc->Outputs().HasTag(kOutputMatrix)) {
if (kOutMatrix(cc).IsConnected()) {
std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
/*flip_horizontaly=*/false,
&matrix);
cc->Outputs()
.Tag(kOutputMatrix)
.AddPacket(MakePacket<std::array<float, 16>>(std::move(matrix))
.At(cc->InputTimestamp()));
kOutMatrix(cc).Send(std::move(matrix));
}
ASSIGN_OR_RETURN(
@ -251,11 +236,9 @@ class ImageToTensorCalculator : public CalculatorBase {
converter_->Convert(image_packet, roi, {output_width_, output_height_},
range_min_, range_max_));
std::vector<Tensor> result;
result.push_back(std::move(tensor));
cc->Outputs().Tag(kOutput).AddPacket(
MakePacket<std::vector<Tensor>>(std::move(result))
.At(cc->InputTimestamp()));
auto result = std::make_unique<std::vector<Tensor>>();
result->push_back(std::move(tensor));
kOutTensors(cc).Send(std::move(result));
return mediapipe::OkStatus();
}
@ -286,6 +269,7 @@ class ImageToTensorCalculator : public CalculatorBase {
float range_max_ = 1.0f;
};
REGISTER_CALCULATOR(ImageToTensorCalculator);
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
@ -88,7 +89,6 @@ bool ShouldUseGpu(const mediapipe::InferenceCalculatorOptions& options) {
}
constexpr char kTensorsTag[] = "TENSORS";
} // namespace
#if defined(MEDIAPIPE_EDGE_TPU)
#include "edgetpu.h"
@ -112,7 +112,10 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
}
#endif // MEDIAPIPE_EDGE_TPU
} // namespace
namespace mediapipe {
namespace api2 {
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace {
@ -224,12 +227,19 @@ int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) {
// Tensors are assumed to be ordered correctly (sequentially added to model).
// Input tensors are assumed to be of the correct size and already normalized.
class InferenceCalculator : public CalculatorBase {
class InferenceCalculator : public Node {
public:
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
static mediapipe::Status GetContract(CalculatorContract* cc);
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
kOutTensors);
static mediapipe::Status UpdateContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
@ -239,11 +249,12 @@ class InferenceCalculator : public CalculatorBase {
mediapipe::Status ReadKernelsFromFile();
mediapipe::Status WriteKernelsToFile();
mediapipe::Status LoadModel(CalculatorContext* cc);
mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
mediapipe::StatusOr<mediapipe::Packet> GetModelAsPacket(
const CalculatorContext& cc);
mediapipe::Status LoadDelegate(CalculatorContext* cc);
mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
Packet model_packet_;
mediapipe::Packet model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_;
@ -277,28 +288,13 @@ class InferenceCalculator : public CalculatorBase {
std::string cached_kernel_filename_;
};
REGISTER_CALCULATOR(InferenceCalculator);
mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
cc->Inputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
RET_CHECK(cc->Outputs().HasTag(kTensorsTag));
cc->Outputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
MEDIAPIPE_REGISTER_NODE(InferenceCalculator);
mediapipe::Status InferenceCalculator::UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
RET_CHECK(!options.model_path().empty() ^
cc->InputSidePackets().HasTag("MODEL"))
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER")
.Set<tflite::ops::builtin::BuiltinOpResolver>();
}
if (cc->InputSidePackets().HasTag("MODEL")) {
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
}
if (ShouldUseGpu(options)) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
@ -310,8 +306,6 @@ mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) {
}
mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
#if MEDIAPIPE_TFLITE_GL_INFERENCE || MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
if (ShouldUseGpu(options)) {
@ -361,11 +355,10 @@ mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) {
}
mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
if (kInTensors(cc).IsEmpty()) {
return mediapipe::OkStatus();
}
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK(!input_tensors.empty());
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
@ -509,9 +502,7 @@ mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) {
output_tensors->back().bytes());
}
}
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp());
kOutTensors(cc).Send(std::move(output_tensors));
return mediapipe::OkStatus();
}
@ -575,12 +566,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner(
#if MEDIAPIPE_TFLITE_GL_INFERENCE
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER")
.Get<tflite::ops::builtin::BuiltinOpResolver>();
}
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver());
// Create runner
tflite::gpu::InferenceOptions options;
@ -629,12 +617,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner(
mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER")
.Get<tflite::ops::builtin::BuiltinOpResolver>();
}
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver());
#if defined(MEDIAPIPE_EDGE_TPU)
interpreter_ =
@ -659,7 +644,7 @@ mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) {
return mediapipe::OkStatus();
}
mediapipe::StatusOr<Packet> InferenceCalculator::GetModelAsPacket(
mediapipe::StatusOr<mediapipe::Packet> InferenceCalculator::GetModelAsPacket(
const CalculatorContext& cc) {
const auto& options = cc.Options<mediapipe::InferenceCalculatorOptions>();
if (!options.model_path().empty()) {
@ -845,4 +830,5 @@ mediapipe::Status InferenceCalculator::LoadDelegate(CalculatorContext* cc) {
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/tensor.h"
@ -32,6 +33,7 @@
#endif
namespace mediapipe {
namespace api2 {
// Convert result tensors from classification models into MediaPipe
// classifications.
@ -57,9 +59,12 @@ namespace mediapipe {
// }
// }
// }
class TensorsToClassificationCalculator : public CalculatorBase {
class TensorsToClassificationCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc);
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr Output<ClassificationList> kOutClassificationList{
"CLASSIFICATIONS"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutClassificationList);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
@ -71,28 +76,10 @@ class TensorsToClassificationCalculator : public CalculatorBase {
std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
};
REGISTER_CALCULATOR(TensorsToClassificationCalculator);
mediapipe::Status TensorsToClassificationCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
if (cc->Inputs().HasTag("TENSORS")) {
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
}
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
}
return mediapipe::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator);
mediapipe::Status TensorsToClassificationCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ =
cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>();
@ -118,9 +105,7 @@ mediapipe::Status TensorsToClassificationCalculator::Open(
mediapipe::Status TensorsToClassificationCalculator::Process(
CalculatorContext* cc) {
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
int num_classes = input_tensors[0].shape().num_elements();
@ -182,10 +167,7 @@ mediapipe::Status TensorsToClassificationCalculator::Process(
raw_classification_list->DeleteSubrange(
top_k_, raw_classification_list->size() - top_k_);
}
cc->Outputs()
.Tag("CLASSIFICATIONS")
.Add(classification_list.release(), cc->InputTimestamp());
kOutClassificationList(cc).Send(std::move(classification_list));
return mediapipe::OkStatus();
}
@ -194,4 +176,5 @@ mediapipe::Status TensorsToClassificationCalculator::Close(
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h"
@ -47,9 +48,6 @@
namespace {
constexpr int kNumInputTensorsWithAnchors = 3;
constexpr int kNumCoordsPerBox = 4;
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kAnchorsTag[] = "ANCHORS";
bool CanUseGpu() {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || MEDIAPIPE_METAL_ENABLED
@ -63,6 +61,7 @@ bool CanUseGpu() {
} // namespace
namespace mediapipe {
namespace api2 {
namespace {
@ -128,9 +127,14 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
// }
// }
// }
class TensorsToDetectionsCalculator : public CalculatorBase {
class TensorsToDetectionsCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc);
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr SideInput<std::vector<Anchor>>::Optional kInAnchors{
"ANCHORS"};
static constexpr Output<std::vector<Detection>> kOutDetections{"DETECTIONS"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections);
static mediapipe::Status UpdateContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
@ -161,7 +165,6 @@ class TensorsToDetectionsCalculator : public CalculatorBase {
::mediapipe::TensorsToDetectionsCalculatorOptions options_;
std::vector<Anchor> anchors_;
bool side_packet_anchors_{};
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
mediapipe::GlCalculatorHelper gpu_helper_;
@ -179,22 +182,10 @@ class TensorsToDetectionsCalculator : public CalculatorBase {
bool gpu_input_ = false;
bool anchors_init_ = false;
};
REGISTER_CALCULATOR(TensorsToDetectionsCalculator);
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
mediapipe::Status TensorsToDetectionsCalculator::GetContract(
mediapipe::Status TensorsToDetectionsCalculator::UpdateContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
cc->Inputs().Tag(kTensorsTag).Set<std::vector<Tensor>>();
RET_CHECK(cc->Outputs().HasTag(kDetectionsTag));
cc->Outputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
if (cc->InputSidePackets().UsesTags()) {
if (cc->InputSidePackets().HasTag(kAnchorsTag)) {
cc->InputSidePackets().Tag(kAnchorsTag).Set<std::vector<Anchor>>();
}
}
if (CanUseGpu()) {
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
@ -207,8 +198,6 @@ mediapipe::Status TensorsToDetectionsCalculator::GetContract(
}
mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
side_packet_anchors_ = cc->InputSidePackets().HasTag(kAnchorsTag);
MP_RETURN_IF_ERROR(LoadOptions(cc));
if (CanUseGpu()) {
@ -226,18 +215,12 @@ mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
mediapipe::Status TensorsToDetectionsCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
return mediapipe::OkStatus();
}
auto output_detections = absl::make_unique<std::vector<Detection>>();
bool gpu_processing = false;
if (CanUseGpu()) {
// Use GPU processing only if at least one input tensor is already on GPU
// (to avoid CPU->GPU overhead).
for (const auto& tensor :
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>()) {
for (const auto& tensor : *kInTensors(cc)) {
if (tensor.ready_on_gpu()) {
gpu_processing = true;
break;
@ -251,18 +234,13 @@ mediapipe::Status TensorsToDetectionsCalculator::Process(
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
}
// Output
cc->Outputs()
.Tag(kDetectionsTag)
.Add(output_detections.release(), cc->InputTimestamp());
kOutDetections(cc).Send(std::move(output_detections));
return mediapipe::OkStatus();
}
mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
if (input_tensors.size() == 2 ||
input_tensors.size() == kNumInputTensorsWithAnchors) {
@ -294,10 +272,8 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
auto anchor_view = anchor_tensor->GetCpuReadView();
auto raw_anchors = anchor_view.buffer<float>();
ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
} else if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
anchors_ =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
} else if (!kInAnchors(cc).IsEmpty()) {
anchors_ = *kInAnchors(cc);
} else {
return mediapipe::UnavailableError("No anchor data available.");
}
@ -391,8 +367,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU(
mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK_GE(input_tensors.size(), 2);
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
@ -400,21 +375,20 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
&output_detections]()
-> mediapipe::Status {
if (!anchors_init_) {
if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty());
const auto& anchors =
cc->InputSidePackets().Tag(kAnchorsTag).Get<std::vector<Anchor>>();
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
auto raw_anchors = anchors_view.buffer<float>();
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors);
} else {
CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
auto read_view = input_tensors[2].GetOpenGlBufferReadView();
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView();
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
input_tensors[2].bytes());
} else if (!kInAnchors(cc).IsEmpty()) {
const auto& anchors = *kInAnchors(cc);
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
auto raw_anchors = anchors_view.buffer<float>();
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors);
} else {
return mediapipe::UnavailableError("No anchor data available.");
}
anchors_init_ = true;
}
@ -464,14 +438,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
#elif MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
if (!anchors_init_) {
if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty());
const auto& anchors =
cc->InputSidePackets().Tag(kAnchorsTag).Get<std::vector<Anchor>>();
auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView();
ConvertAnchorsToRawValues(anchors, num_boxes_,
raw_anchors_view.buffer<float>());
} else {
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
auto command_buffer = [gpu_helper_ commandBuffer];
auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer);
@ -486,6 +453,13 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU(
size:input_tensors[2].bytes()];
[blit_command endEncoding];
[command_buffer commit];
} else if (!kInAnchors(cc).IsEmpty()) {
const auto& anchors = *kInAnchors(cc);
auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView();
ConvertAnchorsToRawValues(anchors, num_boxes_,
raw_anchors_view.buffer<float>());
} else {
return mediapipe::UnavailableError("No anchor data available.");
}
anchors_init_ = true;
}
@ -1157,4 +1131,5 @@ kernel void scoreKernel(
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
@ -43,47 +44,39 @@ inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); }
// input_stream: "TENSORS:tensors"
// output_stream: "FLOATS:floats"
// }
class TensorsToFloatsCalculator : public CalculatorBase {
namespace api2 {
class TensorsToFloatsCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc);
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr Output<float>::Optional kOutFloat{"FLOAT"};
static constexpr Output<std::vector<float>>::Optional kOutFloats{"FLOATS"};
MEDIAPIPE_NODE_INTERFACE(TensorsToFloatsCalculator, kInTensors, kOutFloat,
kOutFloats);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
static mediapipe::Status UpdateContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) final;
mediapipe::Status Process(CalculatorContext* cc) final;
private:
::mediapipe::TensorsToFloatsCalculatorOptions options_;
};
REGISTER_CALCULATOR(TensorsToFloatsCalculator);
MEDIAPIPE_REGISTER_NODE(TensorsToFloatsCalculator);
mediapipe::Status TensorsToFloatsCalculator::GetContract(
mediapipe::Status TensorsToFloatsCalculator::UpdateContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
if (cc->Outputs().HasTag("FLOATS")) {
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
}
if (cc->Outputs().HasTag("FLOAT")) {
cc->Outputs().Tag("FLOAT").Set<float>();
}
// Only exactly a single output allowed.
RET_CHECK(kOutFloat(cc).IsConnected() ^ kOutFloats(cc).IsConnected());
return mediapipe::OkStatus();
}
mediapipe::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<::mediapipe::TensorsToFloatsCalculatorOptions>();
return mediapipe::OkStatus();
}
mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK(!input_tensors.empty());
// TODO: Add option to specify which tensor to take from.
auto view = input_tensors[0].GetCpuReadView();
auto raw_floats = view.buffer<float>();
@ -100,18 +93,15 @@ mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) {
break;
}
if (cc->Outputs().HasTag("FLOAT")) {
// TODO: Could add an index in the option to specifiy returning one
// value of a float array.
if (kOutFloat(cc).IsConnected()) {
// TODO: Could add an index in the option to specifiy returning
// one value of a float array.
RET_CHECK_EQ(num_values, 1);
cc->Outputs().Tag("FLOAT").AddPacket(
MakePacket<float>(output_floats->at(0)).At(cc->InputTimestamp()));
kOutFloat(cc).Send(output_floats->at(0));
} else {
kOutFloats(cc).Send(std::move(output_floats));
}
if (cc->Outputs().HasTag("FLOATS")) {
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
cc->InputTimestamp());
}
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -13,12 +13,14 @@
// limitations under the License.
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace api2 {
namespace {
@ -85,9 +87,18 @@ float ApplyActivation(
// }
// }
// }
class TensorsToLandmarksCalculator : public CalculatorBase {
class TensorsToLandmarksCalculator : public Node {
public:
static mediapipe::Status GetContract(CalculatorContract* cc);
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr Input<bool>::SideFallback::Optional kFlipHorizontally{
"FLIP_HORIZONTALLY"};
static constexpr Input<bool>::SideFallback::Optional kFlipVertically{
"FLIP_VERTICALLY"};
static constexpr Output<LandmarkList>::Optional kOutLandmarkList{"LANDMARKS"};
static constexpr Output<NormalizedLandmarkList>::Optional
kOutNormalizedLandmarkList{"NORM_LANDMARKS"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kFlipHorizontally, kFlipVertically,
kOutLandmarkList, kOutNormalizedLandmarkList);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
@ -95,100 +106,39 @@ class TensorsToLandmarksCalculator : public CalculatorBase {
private:
mediapipe::Status LoadOptions(CalculatorContext* cc);
int num_landmarks_ = 0;
bool flip_vertically_ = false;
bool flip_horizontally_ = false;
::mediapipe::TensorsToLandmarksCalculatorOptions options_;
};
REGISTER_CALCULATOR(TensorsToLandmarksCalculator);
mediapipe::Status TensorsToLandmarksCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
if (cc->Inputs().HasTag("TENSORS")) {
cc->Inputs().Tag("TENSORS").Set<std::vector<Tensor>>();
}
if (cc->Inputs().HasTag("FLIP_HORIZONTALLY")) {
cc->Inputs().Tag("FLIP_HORIZONTALLY").Set<bool>();
}
if (cc->Inputs().HasTag("FLIP_VERTICALLY")) {
cc->Inputs().Tag("FLIP_VERTICALLY").Set<bool>();
}
if (cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY")) {
cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Set<bool>();
}
if (cc->InputSidePackets().HasTag("FLIP_VERTICALLY")) {
cc->InputSidePackets().Tag("FLIP_VERTICALLY").Set<bool>();
}
if (cc->Outputs().HasTag("LANDMARKS")) {
cc->Outputs().Tag("LANDMARKS").Set<LandmarkList>();
}
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
cc->Outputs().Tag("NORM_LANDMARKS").Set<NormalizedLandmarkList>();
}
return mediapipe::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TensorsToLandmarksCalculator);
mediapipe::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
if (kOutNormalizedLandmarkList(cc).IsConnected()) {
RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width())
<< "Must provide input with/height for getting normalized landmarks.";
}
if (cc->Outputs().HasTag("LANDMARKS") &&
(options_.flip_vertically() || options_.flip_horizontally() ||
cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY") ||
cc->InputSidePackets().HasTag("FLIP_VERTICALLY"))) {
if (kOutLandmarkList(cc).IsConnected() &&
(options_.flip_horizontally() || options_.flip_vertically() ||
kFlipHorizontally(cc).IsConnected() ||
kFlipVertically(cc).IsConnected())) {
RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width())
<< "Must provide input with/height for using flip_vertically option "
"when outputing landmarks in absolute coordinates.";
<< "Must provide input with/height for using flipping when outputing "
"landmarks in absolute coordinates.";
}
flip_horizontally_ =
cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY")
? cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Get<bool>()
: options_.flip_horizontally();
flip_vertically_ =
cc->InputSidePackets().HasTag("FLIP_VERTICALLY")
? cc->InputSidePackets().Tag("FLIP_VERTICALLY").Get<bool>()
: options_.flip_vertically();
return mediapipe::OkStatus();
}
mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
// Override values if specified so.
if (cc->Inputs().HasTag("FLIP_HORIZONTALLY") &&
!cc->Inputs().Tag("FLIP_HORIZONTALLY").IsEmpty()) {
flip_horizontally_ = cc->Inputs().Tag("FLIP_HORIZONTALLY").Get<bool>();
}
if (cc->Inputs().HasTag("FLIP_VERTICALLY") &&
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
}
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
if (kInTensors(cc).IsEmpty()) {
return mediapipe::OkStatus();
}
bool flip_horizontally =
kFlipHorizontally(cc).GetOr(options_.flip_horizontally());
bool flip_vertically = kFlipVertically(cc).GetOr(options_.flip_vertically());
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<Tensor>>();
const auto& input_tensors = *kInTensors(cc);
int num_values = input_tensors[0].shape().num_elements();
const int num_dimensions = num_values / num_landmarks_;
CHECK_GT(num_dimensions, 0);
@ -202,13 +152,13 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
const int offset = ld * num_dimensions;
Landmark* landmark = output_landmarks.add_landmark();
if (flip_horizontally_) {
if (flip_horizontally) {
landmark->set_x(options_.input_image_width() - raw_landmarks[offset]);
} else {
landmark->set_x(raw_landmarks[offset]);
}
if (num_dimensions > 1) {
if (flip_vertically_) {
if (flip_vertically) {
landmark->set_y(options_.input_image_height() -
raw_landmarks[offset + 1]);
} else {
@ -229,7 +179,7 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
}
// Output normalized landmarks if required.
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
if (kOutNormalizedLandmarkList(cc).IsConnected()) {
NormalizedLandmarkList output_norm_landmarks;
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
const Landmark& landmark = output_landmarks.landmark(i);
@ -246,18 +196,12 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
norm_landmark->set_presence(landmark.presence());
}
}
cc->Outputs()
.Tag("NORM_LANDMARKS")
.AddPacket(MakePacket<NormalizedLandmarkList>(output_norm_landmarks)
.At(cc->InputTimestamp()));
kOutNormalizedLandmarkList(cc).Send(std::move(output_norm_landmarks));
}
// Output absolute landmarks.
if (cc->Outputs().HasTag("LANDMARKS")) {
cc->Outputs()
.Tag("LANDMARKS")
.AddPacket(MakePacket<LandmarkList>(output_landmarks)
.At(cc->InputTimestamp()));
if (kOutLandmarkList(cc).IsConnected()) {
kOutLandmarkList(cc).Send(std::move(output_landmarks));
}
return mediapipe::OkStatus();
@ -272,4 +216,5 @@ mediapipe::Status TensorsToLandmarksCalculator::LoadOptions(
return mediapipe::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -396,6 +396,7 @@ cc_library(
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
@ -841,6 +842,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",

View File

@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
@ -57,7 +58,7 @@ namespace mpms = mediapipe::mediasequence;
// bounding boxes from vector<Detections>, and streams with the
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints
// from unordered_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
// from flat_hash_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of
// each stream, which allows for multiple image streams to be included. However,
// the default names are suppored by more tools.
@ -131,7 +132,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
cc->Inputs()
.Tag(tag)
.Set<std::unordered_map<std::string,
.Set<absl::flat_hash_map<std::string,
std::vector<std::pair<float, float>>>>();
}
if (absl::StartsWith(tag, kBBoxTag)) {
@ -348,7 +349,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
const auto& keypoints =
cc->Inputs()
.Tag(tag)
.Get<std::unordered_map<
.Get<absl::flat_hash_map<
std::string, std::vector<std::pair<float, float>>>>();
for (const auto& pair : keypoints) {
std::string prefix = mpms::merge_prefix(key, pair.first);

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
@ -537,8 +538,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
std::unordered_map<std::string, std::vector<std::pair<float, float>>> points =
{{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
runner_->MutableInputs()
->Tag("KEYPOINTS_TEST")
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));

View File

@ -450,7 +450,9 @@ mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
}
#else
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
if (gpu_inference_) {
// Metal delegate supports external command encoder only if all input and
// output buffers are on GPU.
if (gpu_inference_ && gpu_input_ && gpu_output_) {
RET_CHECK(
TFLGpuDelegateSetCommandEncoder(delegate_.get(), compute_encoder));
}

View File

@ -934,11 +934,22 @@ cc_test(
],
)
mediapipe_proto_library(
name = "local_file_contents_calculator_proto",
srcs = ["local_file_contents_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "local_file_contents_calculator",
srcs = ["local_file_contents_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":local_file_contents_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",

View File

@ -15,6 +15,7 @@
#include <memory>
#include <string>
#include "mediapipe/calculators/util/local_file_contents_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/resource_util.h"
@ -78,6 +79,8 @@ class LocalFileContentsCalculator : public CalculatorBase {
mediapipe::Status Open(CalculatorContext* cc) override {
CollectionItemId input_id = cc->InputSidePackets().BeginId(kFilePathTag);
CollectionItemId output_id = cc->OutputSidePackets().BeginId(kContentsTag);
auto options = cc->Options<mediapipe::LocalFileContentsCalculatorOptions>();
// Number of inputs and outpus is the same according to the contract.
for (; input_id != cc->InputSidePackets().EndId(kFilePathTag);
++input_id, ++output_id) {
@ -86,7 +89,8 @@ class LocalFileContentsCalculator : public CalculatorBase {
ASSIGN_OR_RETURN(file_path, PathToResourceAsFile(file_path));
std::string contents;
MP_RETURN_IF_ERROR(GetResourceContents(file_path, &contents));
MP_RETURN_IF_ERROR(
GetResourceContents(file_path, &contents, options.read_as_binary()));
cc->OutputSidePackets().Get(output_id).Set(
MakePacket<std::string>(std::move(contents)));
}

View File

@ -0,0 +1,28 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message LocalFileContentsCalculatorOptions {
extend CalculatorOptions {
optional LocalFileContentsCalculatorOptions ext = 346849340;
}
// If true, set the file open mode to 'rb'. Otherwise, set the mode to 'r'.
optional bool read_as_binary = 1 [default = true];
}

View File

@ -56,8 +56,10 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
} catch (NameNotFoundException e) {
Log.e(TAG, "Cannot find application info: " + e);
}
// Get allowed object category.
String categoryName = applicationInfo.metaData.getString("categoryName");
// Get maximum allowed number of objects.
int maxNumObjects = applicationInfo.metaData.getInt("maxNumObjects");
float[] modelScale = parseFloatArrayFromString(
applicationInfo.metaData.getString("modelScale"));
float[] modelTransform = parseFloatArrayFromString(
@ -70,6 +72,7 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
inputSidePackets.put("obj_texture", packetCreator.createRgbaImageFrame(objTexture));
inputSidePackets.put("box_texture", packetCreator.createRgbaImageFrame(boxTexture));
inputSidePackets.put("allowed_labels", packetCreator.createString(categoryName));
inputSidePackets.put("max_num_objects", packetCreator.createInt32(maxNumObjects));
inputSidePackets.put("model_scale", packetCreator.createFloat32Array(modelScale));
inputSidePackets.put("model_transformation", packetCreator.createFloat32Array(modelTransform));
processor.setInputSidePackets(inputSidePackets);
@ -118,8 +121,8 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
} catch (RuntimeException e) {
Log.e(
TAG,
"MediaPipeException encountered adding packets to width and height"
+ " input streams.");
"MediaPipeException encountered adding packets to input_width and input_height"
+ " input streams.", e);
}
widthPacket.release();
heightPacket.release();

View File

@ -8,6 +8,7 @@
<application>
<meta-data android:name="categoryName" android:value="Camera"/>
<meta-data android:name="maxNumObjects" android:value="5"/>
<meta-data android:name="modelScale" android:value="250, 250, 250"/>
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,

View File

@ -8,6 +8,7 @@
<application>
<meta-data android:name="categoryName" android:value="Chair"/>
<meta-data android:name="maxNumObjects" android:value="5"/>
<meta-data android:name="modelScale" android:value="0.1, 0.05, 0.1"/>
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, -10.0,

View File

@ -8,6 +8,7 @@
<application>
<meta-data android:name="categoryName" android:value="Coffee cup,Mug"/>
<meta-data android:name="maxNumObjects" android:value="5"/>
<meta-data android:name="modelScale" android:value="500, 500, 500"/>
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 1.0, -0.001,

View File

@ -8,6 +8,7 @@
<application>
<meta-data android:name="categoryName" android:value="Footwear"/>
<meta-data android:name="maxNumObjects" android:value="5"/>
<meta-data android:name="modelScale" android:value="0.25, 0.25, 0.12"/>
<meta-data android:name="modelTransformation" android:value="1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,

View File

@ -52,8 +52,9 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position,
}
}
double delta_degs = (Median(raw_positions_at_time_) - current_position_px_) /
pixels_per_degree_;
int filtered_position = Median(raw_positions_at_time_);
double delta_degs =
(filtered_position - current_position_px_) / pixels_per_degree_;
// If the motion is smaller than the min_motion_to_reframe and camera is
// stationary, don't use the update.
@ -68,14 +69,14 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position,
} else if (delta_degs > 0) {
// Apply new position, less the reframe window size.
target_position_px_ =
position - pixels_per_degree_ * options_.reframe_window();
filtered_position - pixels_per_degree_ * options_.reframe_window();
delta_degs =
(target_position_px_ - current_position_px_) / pixels_per_degree_;
motion_state_ = true;
} else {
// Apply new position, plus the reframe window size.
target_position_px_ =
position + pixels_per_degree_ * options_.reframe_window();
filtered_position + pixels_per_degree_ * options_.reframe_window();
delta_degs =
(target_position_px_ - current_position_px_) / pixels_per_degree_;
motion_state_ = true;

View File

@ -351,7 +351,6 @@ cc_library(
":calculator_base",
":calculator_context",
":calculator_context_manager",
":calculator_registry_util",
":calculator_state",
":counter_factory",
":input_side_packet_handler",
@ -405,27 +404,6 @@ cc_library(
],
)
cc_library(
name = "calculator_registry_util",
srcs = ["calculator_registry_util.cc"],
hdrs = ["calculator_registry_util.h"],
visibility = [
":mediapipe_internal",
],
deps = [
":calculator_base",
":calculator_context",
":calculator_state",
":collection",
":collection_item_id",
":packet_set",
":timestamp",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/tool:tag_map",
],
)
cc_library(
name = "calculator_runner",
testonly = 1,
@ -1128,7 +1106,6 @@ cc_library(
deps = [
":calculator_base",
":calculator_contract",
":calculator_registry_util",
":legacy_calculator_support",
":packet",
":packet_generator",

View File

@ -0,0 +1,231 @@
package(
default_visibility = [":preview_users"],
features = ["-use_header_modules"],
)
# API2 is in preview mode. Internal clients are welcome and encouraged to try
# it out, but be aware that there may be more changes before release. Please
# add your package to this list and reach out to the MediaPipe team (use
# camillol@ as the CL reviewer).
package_group(
name = "preview_users",
packages = [
"//mediapipe/...",
"//video/content_analysis/...",
],
)
licenses(["notice"])
cc_library(
name = "const_str",
hdrs = ["const_str.h"],
)
cc_library(
name = "builder",
hdrs = ["builder.h"],
deps = [
":const_str",
":contract",
":node",
":packet",
":port",
"//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_test(
name = "builder_test",
srcs = ["builder_test.cc"],
deps = [
":builder",
":node",
":packet",
":port",
":tag",
":test_contracts",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "contract",
hdrs = ["contract.h"],
deps = [
":const_str",
":packet",
":port",
":tag",
":tuple",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:output_side_packet",
"//mediapipe/framework/port:logging",
],
)
cc_test(
name = "contract_test",
srcs = ["contract_test.cc"],
deps = [
":contract",
":port",
":tag",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library(
name = "node",
srcs = ["node.cc"],
hdrs = ["node.h"],
deps = [
":const_str",
":contract",
":packet",
":port",
"//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/deps:no_destructor",
],
)
cc_library(
name = "test_contracts",
testonly = 1,
hdrs = ["test_contracts.h"],
deps = [
":node",
],
)
cc_test(
name = "node_test",
srcs = ["node_test.cc"],
deps = [
":node",
":packet",
":port",
":test_contracts",
":tuple",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "packet",
srcs = ["packet.cc"],
hdrs = ["packet.h"],
deps = [
":tuple",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:logging",
],
)
cc_test(
name = "packet_test",
size = "small",
srcs = [
"packet_test.cc",
],
deps = [
":packet",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "port",
hdrs = ["port.h"],
deps = [
":const_str",
":packet",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:output_side_packet",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "port_test",
size = "small",
srcs = [
"port_test.cc",
],
deps = [
":port",
"//mediapipe/framework/port:gtest_main",
],
)
cc_test(
name = "subgraph_test",
srcs = ["subgraph_test.cc"],
deps = [
":builder",
":node",
":packet",
":port",
":test_contracts",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:subgraph_expansion",
],
)
cc_library(
name = "tag",
hdrs = ["tag.h"],
deps = [":const_str"],
)
cc_test(
name = "tag_test",
size = "small",
srcs = [
"tag_test.cc",
],
deps = [
":tag",
"//mediapipe/framework/port:gtest_main",
],
)
cc_library(
name = "tuple",
hdrs = ["tuple.h"],
deps = ["@com_google_absl//absl/meta:type_traits"],
)
cc_test(
name = "tuple_test",
size = "small",
srcs = [
"tuple_test.cc",
],
deps = [
":tuple",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -0,0 +1,111 @@
# Experimental new APIs
This directory defines new APIs for MediaPipe:
- Node API, an update to the Calculator API for defining MediaPipe components.
- Builder API, for assembling CalculatorGraphConfigs with C++, as an alternative
to using the proto API directly.
The code is working, and the new APIs interoperate fully with the existing
framework code. They are considered a work in progress, but are being released
now so we can begin adopting them in our calculators.
Developers are welcome to try out these APIs as early adopters, but should
expect breaking changes. The placement of this code under the `mediapipe::api2`
namespace is not final.
## Node API
This API can be used to define calculators. It is designed to be more type-safe
and less verbose than the original API.
Input/output ports (streams and side packets) can now be declared as typed
constants, instead of using plain strings for access.
For example, instead of
```
constexpr char kSelectTag[] = "SELECT";
if (cc->Inputs().HasTag(kSelectTag)) {
cc->Inputs().Tag(kSelectTag).Set<int>();
}
```
you can write
```
static constexpr Input<int>::Optional kSelect{"SELECT"};
```
Instead of setting up the contract procedurally in `GetContract`, add ports to
the contract declaratively, as follows:
```
MEDIAPIPE_NODE_CONTRACT(kInput, kOutput);
```
To access an input in Process, instead of
```
int select = cc->Inputs().Tag(kSelectTag).Get<int>();
```
write
```
int select = kSelectTag(cc).Get(); // alternative: *kSelectTag(cc)
```
Sets of multiple ports can be declared with `::Multiple`. Note, also, that a tag
string must always be provided when declaring a port; use `""` for untagged
ports. For example:
```
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
}
```
becomes
```
static constexpr Input<AnyType>::Multiple kIn{""};
```
For output ports, the payload can be passed directly to the `Send` method. For
example, instead of
```
cc->Outputs().Index(0).Add(
new std::pair<Packet, Packet>(cc->Inputs().Index(0).Value(),
cc->Inputs().Index(1).Value()),
cc->InputTimestamp());
```
you can write
```
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
```
The input timestamp is propagated to the outputs by default. If your calculator
wants to alter timestamps, it must add a `TimestampChange` entry to its contract
declaration. For example:
```
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
StreamHandler("ImmediateInputStreamHandler"),
TimestampChange::Arbitrary());
```
Several calculators in
[`calculators/core`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core) and
[`calculators/tensor`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/tensor)
have been updated to use this API. Reference them for more examples.
More complete documentation will be provided in the future.
## Builder API
Documentation will be provided in the future.

View File

@ -0,0 +1,576 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
#include <string>
#include <type_traits>
#include "absl/container/flat_hash_map.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h"
namespace mediapipe {
namespace api2 {
namespace builder {
template <typename T>
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
auto& vec = *vecp;
if (vec.size() <= index) {
vec.resize(index + 1);
}
if (vec[index] == nullptr) {
vec[index] = absl::make_unique<T>();
}
return *vec[index];
}
struct TagIndexLocation {
const std::string& tag;
std::size_t index;
std::size_t count;
};
template <typename T>
class TagIndexMap {
public:
std::vector<std::unique_ptr<T>>& operator[](const std::string& tag) {
return map_[tag];
}
void Visit(std::function<void(const TagIndexLocation&, const T&)> fun) const {
for (const auto& tagged : map_) {
TagIndexLocation loc{tagged.first, 0, tagged.second.size()};
for (const auto& item : tagged.second) {
fun(loc, *item);
++loc.index;
}
}
}
void Visit(std::function<void(const TagIndexLocation&, T*)> fun) {
for (auto& tagged : map_) {
TagIndexLocation loc{tagged.first, 0, tagged.second.size()};
for (auto& item : tagged.second) {
fun(loc, item.get());
++loc.index;
}
}
}
// Note: entries are held by a unique_ptr to ensure pointers remain valid.
// Should use absl::flat_hash_map but ordering keys for now.
std::map<std::string, std::vector<std::unique_ptr<T>>> map_;
};
// These structs are used internally to store information about the endpoints
// of a connection.
struct SourceBase;
struct DestinationBase {
SourceBase* source = nullptr;
};
struct SourceBase {
std::vector<DestinationBase*> dests_;
std::string name_;
};
// Following existing GraphConfig usage, we allow using a multiport as a single
// port as well. This is necessary for generic nodes, since we have no
// information about which ports are meant to be multiports or not, but it is
// also convenient with typed nodes.
template <typename Single>
class MultiPort : public Single {
public:
using Base = typename Single::Base;
explicit MultiPort(std::vector<std::unique_ptr<Base>>* vec)
: Single(vec), vec_(*vec) {}
Single operator[](int index) {
CHECK_GE(index, 0);
return Single{&GetWithAutoGrow(&vec_, index)};
}
private:
std::vector<std::unique_ptr<Base>>& vec_;
};
// These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API.
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
class DestinationImpl {
public:
using Base = DestinationBase;
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
DestinationBase& base_;
};
template <bool IsSide, typename T>
class DestinationImpl<true, IsSide, T>
: public MultiPort<DestinationImpl<false, IsSide, T>> {
public:
using MultiPort<DestinationImpl<false, IsSide, T>>::MultiPort;
};
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
class SourceImpl {
public:
using Base = SourceBase;
// Src is used as the return type of fluent methods below. Since these are
// single-port methods, it is desirable to always decay to a reference to the
// single-port superclass, even if they are called on a multiport.
using Src = SourceImpl<false, IsSide, T>;
template <typename U>
using Dst = DestinationImpl<false, IsSide, U>;
// clang-format off
template <typename U>
struct AllowConnection : public std::integral_constant<bool,
std::is_same<T, U>{} || std::is_same<T, internal::Generic>{} ||
std::is_same<U, internal::Generic>{}> {};
// clang-format on
explicit SourceImpl(std::vector<std::unique_ptr<Base>>* vec)
: SourceImpl(&GetWithAutoGrow(vec, 0)) {}
explicit SourceImpl(SourceBase* base) : base_(*base) {}
template <typename U,
typename std::enable_if<AllowConnection<U>{}, int>::type = 0>
Src& AddTarget(const Dst<U>& dest) {
CHECK(dest.base_.source == nullptr);
dest.base_.source = &base_;
base_.dests_.emplace_back(&dest.base_);
return *this;
}
Src& SetName(std::string name) {
base_.name_ = std::move(name);
return *this;
}
template <typename U>
Src& operator>>(const Dst<U>& dest) {
return AddTarget(dest);
}
private:
SourceBase& base_;
};
template <bool IsSide, typename T>
class SourceImpl<true, IsSide, T>
: public MultiPort<SourceImpl<false, IsSide, T>> {
public:
using MultiPort<SourceImpl<false, IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node,
// and a side source and side destination correspond to an output/input side
// packet.
// For graph inputs/outputs, however, the inputs are sources, and the outputs
// are destinations. This is because graph ports are connected "from inside"
// when building the graph.
template <bool AllowMultiple = false, typename T = internal::Generic>
using Source = SourceImpl<AllowMultiple, false, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using SideSource = SourceImpl<AllowMultiple, true, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using Destination = DestinationImpl<AllowMultiple, false, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using SideDestination = DestinationImpl<AllowMultiple, true, T>;
class NodeBase {
public:
// TODO: right now access to an indexed port is made directly by
// specifying both a tag and an index. It would be better to represent this
// as a two-step lookup, first getting a multi-port, and then accessing one
// of its entries by index. However, for nodes without visible contracts we
// can't know whether a tag is indexable or not, so we would need the
// multi-port to also be usable as a port directly (representing index 0).
Source<true> Out(const std::string& tag) {
return Source<true>(&out_streams_[tag]);
}
Destination<true> In(const std::string& tag) {
return Destination<true>(&in_streams_[tag]);
}
SideSource<true> SideOut(const std::string& tag) {
return SideSource<true>(&out_sides_[tag]);
}
SideDestination<true> SideIn(const std::string& tag) {
return SideDestination<true>(&in_sides_[tag]);
}
// Convenience methods for accessing purely index-based ports.
Source<false> Out(int index) { return Out("")[index]; }
Destination<false> In(int index) { return In("")[index]; }
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
template <typename T>
T& GetOptions() {
options_used_ = true;
return *options_.MutableExtension(T::ext);
}
protected:
NodeBase(std::string type) : type_(std::move(type)) {}
std::string type_;
TagIndexMap<DestinationBase> in_streams_;
TagIndexMap<SourceBase> out_streams_;
TagIndexMap<DestinationBase> in_sides_;
TagIndexMap<SourceBase> out_sides_;
CalculatorOptions options_;
// ideally we'd just check if any extensions are set on options_
bool options_used_ = false;
friend class Graph;
};
template <class Calc = internal::Generic>
class Node;
#if __cplusplus >= 201703L
// Deduction guide to silence -Wctad-maybe-unsupported.
explicit Node()->Node<internal::Generic>;
#endif // C++17
template <>
class Node<internal::Generic> : public NodeBase {
public:
Node(std::string type) : NodeBase(std::move(type)) {}
};
using GenericNode = Node<internal::Generic>;
template <template <bool, class> class BP, class Port, class TagIndexMapT>
auto MakeBuilderPort(const Port& port, TagIndexMapT& streams) {
return BP<Port::kMultiple, typename Port::PayloadT>(&streams[port.Tag()]);
}
template <class Calc>
class Node : public NodeBase {
public:
Node() : NodeBase(Calc::kCalculatorName) {}
// Overrides the built-in calculator type std::string with the provided
// argument. Can be used to create nodes from pure interfaces.
// TODO: only use this for pure interfaces
Node(const std::string& type_override) : NodeBase(type_override) {}
// These methods only allow access to ports declared in the contract.
// The argument must be a tag object created with the MPP_TAG macro.
// These objects encode the tag in their type, which allows us to return
// a result with the appropriate payload type depending on the tag.
template <class Tag>
auto Out(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedOutputs::get(tag);
return MakeBuilderPort<Source>(port, out_streams_);
}
template <class Tag>
auto In(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedInputs::get(tag);
return MakeBuilderPort<Destination>(port, in_streams_);
}
template <class Tag>
auto SideOut(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedSideOutputs::get(tag);
return MakeBuilderPort<SideSource>(port, out_sides_);
}
template <class Tag>
auto SideIn(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedSideInputs::get(tag);
return MakeBuilderPort<SideDestination>(port, in_sides_);
}
// We could allow using the non-checked versions with typed nodes too, but
// we don't.
// using NodeBase::Out;
// using NodeBase::In;
// using NodeBase::SideOut;
// using NodeBase::SideIn;
};
// For legacy PacketGenerators.
class PacketGenerator {
public:
PacketGenerator(std::string type) : type_(std::move(type)) {}
SideSource<true> SideOut(const std::string& tag) {
return SideSource<true>(&out_sides_[tag]);
}
SideDestination<true> SideIn(const std::string& tag) {
return SideDestination<true>(&in_sides_[tag]);
}
// Convenience methods for accessing purely index-based ports.
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
template <typename T>
T& GetOptions() {
options_used_ = true;
return *options_.MutableExtension(T::ext);
}
private:
std::string type_;
TagIndexMap<DestinationBase> in_sides_;
TagIndexMap<SourceBase> out_sides_;
mediapipe::PacketGeneratorOptions options_;
// ideally we'd just check if any extensions are set on options_
bool options_used_ = false;
friend class Graph;
};
class Graph {
public:
void SetType(std::string type) { type_ = std::move(type); }
// Creates a node of a specific type. Should be used for calculators whose
// contract is available.
template <class Calc>
Node<Calc>& AddNode() {
auto node = std::make_unique<Node<Calc>>();
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
}
// Creates a node of a specific type. Should be used for pure interfaces,
// which do not have a built-in type std::string.
template <class Calc>
Node<Calc>& AddNode(const std::string& type) {
auto node = std::make_unique<Node<Calc>>(type);
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
}
// Creates a generic node, with no compile-time checking of inputs and
// outputs. This can be used for calculators whose contract is not visible.
GenericNode& AddNode(const std::string& type) {
auto node = std::make_unique<GenericNode>(type);
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
}
// For legacy PacketGenerators.
PacketGenerator& AddPacketGenerator(const std::string& type) {
auto node = std::make_unique<PacketGenerator>(type);
auto node_p = node.get();
packet_gens_.emplace_back(std::move(node));
return *node_p;
}
// Graph ports, non-typed.
Source<true> In(const std::string& graph_input) {
return graph_boundary_.Out(graph_input);
}
Destination<true> Out(const std::string& graph_output) {
return graph_boundary_.In(graph_output);
}
SideSource<true> SideIn(const std::string& graph_input) {
return graph_boundary_.SideOut(graph_input);
}
SideDestination<true> SideOut(const std::string& graph_output) {
return graph_boundary_.SideIn(graph_output);
}
// Convenience methods for accessing purely index-based ports.
Source<false> In(int index) { return In("")[0]; }
Destination<false> Out(int index) { return Out("")[0]; }
SideSource<false> SideIn(int index) { return SideIn("")[0]; }
SideDestination<false> SideOut(int index) { return SideOut("")[0]; }
// Graph ports, typed.
// TODO: make graph_boundary_ a typed node!
template <class PortT, class Payload = typename PortT::PayloadT,
class Src = Source<PortT::kMultiple, Payload>>
Src In(const PortT& graph_input) {
return Src(&graph_boundary_.out_streams_[graph_input.Tag()]);
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Dst = Destination<PortT::kMultiple, Payload>>
Dst Out(const PortT& graph_output) {
return Dst(&graph_boundary_.in_streams_[graph_output.Tag()]);
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Src = SideSource<PortT::kMultiple, Payload>>
Src SideIn(const PortT& graph_input) {
return Src(&graph_boundary_.out_sides_[graph_input.Tag()]);
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Dst = SideDestination<PortT::kMultiple, Payload>>
Dst SideOut(const PortT& graph_output) {
return Dst(&graph_boundary_.in_sides_[graph_output.Tag()]);
}
// Returns the graph config. This can be used to instantiate and run the
// graph.
CalculatorGraphConfig GetConfig() {
CalculatorGraphConfig config;
if (!type_.empty()) {
config.set_type(type_);
}
FixUnnamedConnections();
CHECK_OK(UpdateBoundaryConfig(&config));
for (const std::unique_ptr<NodeBase>& node : nodes_) {
auto* out_node = config.add_node();
CHECK_OK(UpdateNodeConfig(*node, out_node));
}
for (const std::unique_ptr<PacketGenerator>& node : packet_gens_) {
auto* out_node = config.add_packet_generator();
CHECK_OK(UpdateNodeConfig(*node, out_node));
}
return config;
}
private:
void FixUnnamedConnections(NodeBase* node, int* unnamed_count) {
node->out_streams_.Visit([&](const TagIndexLocation&, SourceBase* source) {
if (source->name_.empty()) {
source->name_ = absl::StrCat("__stream_", (*unnamed_count)++);
}
});
node->out_sides_.Visit([&](const TagIndexLocation&, SourceBase* source) {
if (source->name_.empty()) {
source->name_ = absl::StrCat("__side_packet_", (*unnamed_count)++);
}
});
}
void FixUnnamedConnections() {
int unnamed_count = 0;
FixUnnamedConnections(&graph_boundary_, &unnamed_count);
for (std::unique_ptr<NodeBase>& node : nodes_) {
FixUnnamedConnections(node.get(), &unnamed_count);
}
for (std::unique_ptr<PacketGenerator>& node : packet_gens_) {
node->out_sides_.Visit([&](const TagIndexLocation&, SourceBase* source) {
if (source->name_.empty()) {
source->name_ = absl::StrCat("__side_packet_", unnamed_count++);
}
});
}
}
std::string TaggedName(const TagIndexLocation& loc, const std::string& name) {
if (loc.tag.empty()) {
// ParseTagIndexName does not allow using explicit indices without tags,
// while ParseTagIndex does. There is no explanation for this discrepancy
// in the CLs that introduced them (cl/143209019, cl/156499931).
// TODO: decide whether we should just allow it.
return name;
} else {
if (loc.count <= 1) {
return absl::StrCat(loc.tag, ":", name);
} else {
return absl::StrCat(loc.tag, ":", loc.index, ":", name);
}
}
}
mediapipe::Status UpdateNodeConfig(const NodeBase& node,
CalculatorGraphConfig::Node* config) {
config->set_calculator(node.type_);
node.in_streams_.Visit(
[&](const TagIndexLocation& loc, const DestinationBase& endpoint) {
CHECK(endpoint.source != nullptr);
config->add_input_stream(TaggedName(loc, endpoint.source->name_));
});
node.out_streams_.Visit(
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_output_stream(TaggedName(loc, endpoint.name_));
});
node.in_sides_.Visit([&](const TagIndexLocation& loc,
const DestinationBase& endpoint) {
CHECK(endpoint.source != nullptr);
config->add_input_side_packet(TaggedName(loc, endpoint.source->name_));
});
node.out_sides_.Visit(
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
});
if (node.options_used_) {
*config->mutable_options() = node.options_;
}
return {};
}
mediapipe::Status UpdateNodeConfig(const PacketGenerator& node,
PacketGeneratorConfig* config) {
config->set_packet_generator(node.type_);
node.in_sides_.Visit([&](const TagIndexLocation& loc,
const DestinationBase& endpoint) {
CHECK(endpoint.source != nullptr);
config->add_input_side_packet(TaggedName(loc, endpoint.source->name_));
});
node.out_sides_.Visit(
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
});
if (node.options_used_) {
*config->mutable_options() = node.options_;
}
return {};
}
// For special boundary node.
mediapipe::Status UpdateBoundaryConfig(CalculatorGraphConfig* config) {
graph_boundary_.in_streams_.Visit(
[&](const TagIndexLocation& loc, const DestinationBase& endpoint) {
CHECK(endpoint.source != nullptr);
config->add_output_stream(TaggedName(loc, endpoint.source->name_));
});
graph_boundary_.out_streams_.Visit(
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_input_stream(TaggedName(loc, endpoint.name_));
});
graph_boundary_.in_sides_.Visit([&](const TagIndexLocation& loc,
const DestinationBase& endpoint) {
CHECK(endpoint.source != nullptr);
config->add_output_side_packet(TaggedName(loc, endpoint.source->name_));
});
graph_boundary_.out_sides_.Visit(
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_input_side_packet(TaggedName(loc, endpoint.name_));
});
return {};
}
std::string type_;
std::vector<std::unique_ptr<NodeBase>> nodes_;
std::vector<std::unique_ptr<PacketGenerator>> packet_gens_;
// Special node representing graph inputs and outputs.
NodeBase graph_boundary_{"__GRAPH__"};
};
} // namespace builder
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_

View File

@ -0,0 +1,190 @@
#include "mediapipe/framework/api2/builder.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/tag.h"
#include "mediapipe/framework/api2/test_contracts.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace api2 {
namespace test {
TEST(BuilderTest, BuildGraph) {
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
auto& bar = graph.AddNode("Bar");
graph.In("IN").SetName("base") >> foo.In("BASE");
graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE");
foo.Out("OUT") >> bar.In("IN");
bar.Out("OUT").SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:__stream_0"
}
node {
calculator: "Bar"
input_stream: "IN:__stream_0"
output_stream: "OUT:out"
}
)");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
template <class FooT>
void BuildGraphTypedTest() {
builder::Graph graph;
auto& foo = graph.AddNode<FooT>();
auto& bar = graph.AddNode<Bar>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS"));
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "IN:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:out"
node {
calculator: "$0"
input_stream: "BASE:base"
input_side_packet: "BIAS:side"
output_stream: "OUT:__stream_0"
}
node {
calculator: "Bar"
input_stream: "IN:__stream_0"
output_stream: "OUT:out"
}
)",
FooT::kCalculatorName));
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<Foo>(); }
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<Foo2>(); }
TEST(BuilderTest, FanOut) {
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
auto& adder = graph.AddNode("FloatAdder");
graph.In("IN").SetName("base") >> foo.In("BASE");
foo.Out("OUT") >> adder.In("IN")[0];
foo.Out("OUT") >> adder.In("IN")[1];
adder.Out("OUT").SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:base"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
output_stream: "OUT:__stream_0"
}
node {
calculator: "FloatAdder"
input_stream: "IN:0:__stream_0"
input_stream: "IN:1:__stream_0"
output_stream: "OUT:out"
}
)");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, TypedMultiple) {
builder::Graph graph;
auto& foo = graph.AddNode<Foo>();
auto& adder = graph.AddNode<FloatAdder>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:base"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
output_stream: "OUT:__stream_0"
}
node {
calculator: "FloatAdder"
input_stream: "IN:0:__stream_0"
input_stream: "IN:1:__stream_0"
output_stream: "OUT:out"
}
)");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, PacketGenerator) {
builder::Graph graph;
auto& generator = graph.AddPacketGenerator("FloatGenerator");
graph.SideIn("IN") >> generator.SideIn("IN");
generator.SideOut("OUT") >> graph.SideOut("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_side_packet: "IN:__side_packet_0"
output_side_packet: "OUT:__side_packet_1"
packet_generator {
packet_generator: "FloatGenerator"
input_side_packet: "IN:__side_packet_0"
output_side_packet: "OUT:__side_packet_1"
}
)");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, EmptyTag) {
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In("A").SetName("a") >> foo.In("")[0];
graph.In("C").SetName("c") >> foo.In("")[2];
graph.In("B").SetName("b") >> foo.In("")[1];
foo.Out("")[0].SetName("x") >> graph.Out("ONE");
foo.Out("")[1].SetName("y") >> graph.Out("TWO");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "A:a"
input_stream: "B:b"
input_stream: "C:c"
output_stream: "ONE:x"
output_stream: "TWO:y"
node {
calculator: "Foo"
input_stream: "a"
input_stream: "b"
input_stream: "c"
output_stream: "x"
output_stream: "y"
}
)");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,43 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_
#define MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_
#include <string>
namespace mediapipe {
namespace api2 {
// This class stores a constant std::string that can be inspected at compile
// time in constexpr code.
class const_str {
public:
constexpr const_str(std::size_t size, const char* data)
: len_(size - 1), data_(data) {}
template <std::size_t N>
explicit constexpr const_str(const char (&str)[N]) : const_str(N, str) {}
constexpr std::size_t len() const { return len_; }
constexpr const char* data() const { return data_; }
constexpr bool operator==(const const_str& other) const {
return len_ == other.len_ && equal(len_, data_, other.data_);
}
constexpr char operator[](const std::size_t idx) const {
return idx <= len_ ? data_[idx] : '\0';
}
private:
static constexpr bool equal(std::size_t len, const char* const p,
const char* const q) {
return len == 0 || (*p == *q && equal(len - 1, p + 1, q + 1));
}
const std::size_t len_;
const char* const data_;
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_CONST_STR_H_

View File

@ -0,0 +1,387 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_
#define MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/tag.h"
#include "mediapipe/framework/api2/tuple.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
namespace api2 {
class StreamHandler {
public:
template <std::size_t N>
explicit constexpr StreamHandler(const char (&name)[N]) : name_(N, name) {}
const const_str& name() { return name_; }
mediapipe::Status AddToContract(CalculatorContract* cc) const {
cc->SetInputStreamHandler(name_.data());
return {};
}
private:
const const_str name_;
};
class TimestampChange {
public:
// Note: we don't use TimestampDiff as an argument because it's not constexpr.
static constexpr TimestampChange Offset(int64_t offset) {
return TimestampChange(offset);
}
static constexpr TimestampChange Arbitrary() {
// Same value as used for Timestamp::Unset.
return TimestampChange(kUnset);
}
mediapipe::Status AddToContract(CalculatorContract* cc) const {
if (offset_ != kUnset) cc->SetTimestampOffset(offset_);
return {};
}
private:
constexpr TimestampChange(int64_t offset) : offset_(offset) {}
static constexpr int64_t kUnset = std::numeric_limits<int64_t>::min();
int64_t offset_;
};
namespace internal {
template <class Base>
struct IsSubclass {
template <class T>
using pred = std::is_base_of<Base, std::decay_t<T>>;
};
template <class T, class = void>
struct HasProcessMethod : std::false_type {};
template <class T>
struct HasProcessMethod<
T, std::void_t<decltype(mediapipe::Status(
std::declval<std::decay_t<T>>().Process(
std::declval<mediapipe::CalculatorContext*>())))>>
: std::true_type {};
template <class T, class = void>
struct HasNestedItems : std::false_type {};
template <class T>
struct HasNestedItems<
T, std::void_t<decltype(std::declval<std::decay_t<T>>().nested_items())>>
: std::true_type {};
// Helper to construct a tuple of Tag types (see tag.h) from a tuple of ports.
template <class TupleRef>
struct TagTuple {
template <std::size_t J>
struct S {
const const_str tag{std::get<J>(TupleRef::get()).tag_};
};
template <std::size_t... I>
static constexpr auto Make(std::index_sequence<I...> indices) {
return std::make_tuple(mediapipe::api2::internal::tag_build(S<I>{})...);
}
static constexpr auto Make() {
using TupleT = decltype(TupleRef::get());
return Make(internal::tuple_index_sequence<TupleT>());
}
};
// Helper to access a tuple of ports by static tag. Attempts to look up a
// missing tag will not compile.
template <class TupleRef>
struct TaggedAccess {
// This is not functionally necessary (we could do the tag search directly
// on the port tuple), but it gives a more readable error message when the
// static_assert below fails.
static constexpr auto kTagTuple = TagTuple<TupleRef>::Make();
template <class Tag>
static constexpr auto& get(Tag tag) {
constexpr auto i =
internal::tuple_find([tag](auto x) { return x == tag; }, kTagTuple);
static_assert(i < std::tuple_size_v<decltype(kTagTuple)>, "tag not found");
return std::get<i>(TupleRef::get());
}
};
template <class... T>
constexpr auto ExtractNestedItems(std::tuple<T...> tuple) {
return internal::flatten_tuple(internal::map_tuple(
[](auto&& item) {
if constexpr (HasNestedItems<decltype(item)>{}) {
return std::tuple_cat(std::make_tuple(item), item.nested_items());
} else {
return std::make_tuple(item);
}
},
tuple));
}
// Internal contract type. Takes a list of ports or other contract items.
template <typename... T>
class Contract {
public:
constexpr Contract(std::tuple<T...> tuple) : items(tuple) {}
constexpr Contract(T&&... args)
: Contract(std::tuple<T...>{std::move(args)...}) {}
mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) const {
std::vector<mediapipe::Status> statuses;
auto store_status = [&statuses](mediapipe::Status status) {
if (!status.ok()) statuses.push_back(std::move(status));
};
internal::tuple_for_each(
[cc, &store_status](auto&& item) {
store_status(item.AddToContract(cc));
},
all_items);
if (timestamp_change_count() == 0) {
// Default to SetOffset(0);
store_status(TimestampChange::Offset(0).AddToContract(cc));
}
if (statuses.empty()) return {};
if (statuses.size() == 1) return statuses[0];
return tool::CombinedStatus("Multiple errors", statuses);
}
std::tuple<T...> items;
// TODO: when forwarding nested items (e.g. ports), check for conflicts.
decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)};
constexpr auto inputs() const {
return internal::filter_tuple<IsSubclass<InputBase>::pred>(all_items);
}
constexpr auto outputs() const {
return internal::filter_tuple<IsSubclass<OutputBase>::pred>(all_items);
}
constexpr auto side_inputs() const {
return internal::filter_tuple<IsSubclass<SideInputBase>::pred>(all_items);
}
constexpr auto side_outputs() const {
return internal::filter_tuple<IsSubclass<SideOutputBase>::pred>(all_items);
}
constexpr auto timestamp_change_count() const {
return internal::filtered_tuple_indices<IsSubclass<TimestampChange>::pred>(
all_items)
.size();
}
constexpr auto process_items() const {
return internal::filter_tuple<HasProcessMethod>(all_items);
}
};
// Helpers to construct a Contract.
template <typename... T>
constexpr auto MakeContract(T&&... args) {
return Contract<T...>(std::forward<T>(args)...);
}
template <typename... T>
constexpr auto MakeContract(const std::tuple<T...>& tuple) {
return Contract<T...>(tuple);
}
// Helper for accessing the ports of a Contract by static tags.
template <typename C2T, const C2T& c2>
class TaggedContract {
public:
constexpr TaggedContract() = default;
static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
return c2.GetContract(cc);
}
template <class Tuple, Tuple (C2T::*member)() const>
struct GetMember {
static constexpr const auto get() { return (c2.*member)(); }
};
using TaggedInputs =
TaggedAccess<GetMember<decltype(c2.inputs()), &C2T::inputs>>;
using TaggedOutputs =
TaggedAccess<GetMember<decltype(c2.outputs()), &C2T::outputs>>;
using TaggedSideInputs =
TaggedAccess<GetMember<decltype(c2.side_inputs()), &C2T::side_inputs>>;
using TaggedSideOutputs =
TaggedAccess<GetMember<decltype(c2.side_outputs()), &C2T::side_outputs>>;
};
// Support for function-based Process.
template <class T>
struct IsInputPort
: std::bool_constant<std::is_base_of<InputBase, std::decay_t<T>>{} ||
std::is_base_of<SideInputBase, std::decay_t<T>>{}> {};
template <class T>
struct IsOutputPort
: std::bool_constant<std::is_base_of<OutputBase, std::decay_t<T>>{} ||
std::is_base_of<SideOutputBase, std::decay_t<T>>{}> {};
// Helper class that converts a port specification into a function argument.
template <class P>
class PortArg {
public:
PortArg(CalculatorContext* cc, const P& port) : cc_(cc), port_(port) {}
using PayloadT = typename P::PayloadT;
operator const PayloadT&() { return port_(cc_).Get(); }
operator Packet<typename P::value_t>() { return port_(cc_); }
operator PacketBase() { return port_(cc_).packet(); }
private:
CalculatorContext* cc_;
const P& port_;
};
template <class P>
auto MakePortArg(CalculatorContext* cc, const P& port) {
return PortArg<P>(cc, port);
}
// Helper class that takes a function result and sends it into outputs.
template <class... P>
class OutputSender {
public:
OutputSender(P&&... args) : outputs_(args...) {}
OutputSender(std::tuple<P...>&& args) : outputs_(args) {}
template <class R, std::enable_if_t<sizeof...(P) == 1, int> = 0>
mediapipe::Status operator()(CalculatorContext* cc,
mediapipe::StatusOr<R>&& result) {
if (result.ok()) {
return this(cc, result.ValueOrDie());
} else {
return result.status();
}
}
template <class R, std::enable_if_t<sizeof...(P) == 1, int> = 0>
mediapipe::Status operator()(CalculatorContext* cc, R&& result) {
std::get<0>(outputs_)(cc).Send(std::forward<R>(result));
return {};
}
template <class... R>
mediapipe::Status operator()(CalculatorContext* cc,
mediapipe::StatusOr<std::tuple<R...>>&& result) {
if (result.ok()) {
return this(cc, result.ValueOrDie());
} else {
return result.status();
}
}
template <class... R>
mediapipe::Status operator()(CalculatorContext* cc,
std::tuple<R...>&& result) {
static_assert(sizeof...(P) == sizeof...(R), "");
internal::tuple_for_each(
[cc, &result](const auto& port, auto i_const) {
constexpr std::size_t i = decltype(i_const)::value;
port(cc).Send(std::get<i>(result));
},
outputs_);
return {};
}
std::tuple<P...> outputs_;
};
template <class... P>
auto MakeOutputSender(P&&... args) {
return OutputSender<P...>(std::forward<P>(args)...);
}
template <class... P>
auto MakeOutputSender(std::tuple<P...>&& args) {
return OutputSender<P...>(std::forward<std::tuple<P...>>(args));
}
// Contract item that specifies that certain I/O ports are handled by invoking
// a specific function.
template <class F, class... P>
class FunCaller {
public:
constexpr FunCaller(F&& f, P&&... args) : f_(f), args_(args...) {}
auto operator()(CalculatorContext* cc) const {
auto output_sender = MakeOutputSender(outputs());
// tuple_apply gives better error messages than std::apply if the argument
// types don't match.
return output_sender(
cc, internal::tuple_apply(f_, internal::map_tuple(
[cc](const auto& port) {
return MakePortArg(cc, port);
},
inputs())));
}
auto inputs() const { return internal::filter_tuple<IsInputPort>(args_); }
auto outputs() const { return internal::filter_tuple<IsOutputPort>(args_); }
mediapipe::Status AddToContract(CalculatorContract* cc) const { return {}; }
mediapipe::Status Process(CalculatorContext* cc) const { return (*this)(cc); }
constexpr std::tuple<P...> nested_items() const { return args_; }
F f_;
std::tuple<P...> args_;
};
// Helper function to invoke function callers in Process.
// TODO: implement multiple callers for syncsets.
template <class... T>
mediapipe::Status ProcessFnCallers(CalculatorContext* cc,
std::tuple<T...> callers);
inline mediapipe::Status ProcessFnCallers(CalculatorContext* cc, std::tuple<>) {
return mediapipe::InternalError("Process unimplemented");
}
template <class T>
mediapipe::Status ProcessFnCallers(CalculatorContext* cc,
std::tuple<T> callers) {
return std::get<0>(callers).Process(cc);
}
} // namespace internal
// Function used to add a process function to a calculator contract.
template <class F, class... P>
constexpr auto ProcessFn(F&& f, P&&... args) {
return internal::FunCaller<F, P...>(std::forward<F>(f),
std::forward<P>(args)...);
}
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_CONTRACT_H_

View File

@ -0,0 +1,73 @@
#include "mediapipe/framework/api2/contract.h"
#include <tuple>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace {
struct ProcessItem {
mediapipe::Status Process(CalculatorContext* cc) { return {}; }
};
struct ItemWithNested {
constexpr auto nested_items() { return std::make_tuple(Input<char>{"FWD"}); }
};
static constexpr auto kTestContract = internal::MakeContract(
Input<int>{"BASE"}, Input<float>::Optional{"SCALE"}, Output<float>{"OUT"},
SideInput<float>::Optional{"BIAS"}, SideOutput<char>{"SIDE"},
ProcessItem{});
static_assert(std::tuple_size_v<decltype(kTestContract.inputs())> == 2, "");
static_assert(std::tuple_size_v<decltype(kTestContract.outputs())> == 1, "");
static_assert(std::tuple_size_v<decltype(kTestContract.side_inputs())> == 1,
"");
static_assert(std::tuple_size_v<decltype(kTestContract.side_outputs())> == 1,
"");
static_assert(internal::HasProcessMethod<ProcessItem>{}, "");
static_assert(!internal::HasProcessMethod<Input<int>>{}, "");
static_assert(std::tuple_size_v<decltype(kTestContract.process_items())> == 1,
"");
static constexpr auto kExtractNested1 = internal::ExtractNestedItems(
std::make_tuple(Input<int>{"BASE"}, Input<float>::Optional{"SCALE"},
Output<float>{"OUT"}));
static_assert(std::tuple_size_v<decltype(kExtractNested1)> == 3, "");
static constexpr auto kExtractNested2 = internal::ExtractNestedItems(
std::make_tuple(Input<int>{"BASE"}, Input<float>::Optional{"SCALE"},
Output<float>{"OUT"}, ItemWithNested{}));
static_assert(std::tuple_size_v<decltype(kExtractNested2)> == 5, "");
using TaggedTestContract =
internal::TaggedContract<decltype(kTestContract), kTestContract>;
static constexpr auto kBASE = MPP_TAG("BASE");
static constexpr auto kSCALE = MPP_TAG("SCALE");
static constexpr auto kBIAS = MPP_TAG("BIAS");
static constexpr auto kOUT = MPP_TAG("OUT");
static constexpr auto kSIDE = MPP_TAG("SIDE");
static_assert(TaggedTestContract::TaggedInputs::get(kBASE).tag_ == kBASE.kStr,
"");
static_assert(TaggedTestContract::TaggedInputs::get(kSCALE).tag_ == kSCALE.kStr,
"");
static_assert(TaggedTestContract::TaggedOutputs::get(kOUT).tag_ == kOUT.kStr,
"");
static_assert(TaggedTestContract::TaggedSideInputs::get(kBIAS).tag_ ==
kBIAS.kStr,
"");
static_assert(TaggedTestContract::TaggedSideOutputs::get(kSIDE).tag_ ==
kSIDE.kStr,
"");
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,9 @@
#include "mediapipe/framework/api2/node.h"
namespace mediapipe {
namespace api2 {
Node::~Node() {}
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,248 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_NODE_H_
#define MEDIAPIPE_FRAMEWORK_API2_NODE_H_
#include <functional>
#include <string>
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/deps/no_destructor.h"
#include "mediapipe/framework/subgraph.h"
namespace mediapipe {
namespace api2 {
class NodeIntf {};
class Node : public CalculatorBase {
public:
virtual ~Node();
};
} // namespace api2
namespace internal {
template <class T>
class CalculatorBaseFactoryFor<
T,
typename std::enable_if<std::is_base_of<mediapipe::api2::Node, T>{}>::type>
: public CalculatorBaseFactory {
public:
mediapipe::Status GetContract(CalculatorContract* cc) final {
auto status = T::Contract::GetContract(cc);
if (status.ok()) {
status = UpdateContract<T>(cc);
}
return status;
}
std::unique_ptr<CalculatorBase> CreateCalculator(
CalculatorContext* calculator_context) final {
return absl::make_unique<T>();
}
private:
template <typename U>
auto UpdateContract(CalculatorContract* cc)
-> decltype(U::UpdateContract(cc)) {
return U::UpdateContract(cc);
}
template <typename U>
mediapipe::Status UpdateContract(...) {
return {};
}
};
} // namespace internal
namespace api2 {
namespace internal {
// Defining a member of this type causes P to be ODR-used, which forces its
// instantiation if it's a static member of a template.
// Previously we depended on the pointer's value to determine whether the size
// of a character array is 0 or 1, forcing it to be instantiated so the
// compiler can determine the object's layout. But using it as a template
// argument is more compact.
template <auto* P>
struct ForceStaticInstantiation {
#ifdef _MSC_VER
// Just having it as the template argument does not count as a use for
// MSVC.
static constexpr bool Use() { return P != nullptr; }
char force_static[Use()];
#endif // _MSC_VER
};
// Helper template for forcing the definition of a static registration token.
template <typename T>
struct NodeRegistrationStatic {
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::CalculatorBaseRegistry::Register(
T::kCalculatorName,
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
template <typename T>
struct SubgraphRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::SubgraphRegistry::Register(T::kCalculatorName,
absl::make_unique<T>);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
SubgraphRegistrationImpl<T>::registration(
SubgraphRegistrationImpl<T>::Make());
} // namespace internal
// By passing the Impl parameter, registration is done automatically. No need
// to use MEDIAPIPE_NODE_IMPLEMENTATION.
// For backward compatibility, Impl can be omitted; use
// MEDIAPIPE_NODE_IMPLEMENTATION with this.
// TODO: migrate and remove.
template <class Impl = void>
class RegisteredNode;
template <class Impl>
class RegisteredNode : public Node {
private:
// The member below triggers instantiation of the registration static.
// Note that the constructor of calculator subclasses is only invoked through
// the registration token, and so we cannot simply use the static in the
// constructor.
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
};
// No-op version for backwards compatibility.
template <>
class RegisteredNode<void> : public Node {};
template <class Impl>
struct FunctionNode : public RegisteredNode<Impl> {
mediapipe::Status Process(CalculatorContext* cc) override {
return internal::ProcessFnCallers(cc, Impl::kContract.process_items());
}
};
template <class Intf, class Impl = void>
class NodeImpl : public RegisteredNode<Impl>, public Intf {
protected:
// These methods allow accessing a node's ports by tag. This can be useful in
// a few cases, e.g. if the port is not available as a named constant.
// They parallel the corresponding methods on builder nodes.
template <class Tag>
static constexpr auto Out(Tag t) {
return Intf::Contract::TaggedOutputs::get(t);
}
template <class Tag>
static constexpr auto In(Tag t) {
return Intf::Contract::TaggedInputs::get(t);
}
template <class Tag>
static constexpr auto SideOut(Tag t) {
return Intf::Contract::TaggedSideOutputs::get(t);
}
template <class Tag>
static constexpr auto SideIn(Tag t) {
return Intf::Contract::TaggedSideInputs::get(t);
}
// Convenience.
template <class Tag, class CC>
static auto Out(Tag t, CC cc) {
return Out(t)(cc);
}
template <class Tag, class CC>
static auto In(Tag t, CC cc) {
return In(t)(cc);
}
template <class Tag, class CC>
static auto SideOut(Tag t, CC cc) {
return SideOut(t)(cc);
}
template <class Tag, class CC>
static auto SideIn(Tag t, CC cc) {
return SideIn(t)(cc);
}
};
// This macro is used to define the contract, without also giving the
// node a type name. It can be used directly in pure interfaces.
#define MEDIAPIPE_NODE_CONTRACT(...) \
static constexpr auto kContract = \
mediapipe::api2::internal::MakeContract(__VA_ARGS__); \
using Contract = \
typename mediapipe::api2::internal::TaggedContract<decltype(kContract), \
kContract>;
// This macro is used to define the contract and the type name of a node.
// This saves the name of the calculator, making it available to the
// implementation too, and to the registration macro for it. The reason is
// that the name must be available with the contract (so that it can be used
// to build a graph config, for instance); however, it is the implementation
// that needs to be registered.
// TODO: rename to MEDIAPIPE_NODE_DECLARATION?
// TODO: more detailed explanation.
#define MEDIAPIPE_NODE_INTERFACE(name, ...) \
static constexpr char kCalculatorName[] = #name; \
MEDIAPIPE_NODE_CONTRACT(__VA_ARGS__)
// TODO: verify that the subgraph config fully implements the
// declared interface.
template <class Intf, class Impl>
class SubgraphImpl : public Subgraph, public Intf {
private:
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_;
};
// This macro is used to register a calculator that does not use automatic
// registration. Deprecated.
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(calculator_registration, __LINE__)( \
mediapipe::CalculatorBaseRegistry::Register( \
Impl::kCalculatorName, \
absl::make_unique< \
mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
// This macro is used to register a non-split-contract calculator. Deprecated.
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
// This macro is used to define a subgraph that does not use automatic
// registration. Deprecated.
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(subgraph_registration, \
__LINE__)(mediapipe::SubgraphRegistry::Register( \
Impl::kCalculatorName, absl::make_unique<Impl>))
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_NODE_H_

View File

@ -0,0 +1,527 @@
#include "mediapipe/framework/api2/node.h"
#include <tuple>
#include <utility>
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/test_contracts.h"
#include "mediapipe/framework/api2/tuple.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace api2 {
namespace test {
using testing::ElementsAre;
// Returns the packet values for a vector of Packets.
template <typename T>
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {
std::vector<T> result;
for (const auto& packet : packets) {
result.push_back(packet.Get<T>());
}
return result;
}
class FooImpl : public NodeImpl<Foo, FooImpl> {
public:
mediapipe::Status Process(CalculatorContext* cc) override {
float bias = kBias(cc).GetOr(0.0);
float scale = kScale(cc).GetOr(1.0);
kOut(cc).Send(*kBase(cc) * scale + bias);
return {};
}
};
class Foo3 : public FunctionNode<Foo3> {
public:
static constexpr Input<int> kBase{"BASE"};
static constexpr Input<float>::Optional kScale{"SCALE"};
static constexpr Output<float> kOut{"OUT"};
static constexpr SideInput<float>::Optional kBias{"BIAS"};
static float foo(int base, Packet<float> bias, Packet<float> scale) {
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
}
// TODO: add support for methods.
MEDIAPIPE_NODE_INTERFACE(Foo3, ProcessFn(&foo, kBase, kBias, kScale, kOut));
};
class Foo4 : public FunctionNode<Foo4> {
public:
static float foo(int base, Packet<float> bias, Packet<float> scale) {
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
}
MEDIAPIPE_NODE_INTERFACE(Foo4, ProcessFn(&foo, Input<int>{"BASE"},
SideInput<float>::Optional{"BIAS"},
Input<float>::Optional{"SCALE"},
Output<float>{"OUT"}));
};
class Foo5 : public FunctionNode<Foo5> {
public:
MEDIAPIPE_NODE_INTERFACE(
Foo5, ProcessFn(
[](int base, Packet<float> bias, Packet<float> scale) {
return base * scale.GetOr(1.0) + bias.GetOr(0.0);
},
Input<int>{"BASE"}, SideInput<float>::Optional{"BIAS"},
Input<float>::Optional{"SCALE"}, Output<float>{"OUT"}));
};
class Foo2Impl : public NodeImpl<Foo2, Foo2Impl> {
public:
mediapipe::Status Process(CalculatorContext* cc) override {
float bias = SideIn(MPP_TAG("BIAS"), cc).GetOr(0.0);
float scale = In(MPP_TAG("SCALE"), cc).GetOr(1.0);
Out(MPP_TAG("OUT"), cc).Send(*In(MPP_TAG("BASE"), cc) * scale + bias);
return {};
}
};
class BarImpl : public NodeImpl<Bar, BarImpl> {
public:
mediapipe::Status Process(CalculatorContext* cc) override {
Packet p = kIn(cc);
kOut(cc).Send(p);
return {};
}
};
class BazImpl : public NodeImpl<Baz> {
public:
static mediapipe::Status UpdateContract(CalculatorContract* cc) { return {}; }
mediapipe::Status Process(CalculatorContext* cc) override {
for (int i = 0; i < kData(cc).Count(); ++i) {
kDataOut(cc)[i].Send(kData(cc)[i]);
}
return {};
}
};
MEDIAPIPE_NODE_IMPLEMENTATION(BazImpl);
class IntForwarderImpl : public NodeImpl<IntForwarder, IntForwarderImpl> {
public:
mediapipe::Status Process(CalculatorContext* cc) override {
kOut(cc).Send(*kIn(cc));
return {};
}
};
class ToFloatImpl : public NodeImpl<ToFloat, ToFloatImpl> {
public:
mediapipe::Status Process(CalculatorContext* cc) override {
kIn(cc).Visit([cc](auto x) { kOut(cc).Send(x); });
return {};
}
};
TEST(NodeTest, GetContract) {
// In the old API, contracts are defined "backwards"; first you fill it in
// with what you have in the graph, then you let the calculator fill it in
// with what it expects, and then you see if they match.
const CalculatorGraphConfig::Node node_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "Foo"
input_stream: "BASE:base"
input_stream: "SCALE:scale"
output_stream: "OUT:out"
)");
mediapipe::CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node_config));
MP_EXPECT_OK(Foo::Contract::GetContract(&contract));
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Inputs()));
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Outputs()));
}
TEST(NodeTest, GetContractMulti) {
const CalculatorGraphConfig::Node node_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "Baz"
input_stream: "DATA:0:b"
input_stream: "DATA:1:c"
output_stream: "DATA:0:d"
output_stream: "DATA:1:e"
)");
mediapipe::CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node_config));
MP_EXPECT_OK(Baz::Contract::GetContract(&contract));
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Inputs()));
MP_EXPECT_OK(ValidatePacketTypeSet(contract.Outputs()));
}
TEST(NodeTest, CreateByName) {
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByName("Foo"));
}
void RunFooCalculatorInGraph(const std::string& foo_name) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "base"
input_stream: "scale"
output_stream: "out"
node {
calculator: "$0"
input_stream: "BASE:base"
input_stream: "SCALE:scale"
output_stream: "OUT:out"
}
)",
foo_name));
std::vector<mediapipe::Packet> out_packets;
tool::AddVectorSink("out", &config, &out_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"base", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"scale", mediapipe::MakePacket<float>(2.0).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(20.0));
}
TEST(NodeTest, RunInGraph) { RunFooCalculatorInGraph("Foo"); }
TEST(NodeTest, RunInGraph3) { RunFooCalculatorInGraph("Foo3"); }
TEST(NodeTest, RunInGraph4) { RunFooCalculatorInGraph("Foo4"); }
TEST(NodeTest, RunInGraph5) { RunFooCalculatorInGraph("Foo5"); }
TEST(NodeTest, OptionalStream) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "base"
input_side_packet: "bias"
output_stream: "out"
node {
calculator: "Foo"
input_stream: "BASE:base"
input_side_packet: "BIAS:bias"
output_stream: "OUT:out"
}
)");
std::vector<mediapipe::Packet> out_packets;
tool::AddVectorSink("out", &config, &out_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({{"bias", mediapipe::MakePacket<float>(30.0)}}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"base", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(40.0));
}
TEST(NodeTest, DynamicTypes) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
output_stream: "out"
node {
calculator: "Bar"
input_stream: "IN:in"
output_stream: "OUT:bar"
}
node {
calculator: "IntForwarder"
input_stream: "IN:bar"
output_stream: "OUT:out"
}
)");
std::vector<mediapipe::Packet> out_packets;
tool::AddVectorSink("out", &config, &out_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(PacketValues<int>(out_packets), testing::ElementsAre(10));
}
TEST(NodeTest, MultiPort) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in0"
input_stream: "in1"
output_stream: "out0"
output_stream: "out1"
node {
calculator: "Baz"
input_stream: "DATA:0:in0"
input_stream: "DATA:1:in1"
output_stream: "DATA:0:baz0"
output_stream: "DATA:1:baz1"
}
node {
calculator: "IntForwarder"
input_stream: "IN:baz0"
output_stream: "OUT:out0"
}
node {
calculator: "IntForwarder"
input_stream: "IN:baz1"
output_stream: "OUT:out1"
}
)");
std::vector<mediapipe::Packet> out0_packets;
std::vector<mediapipe::Packet> out1_packets;
tool::AddVectorSink("out0", &config, &out0_packets);
tool::AddVectorSink("out1", &config, &out1_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in0", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in1", mediapipe::MakePacket<int>(5).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in0", mediapipe::MakePacket<int>(15).At(Timestamp(2))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in1", mediapipe::MakePacket<int>(7).At(Timestamp(2))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
std::vector<int> out0_values;
std::vector<int> out1_values;
for (auto& packet : out0_packets) {
out0_values.push_back(packet.Get<int>());
}
for (auto& packet : out1_packets) {
out1_values.push_back(packet.Get<int>());
}
EXPECT_EQ(out0_values, (std::vector<int>{10, 15}));
EXPECT_EQ(out1_values, (std::vector<int>{5, 7}));
}
struct SideFallback : public Node {
static constexpr Input<int> kIn{"IN"};
static constexpr Input<int>::SideFallback kFactor{"FACTOR"};
static constexpr Output<int> kOut{"OUT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kFactor, kOut);
mediapipe::Status Process(CalculatorContext* cc) override {
kOut(cc).Send(kIn(cc).Get() * kFactor(cc).Get());
return {};
}
};
MEDIAPIPE_REGISTER_NODE(SideFallback);
TEST(NodeTest, SideFallbackWithStream) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
input_stream: "factor"
output_stream: "out"
node {
calculator: "SideFallback"
input_stream: "IN:in"
input_stream: "FACTOR:factor"
output_stream: "OUT:out"
}
)");
std::vector<int> outputs;
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(
graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) {
outputs.push_back(p.Get<int>());
return mediapipe::OkStatus();
}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<int>(10).At(Timestamp(0))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"factor", mediapipe::MakePacket<int>(2).At(Timestamp(0))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(outputs, std::vector<int>{20});
}
TEST(NodeTest, SideFallbackWithSide) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
input_side_packet: "factor"
output_stream: "out"
node {
calculator: "SideFallback"
input_stream: "IN:in"
input_side_packet: "FACTOR:factor"
output_stream: "OUT:out"
}
)");
std::vector<int> outputs;
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(
graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) {
outputs.push_back(p.Get<int>());
return mediapipe::OkStatus();
}));
MP_EXPECT_OK(graph.StartRun({{"factor", mediapipe::MakePacket<int>(2)}}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<int>(10).At(Timestamp(0))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(outputs, std::vector<int>{20});
}
TEST(NodeTest, SideFallbackWithNone) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
output_stream: "out"
node {
calculator: "SideFallback"
input_stream: "IN:in"
output_stream: "OUT:out"
}
)");
std::vector<int> outputs;
mediapipe::CalculatorGraph graph;
auto status = graph.Initialize(config, {});
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
}
TEST(NodeTest, SideFallbackWithBoth) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
input_stream: "factor"
input_side_packet: "factor_side"
output_stream: "out"
node {
calculator: "SideFallback"
input_stream: "IN:in"
input_stream: "FACTOR:factor"
input_side_packet: "FACTOR:factor_side"
output_stream: "OUT:out"
}
)");
std::vector<int> outputs;
mediapipe::CalculatorGraph graph;
auto status = graph.Initialize(config, {});
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
}
TEST(NodeTest, OneOf) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in"
output_stream: "out"
node {
calculator: "ToFloat"
input_stream: "IN:in"
output_stream: "OUT:out"
}
)");
std::vector<mediapipe::Packet> out_packets;
tool::AddVectorSink("out", &config, &out_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<int>(10).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<float>(5.0).At(Timestamp(2))));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(PacketValues<float>(out_packets), testing::ElementsAre(10, 5.0));
}
struct DropEvenTimestamps : public Node {
static constexpr Input<AnyType> kIn{"IN"};
static constexpr Output<SameType<kIn>> kOut{"OUT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
mediapipe::Status Process(CalculatorContext* cc) override {
if (cc->InputTimestamp().Value() % 2) {
kOut(cc).Send(kIn(cc));
}
return {};
}
};
MEDIAPIPE_REGISTER_NODE(DropEvenTimestamps);
struct ListIntPackets : public Node {
static constexpr Input<int>::Multiple kIn{"INT"};
static constexpr Output<std::string> kOut{"STR"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
mediapipe::Status Process(CalculatorContext* cc) override {
std::string result = absl::StrCat(cc->InputTimestamp().DebugString(), ":");
for (int i = 0; i < kIn(cc).Count(); ++i) {
if (kIn(cc)[i].IsEmpty()) {
absl::StrAppend(&result, " empty");
} else {
absl::StrAppend(&result, " ", *kIn(cc)[i]);
}
}
kOut(cc).Send(std::move(result));
return {};
}
};
MEDIAPIPE_REGISTER_NODE(ListIntPackets);
TEST(NodeTest, DefaultTimestampChange0) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "a"
input_stream: "b"
output_stream: "out"
node {
calculator: "DropEvenTimestamps"
input_stream: "IN:a"
output_stream: "OUT:a2"
}
node {
calculator: "IntForwarder"
input_stream: "IN:a2"
output_stream: "OUT:a3"
}
node {
calculator: "ListIntPackets"
input_stream: "INT:0:a3"
input_stream: "INT:1:b"
output_stream: "STR:out"
}
)");
std::vector<mediapipe::Packet> out_packets;
tool::AddVectorSink("out", &config, &out_packets);
mediapipe::CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(config, {}));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"a", mediapipe::MakePacket<int>(10).At(Timestamp(2))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"b", mediapipe::MakePacket<int>(10).At(Timestamp(2))));
MP_EXPECT_OK(graph.WaitUntilIdle());
// The packet sent to a should have been dropped, but the timestamp bound
// should be forwarded by IntForwarder, and ListIntPackets should have run.
EXPECT_THAT(PacketValues<std::string>(out_packets),
testing::ElementsAre("2: empty 10"));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
}
} // namespace test
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,15 @@
#include "mediapipe/framework/api2/packet.h"
namespace mediapipe {
namespace api2 {
PacketBase FromOldPacket(const mediapipe::Packet& op) {
return PacketBase(packet_internal::GetHolderShared(op)).At(op.Timestamp());
}
mediapipe::Packet ToOldPacket(const PacketBase& p) {
return mediapipe::packet_internal::Create(p.payload_, p.timestamp_);
}
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,353 @@
// This file defines a typed Packet type. It fully interoperates with the older
// mediapipe::Packet; creating an api::Packet<T> that refers to an existing
// Packet (or vice versa) is cheap, just like copying a Packet. Ownership of
// the payload is shared. Consider this as a typed view into the same data.
//
// Conversion is currently done explicitly with the FromOldPacket and
// ToOldPacket functions, but calculator code does not need to concern itself
// with it.
#ifndef MEDIAPIPE_FRAMEWORK_API2_PACKET_H_
#define MEDIAPIPE_FRAMEWORK_API2_PACKET_H_
#include <functional>
#include <type_traits>
#include "mediapipe/framework/api2/tuple.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
namespace api2 {
using Timestamp = mediapipe::Timestamp;
using HolderBase = mediapipe::packet_internal::HolderBase;
template <typename T>
class Packet;
// Type-erased packet.
class PacketBase {
public:
// Empty.
PacketBase() = default;
// Copy.
PacketBase(const PacketBase&) = default;
PacketBase& operator=(const PacketBase&) = default;
// Move.
PacketBase(PacketBase&&) = default;
PacketBase& operator=(PacketBase&&) = default;
// Get timestamp.
Timestamp timestamp() const { return timestamp_; }
// The original API has a Timestamp method, but it shadows the Timestamp
// type within this class, which is annoying.
// Timestamp Timestamp() const { return timestamp_; }
PacketBase At(Timestamp timestamp) const&;
PacketBase At(Timestamp timestamp) &&;
bool IsEmpty() const { return payload_ == nullptr; }
template <typename T>
Packet<T> As() const;
// Returns the reference to the object of type T if it contains
// one, crashes otherwise.
template <typename T>
const T& Get() const;
// Conversion to old Packet type.
operator mediapipe::Packet() const { return ToOldPacket(*this); }
protected:
explicit PacketBase(std::shared_ptr<HolderBase> payload)
: payload_(std::move(payload)) {}
std::shared_ptr<HolderBase> payload_;
Timestamp timestamp_;
template <typename T>
friend PacketBase PacketBaseAdopting(const T* ptr);
friend PacketBase FromOldPacket(const mediapipe::Packet& op);
friend mediapipe::Packet ToOldPacket(const PacketBase& p);
};
PacketBase FromOldPacket(const mediapipe::Packet& op);
mediapipe::Packet ToOldPacket(const PacketBase& p);
template <typename T>
inline const T& PacketBase::Get() const {
CHECK(payload_);
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
CHECK(typed_payload) << absl::StrCat(
"The Packet stores \"", payload_->DebugTypeName(), "\", but \"",
MediaPipeTypeStringOrDemangled<T>(), "\" was requested.");
return typed_payload->data();
}
// This is used to indicate that the packet could be holding one of a set of
// types, e.g. Packet<OneOf<A, B>>.
//
// A Packet<OneOf<T...>> has an interface similar to std::variant<T...>.
// However, we cannot use std::variant directly, since it requires that the
// contained object be stored in place within the variant.
// Suppose we have a stream that accepts an Image or an ImageFrame, and it
// receives a Packet<ImageFrame>. To present it as a
// std::variant<Image, ImageFrame> we would have to move the ImageFrame into
// the variant (or copy it), but that is not compatible with Packet's existing
// ownership model.
// We could have Get() return a std::variant<std::reference_wrapper<Image>,
// std::reference_wrapper<ImageFrame>>, but that would just make user code more
// convoluted.
//
// TODO: should we just use Packet<T...>?
template <class... T>
struct OneOf {};
namespace internal {
template <class T>
inline void CheckCompatibleType(const HolderBase& holder, internal::Wrap<T>) {
const packet_internal::Holder<T>* typed_payload = holder.As<T>();
CHECK(typed_payload) << absl::StrCat(
"The Packet stores \"", holder.DebugTypeName(), "\", but \"",
MediaPipeTypeStringOrDemangled<T>(), "\" was requested.");
// CHECK(payload_->has_type<T>());
}
template <class... T>
inline void CheckCompatibleType(const HolderBase& holder,
internal::Wrap<OneOf<T...>>) {
bool compatible = (holder.As<T>() || ...);
CHECK(compatible)
<< "The Packet stores \"" << holder.DebugTypeName() << "\", but one of "
<< absl::StrJoin(
{absl::StrCat("\"", MediaPipeTypeStringOrDemangled<T>(), "\"")...},
", ")
<< " was requested.";
}
struct Generic {
Generic() = delete;
};
}; // namespace internal
template <typename T>
inline Packet<T> PacketBase::As() const {
if (!payload_) return Packet<T>().At(timestamp_);
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
internal::CheckCompatibleType(*payload_, internal::Wrap<T>{});
return Packet<T>(payload_).At(timestamp_);
}
template <>
inline Packet<internal::Generic> PacketBase::As<internal::Generic>() const;
template <typename T = internal::Generic>
class Packet;
#if __cplusplus >= 201703L
// Deduction guide to silence -Wctad-maybe-unsupported.
explicit Packet()->Packet<internal::Generic>;
#endif // C++17
template <>
class Packet<internal::Generic> : public PacketBase {
public:
Packet() = default;
Packet<internal::Generic> At(Timestamp timestamp) const&;
Packet<internal::Generic> At(Timestamp timestamp) &&;
protected:
explicit Packet(std::shared_ptr<HolderBase> payload)
: PacketBase(std::move(payload)) {}
friend PacketBase;
};
// Having Packet<T> subclass Packet<Generic> will require hiding some methods
// like As. May be better not to subclass, and allow implicit conversion
// instead.
template <typename T>
class Packet : public Packet<internal::Generic> {
public:
Packet() = default;
Packet<T> At(Timestamp timestamp) const&;
Packet<T> At(Timestamp timestamp) &&;
const T& Get() const {
CHECK(payload_);
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
CHECK(typed_payload);
return typed_payload->data();
}
const T& operator*() const { return Get(); }
template <typename U>
T GetOr(U&& v) const {
return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this;
}
private:
explicit Packet(std::shared_ptr<HolderBase> payload)
: Packet<internal::Generic>(std::move(payload)) {}
friend PacketBase;
template <typename U, typename... Args>
friend Packet<U> MakePacket(Args&&... args);
template <typename U>
friend Packet<U> PacketAdopting(const U* ptr);
template <typename U>
friend Packet<U> PacketAdopting(std::unique_ptr<U> ptr);
};
namespace internal {
template <class... F>
struct Overload : F... {
using F::operator()...;
};
template <class... F>
explicit Overload(F...) -> Overload<F...>;
template <class T, class... U>
struct First {
using type = T;
};
} // namespace internal
template <class... T>
class Packet<OneOf<T...>> : public PacketBase {
public:
Packet() = default;
template <class U>
using AllowedType = std::enable_if_t<(std::is_same_v<U, T> || ...)>;
template <class U, class = AllowedType<U>>
Packet(const Packet<U>& p) : PacketBase(p) {}
template <class U, class = AllowedType<U>>
Packet<OneOf<T...>>& operator=(const Packet<U>& p) {
PacketBase::operator=(p);
return *this;
}
template <class U, class = AllowedType<U>>
Packet(Packet<U>&& p) : PacketBase(std::move(p)) {}
template <class U, class = AllowedType<U>>
Packet<OneOf<T...>>& operator=(Packet<U>&& p) {
PacketBase::operator=(std::move(p));
return *this;
}
Packet<OneOf<T...>> At(Timestamp timestamp) const& {
return Packet<OneOf<T...>>(*this).At(timestamp);
}
Packet<OneOf<T...>> At(Timestamp timestamp) && {
timestamp_ = timestamp;
return std::move(*this);
}
template <class U, class = AllowedType<U>>
const U& Get() const {
CHECK(payload_);
packet_internal::Holder<U>* typed_payload = payload_->As<U>();
CHECK(typed_payload);
return typed_payload->data();
}
template <class U, class = AllowedType<U>>
bool Has() const {
return payload_ && payload_->As<U>();
}
template <class... F>
auto Visit(const F&... args) const {
CHECK(payload_);
auto f = internal::Overload{args...};
using FirstT = typename internal::First<T...>::type;
using ResultType = absl::result_of_t<decltype(f)(const FirstT&)>;
static_assert(
(std::is_same_v<ResultType, absl::result_of_t<decltype(f)(const T&)>> &&
...),
"All visitor overloads must have the same return type");
return Invoke<decltype(f), T...>(f);
}
protected:
explicit Packet(std::shared_ptr<HolderBase> payload)
: PacketBase(std::move(payload)) {}
friend PacketBase;
private:
template <class F, class U>
auto Invoke(const F& f) const {
return f(Get<U>());
}
template <class F, class U, class V, class... W>
auto Invoke(const F& f) const {
return Has<U>() ? f(Get<U>()) : Invoke<F, V, W...>(f);
}
};
template <>
inline Packet<internal::Generic> PacketBase::As<internal::Generic>() const {
if (!payload_) return Packet<internal::Generic>().At(timestamp_);
return Packet<internal::Generic>(payload_).At(timestamp_);
}
inline PacketBase PacketBase::At(Timestamp timestamp) const& {
return PacketBase(*this).At(timestamp);
}
inline PacketBase PacketBase::At(Timestamp timestamp) && {
timestamp_ = timestamp;
return std::move(*this);
}
template <typename T>
inline Packet<T> Packet<T>::At(Timestamp timestamp) const& {
return Packet<T>(*this).At(timestamp);
}
template <typename T>
inline Packet<T> Packet<T>::At(Timestamp timestamp) && {
timestamp_ = timestamp;
return std::move(*this);
}
inline Packet<internal::Generic> Packet<internal::Generic>::At(
Timestamp timestamp) const& {
return Packet<internal::Generic>(*this).At(timestamp);
}
inline Packet<internal::Generic> Packet<internal::Generic>::At(
Timestamp timestamp) && {
timestamp_ = timestamp;
return std::move(*this);
}
template <typename T, typename... Args>
Packet<T> MakePacket(Args&&... args) {
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(
new T(std::forward<Args>(args)...)));
}
template <typename T>
Packet<T> PacketAdopting(const T* ptr) {
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(ptr));
}
template <typename T>
Packet<T> PacketAdopting(std::unique_ptr<T> ptr) {
return Packet<T>(std::make_shared<packet_internal::Holder<T>>(ptr.release()));
}
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_PACKET_H_

View File

@ -0,0 +1,16 @@
#include "mediapipe/framework/api2/packet.h"
namespace api2 {
namespace {
#if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE)
void AssignWrongPacketType() { Packet<int> p = MakePacket<float>(1.0); }
#elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC)
void AssignWrongPacketType() {
Packet<> p = MakePacket<float>(1.0);
Packet<int> p2 = p;
}
#endif
} // namespace
}; // namespace api2

View File

@ -0,0 +1,195 @@
#include "mediapipe/framework/api2/packet.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace {
class LiveCheck {
public:
explicit LiveCheck(bool* alive) : alive_(*alive) { alive_ = true; }
~LiveCheck() { alive_ = false; }
private:
bool& alive_;
};
TEST(PacketTest, PacketBaseDefault) {
PacketBase p;
EXPECT_TRUE(p.IsEmpty());
}
TEST(PacketTest, PacketBaseNonEmpty) {
PacketBase p = PacketAdopting(new int(5));
EXPECT_FALSE(p.IsEmpty());
}
TEST(PacketTest, PacketBaseRefCount) {
bool alive = false;
PacketBase p = PacketAdopting(new LiveCheck(&alive));
EXPECT_TRUE(alive);
PacketBase p2 = p;
p = {};
EXPECT_TRUE(alive);
p2 = {};
EXPECT_FALSE(alive);
}
TEST(PacketTest, PacketBaseSame) {
int* ip = new int(5);
PacketBase p = PacketAdopting(ip);
PacketBase p2 = p;
EXPECT_EQ(&p2.Get<int>(), ip);
}
TEST(PacketTest, PacketNonEmpty) {
Packet<int> p = MakePacket<int>(5);
EXPECT_FALSE(p.IsEmpty());
}
TEST(PacketTest, Get) {
Packet<int> p = MakePacket<int>(5);
EXPECT_EQ(*p, 5);
EXPECT_EQ(p.Get(), 5);
}
TEST(PacketTest, GetOr) {
Packet<int> p_0 = MakePacket<int>(0);
Packet<int> p_5 = MakePacket<int>(5);
Packet<int> p_empty;
EXPECT_EQ(p_0.GetOr(1), 0);
EXPECT_EQ(p_5.GetOr(1), 5);
EXPECT_EQ(p_empty.GetOr(1), 1);
}
// This show how GetOr can be used with a lambda that is only called if the "or"
// case is needed. Can be useful when generating the fallback value is
// expensive.
// We could also add an overload to GetOr for types which are not convertible to
// T, but are callable and return T.
// TODO: consider adding it to make things easier.
template <typename F>
struct Lazy {
F f;
using ValueT = decltype(f());
Lazy(F fun) : f(fun) {}
operator ValueT() const { return f(); }
};
template <typename F>
Lazy(F f) -> Lazy<F>;
TEST(PacketTest, GetOrLazy) {
int expensive_call_count = 0;
auto expensive_string_generation = [&expensive_call_count] {
++expensive_call_count;
return "an expensive fallback";
};
auto p_hello = MakePacket<std::string>("hello");
Packet<std::string> p_empty;
EXPECT_EQ(p_hello.GetOr(Lazy(expensive_string_generation)), "hello");
EXPECT_EQ(expensive_call_count, 0);
EXPECT_EQ(p_empty.GetOr(Lazy(expensive_string_generation)),
"an expensive fallback");
EXPECT_EQ(expensive_call_count, 1);
}
TEST(PacketTest, OneOf) {
Packet<OneOf<std::string, int>> p = MakePacket<std::string>("hi");
EXPECT_TRUE(p.Has<std::string>());
EXPECT_FALSE(p.Has<int>());
EXPECT_EQ(p.Get<std::string>(), "hi");
std::string out =
p.Visit([](std::string s) { return absl::StrCat("string: ", s); },
[](int i) { return absl::StrCat("int: ", i); });
EXPECT_EQ(out, "string: hi");
p = MakePacket<int>(2);
EXPECT_FALSE(p.Has<std::string>());
EXPECT_TRUE(p.Has<int>());
EXPECT_EQ(p.Get<int>(), 2);
out = p.Visit([](std::string s) { return absl::StrCat("string: ", s); },
[](int i) { return absl::StrCat("int: ", i); });
EXPECT_EQ(out, "int: 2");
}
TEST(PacketTest, PacketRefCount) {
bool alive = false;
auto p = MakePacket<LiveCheck>(&alive);
EXPECT_TRUE(alive);
auto p2 = p;
p = {};
EXPECT_TRUE(alive);
p2 = {};
EXPECT_FALSE(alive);
}
TEST(PacketTest, PacketTimestamp) {
auto p = MakePacket<int>(5);
EXPECT_EQ(p.timestamp(), Timestamp::Unset());
auto p2 = p.At(Timestamp(1));
EXPECT_EQ(p.timestamp(), Timestamp::Unset());
EXPECT_EQ(p2.timestamp(), Timestamp(1));
auto p3 = std::move(p2).At(Timestamp(3));
EXPECT_EQ(p3.timestamp(), Timestamp(3));
}
TEST(PacketTest, PacketFromGeneric) {
Packet<> pb = PacketAdopting(new int(5));
Packet<int> p = pb.As<int>();
EXPECT_EQ(p.Get(), 5);
}
TEST(PacketTest, PacketAdopting) {
Packet<float> p = PacketAdopting(new float(1.0));
EXPECT_FALSE(p.IsEmpty());
}
TEST(PacketTest, PacketGeneric) {
// With C++17, Packet<> could be written simply as Packet.
Packet<> p = PacketAdopting(new float(1.0));
EXPECT_FALSE(p.IsEmpty());
}
TEST(PacketTest, PacketGenericTimestamp) {
Packet<> p = MakePacket<int>(5);
EXPECT_EQ(p.timestamp(), mediapipe::Timestamp::Unset());
auto p2 = p.At(Timestamp(1));
EXPECT_EQ(p.timestamp(), mediapipe::Timestamp::Unset());
EXPECT_EQ(p2.timestamp(), Timestamp(1));
auto p3 = std::move(p2).At(Timestamp(3));
EXPECT_EQ(p3.timestamp(), Timestamp(3));
}
TEST(PacketTest, FromOldPacket) {
mediapipe::Packet op = mediapipe::MakePacket<int>(7);
Packet<int> p = FromOldPacket(op).As<int>();
EXPECT_EQ(p.Get(), 7);
}
TEST(PacketTest, ToOldPacket) {
auto p = MakePacket<int>(7);
mediapipe::Packet op = ToOldPacket(p);
EXPECT_EQ(op.Get<int>(), 7);
}
TEST(PacketTest, OldRefCounting) {
bool alive = false;
PacketBase p = PacketAdopting(new LiveCheck(&alive));
EXPECT_TRUE(alive);
mediapipe::Packet op = ToOldPacket(p);
p = {};
EXPECT_TRUE(alive);
PacketBase p2 = FromOldPacket(op);
op = {};
EXPECT_TRUE(alive);
p2 = {};
EXPECT_FALSE(alive);
}
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,628 @@
// This file defines an API to define a node's ports in a concise, type-safe
// way. Example usage in a node:
//
// static constexpr Input<int> kBase("IN");
// static constexpr Output<float> kOut("OUT");
// static constexpr SideInput<float>::Optional kDelta("DELTA");
// static constexpr SideOutput<float> kForward("FORWARD");
//
// Pass a CalculatorContext to a port to access the inputs or outputs in the
// context. For example:
//
// kBase(cc) yields an InputShardAccess<int>
// kOut(cc) yields an OutputShardAccess<float>
// kDelta(cc) yields an InputSidePacketAccess<float>
// kForward(cc) yields an OutputSidePacketAccess<float>
#ifndef MEDIAPIPE_FRAMEWORK_API2_PORT_H_
#define MEDIAPIPE_FRAMEWORK_API2_PORT_H_
#include <type_traits>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
namespace api2 {
// typeid is not constexpr, but a pointer to this is.
template <typename T>
size_t get_type_hash() {
return typeid(T).hash_code();
}
using type_id_fptr = size_t (*)();
// This is a base class for various types of port. It is not meant to be used
// directly by node code.
class PortBase {
public:
constexpr PortBase(std::size_t tag_size, const char* tag,
type_id_fptr get_type_id, bool optional, bool multiple)
: tag_(tag_size, tag),
optional_(optional),
multiple_(multiple),
type_id_getter_(get_type_id) {}
bool IsOptional() const { return optional_; }
bool IsMultiple() const { return multiple_; }
const char* Tag() const { return tag_.data(); }
size_t type_id() const { return type_id_getter_(); }
const const_str tag_;
const bool optional_;
const bool multiple_;
protected:
type_id_fptr type_id_getter_;
};
// These four base classes are used to distinguish between ports of different
// kinds. They are not meant to be used directly by node code.
class InputBase : public PortBase {
using PortBase::PortBase;
};
class OutputBase : public PortBase {
using PortBase::PortBase;
};
class SideInputBase : public PortBase {
using PortBase::PortBase;
};
class SideOutputBase : public PortBase {
using PortBase::PortBase;
};
struct NoneType {
private:
NoneType() = delete;
};
struct DynamicType {};
struct AnyType : public DynamicType {};
template <auto& P>
class SameType : public DynamicType {
public:
static constexpr const decltype(P)& kPort = P;
};
class PacketTypeAccess;
class PacketTypeAccessFallback;
template <typename T>
class InputShardAccess;
template <typename T>
class OutputShardAccess;
template <typename T>
class InputSidePacketAccess;
template <typename T>
class OutputSidePacketAccess;
template <typename T>
class InputShardOrSideAccess;
namespace internal {
// Forward declaration for AddToContract friend.
template <typename...>
class Contract;
template <class CC>
auto GetCollection(CC* cc, const InputBase& port) -> decltype(cc->Inputs()) {
return cc->Inputs();
}
template <class CC>
auto GetCollection(CC* cc, const SideInputBase& port)
-> decltype(cc->InputSidePackets()) {
return cc->InputSidePackets();
}
template <class CC>
auto GetCollection(CC* cc, const OutputBase& port) -> decltype(cc->Outputs()) {
return cc->Outputs();
}
template <class CC>
auto GetCollection(CC* cc, const SideOutputBase& port)
-> decltype(cc->OutputSidePackets()) {
return cc->OutputSidePackets();
}
template <class Collection>
auto GetOrNull(Collection& collection, const std::string& tag, int index)
-> decltype(&collection.Get(std::declval<CollectionItemId>())) {
CollectionItemId id = collection.GetId(tag, index);
return id.IsValid() ? &collection.Get(id) : nullptr;
}
template <class T>
struct IsOneOf : std::false_type {};
template <class... T>
struct IsOneOf<OneOf<T...>> : std::true_type {};
template <typename T, typename std::enable_if<
!std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{},
int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.Set<T>();
}
template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{},
int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
}
template <>
inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) {
pt.SetAny();
}
template <>
inline void SetType<NoneType>(CalculatorContract* cc, PacketType& pt) {
// This is used for header-only streams. Should it be removed?
pt.SetNone();
}
template <typename T, typename std::enable_if<IsOneOf<T>{}, int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.SetAny();
}
template <typename ValueT>
InputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
const InputStreamShard* stream) {
return InputShardAccess<ValueT>(*cc, stream);
}
template <typename ValueT>
OutputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
OutputStreamShard* stream) {
return OutputShardAccess<ValueT>(*cc, stream);
}
template <typename ValueT>
InputSidePacketAccess<ValueT> SinglePortAccess(
mediapipe::CalculatorContext* cc, const mediapipe::Packet* packet) {
return InputSidePacketAccess<ValueT>(packet);
}
template <typename ValueT>
OutputSidePacketAccess<ValueT> SinglePortAccess(
mediapipe::CalculatorContext* cc, OutputSidePacket* osp) {
return OutputSidePacketAccess<ValueT>(osp);
}
template <typename ValueT>
InputShardOrSideAccess<ValueT> SinglePortAccess(
mediapipe::CalculatorContext* cc, const InputStreamShard* stream,
const mediapipe::Packet* packet) {
return InputShardOrSideAccess<ValueT>(*cc, stream, packet);
}
template <typename ValueT>
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
PacketType* pt);
template <typename ValueT>
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
PacketType* pt, bool is_stream);
template <typename ValueT, typename PortT, class CC>
auto AccessPort(std::false_type, const PortT& port, CC* cc) {
auto& collection = GetCollection(cc, port);
return SinglePortAccess<ValueT>(
cc, internal::GetOrNull(collection, port.Tag(), 0));
}
template <typename ValueT, typename X, class CC>
class MultiplePortAccess {
public:
MultiplePortAccess(CC* cc, X* first, int count)
: cc_(cc), first_(first), count_(count) {}
// TODO: maybe this should be size(), like in a standard C++
// container?
int Count() { return count_; }
auto operator[](int pos) {
CHECK_GE(pos, 0);
CHECK_LT(pos, count_);
return SinglePortAccess<ValueT>(cc_, &first_[pos]);
}
// TODO: add begin/end.
private:
CC* cc_;
X* first_;
int count_;
};
template <typename ValueT, typename PortT, class CC>
auto AccessPort(std::true_type, const PortT& port, CC* cc) {
auto& collection = GetCollection(cc, port);
auto* first = internal::GetOrNull(collection, port.Tag(), 0);
using EntryT = typename std::remove_pointer<decltype(first)>::type;
return MultiplePortAccess<ValueT, EntryT, CC>(
cc, first, collection.NumEntries(port.Tag()));
}
template <class Base>
struct SideBase;
template <>
struct SideBase<InputBase> {
using type = SideInputBase;
};
} // namespace internal
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
template <typename T, typename std::enable_if<
!std::is_base_of<DynamicType, T>{}, int>::type = 0>
auto ActualValueT(T) -> T;
auto ActualValueT(DynamicType) -> internal::Generic;
template <typename Base, typename ValueT, bool IsOptional = false,
bool IsMultiple = false>
class SideFallbackT;
// This template is used to define a port. Nodes should use it through one
// of the aliases below (Input, Output, SideInput, SideOutput).
template <typename Base, typename ValueT, bool IsOptionalV = false,
bool IsMultipleV = false>
class PortCommon : public Base {
public:
using value_t = ValueT;
static constexpr bool kOptional = IsOptionalV;
static constexpr bool kMultiple = IsMultipleV;
using Optional = PortCommon<Base, ValueT, true, IsMultipleV>;
using Multiple = PortCommon<Base, ValueT, IsOptionalV, true>;
using SideFallback = SideFallbackT<Base, ValueT, IsOptionalV, IsMultipleV>;
template <std::size_t N>
explicit constexpr PortCommon(const char (&tag)[N])
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV) {}
using PayloadT = decltype(ActualValueT(std::declval<ValueT>()));
auto operator()(CalculatorContext* cc) const {
return internal::AccessPort<PayloadT>(
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
}
auto operator()(CalculatorContract* cc) const {
return internal::AccessPort<PayloadT>(
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
}
private:
mediapipe::Status AddToContract(CalculatorContract* cc) const {
if (kMultiple) {
AddMultiple(cc);
} else {
auto& pt = internal::GetCollection(cc, *this).Tag(this->Tag());
internal::SetType<value_t>(cc, pt);
if (kOptional) {
pt.Optional();
}
}
return {};
}
void AddMultiple(CalculatorContract* cc) const {
auto& collection = internal::GetCollection(cc, *this);
int count = collection.NumEntries(this->Tag());
for (int i = 0; i < count; ++i) {
internal::SetType<value_t>(cc, collection.Get(this->Tag(), i));
}
}
template <typename...>
friend class internal::Contract;
template <typename B, typename VT, bool, bool>
friend class mediapipe::api2::SideFallbackT;
};
// Use one of these templates to define a port in node code.
template <typename T = internal::Generic>
using Input = PortCommon<InputBase, T>;
template <typename T = internal::Generic>
using Output = PortCommon<OutputBase, T>;
template <typename T = internal::Generic>
using SideInput = PortCommon<SideInputBase, T>;
template <typename T = internal::Generic>
using SideOutput = PortCommon<SideOutputBase, T>;
template <typename Base, typename ValueT, bool IsOptionalV, bool IsMultipleV>
class SideFallbackT : public Base {
public:
using value_t = ValueT;
static constexpr bool kOptional = IsOptionalV;
static constexpr bool kMultiple = IsMultipleV;
using Optional = SideFallbackT<Base, ValueT, true, IsMultipleV>;
using PayloadT = decltype(ActualValueT(std::declval<ValueT>()));
const char* Tag() const { return stream_port.Tag(); }
auto operator()(CalculatorContract* cc) const {
bool is_stream = true;
auto& stream_collection = internal::GetCollection(cc, stream_port);
auto* packet_type = internal::GetOrNull(stream_collection, Tag(), 0);
if (packet_type == nullptr) {
auto& side_collection = internal::GetCollection(cc, side_port);
packet_type = internal::GetOrNull(side_collection, Tag(), 0);
is_stream = false;
}
return internal::SinglePortAccess<PayloadT>(cc, packet_type, is_stream);
}
auto operator()(CalculatorContext* cc) const {
auto& stream_collection = internal::GetCollection(cc, stream_port);
auto& side_collection = internal::GetCollection(cc, side_port);
return internal::SinglePortAccess<PayloadT>(
cc, internal::GetOrNull(stream_collection, Tag(), 0),
internal::GetOrNull(side_collection, Tag(), 0));
}
template <std::size_t N>
explicit constexpr SideFallbackT(const char (&tag)[N])
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV),
stream_port(tag),
side_port(tag) {}
protected:
mediapipe::Status AddToContract(CalculatorContract* cc) const {
stream_port.AddToContract(cc);
side_port.AddToContract(cc);
int connected_count =
stream_port(cc).IsConnected() + side_port(cc).IsConnected();
if (connected_count > 1)
return mediapipe::InvalidArgumentError(absl::StrCat(
Tag(),
" can be connected as a stream or as a side packet, but not both"));
if (!IsOptionalV && connected_count == 0)
return mediapipe::InvalidArgumentError(
absl::StrCat(Tag(), " must be connected"));
return {};
}
using StreamPort = PortCommon<Base, ValueT, true, IsMultipleV>;
using SidePort = PortCommon<typename internal::SideBase<Base>::type, ValueT,
true, IsMultipleV>;
StreamPort stream_port;
SidePort side_port;
template <typename...>
friend class internal::Contract;
};
// An OutputShardAccess is returned when accessing an output stream within a
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
// OutputStreamShard. Like that class, this class will not be usually named in
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
class OutputShardAccessBase {
public:
OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)
: context_(cc), output_(output) {}
void SetNextTimestampBound(Timestamp timestamp) {
if (output_) output_->SetNextTimestampBound(timestamp);
}
bool IsClosed() { return output_ ? output_->IsClosed() : true; }
void Close() {
if (output_) output_->Close();
}
bool IsConnected() { return output_ != nullptr; }
protected:
const CalculatorContext& context_;
OutputStreamShard* output_;
};
template <typename T>
class OutputShardAccess : public OutputShardAccessBase {
public:
void Send(Packet<T>&& packet) {
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
}
void Send(const Packet<T>& packet) {
if (output_) output_->AddPacket(ToOldPacket(packet));
}
void Send(const T& payload, Timestamp time) {
Send(api2::MakePacket<T>(payload).At(time));
}
void Send(const T& payload) { Send(payload, context_.InputTimestamp()); }
void Send(std::unique_ptr<T> payload, Timestamp time) {
Send(api2::PacketAdopting(std::move(payload)).At(time));
}
void Send(std::unique_ptr<T> payload) {
Send(std::move(payload), context_.InputTimestamp());
}
private:
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
: OutputShardAccessBase(cc, output) {}
friend OutputShardAccess<T> internal::SinglePortAccess<T>(
mediapipe::CalculatorContext*, OutputStreamShard*);
};
template <>
class OutputShardAccess<internal::Generic> : public OutputShardAccessBase {
public:
void Send(PacketBase&& packet) {
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
}
void Send(const PacketBase& packet) {
if (output_) output_->AddPacket(ToOldPacket(packet));
}
void SetHeader(const PacketBase& header) {
if (output_) output_->SetHeader(ToOldPacket(header));
}
private:
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
: OutputShardAccessBase(cc, output) {}
friend OutputShardAccess<internal::Generic>
internal::SinglePortAccess<internal::Generic>(mediapipe::CalculatorContext*,
OutputStreamShard*);
};
// Equivalent of OutputShardAccess, but for side packets.
template <typename T>
class OutputSidePacketAccess {
public:
void Set(Packet<T> packet) {
if (output_) output_->Set(ToOldPacket(std::move(packet)));
}
void Set(const T& payload) { Set(MakePacket<T>(payload)); }
private:
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}
OutputSidePacket* output_;
friend OutputSidePacketAccess<T> internal::SinglePortAccess<T>(
mediapipe::CalculatorContext*, OutputSidePacket*);
};
template <typename T>
class InputShardAccess : public Packet<T> {
public:
const PacketBase& packet() const& { return *this; }
// Since InputShardAccess is currently created as a temporary, this avoids
// easy mistakes with dangling references.
PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
private:
InputShardAccess(const CalculatorContext&, const InputStreamShard* stream)
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
: Packet<T>()),
stream_(stream) {}
const InputStreamShard* stream_;
friend InputShardAccess<T> internal::SinglePortAccess<T>(
mediapipe::CalculatorContext*, const InputStreamShard*);
};
template <typename T>
class InputSidePacketAccess : public Packet<T> {
public:
const PacketBase& packet() const& { return *this; }
PacketBase packet() const&& { return *this; }
bool IsConnected() { return connected_; }
private:
InputSidePacketAccess(const mediapipe::Packet* packet)
: Packet<T>(packet ? FromOldPacket(*packet).template As<T>()
: Packet<T>()),
connected_(packet != nullptr) {}
bool connected_;
friend InputSidePacketAccess<T> internal::SinglePortAccess<T>(
mediapipe::CalculatorContext*, const mediapipe::Packet*);
};
template <typename T>
class InputShardOrSideAccess : public Packet<T> {
public:
const PacketBase& packet() const& { return *this; }
PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return connected_; }
bool IsStream() { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
private:
InputShardOrSideAccess(const CalculatorContext&,
const InputStreamShard* stream,
const mediapipe::Packet* packet)
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
: packet ? FromOldPacket(*packet).template As<T>()
: Packet<T>()),
stream_(stream),
connected_(stream_ != nullptr || packet != nullptr) {}
const InputStreamShard* stream_;
bool connected_;
friend InputShardOrSideAccess<T> internal::SinglePortAccess<T>(
mediapipe::CalculatorContext*, const InputStreamShard*,
const mediapipe::Packet*);
};
class PacketTypeAccess {
public:
bool IsConnected() { return packet_type_ != nullptr; }
protected:
PacketTypeAccess(PacketType* pt) : packet_type_(pt) {}
PacketType* packet_type_;
template <typename T>
friend PacketTypeAccess internal::SinglePortAccess(
mediapipe::CalculatorContract*, PacketType*);
};
class PacketTypeAccessFallback : public PacketTypeAccess {
public:
bool IsStream() { return is_stream_; }
private:
PacketTypeAccessFallback(PacketType* pt, bool is_stream)
: PacketTypeAccess(pt), is_stream_(is_stream) {}
bool is_stream_;
template <typename T>
friend PacketTypeAccessFallback internal::SinglePortAccess(
mediapipe::CalculatorContract*, PacketType*, bool);
};
namespace internal {
template <typename ValueT>
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
PacketType* pt) {
return PacketTypeAccess(pt);
}
template <typename ValueT>
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
PacketType* pt, bool is_stream) {
return PacketTypeAccessFallback(pt, is_stream);
}
} // namespace internal
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_PORT_H_

View File

@ -0,0 +1,26 @@
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace {
TEST(PortTest, IntInput) {
static constexpr auto port = Input<int>("FOO");
EXPECT_EQ(port.type_id(), typeid(int).hash_code());
}
TEST(PortTest, OptionalInput) {
static constexpr auto port = Input<float>::Optional("BAR");
EXPECT_TRUE(port.IsOptional());
}
TEST(PortTest, Tag) {
static constexpr auto port = Input<int>("FOO");
EXPECT_EQ(std::string(port.Tag()), "FOO");
}
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,157 @@
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/test_contracts.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/subgraph_expansion.h"
namespace mediapipe {
namespace api2 {
namespace test {
class FooBarImpl1 : public SubgraphImpl<FooBar1, FooBarImpl1> {
public:
mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& /*options*/) {
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
auto& bar = graph.AddNode("Bar");
graph.In(kIn) >> foo.In("BASE");
foo.Out("OUT") >> bar.In("IN");
bar.Out("OUT") >> graph.Out(kOut);
return graph.GetConfig();
}
};
class FooBarImpl2 : public SubgraphImpl<FooBar2, FooBarImpl2> {
public:
mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& /*options*/) {
builder::Graph graph;
auto& foo = graph.AddNode<Foo>();
auto& bar = graph.AddNode<Bar>();
graph.In(kIn) >> foo.In(MPP_TAG("BASE"));
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
bar.Out(MPP_TAG("OUT")) >> graph.Out(kOut);
return graph.GetConfig();
}
};
TEST(SubgraphTest, SubgraphConfig) {
CalculatorGraphConfig subgraph = FooBarImpl1().GetConfig({}).ValueOrDie();
const CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:__stream_0"
output_stream: "OUT:__stream_2"
node {
calculator: "Foo"
input_stream: "BASE:__stream_0"
output_stream: "OUT:__stream_1"
}
node {
calculator: "Bar"
input_stream: "IN:__stream_1"
output_stream: "OUT:__stream_2"
}
)");
EXPECT_THAT(subgraph, EqualsProto(expected_graph));
}
TEST(SubgraphTest, TypedSubgraphConfig) {
CalculatorGraphConfig subgraph = FooBarImpl2().GetConfig({}).ValueOrDie();
const CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:__stream_0"
output_stream: "OUT:__stream_2"
node {
calculator: "Foo"
input_stream: "BASE:__stream_0"
output_stream: "OUT:__stream_1"
}
node {
calculator: "Bar"
input_stream: "IN:__stream_1"
output_stream: "OUT:__stream_2"
}
)");
EXPECT_THAT(subgraph, EqualsProto(expected_graph));
}
TEST(SubgraphTest, ProtoApiConfig) {
CalculatorGraphConfig graph;
graph.add_input_stream("IN:__stream_0");
graph.add_output_stream("OUT:__stream_2");
auto* foo = graph.add_node();
foo->set_calculator("Foo");
foo->add_input_stream("BASE:__stream_0");
foo->add_output_stream("OUT:__stream_1");
auto* bar = graph.add_node();
bar->set_calculator("Bar");
bar->add_input_stream("IN:__stream_1");
bar->add_output_stream("OUT:__stream_2");
const CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "IN:__stream_0"
output_stream: "OUT:__stream_2"
node {
calculator: "Foo"
input_stream: "BASE:__stream_0"
output_stream: "OUT:__stream_1"
}
node {
calculator: "Bar"
input_stream: "IN:__stream_1"
output_stream: "OUT:__stream_2"
}
)");
EXPECT_THAT(graph, EqualsProto(expected_graph));
}
TEST(SubgraphTest, ExpandSubgraphs) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
name: "simple_source"
calculator: "SomeSourceCalculator"
output_stream: "foo"
}
node {
calculator: "FooBar"
input_stream: "IN:foo"
output_stream: "OUT:output"
}
)");
const CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
name: "simple_source"
calculator: "SomeSourceCalculator"
output_stream: "foo"
}
node {
name: "foobar__Foo"
calculator: "Foo"
input_stream: "BASE:foo"
output_stream: "OUT:foobar____stream_1"
}
node {
name: "foobar__Bar"
calculator: "Bar"
input_stream: "IN:foobar____stream_1"
output_stream: "OUT:output"
}
)");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
EXPECT_THAT(supergraph, EqualsProto(expected_graph));
}
} // namespace test
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,72 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_TAG_H_
#define MEDIAPIPE_FRAMEWORK_API2_TAG_H_
#include <utility>
#include "mediapipe/framework/api2/const_str.h"
namespace mediapipe {
namespace api2 {
// This template is used to define a separate type for each tag.
// This makes it possible to obtain results of different types depending on
// the tag. See MPP_TAG below for usage examples.
template <char... C>
struct Tag {
static constexpr char const kChars[sizeof...(C) + 1] = {C..., '\0'};
static constexpr const_str const kStr{kChars};
static const std::string str() {
return std::string(kStr.data(), kStr.len());
}
template <char... Q>
constexpr bool operator==(const Tag<Q...>& other) const {
return kStr == other.kStr;
}
template <char... Q>
constexpr bool operator!=(const Tag<Q...>& other) const {
return !(*this == other);
}
};
template <char... C>
constexpr bool is_tag(Tag<C...>) {
return true;
}
template <typename A>
constexpr bool is_tag(A) {
return false;
}
namespace internal {
template <typename S, std::size_t... I>
constexpr auto tag_build_impl(S, std::index_sequence<I...>)
-> Tag<S().tag[I]...> {
return {};
}
template <typename S>
constexpr auto tag_build(S) {
return tag_build_impl(S(), std::make_index_sequence<S().tag.len()>{});
}
} // namespace internal
// Use this to create typed tag objects.
// For example:
// auto kFOO = MPP_TAG(FOO);
// auto kBAR = MPP_TAG(BAR);
#define MPP_TAG(s) \
([] { \
struct S { \
const const_str tag{s}; \
}; \
return ::mediapipe::api2::internal::tag_build(S()); \
}())
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_TAG_H_

View File

@ -0,0 +1,48 @@
#include "mediapipe/framework/api2/tag.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace {
template <typename A, typename B>
constexpr bool same_type(A, B) {
return false;
}
template <typename A>
constexpr bool same_type(A, A) {
return true;
}
auto kFOO = MPP_TAG("FOO");
auto kFOO2 = MPP_TAG("FOO");
auto kBAR = MPP_TAG("BAR");
TEST(TagTest, String) {
EXPECT_EQ(kFOO.str(), "FOO");
EXPECT_EQ(kBAR.str(), "BAR");
}
// Separate invocations of MPP_TAG with the same std::string produce objects of
// the same type.
TEST(TagTest, SameType) { EXPECT_TRUE(same_type(kFOO, kFOO2)); }
// Different tags have different types.
TEST(TagTest, DifferentType) { EXPECT_FALSE(same_type(kFOO, kBAR)); }
TEST(TagTest, Equal) {
EXPECT_EQ(kFOO, kFOO2);
EXPECT_NE(kFOO, kBAR);
}
TEST(TagTest, IsTag) {
EXPECT_TRUE(is_tag(kFOO));
EXPECT_FALSE(is_tag("FOO"));
}
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,87 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_
#define MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_
#include "mediapipe/framework/api2/node.h"
namespace mediapipe {
namespace api2 {
namespace test {
struct Foo : public NodeIntf {
static constexpr Input<int> kBase{"BASE"};
static constexpr Input<float>::Optional kScale{"SCALE"};
static constexpr Output<float> kOut{"OUT"};
static constexpr SideInput<float>::Optional kBias{"BIAS"};
MEDIAPIPE_NODE_INTERFACE(Foo, kBase, kScale, kOut, kBias);
};
struct Foo2 : public NodeIntf {
// clang-format off
static constexpr auto kPorts = std::make_tuple(
Input<int>{"BASE"},
Input<float>::Optional{"SCALE"},
Output<float>{"OUT"},
SideInput<float>::Optional{"BIAS"}
);
// clang-format on
MEDIAPIPE_NODE_INTERFACE(Foo2, kPorts);
};
struct Bar : public NodeIntf {
static constexpr Input<AnyType> kIn{"IN"};
// Should all outputs be treated as optional by default?
static constexpr Output<SameType<kIn>>::Optional kOut{"OUT"};
MEDIAPIPE_NODE_INTERFACE(Bar, kIn, kOut);
};
struct Baz : public NodeIntf {
static constexpr Input<AnyType>::Multiple kData{"DATA"};
// Should all outputs be treated as optional by default?
static constexpr Output<SameType<kData>>::Multiple kDataOut{"DATA"};
MEDIAPIPE_NODE_INTERFACE(Baz, kData, kDataOut);
};
struct IntForwarder : public NodeIntf {
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"OUT"};
MEDIAPIPE_NODE_INTERFACE(IntForwarder, kIn, kOut);
};
struct FloatAdder : public NodeIntf {
static constexpr Input<float>::Multiple kIn{"IN"};
static constexpr Output<float> kOut{"OUT"};
MEDIAPIPE_NODE_INTERFACE(FloatAdder, kIn, kOut);
};
struct ToFloat : public NodeIntf {
static constexpr Input<OneOf<float, int>> kIn{"IN"};
static constexpr Output<float> kOut{"OUT"};
MEDIAPIPE_NODE_INTERFACE(ToFloat, kIn, kOut);
};
struct FooBar : public NodeIntf {
static constexpr Input<int> kIn{"IN"};
static constexpr Output<float> kOut{"OUT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
};
struct FooBar1 : public FooBar {
static constexpr char kCalculatorName[] = "FooBar";
};
struct FooBar2 : public FooBar {
static constexpr char kCalculatorName[] = "FooBar2";
};
} // namespace test
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_TEST_CONTRACTS_H_

View File

@ -0,0 +1,187 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_
#define MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/meta/type_traits.h"
// This file contains utilities for working with constexpr tuples.
namespace mediapipe {
namespace api2 {
namespace internal {
// Defines a std::index_sequence with indices for each item of the tuple.
template <class Tuple>
using tuple_index_sequence =
std::make_index_sequence<std::tuple_size_v<std::decay_t<Tuple>>>;
// Concatenates two std::index_sequences.
template <std::size_t... I, std::size_t... J>
constexpr auto index_sequence_cat(std::index_sequence<I...>,
std::index_sequence<J...>)
-> std::index_sequence<I..., J...> {
return {};
}
template <std::size_t... I, std::size_t... J, class... Tail>
constexpr auto index_sequence_cat(std::index_sequence<I...>,
std::index_sequence<J...>, Tail... tail) {
return index_sequence_cat(std::index_sequence<I..., J...>(), tail...);
}
template <template <typename...> class Pred, typename Tuple, std::size_t... I>
constexpr auto filtered_tuple_indices_impl(Tuple&& t,
std::index_sequence<I...>) {
return index_sequence_cat(
std::conditional_t<
Pred<std::tuple_element_t<I, std::decay_t<Tuple>>>::value,
std::index_sequence<I>, std::index_sequence<>>{}...);
}
// Returns a std::index_sequence with the indices of the tuple items whose
// type satisfied Pred.
template <template <typename...> class Pred, typename Tuple>
constexpr auto filtered_tuple_indices(Tuple&& tuple) {
return filtered_tuple_indices_impl<Pred>(tuple,
tuple_index_sequence<Tuple>());
}
// Convenience type to pass any type as a value.
template <typename T>
struct Wrap {
using type = T;
};
template <class F, typename Tuple, std::size_t... I>
constexpr auto filtered_tuple_indices_impl(Tuple&& t,
std::index_sequence<I...>) {
return index_sequence_cat(
std::conditional_t<
F{}(Wrap<std::tuple_element_t<I, std::decay_t<Tuple>>>{}),
std::index_sequence<I>, std::index_sequence<>>{}...);
}
// Returns a std::index_sequence with the indices of the tuple items for which
// F{}(Wrap<item_type>) returns true.
template <class F, typename Tuple>
constexpr auto filtered_tuple_indices(Tuple&& tuple) {
return filtered_tuple_indices_impl<F>(std::forward<Tuple>(tuple),
tuple_index_sequence<Tuple>());
}
// Returns a tuple of references to the tuple items with the specified indices.
template <typename Tuple, std::size_t... I>
constexpr auto select_tuple_indices(Tuple&& tuple, std::index_sequence<I...>) {
return std::forward_as_tuple(std::get<I>(std::forward<Tuple>(tuple))...);
}
// Returns a tuple of references to the tuple items whose types satisfy Pred.
template <template <typename...> class Pred, typename Tuple>
constexpr auto filter_tuple(Tuple&& t) {
return select_tuple_indices(std::forward<Tuple>(t),
filtered_tuple_indices<Pred>(t));
}
// Returns a tuple of references to the tuple items for which
// F{}(Wrap<item_type>) returns true.
template <typename F, typename Tuple>
constexpr auto filter_tuple(Tuple&& t) {
return select_tuple_indices(
std::forward<Tuple>(t),
filtered_tuple_indices<F>(std::forward<Tuple>(t)));
}
// TODO: ensure only one of these is enabled?
template <class F, class T, class I>
constexpr auto call_with_optional_index(F&& f, T&& t, I i)
-> absl::void_t<decltype(f(std::forward<T>(t), i))> {
return f(std::forward<T>(t), i);
}
template <class F, class T, class I>
constexpr auto call_with_optional_index(F&& f, T&& t, I i)
-> absl::void_t<decltype(f(std::forward<T>(t)))> {
return f(std::forward<T>(t));
}
template <class F, class Tuple, std::size_t... I>
constexpr void tuple_for_each_impl(F&& f, Tuple&& tuple,
std::index_sequence<I...>) {
int unpack[] = {
(call_with_optional_index(std::forward<F>(f),
std::get<I>(std::forward<Tuple>(tuple)),
std::integral_constant<std::size_t, I>{}),
0)...};
(void)unpack;
}
// Invokes f for each item in tuple.
// If f takes one argument, it will be called as f(item).
// If f takes two arguments, it will be called as
// f(item, std::integral_constant<std::size_t, index>{}).
template <class F, class Tuple>
constexpr void tuple_for_each(F&& f, Tuple&& tuple) {
return tuple_for_each_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
tuple_index_sequence<Tuple>());
}
template <class F, class Tuple, std::size_t... I>
constexpr auto map_tuple_impl(F&& f, Tuple&& tuple, std::index_sequence<I...>) {
return std::make_tuple(f(std::get<I>(std::forward<Tuple>(tuple)))...);
}
// Returns a tuple where each item is the result of calling f on the
// corresponding item of the provided tuple.
template <class F, class Tuple>
constexpr auto map_tuple(F&& f, Tuple&& tuple) {
return map_tuple_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
tuple_index_sequence<Tuple>());
}
template <class F, class Tuple, std::size_t... I>
constexpr auto tuple_apply_impl(F&& f, Tuple&& tuple,
std::index_sequence<I...>) {
return f(std::get<I>(std::forward<Tuple>(tuple))...);
}
// Invokes f passing the tuple's items as arguments.
template <class F, class Tuple>
constexpr auto tuple_apply(F&& f, Tuple&& tuple) {
return tuple_apply_impl(std::forward<F>(f), std::forward<Tuple>(tuple),
tuple_index_sequence<Tuple>());
}
// Returns the index [0, tuple_size) of the first item for which f returns true,
// or tuple_size if no such item is found.
template <class F, class Tuple, std::size_t i = 0>
constexpr std::enable_if_t<i == std::tuple_size_v<std::decay_t<Tuple>>,
std::size_t>
tuple_find(F&& f, Tuple&& tuple) {
return i;
}
template <class F, class Tuple, std::size_t i = 0>
constexpr std::enable_if_t<i != std::tuple_size_v<std::decay_t<Tuple>>,
std::size_t>
tuple_find(F&& f, Tuple&& tuple) {
if (f(std::get<i>(std::forward<Tuple>(tuple)))) {
return i;
}
return tuple_find<F, Tuple, i + 1>(std::forward<F>(f),
std::forward<Tuple>(tuple));
}
template <class Tuple>
constexpr auto flatten_tuple(Tuple&& tuple) {
return tuple_apply([](auto&&... args) { return std::tuple_cat(args...); },
tuple);
}
} // namespace internal
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_TUPLE_H_

View File

@ -0,0 +1,147 @@
#include "mediapipe/framework/api2/tuple.h"
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace internal {
namespace {
template <typename A, typename B>
constexpr bool same_type(A, B) {
return false;
}
template <typename A>
constexpr bool same_type(A, A) {
return true;
}
template <std::size_t... I>
using iseq = std::index_sequence<I...>;
TEST(TupleTest, IndexSeq) {
EXPECT_TRUE(
same_type(iseq<0, 1, 2>(), index_sequence_cat(iseq<0, 1>(), iseq<2>())));
EXPECT_TRUE(same_type(iseq<0, 1, 2>(),
index_sequence_cat(iseq<0, 1>(), iseq<>(), iseq<2>())));
}
TEST(TupleTest, FilteredIndices) {
EXPECT_TRUE(same_type(
filtered_tuple_indices<std::is_integral>(std::tuple<int, float, char>()),
iseq<0, 2>()));
}
TEST(TupleTest, SelectIndices) {
auto t = std::make_tuple(5.0, 10, "hi");
EXPECT_EQ((select_tuple_indices(t, iseq<0, 2>())),
(std::make_tuple(5.0, "hi")));
}
TEST(TupleTest, FilterTuple) {
auto t = std::make_tuple(5.0, 10, "hi");
EXPECT_EQ((filter_tuple<std::is_integral>(t)), (std::make_tuple(10)));
}
TEST(TupleTest, FilterTupleRefs) {
auto t = std::make_tuple(5.0, 10, "hi");
auto tr = filter_tuple<std::is_integral>(t);
int x;
EXPECT_TRUE(same_type(tr, std::tuple<int&>{x}));
EXPECT_FALSE(same_type(tr, std::tuple<int>{x}));
auto tr_copy =
std::apply([](auto&&... item) { return std::make_tuple(item...); },
filter_tuple<std::is_integral>(t));
EXPECT_TRUE(same_type(tr_copy, std::tuple<int>{x}));
}
struct is_integral {
template <class W>
constexpr bool operator()(W&&) {
return std::is_integral<typename W::type>{};
}
};
TEST(TupleTest, FilteredIndices2) {
EXPECT_TRUE(same_type(
filtered_tuple_indices<is_integral>(std::tuple<int, float, char>()),
iseq<0, 2>()));
}
// TEST(TupleTest, FilterTuple2) {
// auto t = std::make_tuple(5.0, 10, "hi");
// auto is_int = [](auto&& x) {
// return std::is_integral_v<decltype(x)>;
// };
// EXPECT_EQ((filter_tuple(is_int, t)), (std::make_tuple(10)));
// }
TEST(TupleTest, ForEach) {
auto t = std::make_tuple(5.0, 10, "hi");
std::vector<std::string> s;
tuple_for_each([&s](auto&& item) { s.push_back(absl::StrCat(item)); }, t);
EXPECT_EQ(s, (std::vector<std::string>{"5", "10", "hi"}));
}
TEST(TupleTest, ForEachWithIndex) {
auto t = std::make_tuple(5.0, 10, "hi");
std::vector<std::string> s;
tuple_for_each(
[&s](auto&& item, std::size_t i) {
s.push_back(absl::StrCat(i, ":", item));
},
t);
EXPECT_EQ(s, (std::vector<std::string>{"0:5", "1:10", "2:hi"}));
}
TEST(TupleTest, ForEachZip) {
auto t = std::make_tuple(5.0, 10, "hi");
auto u = std::make_tuple(2.0, 3, "lo");
std::vector<std::string> s;
tuple_for_each(
[&s, &u](auto&& item, auto i_const) {
constexpr std::size_t i = decltype(i_const)::value;
s.push_back(absl::StrCat(i, ":", item, ",", std::get<i>(u)));
},
t);
EXPECT_EQ(s, (std::vector<std::string>{"0:5,2", "1:10,3", "2:hi,lo"}));
}
TEST(TupleTest, Apply) {
auto t = std::make_tuple(5.0, 10, "hi");
std::string s = tuple_apply(
[](float f, int i, const char* s) { return absl::StrCat(f, i, s); }, t);
EXPECT_EQ(s, "510hi");
}
TEST(TupleTest, Map) {
auto t = std::make_tuple(5.0, 10, 2L);
auto t2 = map_tuple([](auto x) { return x * 2; }, t);
EXPECT_EQ(t2, std::make_tuple(10.0, 20, 4L));
}
TEST(TupleFind, Find) {
auto t = std::make_tuple(5.0, 10, 2L);
auto i = tuple_find([](auto x) { return x > 3; }, t);
EXPECT_EQ(i, 0);
}
TEST(TupleFind, Flatten) {
auto t1 = std::make_tuple(5.0, 10);
auto t2 = std::make_tuple(2L);
auto t = std::make_tuple(t1, t2);
auto tf = flatten_tuple(t);
EXPECT_EQ(tf, std::make_tuple(5.0, 10, 2L));
}
} // namespace
} // namespace internal
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,136 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_
#define MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_
#include <string>
#include <type_traits>
#include <utility>
namespace mediapipe {
namespace api2 {
namespace types {
// A list of types. This allows us to store a template parameter pack.
template <typename... Args>
struct List {};
// Concatenate two lists.
template <typename... As, typename... Bs>
auto concat(List<As...>, List<Bs...>) -> List<As..., Bs...> {
return {};
}
// Filter a list using a predicate.
template <template <typename> class Pred, typename... Args>
auto filter(List<Args...>) -> List<Args...> {
return {};
}
template <template <typename> class Pred, typename Head, typename... Tail>
auto filter(List<Head, Tail...>) -> decltype(concat(
typename std::conditional<Pred<Head>::value, List<Head>, List<>>::type{},
filter<Pred>(List<Tail...>{}))) {
return {};
}
template <typename Pred>
auto filter(Pred, List<>) -> List<> {
return {};
}
template <typename Pred, typename Head, typename... Tail>
auto filter(Pred pred, List<Head, Tail...>) -> decltype(concat(
typename std::conditional<pred(Head{}), List<Head>, List<>>::type{},
filter(pred, List<Tail...>{}))) {
return {};
}
// Invoke a template using a list's types as parameters.
template <template <typename...> class T, typename... Args>
auto apply(List<Args...>) -> T<Args...> {
return {};
}
// Wraps a single type. The wrapper can always be instantiated as a value,
// even if T cannot.
template <typename T>
struct Wrap {
using type = T;
};
// Find first match for a predicate.
template <template <typename> class Pred, typename... Args>
auto find(List<Args...>) -> Wrap<void> {
return {};
}
template <template <typename> class Pred, typename Head, typename... Tail>
auto find(List<Head, Tail...>) ->
typename std::conditional<Pred<Head>::value, Wrap<Head>,
decltype(find<Pred>(List<Tail...>{}))>::type {
return {};
}
template <class Pred, typename... Args>
auto find(Pred, List<Args...>) -> Wrap<void> {
return {};
}
template <class Pred, typename Head, typename... Tail>
auto find(Pred pred, List<Head, Tail...>) ->
typename std::conditional<pred(Head{}), Wrap<Head>,
decltype(find(pred, List<Tail...>{}))>::type {
return {};
}
// Apply a function to each item in a list.
template <template <typename> class Fun, typename... Items>
auto map(List<Items...>) -> List<typename Fun<Items>::type...> {
return {};
}
// Get the list's head.
template <typename... Args>
constexpr auto head(List<Args...>) -> Wrap<void> {
return {};
}
template <typename H, typename... T>
constexpr auto head(List<H, T...>) -> Wrap<H> {
return {};
}
// Get the list's length.
template <typename... Args>
constexpr std::size_t length(List<Args...>) {
return 0;
}
template <typename H, typename... T>
constexpr std::size_t length(List<H, T...>) {
return length(List<T...>{}) + 1;
}
// Add indices.
template <std::size_t I, typename T>
struct IndexedType {
static constexpr std::size_t kIndex = I;
using type = T;
};
template <typename... Args, std::size_t... Is>
auto enumerate_impl(List<Args...>, std::index_sequence<Is...>)
-> List<IndexedType<Is, Args>...> {
return {};
}
template <typename... Args>
auto enumerate(List<Args...> a)
-> decltype(enumerate_impl(a, std::index_sequence_for<Args...>{})) {
return {};
}
} // namespace types
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_API2_TYPE_LIST_H_

View File

@ -0,0 +1,101 @@
#include "mediapipe/framework/api2/type_list.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace api2 {
namespace types {
namespace {
template <typename A, typename B>
constexpr bool same_type(A, B) {
return false;
}
template <typename A>
constexpr bool same_type(A, A) {
return true;
}
struct Foo {};
struct Bar {};
struct Baz {};
TEST(TypeListFTest, SameType) {
EXPECT_FALSE(same_type(List<Foo>{}, List<>{}));
EXPECT_TRUE(same_type(List<Foo>{}, List<Foo>{}));
}
TEST(TypeListFTest, Length) {
EXPECT_EQ(length(List<float, int>{}), 2);
EXPECT_EQ(length(List<>{}), 0);
}
TEST(TypeListFTest, Head) {
using Empty = List<>;
using ListA = List<Foo, Bar>;
EXPECT_TRUE(same_type(Wrap<Foo>{}, head(ListA{})));
EXPECT_TRUE(same_type(Wrap<void>{}, head(Empty{})));
}
TEST(TypeListFTest, Concat) {
using Empty = List<>;
using ListA = List<Foo>;
EXPECT_TRUE(same_type(ListA{}, concat(ListA{}, Empty{})));
EXPECT_TRUE(same_type(concat(ListA{}, Empty{}), ListA{}));
using ListB = List<Bar, Baz>;
EXPECT_TRUE(same_type(concat(ListA{}, ListB{}), List<Foo, Bar, Baz>{}));
}
TEST(TypeListFTest, Filter) {
EXPECT_TRUE(same_type(filter<std::is_integral>(List<>{}), List<>{}));
EXPECT_TRUE(same_type(filter<std::is_integral>(List<int, float, char>{}),
List<int, char>{}));
}
TEST(TypeListFTest, Filter2) {
constexpr auto is_integral = [](auto x) {
return std::is_integral<decltype(x)>{};
};
auto x = filter(is_integral, List<>{});
EXPECT_TRUE(same_type(x, List<>{}));
auto y = filter(is_integral, List<int, float, char>{});
EXPECT_TRUE(same_type(y, List<int, char>{}));
auto z = filter([](auto x) { return std::is_integral<decltype(x)>{}; },
List<int, double>{});
EXPECT_TRUE(same_type(z, List<int>{}));
}
TEST(TypeListFTest, Find) {
EXPECT_TRUE(same_type(find<std::is_integral>(List<>{}), Wrap<void>{}));
EXPECT_TRUE(
same_type(find<std::is_integral>(List<float, int>{}), Wrap<int>()));
}
TEST(TypeListFTest, Find2) {
constexpr auto is_integral = [](auto x) {
return std::is_integral<decltype(x)>{};
};
EXPECT_TRUE(same_type(find(is_integral, List<>{}), Wrap<void>{}));
EXPECT_TRUE(same_type(find(is_integral, List<float, int>{}), Wrap<int>()));
}
TEST(TypeListFTest, Map) {
EXPECT_TRUE(
same_type(map<std::remove_cv>(List<const int, const float, const char>{}),
List<int, float, char>{}));
}
TEST(TypeListFTest, Enumerate) {
EXPECT_TRUE(same_type(enumerate(List<int, float, char>{}),
List<IndexedType<0, int>, IndexedType<1, float>,
IndexedType<2, char>>{}));
}
} // namespace
} // namespace types
} // namespace api2
} // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include <type_traits>
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/deps/registration.h"
@ -168,22 +169,23 @@ class CalculatorBase {
virtual Timestamp SourceProcessOrder(const CalculatorContext* cc) const;
};
using CalculatorBaseRegistry =
GlobalFactoryRegistry<std::unique_ptr<CalculatorBase>>;
namespace api2 {
class Node;
} // namespace api2
namespace internal {
// Gives access to the static functions within subclasses of CalculatorBase.
// This adds functionality akin to virtual static functions.
class StaticAccessToCalculatorBase {
class CalculatorBaseFactory {
public:
virtual ~StaticAccessToCalculatorBase() {}
virtual ~CalculatorBaseFactory() {}
virtual mediapipe::Status GetContract(CalculatorContract* cc) = 0;
virtual std::unique_ptr<CalculatorBase> CreateCalculator(
CalculatorContext* calculator_context) = 0;
virtual std::string ContractMethodName() { return "GetContract"; }
};
using StaticAccessToCalculatorBaseRegistry =
GlobalFactoryRegistry<std::unique_ptr<StaticAccessToCalculatorBase>>;
// Functions for checking that the calculator has the required GetContract.
template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
@ -197,14 +199,21 @@ constexpr bool CalculatorHasGetContract(...) {
// Provides access to the static functions within a specific subclass
// of CalculatorBase.
template <typename CalculatorBaseSubclass>
class StaticAccessToCalculatorBaseTyped : public StaticAccessToCalculatorBase {
public:
static_assert(
std::is_base_of<mediapipe::CalculatorBase, CalculatorBaseSubclass>::value,
template <class T, class Enable = void>
class CalculatorBaseFactoryFor : public CalculatorBaseFactory {
static_assert(std::is_base_of<mediapipe::CalculatorBase, T>::value,
"Classes registered with REGISTER_CALCULATOR must be "
"subclasses of mediapipe::CalculatorBase.");
static_assert(CalculatorHasGetContract<CalculatorBaseSubclass>(nullptr),
};
template <class T>
class CalculatorBaseFactoryFor<
T,
typename std::enable_if<std::is_base_of<mediapipe::CalculatorBase, T>{} &&
!std::is_base_of<mediapipe::api2::Node, T>{}>::type>
: public CalculatorBaseFactory {
public:
static_assert(CalculatorHasGetContract<T>(nullptr),
"GetContract() must be defined with the correct signature in "
"every calculator.");
@ -213,12 +222,20 @@ class StaticAccessToCalculatorBaseTyped : public StaticAccessToCalculatorBase {
mediapipe::Status GetContract(CalculatorContract* cc) final {
// CalculatorBaseSubclass must implement this function, since it is not
// implemented in the parent class.
return CalculatorBaseSubclass::GetContract(cc);
return T::GetContract(cc);
}
std::unique_ptr<CalculatorBase> CreateCalculator(
CalculatorContext* calculator_context) final {
return absl::make_unique<T>();
}
};
} // namespace internal
using CalculatorBaseRegistry =
GlobalFactoryRegistry<std::unique_ptr<internal::CalculatorBaseFactory>>;
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_

View File

@ -128,7 +128,6 @@ TEST(CalculatorTest, SourceProcessOrder) {
CalculatorContext calculator_context(&calculator_state,
tool::CreateTagMap({}).ValueOrDie(),
output_stream_managers.TagMap());
InputStreamShardSet& input_set = calculator_context.Inputs();
OutputStreamShardSet& output_set = calculator_context.Outputs();
output_set.Index(0).SetSpec(output_stream_managers.Index(0).Spec());
output_set.Index(0).SetNextTimestampBound(Timestamp(10));
@ -137,19 +136,6 @@ TEST(CalculatorTest, SourceProcessOrder) {
CalculatorContextManager().PushInputTimestampToContext(
&calculator_context, Timestamp::Unstarted());
InputStreamSet input_streams(input_set.TagMap());
OutputStreamSet output_streams(output_set.TagMap());
for (CollectionItemId id = input_streams.BeginId();
id < input_streams.EndId(); ++id) {
input_streams.Get(id) = &input_set.Get(id);
}
for (CollectionItemId id = output_streams.BeginId();
id < output_streams.EndId(); ++id) {
output_streams.Get(id) = &output_set.Get(id);
}
calculator_state.SetInputStreamSet(&input_streams);
calculator_state.SetOutputStreamSet(&output_streams);
test_ns::DeadEndCalculator calculator;
EXPECT_EQ(Timestamp(10), calculator.SourceProcessOrder(&calculator_context));
output_set.Index(0).SetNextTimestampBound(Timestamp(100));
@ -202,7 +188,8 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
// Register a whitelisted calculator.
CalculatorBaseRegistry::Register(
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
absl::make_unique<mediapipe::test_ns::whitelisted_ns::DeadCalculator>);
absl::make_unique<internal::CalculatorBaseFactoryFor<
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
// A whitelisted calculator can be found in its own namespace.
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //

View File

@ -66,11 +66,26 @@ void CalculatorContext::SetOffset(TimestampDiff offset) {
}
const InputStreamSet& CalculatorContext::InputStreams() const {
return calculator_state_->InputStreams();
if (!input_streams_) {
input_streams_ = absl::make_unique<InputStreamSet>(inputs_.TagMap());
for (CollectionItemId id = input_streams_->BeginId();
id < input_streams_->EndId(); ++id) {
input_streams_->Get(id) = const_cast<InputStreamShard*>(&inputs_.Get(id));
}
}
return *input_streams_;
}
const OutputStreamSet& CalculatorContext::OutputStreams() const {
return calculator_state_->OutputStreams();
if (!output_streams_) {
output_streams_ = absl::make_unique<OutputStreamSet>(outputs_.TagMap());
for (CollectionItemId id = output_streams_->BeginId();
id < output_streams_->EndId(); ++id) {
output_streams_->Get(id) =
const_cast<OutputStreamShard*>(&outputs_.Get(id));
}
}
return *output_streams_;
}
} // namespace mediapipe

View File

@ -163,6 +163,10 @@ class CalculatorContext {
CalculatorState* calculator_state_;
InputStreamShardSet inputs_;
OutputStreamShardSet outputs_;
// Created on-demand when needed by legacy APIs. No synchronization needed
// because all possible callers are already serialized.
mutable std::unique_ptr<InputStreamSet> input_streams_;
mutable std::unique_ptr<OutputStreamSet> output_streams_;
// The queue of timestamp values to Process() in this calculator context.
std::queue<Timestamp> input_timestamps_;

View File

@ -27,7 +27,6 @@
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_registry_util.h"
#include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/input_stream_manager.h"
#include "mediapipe/framework/mediapipe_profiling.h"
@ -374,15 +373,12 @@ mediapipe::Status CalculatorNode::PrepareForRun(
MP_RETURN_IF_ERROR(calculator_context_manager_.PrepareForRun(std::bind(
&CalculatorNode::ConnectShardsToStreams, this, std::placeholders::_1)));
auto calculator_statusor = CreateCalculator(
input_stream_handler_->InputTagMap(),
output_stream_handler_->OutputTagMap(), validated_graph_->Package(),
calculator_state_.get(),
ASSIGN_OR_RETURN(
auto calculator_factory,
CalculatorBaseRegistry::CreateByNameInNamespace(
validated_graph_->Package(), calculator_state_->CalculatorType()));
calculator_ = calculator_factory->CreateCalculator(
calculator_context_manager_.GetDefaultCalculatorContext());
if (!calculator_statusor.ok()) {
return calculator_statusor.status();
}
calculator_ = std::move(calculator_statusor).ValueOrDie();
needs_to_close_ = false;

View File

@ -19,14 +19,10 @@
#include "mediapipe/framework/calculator_base.h"
// Macro for registering calculators.
#define REGISTER_CALCULATOR(name) \
REGISTER_FACTORY_FUNCTION_QUALIFIED(mediapipe::CalculatorBaseRegistry, \
calculator_registration, name, \
absl::make_unique<name>); \
REGISTER_FACTORY_FUNCTION_QUALIFIED( \
mediapipe::internal::StaticAccessToCalculatorBaseRegistry, \
access_registration, name, \
absl::make_unique< \
mediapipe::internal::StaticAccessToCalculatorBaseTyped<name>>)
mediapipe::CalculatorBaseRegistry, calculator_registration, name, \
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<name>>)
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_H_

View File

@ -1,61 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_registry_util.h"
#include <algorithm>
#include <string>
#include "mediapipe/framework/collection.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
bool IsLegacyCalculator(const std::string& package_name,
const std::string& node_class) {
return false;
}
mediapipe::Status VerifyCalculatorWithContract(const std::string& package_name,
const std::string& node_class,
CalculatorContract* contract) {
// A number of calculators use the non-CC methods on GlCalculatorHelper
// even though they are CalculatorBase-based.
ASSIGN_OR_RETURN(
auto static_access_to_calculator_base,
internal::StaticAccessToCalculatorBaseRegistry::CreateByNameInNamespace(
package_name, node_class),
_ << "Unable to find Calculator \"" << node_class << "\"");
MP_RETURN_IF_ERROR(static_access_to_calculator_base->GetContract(contract))
.SetPrepend()
<< node_class << ": ";
return mediapipe::OkStatus();
}
mediapipe::StatusOr<std::unique_ptr<CalculatorBase>> CreateCalculator(
const std::shared_ptr<tool::TagMap>& input_tag_map,
const std::shared_ptr<tool::TagMap>& output_tag_map,
const std::string& package_name, CalculatorState* calculator_state,
CalculatorContext* calculator_context) {
std::unique_ptr<CalculatorBase> calculator;
ASSIGN_OR_RETURN(calculator,
CalculatorBaseRegistry::CreateByNameInNamespace(
package_name, calculator_state->CalculatorType()));
return std::move(calculator);
}
} // namespace mediapipe

View File

@ -1,46 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_
#include <memory>
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_state.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/tag_map.h"
// Calculator registry util functions that supports both legacy Calculator API
// and CalculatorBase.
namespace mediapipe {
bool IsLegacyCalculator(const std::string& package_name,
const std::string& node_class);
mediapipe::Status VerifyCalculatorWithContract(const std::string& package_name,
const std::string& node_class,
CalculatorContract* contract);
mediapipe::StatusOr<std::unique_ptr<CalculatorBase>> CreateCalculator(
const std::shared_ptr<tool::TagMap>& input_tag_map,
const std::shared_ptr<tool::TagMap>& output_tag_map,
const std::string& package_name, CalculatorState* calculator_state,
CalculatorContext* calculator_context);
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_REGISTRY_UTIL_H_

View File

@ -33,8 +33,6 @@ CalculatorState::CalculatorState(
calculator_type_(calculator_type),
node_config_(node_config),
profiling_context_(profiling_context),
input_streams_(nullptr),
output_streams_(nullptr),
counter_factory_(nullptr) {
options_.Initialize(node_config);
ResetBetweenRuns();
@ -42,20 +40,8 @@ CalculatorState::CalculatorState(
CalculatorState::~CalculatorState() {}
void CalculatorState::SetInputStreamSet(InputStreamSet* input_stream_set) {
CHECK(input_stream_set);
input_streams_ = input_stream_set;
}
void CalculatorState::SetOutputStreamSet(OutputStreamSet* output_stream_set) {
CHECK(output_stream_set);
output_streams_ = output_stream_set;
}
void CalculatorState::ResetBetweenRuns() {
input_side_packets_ = nullptr;
input_streams_ = nullptr;
output_streams_ = nullptr;
counter_factory_ = nullptr;
}

View File

@ -52,14 +52,6 @@ class CalculatorState {
CalculatorState& operator=(const CalculatorState&) = delete;
~CalculatorState();
// Sets the pointer to the InputStreamSet. The function is invoked by
// CalculatorNode::PrepareForRun.
void SetInputStreamSet(InputStreamSet* input_stream_set);
// Sets the pointer to the OutputStreamSet. The function is invoked by
// CalculatorNode::PrepareForRun.
void SetOutputStreamSet(OutputStreamSet* output_stream_set);
// Called before every call to Calculator::Open() (during the PrepareForRun
// phase).
void ResetBetweenRuns();
@ -79,8 +71,6 @@ class CalculatorState {
////////////////////////////////////////
// Interface for Calculator.
////////////////////////////////////////
const InputStreamSet& InputStreams() const { return *input_streams_; }
const OutputStreamSet& OutputStreams() const { return *output_streams_; }
const PacketSet& InputSidePackets() const { return *input_side_packets_; }
OutputSidePacketSet& OutputSidePackets() { return *output_side_packets_; }
@ -139,12 +129,6 @@ class CalculatorState {
////////////////////////////////////////
// Variables which ARE cleared by ResetBetweenRuns().
////////////////////////////////////////
// The InputStreamSet object is owned by the CalculatorNode.
// CalculatorState obtains its pointer in CalculatorNode::PrepareForRun.
InputStreamSet* input_streams_;
// The OutputStreamSet object is owned by the CalculatorNode.
// CalculatorState obtains its pointer in CalculatorNode::PrepareForRun.
OutputStreamSet* output_streams_;
// The set of input side packets set by CalculatorNode::PrepareForRun().
// ResetBetweenRuns() clears this PacketSet pointer.
const PacketSet* input_side_packets_;

View File

@ -21,7 +21,7 @@
namespace mediapipe {
namespace file {
mediapipe::Status GetContents(absl::string_view file_name, std::string* output,
bool read_as_binary = false);
bool read_as_binary = true);
mediapipe::Status SetContents(absl::string_view file_name,
absl::string_view content);

View File

@ -454,12 +454,12 @@ cc_library(
":status_util",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:legacy_calculator_support",
"//mediapipe/framework:packet_generator",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:packet_set",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",

View File

@ -20,6 +20,7 @@
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/legacy_calculator_support.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_generator.pb.h"

View File

@ -21,7 +21,6 @@
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_registry_util.h"
#include "mediapipe/framework/legacy_calculator_support.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_generator.pb.h"
@ -236,8 +235,14 @@ mediapipe::Status NodeTypeInfo::Initialize(
}
#endif
LegacyCalculatorSupport::Scoped<CalculatorContract> s(&contract_);
MP_RETURN_IF_ERROR(VerifyCalculatorWithContract(validated_graph.Package(),
node_class, &contract_));
// A number of calculators use the non-CC methods on GlCalculatorHelper
// even though they are CalculatorBase-based.
ASSIGN_OR_RETURN(auto calculator_factory,
CalculatorBaseRegistry::CreateByNameInNamespace(
validated_graph.Package(), node_class),
_ << "Unable to find Calculator \"" << node_class << "\"");
MP_RETURN_IF_ERROR(calculator_factory->GetContract(&contract_)).SetPrepend()
<< node_class << ": ";
// Validate result of FillExpectations or GetContract.
std::vector<mediapipe::Status> statuses;
@ -261,10 +266,7 @@ mediapipe::Status NodeTypeInfo::Initialize(
}
if (!statuses.empty()) {
return tool::CombinedStatus(
absl::StrCat(node_class,
IsLegacyCalculator(validated_graph.Package(), node_class)
? "::FillExpectations"
: "::GetContract",
absl::StrCat(node_class, "::", calculator_factory->ContractMethodName(),
" failed to validate: "),
statuses);
}

View File

@ -55,6 +55,7 @@ cc_library(
name = "desktop_cpu_calculators",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/tflite:tflite_model_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/calculators/video:opencv_video_decoder_calculator",

View File

@ -635,12 +635,15 @@ void GlAnimationOverlayCalculator::LoadModelMatrices(
// Process model matrices, if any are being streamed in, and update our
// list.
current_model_matrices_.clear();
if (has_model_matrix_stream_ &&
!cc->Inputs().Tag("MODEL_MATRICES").IsEmpty()) {
const TimedModelMatrixProtoList &model_matrices =
cc->Inputs().Tag("MODEL_MATRICES").Get<TimedModelMatrixProtoList>();
LoadModelMatrices(model_matrices, &current_model_matrices_);
}
current_mask_model_matrices_.clear();
if (has_mask_model_matrix_stream_ &&
!cc->Inputs().Tag("MASK_MODEL_MATRICES").IsEmpty()) {
const TimedModelMatrixProtoList &model_matrices =

View File

@ -8,6 +8,7 @@ input_side_packet: "LABELS_CSV:allowed_labels"
input_side_packet: "MODEL_SCALE:model_scale"
input_side_packet: "MODEL_TRANSFORMATION:model_transformation"
input_side_packet: "TEXTURE:box_texture"
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
input_side_packet: "ANIMATION_ASSET:box_asset_name"
input_side_packet: "MASK_TEXTURE:obj_texture"
input_side_packet: "MASK_ASSET:obj_asset_name"
@ -26,7 +27,7 @@ output_stream: "output_video"
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:lifted_objects"
input_stream: "FINISHED:output_video"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
@ -52,6 +53,7 @@ node {
calculator: "ObjectronGpuSubgraph"
input_stream: "IMAGE_GPU:throttled_input_video_3x4"
input_side_packet: "LABELS_CSV:allowed_labels"
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
output_stream: "FRAME_ANNOTATION:lifted_objects"
}

View File

@ -4,6 +4,16 @@ input_side_packet: "FILE_PATH:0:box_landmark_model_path"
input_side_packet: "LABELS_CSV:allowed_labels"
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
# Generates side packet with max number of objects to detect/track.
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:max_num_objects"
node_options: {
[type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: {
packet { int_value: 5 }
}
}
}
# Decodes an input video file into images and a video header.
node {
@ -30,8 +40,9 @@ node {
input_stream: "IMAGE:input_video"
input_side_packet: "MODEL:box_landmark_model"
input_side_packet: "LABELS_CSV:allowed_labels"
output_stream: "LANDMARKS:box_landmarks"
output_stream: "NORM_RECT:box_rect"
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
output_stream: "MULTI_LANDMARKS:box_landmarks"
output_stream: "NORM_RECTS:box_rect"
}
# Subgraph that renders annotations and overlays them on top of the input
@ -39,8 +50,8 @@ node {
node {
calculator: "RendererSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "LANDMARKS:box_landmarks"
input_stream: "NORM_RECT:box_rect"
input_stream: "MULTI_LANDMARKS:box_landmarks"
input_stream: "NORM_RECTS:box_rect"
output_stream: "IMAGE:output_video"
}

View File

@ -27,6 +27,8 @@ mediapipe_simple_subgraph(
register_as = "RendererSubgraph",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",

View File

@ -3,15 +3,26 @@
type: "RendererSubgraph"
input_stream: "IMAGE:input_image"
input_stream: "LANDMARKS:landmarks"
input_stream: "NORM_RECT:rect"
input_stream: "MULTI_LANDMARKS:multi_landmarks"
input_stream: "NORM_RECTS:multi_rect"
output_stream: "IMAGE:output_image"
# Outputs each element of multi_landmarks at a fake timestamp for the rest
# of the graph to process. At the end of the loop, outputs the BATCH_END
# timestamp for downstream calculators to inform them that all elements in the
# vector have been processed.
node {
calculator: "BeginLoopNormalizedLandmarkListVectorCalculator"
input_stream: "ITERABLE:multi_landmarks"
output_stream: "ITEM:single_landmarks"
output_stream: "BATCH_END:landmark_timestamp"
}
# Converts landmarks to drawing primitives for annotation overlay.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks"
output_stream: "RENDER_DATA:landmark_render_data"
input_stream: "NORM_LANDMARKS:single_landmarks"
output_stream: "RENDER_DATA:single_landmark_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_connections: [1, 2] # edge 1-2
@ -33,11 +44,18 @@ node {
}
}
node {
calculator: "EndLoopRenderDataCalculator"
input_stream: "ITEM:single_landmark_render_data"
input_stream: "BATCH_END:landmark_timestamp"
output_stream: "ITERABLE:multi_landmarks_render_data"
}
# Converts normalized rects to drawing primitives for annotation overlay.
node {
calculator: "RectToRenderDataCalculator"
input_stream: "NORM_RECT:rect"
output_stream: "RENDER_DATA:rect_render_data"
input_stream: "NORM_RECTS:multi_rect"
output_stream: "RENDER_DATA:multi_rect_render_data"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
filled: false
@ -51,7 +69,7 @@ node {
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "IMAGE:input_image"
input_stream: "landmark_render_data"
input_stream: "rect_render_data"
input_stream: "VECTOR:multi_landmarks_render_data"
input_stream: "multi_rect_render_data"
output_stream: "IMAGE:output_image"
}

View File

@ -430,7 +430,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
framesInUse--;
int keep = max(framesToKeep - framesInUse, 0);
while (framesAvailable.size() > keep) {
teardownFrame(framesAvailable.remove());
PoolTextureFrame textureFrameToRemove = framesAvailable.remove();
handler.post(() -> teardownFrame(textureFrameToRemove));
}
}

View File

@ -43,6 +43,7 @@ node {
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
delegate { xnnpack {} }
}
}
}

View File

@ -56,25 +56,19 @@ mediapipe_simple_subgraph(
graph = "box_landmark_gpu.pbtxt",
register_as = "BoxLandmarkSubgraph",
deps = [
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:thresholding_calculator",
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
],
)
@ -83,24 +77,19 @@ mediapipe_simple_subgraph(
graph = "box_landmark_cpu.pbtxt",
register_as = "BoxLandmarkSubgraph",
deps = [
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:thresholding_calculator",
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
],
)
@ -115,9 +104,7 @@ mediapipe_simple_subgraph(
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/modules/objectron/calculators:filter_detection_calculator",
],
)
@ -133,9 +120,7 @@ mediapipe_simple_subgraph(
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/modules/objectron/calculators:filter_detection_calculator",
],
)
@ -147,9 +132,19 @@ mediapipe_simple_subgraph(
deps = [
":box_landmark_cpu",
":object_detection_oid_v4_cpu",
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:clip_vector_size_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:merge_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/util:association_norm_rect_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
],
)
@ -160,9 +155,18 @@ mediapipe_simple_subgraph(
deps = [
":box_landmark_gpu",
":object_detection_oid_v4_gpu",
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:clip_vector_size_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:merge_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/util:association_norm_rect_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator",
"//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator",
],
)

View File

@ -2,82 +2,75 @@
type: "BoxLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "IMAGE:image"
input_stream: "NORM_RECT:box_rect"
input_side_packet: "MODEL:model"
output_stream: "LANDMARKS:box_landmarks_filtered"
output_stream: "NORM_RECT:box_rect_for_next_frame"
output_stream: "PRESENCE:box_presence"
output_stream: "NORM_LANDMARKS:box_landmarks"
# Crops the rectangle that contains a box from the input image.
# Extracts image size from the input images.
node {
calculator: "ImageCroppingCalculator"
input_stream: "IMAGE:input_video"
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:image"
output_stream: "SIZE:image_size"
}
# Expands the rectangle that contain the box so that it's likely to cover the
# entire box.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:box_rect"
output_stream: "IMAGE:box_image"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_scaled"
options: {
[mediapipe.ImageCroppingCalculatorOptions.ext] {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
square_long: true
}
}
}
# Crops, resizes, and converts the input video into tensor.
# Preserves aspect ratio of the images.
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:image"
input_stream: "NORM_RECT:box_rect_scaled"
output_stream: "TENSORS:image_tensor"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_width: 224
output_tensor_height: 224
keep_aspect_ratio: true
output_tensor_float_range {
min: 0.0
max: 1.0
}
gpu_origin: TOP_LEFT
border_mode: BORDER_REPLICATE
}
}
}
# Transforms the input image to a 224x224 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:box_image"
output_stream: "IMAGE:transformed_box_image"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext] {
output_width: 224
output_height: 224
scale_mode: FIT
}
}
}
# Converts the transformed input image into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_box_image"
output_stream: "TENSORS:image_tensor"
options: {
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
zero_center: false
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "opresolver"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
calculator: "InferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:output_tensors"
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
input_side_packet: "MODEL:model"
output_stream: "TENSORS:output_tensors"
options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
use_gpu: false
[mediapipe.InferenceCalculatorOptions.ext] {
delegate { xnnpack {} }
}
}
}
# Splits a vector of tensors into multiple vectors.
node {
calculator: "SplitTfLiteTensorVectorCalculator"
calculator: "SplitTensorVectorCalculator"
input_stream: "output_tensors"
output_stream: "landmark_tensors"
output_stream: "box_flag_tensor"
@ -92,7 +85,7 @@ node {
# Converts the box-flag tensor into a float that represents the confidence
# score of box presence.
node {
calculator: "TfLiteTensorsToFloatsCalculator"
calculator: "TensorsToFloatsCalculator"
input_stream: "TENSORS:box_flag_tensor"
output_stream: "FLOAT:box_presence_score"
}
@ -105,19 +98,27 @@ node {
output_stream: "FLAG:box_presence"
options: {
[mediapipe.ThresholdingCalculatorOptions.ext] {
threshold: 0.1
threshold: 0.99
}
}
}
# Drops landmarks tensors if box is not present.
node {
calculator: "GateCalculator"
input_stream: "landmark_tensors"
input_stream: "ALLOW:box_presence"
output_stream: "gated_landmark_tensors"
}
# Decodes the landmark tensors into a list of landmarks, where the landmark
# coordinates are normalized by the size of the input image to the model.
node {
calculator: "TfLiteTensorsToLandmarksCalculator"
input_stream: "TENSORS:landmark_tensors"
calculator: "TensorsToLandmarksCalculator"
input_stream: "TENSORS:gated_landmark_tensors"
output_stream: "NORM_LANDMARKS:landmarks"
options: {
[mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
[mediapipe.TensorsToLandmarksCalculatorOptions.ext] {
num_landmarks: 9
input_image_width: 224
input_image_height: 224
@ -141,66 +142,6 @@ node {
node {
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:scaled_landmarks"
input_stream: "NORM_RECT:box_rect"
input_stream: "NORM_RECT:box_rect_scaled"
output_stream: "NORM_LANDMARKS:box_landmarks"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:input_video"
output_stream: "SIZE:image_size"
}
# Smooth predicted landmarks coordinates.
node {
calculator: "LandmarksSmoothingCalculator"
input_stream: "NORM_LANDMARKS:box_landmarks"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_FILTERED_LANDMARKS:box_landmarks_filtered"
options: {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
velocity_filter: {
window_size: 10
velocity_scale: 7.5
}
}
}
}
# Convert box landmarks to frame annotation.
node {
calculator: "LandmarksToFrameAnnotationCalculator"
input_stream: "LANDMARKS:box_landmarks_filtered"
output_stream: "FRAME_ANNOTATION:box_annotation"
}
# Lift the 2D landmarks to 3D using EPnP algorithm.
node {
calculator: "Lift2DFrameAnnotationTo3DCalculator"
input_stream: "FRAME_ANNOTATION:box_annotation"
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_box"
}
# Get rotated rectangle from lifted box.
node {
calculator: "FrameAnnotationToRectCalculator"
input_stream: "FRAME_ANNOTATION:lifted_box"
output_stream: "NORM_RECT:rect_from_box"
}
# Expands the box rectangle so that in the next video frame it's likely to
# still contain the box even with some motion.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:rect_from_box"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_for_next_frame"
options: {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
square_long: true
}
}
}

View File

@ -2,81 +2,75 @@
type: "BoxLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "IMAGE:image"
input_stream: "NORM_RECT:box_rect"
output_stream: "FRAME_ANNOTATION:lifted_box"
output_stream: "NORM_RECT:box_rect_for_next_frame"
output_stream: "PRESENCE:box_presence"
output_stream: "NORM_LANDMARKS:box_landmarks"
# Crops the rectangle that contains a box from the input image.
# Extracts image size from the input images.
node {
calculator: "ImageCroppingCalculator"
input_stream: "IMAGE_GPU:input_video"
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:image"
output_stream: "SIZE:image_size"
}
# Expands the rectangle that contain the box so that it's likely to cover the
# entire box.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:box_rect"
output_stream: "IMAGE_GPU:box_image"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_scaled"
options: {
[mediapipe.ImageCroppingCalculatorOptions.ext] {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
square_long: true
}
}
}
# Crops, resizes, and converts the input video into tensor.
# Preserves aspect ratio of the images.
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE_GPU:image"
input_stream: "NORM_RECT:box_rect_scaled"
output_stream: "TENSORS:image_tensor"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_width: 224
output_tensor_height: 224
keep_aspect_ratio: true
output_tensor_float_range {
min: 0.0
max: 1.0
}
gpu_origin: TOP_LEFT
border_mode: BORDER_REPLICATE
}
}
}
# Transforms the input image on GPU to a 224x224 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE_GPU:box_image"
output_stream: "IMAGE_GPU:transformed_box_image"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext] {
output_width: 224
output_height: 224
scale_mode: FIT
}
}
}
# Converts the transformed input image on GPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE_GPU:transformed_box_image"
output_stream: "TENSORS_GPU:image_tensor"
options: {
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
zero_center: false
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "opresolver"
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS_GPU:image_tensor"
calculator: "InferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:output_tensors"
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "object_detection_3d.tflite"
use_gpu: true
delegate { gpu {} }
}
}
}
# Splits a vector of tensors into multiple vectors.
# Splits a vector of tensors to multiple vectors according to the ranges
# specified in option.
node {
calculator: "SplitTfLiteTensorVectorCalculator"
calculator: "SplitTensorVectorCalculator"
input_stream: "output_tensors"
output_stream: "landmark_tensors"
output_stream: "box_flag_tensor"
@ -91,7 +85,7 @@ node {
# Converts the box-flag tensor into a float that represents the confidence
# score of box presence.
node {
calculator: "TfLiteTensorsToFloatsCalculator"
calculator: "TensorsToFloatsCalculator"
input_stream: "TENSORS:box_flag_tensor"
output_stream: "FLOAT:box_presence_score"
}
@ -109,14 +103,22 @@ node {
}
}
# Drops landmarks tensors if box is not present.
node {
calculator: "GateCalculator"
input_stream: "landmark_tensors"
input_stream: "ALLOW:box_presence"
output_stream: "gated_landmark_tensors"
}
# Decodes the landmark tensors into a list of landmarks, where the landmark
# coordinates are normalized by the size of the input image to the model.
node {
calculator: "TfLiteTensorsToLandmarksCalculator"
input_stream: "TENSORS:landmark_tensors"
calculator: "TensorsToLandmarksCalculator"
input_stream: "TENSORS:gated_landmark_tensors"
output_stream: "NORM_LANDMARKS:landmarks"
options: {
[mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
[mediapipe.TensorsToLandmarksCalculatorOptions.ext] {
num_landmarks: 9
input_image_width: 224
input_image_height: 224
@ -140,66 +142,6 @@ node {
node {
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:scaled_landmarks"
input_stream: "NORM_RECT:box_rect"
input_stream: "NORM_RECT:box_rect_scaled"
output_stream: "NORM_LANDMARKS:box_landmarks"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:input_video"
output_stream: "SIZE:image_size"
}
# Smooth predicted landmarks coordinates.
node {
calculator: "LandmarksSmoothingCalculator"
input_stream: "NORM_LANDMARKS:box_landmarks"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_FILTERED_LANDMARKS:box_landmarks_filtered"
options: {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
velocity_filter: {
window_size: 10
velocity_scale: 7.5
}
}
}
}
# Convert box landmarks to frame annotation.
node {
calculator: "LandmarksToFrameAnnotationCalculator"
input_stream: "LANDMARKS:box_landmarks_filtered"
output_stream: "FRAME_ANNOTATION:box_annotation"
}
# Lift the 2D landmarks to 3D using EPnP algorithm.
node {
calculator: "Lift2DFrameAnnotationTo3DCalculator"
input_stream: "FRAME_ANNOTATION:box_annotation"
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_box"
}
# Get rotated rectangle from lifted box.
node {
calculator: "FrameAnnotationToRectCalculator"
input_stream: "FRAME_ANNOTATION:lifted_box"
output_stream: "NORM_RECT:rect_from_box"
}
# Expands the box rectangle so that in the next video frame it's likely to
# still contain the box even with some motion.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:rect_from_box"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_for_next_frame"
options: {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
square_long: true
}
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
#include <cmath>
#include <vector>
#include "Eigen/Dense"
#include "absl/memory/memory.h"
@ -32,7 +33,7 @@ using Eigen::Vector3f;
namespace {
constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION";
constexpr char kOutputNormRectTag[] = "NORM_RECT";
constexpr char kOutputNormRectsTag[] = "NORM_RECTS";
} // namespace
@ -47,14 +48,14 @@ class FrameAnnotationToRectCalculator : public CalculatorBase {
TOP_VIEW_OFF,
};
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
static mediapipe::Status GetContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
private:
void AnnotationToRect(const FrameAnnotation& annotation,
NormalizedRect* rect);
float RotationAngleFromAnnotation(const FrameAnnotation& annotation);
void AddAnnotationToRect(const ObjectAnnotation& annotation,
std::vector<NormalizedRect>* rect);
float RotationAngleFromAnnotation(const ObjectAnnotation& annotation);
float RotationAngleFromPose(const Matrix3fRM& rotation,
const Vector3f& translation, const Vector3f& vec);
@ -64,17 +65,7 @@ class FrameAnnotationToRectCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(FrameAnnotationToRectCalculator);
::mediapipe::Status FrameAnnotationToRectCalculator::Open(
CalculatorContext* cc) {
status_ = TOP_VIEW_OFF;
const auto& options = cc->Options<FrameAnnotationToRectCalculatorOptions>();
off_threshold_ = options.off_threshold();
on_threshold_ = options.on_threshold();
RET_CHECK(off_threshold_ <= on_threshold_);
return ::mediapipe::OkStatus();
}
::mediapipe::Status FrameAnnotationToRectCalculator::GetContract(
mediapipe::Status FrameAnnotationToRectCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
@ -83,57 +74,69 @@ REGISTER_CALCULATOR(FrameAnnotationToRectCalculator);
cc->Inputs().Tag(kInputFrameAnnotationTag).Set<FrameAnnotation>();
}
if (cc->Outputs().HasTag(kOutputNormRectTag)) {
cc->Outputs().Tag(kOutputNormRectTag).Set<NormalizedRect>();
if (cc->Outputs().HasTag(kOutputNormRectsTag)) {
cc->Outputs().Tag(kOutputNormRectsTag).Set<std::vector<NormalizedRect>>();
}
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status FrameAnnotationToRectCalculator::Process(
mediapipe::Status FrameAnnotationToRectCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
status_ = TOP_VIEW_OFF;
const auto& options = cc->Options<FrameAnnotationToRectCalculatorOptions>();
off_threshold_ = options.off_threshold();
on_threshold_ = options.on_threshold();
RET_CHECK(off_threshold_ <= on_threshold_);
return mediapipe::OkStatus();
}
mediapipe::Status FrameAnnotationToRectCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputFrameAnnotationTag).IsEmpty()) {
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
auto output_rects = absl::make_unique<std::vector<NormalizedRect>>();
const auto& frame_annotation =
cc->Inputs().Tag(kInputFrameAnnotationTag).Get<FrameAnnotation>();
for (const auto& object_annotation : frame_annotation.annotations()) {
AddAnnotationToRect(object_annotation, output_rects.get());
}
auto output_rect = absl::make_unique<NormalizedRect>();
AnnotationToRect(
cc->Inputs().Tag(kInputFrameAnnotationTag).Get<FrameAnnotation>(),
output_rect.get());
// Output
// Output.
cc->Outputs()
.Tag(kOutputNormRectTag)
.Add(output_rect.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
.Tag(kOutputNormRectsTag)
.Add(output_rects.release(), cc->InputTimestamp());
return mediapipe::OkStatus();
}
void FrameAnnotationToRectCalculator::AnnotationToRect(
const FrameAnnotation& annotation, NormalizedRect* rect) {
void FrameAnnotationToRectCalculator::AddAnnotationToRect(
const ObjectAnnotation& annotation, std::vector<NormalizedRect>* rects) {
float x_min = std::numeric_limits<float>::max();
float x_max = std::numeric_limits<float>::min();
float y_min = std::numeric_limits<float>::max();
float y_max = std::numeric_limits<float>::min();
const auto& object = annotation.annotations(0);
for (const auto& keypoint : object.keypoints()) {
for (const auto& keypoint : annotation.keypoints()) {
const auto& point_2d = keypoint.point_2d();
x_min = std::min(x_min, point_2d.x());
x_max = std::max(x_max, point_2d.x());
y_min = std::min(y_min, point_2d.y());
y_max = std::max(y_max, point_2d.y());
}
rect->set_x_center((x_min + x_max) / 2);
rect->set_y_center((y_min + y_max) / 2);
rect->set_width(x_max - x_min);
rect->set_height(y_max - y_min);
rect->set_rotation(RotationAngleFromAnnotation(annotation));
NormalizedRect new_rect;
new_rect.set_x_center((x_min + x_max) / 2);
new_rect.set_y_center((y_min + y_max) / 2);
new_rect.set_width(x_max - x_min);
new_rect.set_height(y_max - y_min);
new_rect.set_rotation(RotationAngleFromAnnotation(annotation));
rects->push_back(new_rect);
}
float FrameAnnotationToRectCalculator::RotationAngleFromAnnotation(
const FrameAnnotation& annotation) {
const auto& object = annotation.annotations(0);
const ObjectAnnotation& annotation) {
Box box("category");
std::vector<Vector3f> vertices_3d;
std::vector<Vector2f> vertices_2d;
for (const auto& keypoint : object.keypoints()) {
for (const auto& keypoint : annotation.keypoints()) {
const auto& point_3d = keypoint.point_3d();
const auto& point_2d = keypoint.point_2d();
vertices_3d.emplace_back(

View File

@ -23,6 +23,7 @@ namespace mediapipe {
namespace {
constexpr char kInputLandmarksTag[] = "LANDMARKS";
constexpr char kInputMultiLandmarksTag[] = "MULTI_LANDMARKS";
constexpr char kOutputFrameAnnotationTag[] = "FRAME_ANNOTATION";
} // namespace
@ -30,12 +31,17 @@ constexpr char kOutputFrameAnnotationTag[] = "FRAME_ANNOTATION";
// A calculator that converts NormalizedLandmarkList to FrameAnnotation proto.
class LandmarksToFrameAnnotationCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
static mediapipe::Status GetContract(CalculatorContract* cc);
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
private:
void AddLandmarksToFrameAnnotation(const NormalizedLandmarkList& landmarks,
FrameAnnotation* frame_annotation);
};
REGISTER_CALCULATOR(LandmarksToFrameAnnotationCalculator);
::mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract(
mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
@ -43,34 +49,65 @@ REGISTER_CALCULATOR(LandmarksToFrameAnnotationCalculator);
if (cc->Inputs().HasTag(kInputLandmarksTag)) {
cc->Inputs().Tag(kInputLandmarksTag).Set<NormalizedLandmarkList>();
}
if (cc->Inputs().HasTag(kInputMultiLandmarksTag)) {
cc->Inputs()
.Tag(kInputMultiLandmarksTag)
.Set<std::vector<NormalizedLandmarkList>>();
}
if (cc->Outputs().HasTag(kOutputFrameAnnotationTag)) {
cc->Outputs().Tag(kOutputFrameAnnotationTag).Set<FrameAnnotation>();
}
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status LandmarksToFrameAnnotationCalculator::Process(
mediapipe::Status LandmarksToFrameAnnotationCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
mediapipe::Status LandmarksToFrameAnnotationCalculator::Process(
CalculatorContext* cc) {
auto frame_annotation = absl::make_unique<FrameAnnotation>();
auto* box_annotation = frame_annotation->add_annotations();
// Handle the case when input has only one NormalizedLandmarkList.
if (cc->Inputs().HasTag(kInputLandmarksTag) &&
!cc->Inputs().Tag(kInputLandmarksTag).IsEmpty()) {
const auto& landmarks =
cc->Inputs().Tag(kInputLandmarksTag).Get<NormalizedLandmarkList>();
RET_CHECK_GT(landmarks.landmark_size(), 0)
<< "Input landmark vector is empty.";
for (int i = 0; i < landmarks.landmark_size(); ++i) {
auto* point2d = box_annotation->add_keypoints()->mutable_point_2d();
point2d->set_x(landmarks.landmark(i).x());
point2d->set_y(landmarks.landmark(i).y());
cc->Inputs().Tag(kInputMultiLandmarksTag).Get<NormalizedLandmarkList>();
AddLandmarksToFrameAnnotation(landmarks, frame_annotation.get());
}
// Handle the case when input has muliple NormalizedLandmarkList.
if (cc->Inputs().HasTag(kInputMultiLandmarksTag) &&
!cc->Inputs().Tag(kInputMultiLandmarksTag).IsEmpty()) {
const auto& landmarks_list =
cc->Inputs()
.Tag(kInputMultiLandmarksTag)
.Get<std::vector<NormalizedLandmarkList>>();
for (const auto& landmarks : landmarks_list) {
AddLandmarksToFrameAnnotation(landmarks, frame_annotation.get());
}
}
// Output
if (cc->Outputs().HasTag(kOutputFrameAnnotationTag)) {
cc->Outputs()
.Tag(kOutputFrameAnnotationTag)
.Add(frame_annotation.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
void LandmarksToFrameAnnotationCalculator::AddLandmarksToFrameAnnotation(
const NormalizedLandmarkList& landmarks,
FrameAnnotation* frame_annotation) {
auto* new_annotation = frame_annotation->add_annotations();
for (const auto& landmark : landmarks.landmark()) {
auto* point2d = new_annotation->add_keypoints()->mutable_point_2d();
point2d->set_x(landmark.x());
point2d->set_y(landmark.y());
}
}
} // namespace mediapipe

View File

@ -55,16 +55,16 @@ namespace mediapipe {
// }
class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
static mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
mediapipe::Status Open(CalculatorContext* cc) override;
mediapipe::Status Process(CalculatorContext* cc) override;
mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status ProcessCPU(CalculatorContext* cc,
mediapipe::Status ProcessCPU(CalculatorContext* cc,
FrameAnnotation* output_objects);
::mediapipe::Status LoadOptions(CalculatorContext* cc);
mediapipe::Status LoadOptions(CalculatorContext* cc);
// Increment and assign object ID for each detected object.
// In a single MediaPipe session, the IDs are unique.
@ -73,23 +73,24 @@ class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase {
void AssignObjectIdAndTimestamp(int64 timestamp_us,
FrameAnnotation* annotation);
std::unique_ptr<Decoder> decoder_;
::mediapipe::Lift2DFrameAnnotationTo3DCalculatorOptions options_;
Lift2DFrameAnnotationTo3DCalculatorOptions options_;
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> projection_matrix_;
};
REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::GetContract(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kInputStreamTag));
RET_CHECK(cc->Outputs().HasTag(kOutputStreamTag));
cc->Inputs().Tag(kInputStreamTag).Set<FrameAnnotation>();
cc->Outputs().Tag(kOutputStreamTag).Set<FrameAnnotation>();
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Open(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
MP_RETURN_IF_ERROR(LoadOptions(cc));
// clang-format off
projection_matrix_ <<
@ -101,13 +102,13 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
decoder_ = absl::make_unique<Decoder>(
BeliefDecoderConfig(options_.decoder_config()));
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputStreamTag).IsEmpty()) {
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
auto output_objects = absl::make_unique<FrameAnnotation>();
@ -121,10 +122,10 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
.Add(output_objects.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU(
CalculatorContext* cc, FrameAnnotation* output_objects) {
const auto& input_frame_annotations =
cc->Inputs().Tag(kInputStreamTag).Get<FrameAnnotation>();
@ -140,21 +141,20 @@ REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator);
AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(),
output_objects);
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Close(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Close(
CalculatorContext* cc) {
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
::mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions(
mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions(
CalculatorContext* cc) {
// Get calculator options specified in the graph.
options_ =
cc->Options<::mediapipe::Lift2DFrameAnnotationTo3DCalculatorOptions>();
options_ = cc->Options<Lift2DFrameAnnotationTo3DCalculatorOptions>();
return ::mediapipe::OkStatus();
return mediapipe::OkStatus();
}
void Lift2DFrameAnnotationTo3DCalculator::AssignObjectIdAndTimestamp(

View File

@ -4,9 +4,9 @@ type: "ObjectDetectionOidV4Subgraph"
input_stream: "IMAGE:input_video"
input_side_packet: "LABELS_CSV:allowed_labels"
output_stream: "NORM_RECT:box_rect_from_object_detections"
output_stream: "DETECTIONS:detections"
# Transforms the input image on GPU to a 300x300 image. To scale the image, by
# Transforms the input image on CPU to a 300x300 image. To scale the image, by
# default it uses the STRETCH scale mode that maps the entire input image to the
# entire transformed image. As a result, image aspect ratio may be changed and
# objects in the image may be deformed (stretched or squeezed), but the object
@ -23,7 +23,7 @@ node: {
}
}
# Converts the transformed input image on GPU into an image tensor stored as a
# Converts the transformed input image on CPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
@ -31,7 +31,7 @@ node {
output_stream: "TENSORS:image_tensor"
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
@ -82,7 +82,7 @@ node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
output_stream: "DETECTIONS:all_detections"
options: {
[mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] {
num_classes: 195
@ -95,22 +95,7 @@ node {
y_scale: 10.0
h_scale: 5.0
w_scale: 5.0
min_score_thresh: 0.6
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "suppressed_detections"
options: {
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
min_suppression_threshold: 0.4
max_num_detections: 1
overlap_type: INTERSECTION_OVER_UNION
return_empty_detections: true
min_score_thresh: 0.5
}
}
}
@ -119,7 +104,7 @@ node {
# provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "suppressed_detections"
input_stream: "all_detections"
output_stream: "labeled_detections"
options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
@ -128,50 +113,26 @@ node {
}
}
# Filters the detections to only those with valid scores
# for the specified allowed labels.
node {
calculator: "FilterDetectionCalculator"
input_stream: "DETECTIONS:labeled_detections"
output_stream: "DETECTIONS:filtered_detections"
input_side_packet: "LABELS_CSV:allowed_labels"
options: {
[mediapipe.FilterDetectionCalculatorOptions.ext]: {
min_score: 0.4
}
}
}
# Extracts image size from the input images.
# Performs non-max suppression to remove excessive detections.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:input_video"
output_stream: "SIZE:image_size"
}
# Converts results of box detection into a rectangle (normalized by image size)
# that encloses the box.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:filtered_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECT:box_rect"
calculator: "NonMaxSuppressionCalculator"
input_stream: "filtered_detections"
output_stream: "detections"
options: {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
output_zero_rect_for_empty_detections: true
}
}
}
# Expands the rectangle that contains the box so that it's likely to cover the
# entire box.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:box_rect"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_from_object_detections"
options: {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
min_suppression_threshold: 0.5
max_num_detections: 100
overlap_type: INTERSECTION_OVER_UNION
return_empty_detections: true
}
}
}

View File

@ -4,7 +4,7 @@ type: "ObjectDetectionOidV4Subgraph"
input_stream: "IMAGE_GPU:input_video"
input_side_packet: "LABELS_CSV:allowed_labels"
output_stream: "NORM_RECT:box_rect_from_object_detections"
output_stream: "DETECTIONS:detections"
# Transforms the input image on GPU to a 300x300 image. To scale the image, by
# default it uses the STRETCH scale mode that maps the entire input image to the
@ -40,7 +40,7 @@ node {
output_stream: "TENSORS_GPU:detection_tensors"
options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite"
model_path: "object_detection_ssd_mobilenetv2_oidv4_fp16.tflite"
}
}
}
@ -82,7 +82,7 @@ node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
output_stream: "DETECTIONS:all_detections"
options: {
[mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] {
num_classes: 195
@ -95,22 +95,7 @@ node {
y_scale: 10.0
h_scale: 5.0
w_scale: 5.0
min_score_thresh: 0.6
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "suppressed_detections"
options: {
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
min_suppression_threshold: 0.4
max_num_detections: 1
overlap_type: INTERSECTION_OVER_UNION
return_empty_detections: true
min_score_thresh: 0.5
}
}
}
@ -119,59 +104,35 @@ node {
# provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "suppressed_detections"
input_stream: "all_detections"
output_stream: "labeled_detections"
options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
label_map_path: "mediapipe/models/object_detection_oidv4_labelmap.pbtxt"
label_map_path: "object_detection_oidv4_labelmap.pbtxt"
}
}
}
# Filters the detections to only those with valid scores
# for the specified allowed labels.
node {
calculator: "FilterDetectionCalculator"
input_stream: "DETECTIONS:labeled_detections"
output_stream: "DETECTIONS:filtered_detections"
input_side_packet: "LABELS_CSV:allowed_labels"
options: {
[mediapipe.FilterDetectionCalculatorOptions.ext]: {
min_score: 0.4
}
}
}
# Extracts image size from the input images.
# Performs non-max suppression to remove excessive detections.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:input_video"
output_stream: "SIZE:image_size"
}
# Converts results of box detection into a rectangle (normalized by image size)
# that encloses the box.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:filtered_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECT:box_rect"
calculator: "NonMaxSuppressionCalculator"
input_stream: "filtered_detections"
output_stream: "detections"
options: {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
output_zero_rect_for_empty_detections: true
}
}
}
# Expands the rectangle that contains the box so that it's likely to cover the
# entire box.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:box_rect"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "box_rect_from_object_detections"
options: {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 1.5
scale_y: 1.5
[mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
min_suppression_threshold: 0.5
max_num_detections: 100
overlap_type: INTERSECTION_OVER_UNION
return_empty_detections: true
}
}
}

View File

@ -4,6 +4,8 @@ input_stream: "IMAGE:input_video"
input_side_packet: "MODEL:box_landmark_model"
# Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera
input_side_packet: "LABELS_CSV:allowed_labels"
# Max number of objects to detect/track. (int)
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
# Bounding box landmarks topology definition.
# The numbers are indices in the box_landmarks list.
#
@ -22,36 +24,47 @@ input_side_packet: "LABELS_CSV:allowed_labels"
# \+ \+
# 2 + + + + + + + + 6
#
output_stream: "LANDMARKS:box_landmarks"
# Crop rectangle derived from bounding box landmarks.
output_stream: "NORM_RECT:box_rect"
output_stream: "MULTI_LANDMARKS:multi_box_landmarks"
# Crop rectangles derived from bounding box landmarks.
output_stream: "NORM_RECTS:multi_box_rects"
# Caches a box-presence decision fed back from boxLandmarkSubgraph, and upon
# the arrival of the next input image sends out the cached decision with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous box-presence decision. Note that upon the arrival
# of the very first input image, an empty packet is sent out to jump start the
# feedback loop.
# Defines whether landmarks from the previous video frame should be used to help
# predict landmarks on the current video frame.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:box_presence"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
name: "ConstantSidePacketCalculator"
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:use_prev_landmarks"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { bool_value: true }
}
}
output_stream: "PREV_LOOP:prev_box_presence"
}
# Drops the incoming image if boxLandmarkSubgraph was able to identify box
node {
calculator: "GateCalculator"
input_side_packet: "ALLOW:use_prev_landmarks"
input_stream: "prev_box_rects_from_landmarks"
output_stream: "gated_prev_box_rects_from_landmarks"
}
# Determines if an input vector of NormalizedRect has a size greater than or
# equal to the provided max_num_objects.
node {
calculator: "NormalizedRectVectorHasMinSizeCalculator"
input_stream: "ITERABLE:gated_prev_box_rects_from_landmarks"
input_side_packet: "max_num_objects"
output_stream: "prev_has_enough_objects"
}
# Drops the incoming image if BoxLandmarkSubgraph was able to identify box
# presence in the previous image. Otherwise, passes the incoming image through
# to trigger a new round of box detection in boxDetectionSubgraph.
# to trigger a new round of box detection in ObjectDetectionOidV4Subgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_box_presence"
output_stream: "box_detection_input_video"
input_stream: "DISALLOW:prev_has_enough_objects"
output_stream: "detection_input_video"
options: {
[mediapipe.GateCalculatorOptions.ext] {
@ -60,23 +73,112 @@ node {
}
}
# Subgraph that detections boxs (see object_detection_oid_v4_cpu.pbtxt).
# Subgraph that performs 2D object detection.
node {
calculator: "ObjectDetectionOidV4Subgraph"
input_stream: "IMAGE:box_detection_input_video"
input_stream: "IMAGE:detection_input_video"
input_side_packet: "LABELS_CSV:allowed_labels"
output_stream: "NORM_RECT:box_rect_from_object_detections"
output_stream: "DETECTIONS:raw_detections"
}
# Subgraph that localizes box landmarks (see box_landmark_gpu.pbtxt).
# Makes sure there are no more detections than provided max_num_objects.
node {
calculator: "ClipDetectionVectorSizeCalculator"
input_stream: "raw_detections"
output_stream: "detections"
input_side_packet: "max_num_objects"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:input_video"
output_stream: "SIZE:image_size"
}
# Converts results of box detection into rectangles (normalized by image size)
# that encloses the box.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECTS:box_rects_from_detections"
options: {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
output_zero_rect_for_empty_detections: false
}
}
}
# Performs association between NormalizedRect vector elements from previous
# image and rects based on object detections from the current image. This
# calculator ensures that the output box_rects vector doesn't contain
# overlapping regions based on the specified min_similarity_threshold.
node {
calculator: "AssociationNormRectCalculator"
input_stream: "box_rects_from_detections"
input_stream: "gated_prev_box_rects_from_landmarks"
output_stream: "multi_box_rects"
options: {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.2
}
}
}
# Outputs each element of box_rects at a fake timestamp for the rest of the
# graph to process. Clones image and image size packets for each
# single_box_rect at the fake timestamp. At the end of the loop, outputs the
# BATCH_END timestamp for downstream calculators to inform them that all
# elements in the vector have been processed.
node {
calculator: "BeginLoopNormalizedRectCalculator"
input_stream: "ITERABLE:multi_box_rects"
input_stream: "CLONE:input_video"
output_stream: "ITEM:single_box_rect"
output_stream: "CLONE:landmarks_input_video"
output_stream: "BATCH_END:box_rects_timestamp"
}
# Subgraph that localizes box landmarks.
node {
calculator: "BoxLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:box_rect"
input_stream: "IMAGE:landmarks_input_video"
input_side_packet: "MODEL:box_landmark_model"
output_stream: "LANDMARKS:box_landmarks"
output_stream: "NORM_RECT:box_rect_from_landmarks"
output_stream: "PRESENCE:box_presence"
input_stream: "NORM_RECT:single_box_rect"
output_stream: "NORM_LANDMARKS:single_box_landmarks"
}
# Collects a set of landmarks for each hand into a vector. Upon receiving the
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
# timestamp.
node {
calculator: "EndLoopNormalizedLandmarkListVectorCalculator"
input_stream: "ITEM:single_box_landmarks"
input_stream: "BATCH_END:box_rects_timestamp"
output_stream: "ITERABLE:multi_box_landmarks"
}
# Convert box landmarks to frame annotations.
node {
calculator: "LandmarksToFrameAnnotationCalculator"
input_stream: "MULTI_LANDMARKS:multi_box_landmarks"
output_stream: "FRAME_ANNOTATION:box_annotations"
}
# Lift the 2D landmarks to 3D using EPnP algorithm.
node {
calculator: "Lift2DFrameAnnotationTo3DCalculator"
input_stream: "FRAME_ANNOTATION:box_annotations"
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects"
}
# Get rotated rectangle from lifted box.
node {
calculator: "FrameAnnotationToRectCalculator"
input_stream: "FRAME_ANNOTATION:lifted_objects"
output_stream: "NORM_RECTS:box_rects_from_landmarks"
}
# Caches a box rectangle fed back from boxLandmarkSubgraph, and upon the
@ -88,25 +190,10 @@ node {
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:box_rect_from_landmarks"
input_stream: "LOOP:box_rects_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_box_rect_from_landmarks"
}
# Merges a stream of box rectangles generated by ObjectDetectionSubgraph and that
# generated by BoxLandmarkSubgraph into a single output stream by selecting
# between one of the two streams. The former is selected if the incoming packet
# is not empty, i.e., box detection is performed on the current image by
# BoxDetectionSubgraph (because BoxLandmarkSubgraph could not identify box
# presence in the previous image). Otherwise, the latter is selected, which is
# never empty because BoxLandmarkSubgraphs processes all images (that went
# through FlowLimiterCaculator).
node {
calculator: "MergeCalculator"
input_stream: "box_rect_from_object_detections"
input_stream: "prev_box_rect_from_landmarks"
output_stream: "box_rect"
output_stream: "PREV_LOOP:prev_box_rects_from_landmarks"
}

View File

@ -5,33 +5,48 @@
input_stream: "IMAGE_GPU:input_video"
# Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera
input_side_packet: "LABELS_CSV:allowed_labels"
# Max number of objects to detect/track. (int)
input_side_packet: "MAX_NUM_OBJECTS:max_num_objects"
# Collection of detected 3D objects, represented as a FrameAnnotation.
output_stream: "FRAME_ANNOTATION:lifted_objects"
# Caches a box-presence decision fed back from boxLandmarkSubgraph, and upon
# the arrival of the next input image sends out the cached decision with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous box-presence decision. Note that upon the arrival
# of the very first input image, an empty packet is sent out to jump start the
# feedback loop.
# Defines whether landmarks from the previous video frame should be used to help
# predict landmarks on the current video frame.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:box_presence"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
name: "ConstantSidePacketCalculator"
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:use_prev_landmarks"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { bool_value: true }
}
}
output_stream: "PREV_LOOP:prev_box_presence"
}
# Drops the incoming image if boxLandmarkSubgraph was able to identify box
node {
calculator: "GateCalculator"
input_side_packet: "ALLOW:use_prev_landmarks"
input_stream: "prev_box_rects_from_landmarks"
output_stream: "gated_prev_box_rects_from_landmarks"
}
# Determines if an input vector of NormalizedRect has a size greater than or
# equal to the provided max_num_objects.
node {
calculator: "NormalizedRectVectorHasMinSizeCalculator"
input_stream: "ITERABLE:gated_prev_box_rects_from_landmarks"
input_side_packet: "max_num_objects"
output_stream: "prev_has_enough_objects"
}
# Drops the incoming image if BoxLandmarkSubgraph was able to identify box
# presence in the previous image. Otherwise, passes the incoming image through
# to trigger a new round of box detection in boxDetectionSubgraph.
# to trigger a new round of box detection in ObjectDetectionOidV4Subgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_box_presence"
input_stream: "DISALLOW:prev_has_enough_objects"
output_stream: "detection_input_video"
options: {
@ -46,17 +61,106 @@ node {
calculator: "ObjectDetectionOidV4Subgraph"
input_stream: "IMAGE_GPU:detection_input_video"
input_side_packet: "LABELS_CSV:allowed_labels"
output_stream: "NORM_RECT:box_rect_from_object_detections"
output_stream: "DETECTIONS:raw_detections"
}
# Makes sure there are no more detections than provided max_num_objects.
node {
calculator: "ClipDetectionVectorSizeCalculator"
input_stream: "raw_detections"
output_stream: "detections"
input_side_packet: "max_num_objects"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:input_video"
output_stream: "SIZE:image_size"
}
# Converts results of box detection into rectangles (normalized by image size)
# that encloses the box.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECTS:box_rects_from_detections"
options: {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
output_zero_rect_for_empty_detections: false
}
}
}
# Performs association between NormalizedRect vector elements from previous
# image and rects based on object detections from the current image. This
# calculator ensures that the output box_rects vector doesn't contain
# overlapping regions based on the specified min_similarity_threshold.
node {
calculator: "AssociationNormRectCalculator"
input_stream: "box_rects_from_detections"
input_stream: "gated_prev_box_rects_from_landmarks"
output_stream: "box_rects"
options: {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.2
}
}
}
# Outputs each element of box_rects at a fake timestamp for the rest of the
# graph to process. Clones image and image size packets for each
# single_box_rect at the fake timestamp. At the end of the loop, outputs the
# BATCH_END timestamp for downstream calculators to inform them that all
# elements in the vector have been processed.
node {
calculator: "BeginLoopNormalizedRectCalculator"
input_stream: "ITERABLE:box_rects"
input_stream: "CLONE:input_video"
output_stream: "ITEM:single_box_rect"
output_stream: "CLONE:landmarks_input_video"
output_stream: "BATCH_END:box_rects_timestamp"
}
# Subgraph that localizes box landmarks.
node {
calculator: "BoxLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:box_rect"
output_stream: "FRAME_ANNOTATION:lifted_objects"
output_stream: "NORM_RECT:box_rect_from_landmarks"
output_stream: "PRESENCE:box_presence"
input_stream: "IMAGE:landmarks_input_video"
input_stream: "NORM_RECT:single_box_rect"
output_stream: "NORM_LANDMARKS:single_box_landmarks"
}
# Collects a set of landmarks for each hand into a vector. Upon receiving the
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
# timestamp.
node {
calculator: "EndLoopNormalizedLandmarkListVectorCalculator"
input_stream: "ITEM:single_box_landmarks"
input_stream: "BATCH_END:box_rects_timestamp"
output_stream: "ITERABLE:multi_box_landmarks"
}
# Convert box landmarks to frame annotations.
node {
calculator: "LandmarksToFrameAnnotationCalculator"
input_stream: "MULTI_LANDMARKS:multi_box_landmarks"
output_stream: "FRAME_ANNOTATION:box_annotations"
}
# Lift the 2D landmarks to 3D using EPnP algorithm.
node {
calculator: "Lift2DFrameAnnotationTo3DCalculator"
input_stream: "FRAME_ANNOTATION:box_annotations"
output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects"
}
# Get rotated rectangle from lifted box.
node {
calculator: "FrameAnnotationToRectCalculator"
input_stream: "FRAME_ANNOTATION:lifted_objects"
output_stream: "NORM_RECTS:box_rects_from_landmarks"
}
# Caches a box rectangle fed back from boxLandmarkSubgraph, and upon the
@ -68,25 +172,10 @@ node {
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:box_rect_from_landmarks"
input_stream: "LOOP:box_rects_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_box_rect_from_landmarks"
}
# Merges a stream of box rectangles generated by boxDetectionSubgraph and that
# generated by boxLandmarkSubgraph into a single output stream by selecting
# between one of the two streams. The former is selected if the incoming packet
# is not empty, i.e., box detection is performed on the current image by
# boxDetectionSubgraph (because boxLandmarkSubgraph could not identify box
# presence in the previous image). Otherwise, the latter is selected, which is
# never empty because boxLandmarkSubgraphs processes all images (that went
# through FlowLimiterCaculator).
node {
calculator: "MergeCalculator"
input_stream: "box_rect_from_object_detections"
input_stream: "prev_box_rect_from_landmarks"
output_stream: "box_rect"
output_stream: "PREV_LOOP:prev_box_rects_from_landmarks"
}

View File

@ -46,6 +46,11 @@ node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
options: {
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
read_as_binary: true
}
}
}
# Converts the input blob into a TF Lite model.

Some files were not shown because too many files have changed in this diff Show More