diff --git a/README.md b/README.md index 588ab69b7..444b2b1f6 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,8 @@ run code search using ## Publications +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) + in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) in Google AI Blog * [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html) diff --git a/WORKSPACE b/WORKSPACE index b52e605cc..d88d8fc95 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -364,9 +364,9 @@ http_archive( ) #Tensorflow repo should always go after the other external dependencies. -# 2020-12-07 -_TENSORFLOW_GIT_COMMIT = "f556709f4df005ad57fd24d5eaa0d9380128d3ba" -_TENSORFLOW_SHA256= "9e157d4723921b48a974f645f70d07c8fd3c363569a0ac6ee85fec114d6459ea" +# 2020-12-09 +_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a" +_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0" http_archive( name = "org_tensorflow", urls = [ diff --git a/build_desktop_examples.sh b/build_desktop_examples.sh index a6b2b54f7..5e493e79c 100644 --- a/build_desktop_examples.sh +++ b/build_desktop_examples.sh @@ -94,7 +94,8 @@ for app in ${apps}; do else graph_name="${target_name}/${target_name}" fi - if [[ ${target_name} == "iris_tracking" || + if [[ ${target_name} == "holistic_tracking" || + ${target_name} == "iris_tracking" || ${target_name} == "pose_tracking" || ${target_name} == "upper_body_pose_tracking" ]]; then graph_suffix="cpu" diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index d8b3122c0..7a02def53 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -31,16 +31,16 @@ install --user six`. [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) to install Bazel 3.4 or higher. - For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu only, Bazel needs + For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs to be built from source: ```bash - # For Bazel 3.4.0 - mkdir $HOME/bazel-3.4.0 - cd $HOME/bazel-3.4.0 - wget https://github.com/bazelbuild/bazel/releases/download/3.4.0/bazel-3.4.0-dist.zip + # For Bazel 3.4.1 + mkdir $HOME/bazel-3.4.1 + cd $HOME/bazel-3.4.1 + wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip sudo apt-get install build-essential openjdk-8-jdk python zip unzip - unzip bazel-3.4.0-dist.zip + unzip bazel-3.4.1-dist.zip env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh sudo cp output/bazel /usr/local/bin/ ``` @@ -338,14 +338,7 @@ build issues. 2. Install Bazel. - Option 1. Use package manager tool to install Bazel - - ```bash - $ brew install bazel - # Run 'bazel version' to check version of bazel - ``` - - Option 2. Follow the official + Follow the official [Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) to install Bazel 3.4 or higher. @@ -604,14 +597,14 @@ cameras. Alternatively, you use a video file as input. ```bash username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ - https://storage.googleapis.com/bazel/3.4.0/release/bazel-3.4.0-installer-linux-x86_64.sh && \ - sudo mkdir -p /usr/local/bazel/3.4.0 && \ - chmod 755 bazel-3.4.0-installer-linux-x86_64.sh && \ - sudo ./bazel-3.4.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.0 && \ - source /usr/local/bazel/3.4.0/lib/bazel/bin/bazel-complete.bash + https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \ + sudo mkdir -p /usr/local/bazel/3.4.1 && \ + chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \ + sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \ + source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash - username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.0/lib/bazel/bin/bazel version && \ - alias bazel='/usr/local/bazel/3.4.0/lib/bazel/bin/bazel' + username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \ + alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel' ``` 6. Checkout MediaPipe repository. diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index 4fd332630..c97f2a839 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -70,6 +70,11 @@ Python package from source. Otherwise, we strongly encourage our users to simply run `pip install mediapipe` to use the ready-to-use solutions, more convenient and much faster. +MediaPipe PyPI currently doesn't provide aarch64 Python wheel +files. For building and using MediaPipe Python on aarch64 Linux systems such as +Nvidia Jetson and Raspberry Pi, please read +[here](https://github.com/jiuqiant/mediapipe-python-aarch64). + 1. Make sure that Bazel and OpenCV are correctly installed and configured for MediaPipe. Please see [Installation](./install.md) for how to setup Bazel and OpenCV for MediaPipe on Linux and macOS. @@ -82,12 +87,18 @@ and much faster. $ sudo apt install python3-dev $ sudo apt install python3-venv $ sudo apt install -y protobuf-compiler + + # If you need to build opencv from source. + $ sudo apt install cmake ``` macOS: ```bash $ brew install protobuf + + # If you need to build opencv from source. + $ brew install cmake ``` Windows: @@ -118,3 +129,10 @@ and much faster. (mp_env)mediapipe$ python3 setup.py gen_protos (mp_env)mediapipe$ python3 setup.py install --link-opencv ``` + + or + + ```bash + (mp_env)mediapipe$ python3 setup.py gen_protos + (mp_env)mediapipe$ python3 setup.py bdist_wheel + ``` diff --git a/docs/index.md b/docs/index.md index 89e633e77..806e31c7f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -102,6 +102,8 @@ run code search using ## Publications +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) + in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) in Google AI Blog * [MediaPipe 3D Face Transform](https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html) diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index e0d941e59..c2de9185a 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -405,7 +405,7 @@ on how to build MediaPipe examples. ## Resources * Google AI Blog: - [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction on Device](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) + [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) * [Models and model cards](./models.md#holistic) [Colab]:https://mediapipe.page.link/holistic_py_colab diff --git a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen index 257bda2f5..862282f72 100644 --- a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen +++ b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen @@ -2,33 +2,33 @@ "additionalFilePaths" : [ "/BUILD", "mediapipe/BUILD", - "mediapipe/objc/BUILD", - "mediapipe/framework/BUILD", - "mediapipe/gpu/BUILD", - "mediapipe/objc/testing/app/BUILD", "mediapipe/examples/ios/common/BUILD", - "mediapipe/examples/ios/helloworld/BUILD", "mediapipe/examples/ios/facedetectioncpu/BUILD", "mediapipe/examples/ios/facedetectiongpu/BUILD", "mediapipe/examples/ios/faceeffect/BUILD", "mediapipe/examples/ios/facemeshgpu/BUILD", "mediapipe/examples/ios/handdetectiongpu/BUILD", "mediapipe/examples/ios/handtrackinggpu/BUILD", + "mediapipe/examples/ios/helloworld/BUILD", "mediapipe/examples/ios/holistictrackinggpu/BUILD", "mediapipe/examples/ios/iristrackinggpu/BUILD", "mediapipe/examples/ios/objectdetectioncpu/BUILD", "mediapipe/examples/ios/objectdetectiongpu/BUILD", "mediapipe/examples/ios/posetrackinggpu/BUILD", - "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD" + "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD", + "mediapipe/framework/BUILD", + "mediapipe/gpu/BUILD", + "mediapipe/objc/BUILD", + "mediapipe/objc/testing/app/BUILD" ], "buildTargets" : [ - "//mediapipe/examples/ios/helloworld:HelloWorldApp", "//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp", "//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp", "//mediapipe/examples/ios/faceeffect:FaceEffectApp", "//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp", "//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp", "//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp", + "//mediapipe/examples/ios/helloworld:HelloWorldApp", "//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp", "//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp", "//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp", @@ -91,13 +91,13 @@ "mediapipe/examples/ios", "mediapipe/examples/ios/common", "mediapipe/examples/ios/common/Base.lproj", - "mediapipe/examples/ios/helloworld", "mediapipe/examples/ios/facedetectioncpu", "mediapipe/examples/ios/facedetectiongpu", "mediapipe/examples/ios/faceeffect", "mediapipe/examples/ios/faceeffect/Base.lproj", "mediapipe/examples/ios/handdetectiongpu", "mediapipe/examples/ios/handtrackinggpu", + "mediapipe/examples/ios/helloworld", "mediapipe/examples/ios/holistictrackinggpu", "mediapipe/examples/ios/iristrackinggpu", "mediapipe/examples/ios/objectdetectioncpu", diff --git a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf index c829f2706..0b0f6569c 100644 --- a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf +++ b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf @@ -9,7 +9,6 @@ "packages" : [ "", "mediapipe", - "mediapipe/objc", "mediapipe/examples/ios", "mediapipe/examples/ios/facedetectioncpu", "mediapipe/examples/ios/facedetectiongpu", @@ -22,7 +21,8 @@ "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/posetrackinggpu", - "mediapipe/examples/ios/upperbodyposetrackinggpu" + "mediapipe/examples/ios/upperbodyposetrackinggpu", + "mediapipe/objc" ], "projectName" : "Mediapipe", "workspaceRoot" : "../.." diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2def194ea..af98fef3a 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -146,6 +146,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", ], @@ -286,6 +287,7 @@ cc_library( deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -393,6 +395,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", ], alwayslink = 1, @@ -406,7 +409,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", "@eigen_archive//:eigen", @@ -422,7 +425,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", "@eigen_archive//:eigen", @@ -438,6 +441,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/stream_handler:mux_input_stream_handler", ], @@ -606,6 +610,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", @@ -958,6 +963,7 @@ cc_library( deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", ], alwayslink = 1, @@ -1007,6 +1013,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1042,6 +1049,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], diff --git a/mediapipe/calculators/core/add_header_calculator.cc b/mediapipe/calculators/core/add_header_calculator.cc index 918729ec0..dc0fa8aed 100644 --- a/mediapipe/calculators/core/add_header_calculator.cc +++ b/mediapipe/calculators/core/add_header_calculator.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { +namespace api2 { // Attach the header from a stream or side input to another stream. // @@ -42,49 +44,40 @@ namespace mediapipe { // output_stream: "audio_with_header" // } // -class AddHeaderCalculator : public CalculatorBase { +class AddHeaderCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - bool has_side_input = false; - bool has_header_stream = false; - if (cc->InputSidePackets().HasTag("HEADER")) { - cc->InputSidePackets().Tag("HEADER").SetAny(); - has_side_input = true; - } - if (cc->Inputs().HasTag("HEADER")) { - cc->Inputs().Tag("HEADER").SetNone(); - has_header_stream = true; - } - if (has_side_input == has_header_stream) { + static constexpr Input::Optional kHeader{"HEADER"}; + static constexpr SideInput::Optional kHeaderSide{"HEADER"}; + static constexpr Input kData{"DATA"}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut); + + static mediapipe::Status UpdateContract(CalculatorContract* cc) { + if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) { return mediapipe::InvalidArgumentError( "Header must be provided via exactly one of side input and input " "stream"); } - cc->Inputs().Tag("DATA").SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA")); return mediapipe::OkStatus(); } mediapipe::Status Open(CalculatorContext* cc) override { - Packet header; - if (cc->InputSidePackets().HasTag("HEADER")) { - header = cc->InputSidePackets().Tag("HEADER"); - } - if (cc->Inputs().HasTag("HEADER")) { - header = cc->Inputs().Tag("HEADER").Header(); - } + const PacketBase& header = + kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc); if (!header.IsEmpty()) { - cc->Outputs().Index(0).SetHeader(header); + kOut(cc).SetHeader(header); } - cc->SetOffset(TimestampDiff(0)); return mediapipe::OkStatus(); } mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value()); + kOut(cc).Send(kData(cc).packet()); return mediapipe::OkStatus(); } }; -REGISTER_CALCULATOR(AddHeaderCalculator); +MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator); + +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc index e94ab5ae8..fb405533b 100644 --- a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc @@ -16,6 +16,7 @@ #define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -23,27 +24,24 @@ #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Concatenates several NormalizedLandmarkList protos following stream index // order. This class assumes that every input stream contains a // NormalizedLandmarkList proto object. -class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase { +class ConcatenateNormalizedLandmarkListCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().NumEntries() != 0); - RET_CHECK(cc->Outputs().NumEntries() == 1); + static constexpr Input::Multiple kIn{""}; + static constexpr Output kOut{""}; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - cc->Inputs().Index(i).Set(); - } - - cc->Outputs().Index(0).Set(); + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + static mediapipe::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GE(kIn(cc).Count(), 1); return mediapipe::OkStatus(); } mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); only_emit_if_all_present_ = cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() .only_emit_if_all_present(); @@ -52,32 +50,29 @@ class ConcatenateNormalizedLandmarkListCalculator : public CalculatorBase { mediapipe::Status Process(CalculatorContext* cc) override { if (only_emit_if_all_present_) { - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) return mediapipe::OkStatus(); + for (int i = 0; i < kIn(cc).Count(); ++i) { + if (kIn(cc)[i].IsEmpty()) return mediapipe::OkStatus(); } } NormalizedLandmarkList output; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) continue; - const NormalizedLandmarkList& input = - cc->Inputs().Index(i).Get(); + for (int i = 0; i < kIn(cc).Count(); ++i) { + if (kIn(cc)[i].IsEmpty()) continue; + const NormalizedLandmarkList& input = *kIn(cc)[i]; for (int j = 0; j < input.landmark_size(); ++j) { - const NormalizedLandmark& input_landmark = input.landmark(j); - *output.add_landmark() = input_landmark; + *output.add_landmark() = input.landmark(j); } } - cc->Outputs().Index(0).AddPacket( - MakePacket(output).At(cc->InputTimestamp())); + kOut(cc).Send(std::move(output)); return mediapipe::OkStatus(); } private: bool only_emit_if_all_present_; }; +MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator); -REGISTER_CALCULATOR(ConcatenateNormalizedLandmarkListCalculator); - +} // namespace api2 } // namespace mediapipe // NOLINTNEXTLINE diff --git a/mediapipe/calculators/core/make_pair_calculator.cc b/mediapipe/calculators/core/make_pair_calculator.cc index 58029ea6b..5d3cf1daf 100644 --- a/mediapipe/calculators/core/make_pair_calculator.cc +++ b/mediapipe/calculators/core/make_pair_calculator.cc @@ -15,10 +15,12 @@ #include #include +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Given two input streams (A, B), output a single stream containing a pair. @@ -30,32 +32,27 @@ namespace mediapipe { // input_stream: "packet_b" // output_stream: "output_pair_a_b" // } -class MakePairCalculator : public CalculatorBase { +class MakePairCalculator : public Node { public: - MakePairCalculator() {} - ~MakePairCalculator() override {} + static constexpr Input::Multiple kIn{""}; + // Note that currently api2::Packet is a different type from mediapipe::Packet + static constexpr Output> + kPair{""}; - static mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Inputs().Index(1).SetAny(); - cc->Outputs().Index(0).Set>(); - return mediapipe::OkStatus(); - } + MEDIAPIPE_NODE_CONTRACT(kIn, kPair); - mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); + static mediapipe::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_EQ(kIn(cc).Count(), 2); return mediapipe::OkStatus(); } mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).Add( - new std::pair(cc->Inputs().Index(0).Value(), - cc->Inputs().Index(1).Value()), - cc->InputTimestamp()); + kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()}); return mediapipe::OkStatus(); } }; -REGISTER_CALCULATOR(MakePairCalculator); +MEDIAPIPE_REGISTER_NODE(MakePairCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_multiply_calculator.cc b/mediapipe/calculators/core/matrix_multiply_calculator.cc index e5f479511..fbc18297b 100644 --- a/mediapipe/calculators/core/matrix_multiply_calculator.cc +++ b/mediapipe/calculators/core/matrix_multiply_calculator.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "Eigen/Core" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Perform a (left) matrix multiply. Meaning (output = A * input) // where A is the matrix which is provided as an input side packet. // @@ -28,39 +30,22 @@ namespace mediapipe { // output_stream: "multiplied_samples" // input_side_packet: "multiplication_matrix" // } -class MatrixMultiplyCalculator : public CalculatorBase { +class MatrixMultiplyCalculator : public Node { public: - MatrixMultiplyCalculator() {} - ~MatrixMultiplyCalculator() override {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; + static constexpr SideInput kSide{""}; - static mediapipe::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide); - mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(MatrixMultiplyCalculator); - -// static -mediapipe::Status MatrixMultiplyCalculator::GetContract( - CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); - cc->InputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); -} - -mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) { - // The output is at the same timestamp as the input. - cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); -} +MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator); mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { - Matrix* multiplied = new Matrix(); - *multiplied = cc->InputSidePackets().Index(0).Get() * - cc->Inputs().Index(0).Get(); - cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp()); + kOut(cc).Send(*kSide(cc) * *kIn(cc)); return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator.cc b/mediapipe/calculators/core/matrix_subtract_calculator.cc index 4a9b18bbd..f526a0ceb 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "Eigen/Core" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // Subtract input matrix from the side input matrix and vice versa. The matrices // must have the same dimension. @@ -41,83 +43,40 @@ namespace mediapipe { // input_side_packet: "MINUEND:side_matrix" // output_stream: "output_matrix" // } -class MatrixSubtractCalculator : public CalculatorBase { +class MatrixSubtractCalculator : public Node { public: - MatrixSubtractCalculator() {} - ~MatrixSubtractCalculator() override {} + static constexpr Input::SideFallback kMinuend{"MINUEND"}; + static constexpr Input::SideFallback kSubtrahend{"SUBTRAHEND"}; + static constexpr Output kOut{""}; - static mediapipe::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut); + static mediapipe::Status UpdateContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; - - private: - bool subtract_from_input_ = false; }; -REGISTER_CALCULATOR(MatrixSubtractCalculator); +MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator); // static -mediapipe::Status MatrixSubtractCalculator::GetContract( +mediapipe::Status MatrixSubtractCalculator::UpdateContract( CalculatorContract* cc) { - if (cc->Inputs().NumEntries() != 1 || - cc->InputSidePackets().NumEntries() != 1) { - return mediapipe::InvalidArgumentError( - "MatrixSubtractCalculator only accepts exactly one input stream and " - "one " - "input side packet"); - } - if (cc->Inputs().HasTag("MINUEND") && - cc->InputSidePackets().HasTag("SUBTRAHEND")) { - cc->Inputs().Tag("MINUEND").Set(); - cc->InputSidePackets().Tag("SUBTRAHEND").Set(); - } else if (cc->Inputs().HasTag("SUBTRAHEND") && - cc->InputSidePackets().HasTag("MINUEND")) { - cc->Inputs().Tag("SUBTRAHEND").Set(); - cc->InputSidePackets().Tag("MINUEND").Set(); - } else { - return mediapipe::InvalidArgumentError( - "Must specify exactly one minuend and one subtrahend."); - } - cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); -} - -mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) { - // The output is at the same timestamp as the input. - cc->SetOffset(TimestampDiff(0)); - if (cc->Inputs().HasTag("MINUEND")) { - subtract_from_input_ = true; - } + // TODO: the next restriction could be relaxed. + RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream()) + << "MatrixSubtractCalculator only accepts exactly one input stream and " + "one input side packet"; return mediapipe::OkStatus(); } mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { - Matrix* subtracted = new Matrix(); - if (subtract_from_input_) { - const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get(); - const Matrix& side_input_matrix = - cc->InputSidePackets().Tag("SUBTRAHEND").Get(); - if (input_matrix.rows() != side_input_matrix.rows() || - input_matrix.cols() != side_input_matrix.cols()) { - return mediapipe::InvalidArgumentError( - "Input matrix and the input side matrix must have the same " - "dimension."); - } - *subtracted = input_matrix - side_input_matrix; - } else { - const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get(); - const Matrix& side_input_matrix = - cc->InputSidePackets().Tag("MINUEND").Get(); - if (input_matrix.rows() != side_input_matrix.rows() || - input_matrix.cols() != side_input_matrix.cols()) { - return mediapipe::InvalidArgumentError( - "Input matrix and the input side matrix must have the same " - "dimension."); - } - *subtracted = side_input_matrix - input_matrix; + const Matrix& minuend = *kMinuend(cc); + const Matrix& subtrahend = *kSubtrahend(cc); + if (minuend.rows() != subtrahend.rows() || + minuend.cols() != subtrahend.cols()) { + return mediapipe::InvalidArgumentError( + "Minuend and subtrahend must have the same dimensions."); } - cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp()); + kOut(cc).Send(minuend - subtrahend); return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc index 162d10e0c..92291050d 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc @@ -89,9 +89,8 @@ TEST(MatrixSubtractCalculatorTest, WrongConfig2) { )"); CalculatorRunner runner(node_config); auto status = runner.Run(); - EXPECT_THAT( - status.message(), - testing::HasSubstr("specify exactly one minuend and one subtrahend.")); + EXPECT_THAT(status.message(), testing::HasSubstr("must be connected")); + EXPECT_THAT(status.message(), testing::HasSubstr("not both")); } TEST(MatrixSubtractCalculatorTest, SubtractFromInput) { diff --git a/mediapipe/calculators/core/matrix_to_vector_calculator.cc b/mediapipe/calculators/core/matrix_to_vector_calculator.cc index 889ab22fa..cd10d2668 100644 --- a/mediapipe/calculators/core/matrix_to_vector_calculator.cc +++ b/mediapipe/calculators/core/matrix_to_vector_calculator.cc @@ -21,6 +21,7 @@ #include "Eigen/Core" #include "absl/memory/memory.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/integral_types.h" @@ -30,6 +31,7 @@ #include "mediapipe/util/time_series_util.h" namespace mediapipe { +namespace api2 { // A calculator that converts a Matrix M to a vector containing all the // entries of M in column-major order. @@ -40,33 +42,20 @@ namespace mediapipe { // input_stream: "input_matrix" // output_stream: "column_major_vector" // } -class MatrixToVectorCalculator : public CalculatorBase { +class MatrixToVectorCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set( - // Input Packet containing a Matrix. - ); - cc->Outputs().Index(0).Set>( - // Output Packet containing a vector, one for each input Packet. - ); - return mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr Output> kOut{""}; - mediapipe::Status Open(CalculatorContext* cc) override; + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); // Outputs a packet containing a vector for each input packet. mediapipe::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(MatrixToVectorCalculator); - -mediapipe::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we don't alter timestamps. - cc->SetOffset(mediapipe::TimestampDiff(0)); - return mediapipe::OkStatus(); -} +MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator); mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { - const Matrix& input = cc->Inputs().Index(0).Get(); + const Matrix& input = *kIn(cc); auto output = absl::make_unique>(); // The following lines work to convert the Matrix to a vector because Matrix @@ -76,8 +65,9 @@ mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { Eigen::Map(output->data(), input.rows(), input.cols()); output_as_matrix = input; - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + kOut(cc).Send(std::move(output)); return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_calculator.cc b/mediapipe/calculators/core/merge_calculator.cc index 9d67f9068..96056b4e3 100644 --- a/mediapipe/calculators/core/merge_calculator.cc +++ b/mediapipe/calculators/core/merge_calculator.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +namespace api2 { // This calculator takes a set of input streams and combines them into a single // output stream. The packets from different streams do not need to contain the @@ -41,40 +43,31 @@ namespace mediapipe { // output_stream: "merged_shot_infos" // } // -class MergeCalculator : public CalculatorBase { +class MergeCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_GT(cc->Inputs().NumEntries(), 0) - << "Needs at least one input stream"; - RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - if (cc->Inputs().NumEntries() == 1) { + static constexpr Input::Multiple kIn{""}; + static constexpr Output kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + static mediapipe::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream"; + if (kIn(cc).Count() == 1) { LOG(WARNING) << "MergeCalculator expects multiple input streams to merge but is " "receiving only one. Make sure the calculator is configured " "correctly or consider removing this calculator to reduce " "unnecessary overhead."; } - - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - cc->Inputs().Index(i).SetAny(); - } - cc->Outputs().Index(0).SetAny(); - - return mediapipe::OkStatus(); - } - - mediapipe::Status Open(CalculatorContext* cc) final { - cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); } mediapipe::Status Process(CalculatorContext* cc) final { // Output the packet from the first input stream with a packet ready at this // timestamp. - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (!cc->Inputs().Index(i).IsEmpty()) { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value()); + for (int i = 0; i < kIn(cc).Count(); ++i) { + if (!kIn(cc)[i].IsEmpty()) { + kOut(cc).Send(kIn(cc)[i].packet()); return mediapipe::OkStatus(); } } @@ -86,6 +79,7 @@ class MergeCalculator : public CalculatorBase { } }; -REGISTER_CALCULATOR(MergeCalculator); +MEDIAPIPE_REGISTER_NODE(MergeCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index 0100d4ce8..f488a5c98 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -12,97 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" namespace mediapipe { - -namespace { -constexpr char kSelectTag[] = "SELECT"; -constexpr char kInputTag[] = "INPUT"; -} // namespace +namespace api2 { // A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., -// using the integer value (0, 1, ...) in the packet on the kSelectTag input +// using the integer value (0, 1, ...) in the packet on the "SELECT" input // stream, and passes the packet on the selected input stream to the "OUTPUT" // output stream. -// The kSelectTag input can also be passed in as an input side packet, instead -// of as an input stream. Either of input stream or input side packet must be -// specified but not both. // // Note that this calculator defaults to use MuxInputStreamHandler, which is // required for this calculator. However, it can be overridden to work with // other InputStreamHandlers. Check out the unit tests on for an example usage // with DefaultInputStreamHandler. -class MuxCalculator : public CalculatorBase { +// TODO: why would you need to use DefaultISH? Perhaps b/167596925? +class MuxCalculator : public Node { public: - static mediapipe::Status CheckAndInitAllowDisallowInputs( - CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kSelectTag) ^ - cc->InputSidePackets().HasTag(kSelectTag)); - if (cc->Inputs().HasTag(kSelectTag)) { - cc->Inputs().Tag(kSelectTag).Set(); - } else { - cc->InputSidePackets().Tag(kSelectTag).Set(); - } - return mediapipe::OkStatus(); - } + static constexpr Input::SideFallback kSelect{"SELECT"}; + // TODO: this currently sets them all to Any independently, instead + // of the first being Any and the others being SameAs. + static constexpr Input::Multiple kIn{"INPUT"}; + static constexpr Output> kOut{"OUTPUT"}; - static mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); - CollectionItemId data_input_id = cc->Inputs().BeginId(kInputTag); - PacketType* data_input0 = &cc->Inputs().Get(data_input_id); - data_input0->SetAny(); - ++data_input_id; - for (; data_input_id < cc->Inputs().EndId(kInputTag); ++data_input_id) { - cc->Inputs().Get(data_input_id).SetSameAs(data_input0); - } - RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); - - cc->SetInputStreamHandler("MuxInputStreamHandler"); - MediaPipeOptions options; - cc->SetInputStreamHandlerOptions(options); - - return mediapipe::OkStatus(); - } - - mediapipe::Status Open(CalculatorContext* cc) final { - use_side_packet_select_ = false; - if (cc->InputSidePackets().HasTag(kSelectTag)) { - use_side_packet_select_ = true; - selected_index_ = cc->InputSidePackets().Tag(kSelectTag).Get(); - } else { - select_input_ = cc->Inputs().GetId(kSelectTag, 0); - } - data_input_base_ = cc->Inputs().GetId(kInputTag, 0); - num_data_inputs_ = cc->Inputs().NumEntries(kInputTag); - output_ = cc->Outputs().GetId("OUTPUT", 0); - cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); - } + MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut, + StreamHandler("MuxInputStreamHandler")); mediapipe::Status Process(CalculatorContext* cc) final { - int select = use_side_packet_select_ - ? selected_index_ - : cc->Inputs().Get(select_input_).Get(); - RET_CHECK(0 <= select && select < num_data_inputs_); - if (!cc->Inputs().Get(data_input_base_ + select).IsEmpty()) { - cc->Outputs().Get(output_).AddPacket( - cc->Inputs().Get(data_input_base_ + select).Value()); + int select = *kSelect(cc); + RET_CHECK(0 <= select && select < kIn(cc).Count()); + if (!kIn(cc)[select].IsEmpty()) { + kOut(cc).Send(kIn(cc)[select].packet()); } return mediapipe::OkStatus(); } - - private: - CollectionItemId select_input_; - CollectionItemId data_input_base_; - int num_data_inputs_ = 0; - CollectionItemId output_; - bool use_side_packet_select_; - int selected_index_; }; -REGISTER_CALCULATOR(MuxCalculator); +MEDIAPIPE_REGISTER_NODE(MuxCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index 46102d3ea..e42261c28 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -14,12 +14,14 @@ #include +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" namespace mediapipe { +namespace api2 { // PreviousLoopbackCalculator is useful when a graph needs to process an input // together with some previous output. @@ -51,15 +53,19 @@ namespace mediapipe { // input_stream: "PREV_TRACK:prev_output" // output_stream: "TRACK:output" // } -class PreviousLoopbackCalculator : public CalculatorBase { +class PreviousLoopbackCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Get("MAIN", 0).SetAny(); - cc->Inputs().Get("LOOP", 0).SetAny(); - cc->Outputs().Get("PREV_LOOP", 0).SetSameAs(&(cc->Inputs().Get("LOOP", 0))); - // TODO: an optional PREV_TIMESTAMP output could be added to - // carry the original timestamp of the packet on PREV_LOOP. - cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + static constexpr Input kMain{"MAIN"}; + static constexpr Input kLoop{"LOOP"}; + static constexpr Output> kPrevLoop{"PREV_LOOP"}; + // TODO: an optional PREV_TIMESTAMP output could be added to + // carry the original timestamp of the packet on PREV_LOOP. + + MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop, + StreamHandler("ImmediateInputStreamHandler"), + TimestampChange::Arbitrary()); + + static mediapipe::Status UpdateContract(CalculatorContract* cc) { // Process() function is invoked in response to MAIN/LOOP stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); @@ -67,12 +73,7 @@ class PreviousLoopbackCalculator : public CalculatorBase { } mediapipe::Status Open(CalculatorContext* cc) final { - main_id_ = cc->Inputs().GetId("MAIN", 0); - loop_id_ = cc->Inputs().GetId("LOOP", 0); - prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0); - cc->Outputs() - .Get(prev_loop_id_) - .SetHeader(cc->Inputs().Get(loop_id_).Header()); + kPrevLoop(cc).SetHeader(kLoop(cc).Header()); return mediapipe::OkStatus(); } @@ -82,48 +83,47 @@ class PreviousLoopbackCalculator : public CalculatorBase { // packets within the same stream. Calculator tracks and operates on such // packets. - const Packet& main_packet = cc->Inputs().Get(main_id_).Value(); - if (prev_main_ts_ < main_packet.Timestamp()) { + const PacketBase& main_packet = kMain(cc).packet(); + if (prev_main_ts_ < main_packet.timestamp()) { Timestamp loop_timestamp; if (!main_packet.IsEmpty()) { loop_timestamp = prev_non_empty_main_ts_; - prev_non_empty_main_ts_ = main_packet.Timestamp(); + prev_non_empty_main_ts_ = main_packet.timestamp(); } else { // Calculator advances PREV_LOOP timestamp bound in response to empty // MAIN packet, hence not caring about corresponding loop packet. loop_timestamp = Timestamp::Unset(); } - main_packet_specs_.push_back({main_packet.Timestamp(), loop_timestamp}); - prev_main_ts_ = main_packet.Timestamp(); + main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp}); + prev_main_ts_ = main_packet.timestamp(); } - const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value(); - if (prev_loop_ts_ < loop_packet.Timestamp()) { + const PacketBase& loop_packet = kLoop(cc).packet(); + if (prev_loop_ts_ < loop_packet.timestamp()) { loop_packets_.push_back(loop_packet); - prev_loop_ts_ = loop_packet.Timestamp(); + prev_loop_ts_ = loop_packet.timestamp(); } - auto& prev_loop = cc->Outputs().Get(prev_loop_id_); while (!main_packet_specs_.empty() && !loop_packets_.empty()) { // The earliest MAIN packet. const MainPacketSpec& main_spec = main_packet_specs_.front(); // The earliest LOOP packet. - const Packet& loop_candidate = loop_packets_.front(); + const PacketBase& loop_candidate = loop_packets_.front(); // Match LOOP and MAIN packets. - if (main_spec.loop_timestamp < loop_candidate.Timestamp()) { + if (main_spec.loop_timestamp < loop_candidate.timestamp()) { // No LOOP packet can match the MAIN packet under review. - prev_loop.SetNextTimestampBound(main_spec.timestamp + 1); + kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); main_packet_specs_.pop_front(); - } else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) { + } else if (main_spec.loop_timestamp > loop_candidate.timestamp()) { // No MAIN packet can match the LOOP packet under review. loop_packets_.pop_front(); } else { // Exact match found. if (loop_candidate.IsEmpty()) { // However, LOOP packet is empty. - prev_loop.SetNextTimestampBound(main_spec.timestamp + 1); + kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); } else { - prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp)); + kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); } loop_packets_.pop_front(); main_packet_specs_.pop_front(); @@ -135,7 +135,7 @@ class PreviousLoopbackCalculator : public CalculatorBase { // b) Empty MAIN packet has been received with Timestamp::Max() indicating // MAIN is done. if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) { - prev_loop.Close(); + kPrevLoop(cc).Close(); } } @@ -150,10 +150,6 @@ class PreviousLoopbackCalculator : public CalculatorBase { Timestamp loop_timestamp; }; - CollectionItemId main_id_; - CollectionItemId loop_id_; - CollectionItemId prev_loop_id_; - // Contains specs for MAIN packets which only can be: // - non-empty packets // - empty packets indicating timestamp bound updates @@ -169,12 +165,13 @@ class PreviousLoopbackCalculator : public CalculatorBase { // - empty packets indicating timestamp bound updates // // Sorted according to packet timestamps. - std::deque loop_packets_; + std::deque loop_packets_; // Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to // allow addition of the very first empty packet (which doesn't indicate // timestamp bound change necessarily). Timestamp prev_loop_ts_ = Timestamp::Unset(); }; -REGISTER_CALCULATOR(PreviousLoopbackCalculator); +MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index e288128e1..c2c9b8e5e 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -15,9 +15,11 @@ #include #include "mediapipe/calculators/core/sequence_shift_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" namespace mediapipe { +namespace api2 { // A Calculator that shifts the timestamps of packets along a stream. Packets on // the input stream are output with a timestamp of the packet given by packet @@ -28,24 +30,19 @@ namespace mediapipe { // of -1, the first packet on the stream will be dropped, the second will be // output with the timestamp of the first, the third with the timestamp of the // second, and so on. -class SequenceShiftCalculator : public CalculatorBase { +class SequenceShiftCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) { - cc->InputSidePackets().Tag(kPacketOffsetTag).Set(); - } - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr SideInput::Optional kOffset{"PACKET_OFFSET"}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOffset, kOut, TimestampChange::Arbitrary()); // Reads from options to set cache_size_ and packet_offset_. mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; private: - static constexpr const char* kPacketOffsetTag = "PACKET_OFFSET"; - // A positive offset means we want a packet to be output with the timestamp of // a later packet. Stores packets waiting for their output timestamps and // outputs a single packet when the cache fills. @@ -58,7 +55,7 @@ class SequenceShiftCalculator : public CalculatorBase { // Storage for packets waiting to be output when packet_offset > 0. When cache // is full, oldest packet is output with current timestamp. - std::deque packet_cache_; + std::deque packet_cache_; // Storage for previous timestamps used when packet_offset < 0. When cache is // full, oldest timestamp is used for current packet. @@ -70,14 +67,11 @@ class SequenceShiftCalculator : public CalculatorBase { // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; }; -REGISTER_CALCULATOR(SequenceShiftCalculator); +MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { - packet_offset_ = - cc->Options().packet_offset(); - if (cc->InputSidePackets().HasTag(kPacketOffsetTag)) { - packet_offset_ = cc->InputSidePackets().Tag(kPacketOffsetTag).Get(); - } + packet_offset_ = kOffset(cc).GetOr( + cc->Options().packet_offset()); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { @@ -92,7 +86,7 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { } else if (packet_offset_ < 0) { ProcessNegativeOffset(cc); } else { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + kOut(cc).Send(kIn(cc).packet()); } return mediapipe::OkStatus(); } @@ -100,23 +94,22 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { if (packet_cache_.size() >= cache_size_) { // Ready to output oldest packet with current timestamp. - cc->Outputs().Index(0).AddPacket( - packet_cache_.front().At(cc->InputTimestamp())); + kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); } // Store current packet for later output. - packet_cache_.push_back(cc->Inputs().Index(0).Value()); + packet_cache_.push_back(kIn(cc).packet()); } void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) { if (timestamp_cache_.size() >= cache_size_) { // Ready to output current packet with oldest timestamp. - cc->Outputs().Index(0).AddPacket( - cc->Inputs().Index(0).Value().At(timestamp_cache_.front())); + kOut(cc).Send(kIn(cc).packet().At(timestamp_cache_.front())); timestamp_cache_.pop_front(); } // Store current timestamp for use by a future packet. timestamp_cache_.push_back(cc->InputTimestamp()); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index 7dc558160..5f8a6e325 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -61,7 +61,7 @@ class StringToIntCalculatorTemplate : public CalculatorBase { using StringToIntCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToIntCalculator); -using StringToUintCalculator = StringToIntCalculatorTemplate; +using StringToUintCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUintCalculator); using StringToInt32Calculator = StringToIntCalculatorTemplate; diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 94daac793..dc465d4cd 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -59,6 +59,7 @@ cc_library( deps = [ ":inference_calculator_cc_proto", "@com_google_absl//absl/memory", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:tensor", "//mediapipe/util/tflite:tflite_model_loader", @@ -234,6 +235,7 @@ cc_library( "//mediapipe/framework/formats:detection_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", @@ -286,6 +288,7 @@ cc_library( deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", @@ -317,6 +320,7 @@ cc_library( deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", ], @@ -355,6 +359,7 @@ cc_library( ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", @@ -421,6 +426,7 @@ cc_library( ":image_to_tensor_converter", ":image_to_tensor_converter_opencv", ":image_to_tensor_utils", + "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 2a93355c4..775e0e70b 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -20,6 +20,7 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -45,16 +46,15 @@ #endif // !MEDIAPIPE_DISABLE_GPU -namespace { -constexpr char kInputCpu[] = "IMAGE"; -constexpr char kInputGpu[] = "IMAGE_GPU"; -constexpr char kOutputMatrix[] = "MATRIX"; -constexpr char kOutput[] = "TENSORS"; -constexpr char kInputNormRect[] = "NORM_RECT"; -constexpr char kOutputLetterboxPadding[] = "LETTERBOX_PADDING"; -} // namespace - namespace mediapipe { +namespace api2 { + +#if MEDIAPIPE_DISABLE_GPU +// Just a placeholder to not have to depend on mediapipe::GpuBuffer. +using GpuBuffer = AnyType; +#else +using GpuBuffer = mediapipe::GpuBuffer; +#endif // MEDIAPIPE_DISABLE_GPU // Converts image into Tensor, possibly with cropping, resizing and // normalization, according to specified inputs and options. @@ -110,9 +110,21 @@ namespace mediapipe { // } // } // } -class ImageToTensorCalculator : public CalculatorBase { +class ImageToTensorCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static constexpr Input::Optional kInCpu{"IMAGE"}; + static constexpr Input::Optional kInGpu{"IMAGE_GPU"}; + static constexpr Input::Optional kInNormRect{ + "NORM_RECT"}; + static constexpr Output> kOutTensors{"TENSORS"}; + static constexpr Output>::Optional kOutLetterboxPadding{ + "LETTERBOX_PADDING"}; + static constexpr Output>::Optional kOutMatrix{"MATRIX"}; + + MEDIAPIPE_NODE_CONTRACT(kInCpu, kInGpu, kInNormRect, kOutTensors, + kOutLetterboxPadding, kOutMatrix); + + static ::mediapipe::Status UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options(); @@ -126,24 +138,10 @@ class ImageToTensorCalculator : public CalculatorBase { RET_CHECK_GT(options.output_tensor_height(), 0) << "Valid output tensor height is required."; - if (cc->Inputs().HasTag(kInputNormRect)) { - cc->Inputs().Tag(kInputNormRect).Set(); - } - if (cc->Outputs().HasTag(kOutputLetterboxPadding)) { - cc->Outputs().Tag(kOutputLetterboxPadding).Set>(); - } - if (cc->Outputs().HasTag(kOutputMatrix)) { - cc->Outputs().Tag(kOutputMatrix).Set>(); - } + RET_CHECK(kInCpu(cc).IsConnected() ^ kInGpu(cc).IsConnected()) + << "One and only one of CPU or GPU input is expected."; - const bool has_cpu_input = cc->Inputs().HasTag(kInputCpu); - const bool has_gpu_input = cc->Inputs().HasTag(kInputGpu); - RET_CHECK_EQ((has_cpu_input ? 1 : 0) + (has_gpu_input ? 1 : 0), 1) - << "Either CPU or GPU input is expected, not both."; - - if (has_cpu_input) { - cc->Inputs().Tag(kInputCpu).Set(); - } else if (has_gpu_input) { + if (kInGpu(cc).IsConnected()) { #if MEDIAPIPE_DISABLE_GPU return mediapipe::UnimplementedError("GPU processing is disabled"); #else @@ -153,25 +151,20 @@ class ImageToTensorCalculator : public CalculatorBase { #else MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #endif // MEDIAPIPE_METAL_ENABLED - cc->Inputs().Tag(kInputGpu).Set(); #endif // MEDIAPIPE_DISABLE_GPU } - cc->Outputs().Tag(kOutput).Set>(); return mediapipe::OkStatus(); } mediapipe::Status Open(CalculatorContext* cc) { - // Makes sure outputs' next timestamp bound update is handled automatically - // by the framework. - cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); output_width_ = options_.output_tensor_width(); output_height_ = options_.output_tensor_height(); range_min_ = options_.output_tensor_float_range().min(); range_max_ = options_.output_tensor_float_range().max(); - if (cc->Inputs().HasTag(kInputCpu)) { + if (kInCpu(cc).IsConnected()) { ASSIGN_OR_RETURN(converter_, CreateOpenCvConverter(cc, GetBorderMode())); } else { #if MEDIAPIPE_DISABLE_GPU @@ -196,21 +189,20 @@ class ImageToTensorCalculator : public CalculatorBase { } mediapipe::Status Process(CalculatorContext* cc) { - const InputStreamShard& input = cc->Inputs().Tag( - cc->Inputs().HasTag(kInputCpu) ? kInputCpu : kInputGpu); - if (input.IsEmpty()) { + const PacketBase& image_packet = + kInCpu(cc).IsConnected() ? kInCpu(cc).packet() : kInGpu(cc).packet(); + if (image_packet.IsEmpty()) { // Timestamp bound update happens automatically. (See Open().) return mediapipe::OkStatus(); } absl::optional norm_rect; - if (cc->Inputs().HasTag(kInputNormRect)) { - if (cc->Inputs().Tag(kInputNormRect).IsEmpty()) { + if (kInNormRect(cc).IsConnected()) { + if (kInNormRect(cc).IsEmpty()) { // Timestamp bound update happens automatically. (See Open().) return mediapipe::OkStatus(); } - norm_rect = - cc->Inputs().Tag(kInputNormRect).Get(); + norm_rect = *kInNormRect(cc); if (norm_rect->width() == 0 && norm_rect->height() == 0) { // WORKAROUND: some existing graphs may use sentinel rects {width=0, // height=0, ...} quite often and calculator has to handle them @@ -223,27 +215,20 @@ class ImageToTensorCalculator : public CalculatorBase { } } - const Packet& image_packet = input.Value(); const Size& size = converter_->GetImageSize(image_packet); RotatedRect roi = GetRoi(size.width, size.height, norm_rect); ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), options_.output_tensor_height(), options_.keep_aspect_ratio(), &roi)); - if (cc->Outputs().HasTag(kOutputLetterboxPadding)) { - cc->Outputs() - .Tag(kOutputLetterboxPadding) - .AddPacket(MakePacket>(padding).At( - cc->InputTimestamp())); + if (kOutLetterboxPadding(cc).IsConnected()) { + kOutLetterboxPadding(cc).Send(padding); } - if (cc->Outputs().HasTag(kOutputMatrix)) { + if (kOutMatrix(cc).IsConnected()) { std::array matrix; GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height, /*flip_horizontaly=*/false, &matrix); - cc->Outputs() - .Tag(kOutputMatrix) - .AddPacket(MakePacket>(std::move(matrix)) - .At(cc->InputTimestamp())); + kOutMatrix(cc).Send(std::move(matrix)); } ASSIGN_OR_RETURN( @@ -251,11 +236,9 @@ class ImageToTensorCalculator : public CalculatorBase { converter_->Convert(image_packet, roi, {output_width_, output_height_}, range_min_, range_max_)); - std::vector result; - result.push_back(std::move(tensor)); - cc->Outputs().Tag(kOutput).AddPacket( - MakePacket>(std::move(result)) - .At(cc->InputTimestamp())); + auto result = std::make_unique>(); + result->push_back(std::move(tensor)); + kOutTensors(cc).Send(std::move(result)); return mediapipe::OkStatus(); } @@ -286,6 +269,7 @@ class ImageToTensorCalculator : public CalculatorBase { float range_max_ = 1.0f; }; -REGISTER_CALCULATOR(ImageToTensorCalculator); +MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index f675813b5..2c4e33fb3 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -19,6 +19,7 @@ #include "absl/memory/memory.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" @@ -88,7 +89,6 @@ bool ShouldUseGpu(const mediapipe::InferenceCalculatorOptions& options) { } constexpr char kTensorsTag[] = "TENSORS"; -} // namespace #if defined(MEDIAPIPE_EDGE_TPU) #include "edgetpu.h" @@ -112,7 +112,10 @@ std::unique_ptr BuildEdgeTpuInterpreter( } #endif // MEDIAPIPE_EDGE_TPU +} // namespace + namespace mediapipe { +namespace api2 { #if MEDIAPIPE_TFLITE_METAL_INFERENCE namespace { @@ -224,12 +227,19 @@ int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { // Tensors are assumed to be ordered correctly (sequentially added to model). // Input tensors are assumed to be of the correct size and already normalized. -class InferenceCalculator : public CalculatorBase { +class InferenceCalculator : public Node { public: using TfLiteDelegatePtr = std::unique_ptr>; - static mediapipe::Status GetContract(CalculatorContract* cc); + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr SideInput::Optional + kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput::Optional kSideInModel{"MODEL"}; + static constexpr Output> kOutTensors{"TENSORS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, + kOutTensors); + static mediapipe::Status UpdateContract(CalculatorContract* cc); mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; @@ -239,11 +249,12 @@ class InferenceCalculator : public CalculatorBase { mediapipe::Status ReadKernelsFromFile(); mediapipe::Status WriteKernelsToFile(); mediapipe::Status LoadModel(CalculatorContext* cc); - mediapipe::StatusOr GetModelAsPacket(const CalculatorContext& cc); + mediapipe::StatusOr GetModelAsPacket( + const CalculatorContext& cc); mediapipe::Status LoadDelegate(CalculatorContext* cc); mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); - Packet model_packet_; + mediapipe::Packet model_packet_; std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; @@ -277,28 +288,13 @@ class InferenceCalculator : public CalculatorBase { std::string cached_kernel_filename_; }; -REGISTER_CALCULATOR(InferenceCalculator); - -mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); - cc->Inputs().Tag(kTensorsTag).Set>(); - RET_CHECK(cc->Outputs().HasTag(kTensorsTag)); - cc->Outputs().Tag(kTensorsTag).Set>(); +MEDIAPIPE_REGISTER_NODE(InferenceCalculator); +mediapipe::Status InferenceCalculator::UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); - RET_CHECK(!options.model_path().empty() ^ - cc->InputSidePackets().HasTag("MODEL")) + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) << "Either model as side packet or model path in options is required."; - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Set(); - } - if (cc->InputSidePackets().HasTag("MODEL")) { - cc->InputSidePackets().Tag("MODEL").Set(); - } - if (ShouldUseGpu(options)) { #if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); @@ -310,8 +306,6 @@ mediapipe::Status InferenceCalculator::GetContract(CalculatorContract* cc) { } mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - #if MEDIAPIPE_TFLITE_GL_INFERENCE || MEDIAPIPE_TFLITE_METAL_INFERENCE const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); if (ShouldUseGpu(options)) { @@ -361,11 +355,10 @@ mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) { } mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { + if (kInTensors(cc).IsEmpty()) { return mediapipe::OkStatus(); } - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); + const auto& input_tensors = *kInTensors(cc); RET_CHECK(!input_tensors.empty()); auto output_tensors = absl::make_unique>(); #if MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -509,9 +502,7 @@ mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) { output_tensors->back().bytes()); } } - cc->Outputs() - .Tag(kTensorsTag) - .Add(output_tensors.release(), cc->InputTimestamp()); + kOutTensors(cc).Send(std::move(output_tensors)); return mediapipe::OkStatus(); } @@ -575,12 +566,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner( #if MEDIAPIPE_TFLITE_GL_INFERENCE ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver; - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver = cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get(); - } + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); // Create runner tflite::gpu::InferenceOptions options; @@ -629,12 +617,9 @@ mediapipe::Status InferenceCalculator::InitTFLiteGPURunner( mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver; - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { - op_resolver = cc->InputSidePackets() - .Tag("CUSTOM_OP_RESOLVER") - .Get(); - } + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); #if defined(MEDIAPIPE_EDGE_TPU) interpreter_ = @@ -659,7 +644,7 @@ mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) { return mediapipe::OkStatus(); } -mediapipe::StatusOr InferenceCalculator::GetModelAsPacket( +mediapipe::StatusOr InferenceCalculator::GetModelAsPacket( const CalculatorContext& cc) { const auto& options = cc.Options(); if (!options.model_path().empty()) { @@ -845,4 +830,5 @@ mediapipe::Status InferenceCalculator::LoadDelegate(CalculatorContext* cc) { return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index b9a72ac1d..e076e2451 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -19,6 +19,7 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/tensor.h" @@ -32,6 +33,7 @@ #endif namespace mediapipe { +namespace api2 { // Convert result tensors from classification models into MediaPipe // classifications. @@ -57,9 +59,12 @@ namespace mediapipe { // } // } // } -class TensorsToClassificationCalculator : public CalculatorBase { +class TensorsToClassificationCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Output kOutClassificationList{ + "CLASSIFICATIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutClassificationList); mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; @@ -71,28 +76,10 @@ class TensorsToClassificationCalculator : public CalculatorBase { std::unordered_map label_map_; bool label_map_loaded_ = false; }; -REGISTER_CALCULATOR(TensorsToClassificationCalculator); - -mediapipe::Status TensorsToClassificationCalculator::GetContract( - CalculatorContract* cc) { - RET_CHECK(!cc->Inputs().GetTags().empty()); - RET_CHECK(!cc->Outputs().GetTags().empty()); - - if (cc->Inputs().HasTag("TENSORS")) { - cc->Inputs().Tag("TENSORS").Set>(); - } - - if (cc->Outputs().HasTag("CLASSIFICATIONS")) { - cc->Outputs().Tag("CLASSIFICATIONS").Set(); - } - - return mediapipe::OkStatus(); -} +MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); mediapipe::Status TensorsToClassificationCalculator::Open( CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - options_ = cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>(); @@ -118,9 +105,7 @@ mediapipe::Status TensorsToClassificationCalculator::Open( mediapipe::Status TensorsToClassificationCalculator::Process( CalculatorContext* cc) { - const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); - + const auto& input_tensors = *kInTensors(cc); RET_CHECK_EQ(input_tensors.size(), 1); int num_classes = input_tensors[0].shape().num_elements(); @@ -182,10 +167,7 @@ mediapipe::Status TensorsToClassificationCalculator::Process( raw_classification_list->DeleteSubrange( top_k_, raw_classification_list->size() - top_k_); } - cc->Outputs() - .Tag("CLASSIFICATIONS") - .Add(classification_list.release(), cc->InputTimestamp()); - + kOutClassificationList(cc).Send(std::move(classification_list)); return mediapipe::OkStatus(); } @@ -194,4 +176,5 @@ mediapipe::Status TensorsToClassificationCalculator::Close( return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index c095ea8bb..52c2adb14 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -18,6 +18,7 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/detection.pb.h" @@ -47,9 +48,6 @@ namespace { constexpr int kNumInputTensorsWithAnchors = 3; constexpr int kNumCoordsPerBox = 4; -constexpr char kDetectionsTag[] = "DETECTIONS"; -constexpr char kTensorsTag[] = "TENSORS"; -constexpr char kAnchorsTag[] = "ANCHORS"; bool CanUseGpu() { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || MEDIAPIPE_METAL_ENABLED @@ -63,6 +61,7 @@ bool CanUseGpu() { } // namespace namespace mediapipe { +namespace api2 { namespace { @@ -128,9 +127,14 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, // } // } // } -class TensorsToDetectionsCalculator : public CalculatorBase { +class TensorsToDetectionsCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr SideInput>::Optional kInAnchors{ + "ANCHORS"}; + static constexpr Output> kOutDetections{"DETECTIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections); + static mediapipe::Status UpdateContract(CalculatorContract* cc); mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; @@ -161,7 +165,6 @@ class TensorsToDetectionsCalculator : public CalculatorBase { ::mediapipe::TensorsToDetectionsCalculatorOptions options_; std::vector anchors_; - bool side_packet_anchors_{}; #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE mediapipe::GlCalculatorHelper gpu_helper_; @@ -179,22 +182,10 @@ class TensorsToDetectionsCalculator : public CalculatorBase { bool gpu_input_ = false; bool anchors_init_ = false; }; -REGISTER_CALCULATOR(TensorsToDetectionsCalculator); +MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); -mediapipe::Status TensorsToDetectionsCalculator::GetContract( +mediapipe::Status TensorsToDetectionsCalculator::UpdateContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); - cc->Inputs().Tag(kTensorsTag).Set>(); - - RET_CHECK(cc->Outputs().HasTag(kDetectionsTag)); - cc->Outputs().Tag(kDetectionsTag).Set>(); - - if (cc->InputSidePackets().UsesTags()) { - if (cc->InputSidePackets().HasTag(kAnchorsTag)) { - cc->InputSidePackets().Tag(kAnchorsTag).Set>(); - } - } - if (CanUseGpu()) { #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); @@ -207,8 +198,6 @@ mediapipe::Status TensorsToDetectionsCalculator::GetContract( } mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - side_packet_anchors_ = cc->InputSidePackets().HasTag(kAnchorsTag); MP_RETURN_IF_ERROR(LoadOptions(cc)); if (CanUseGpu()) { @@ -226,18 +215,12 @@ mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { mediapipe::Status TensorsToDetectionsCalculator::Process( CalculatorContext* cc) { - if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return mediapipe::OkStatus(); - } - auto output_detections = absl::make_unique>(); - bool gpu_processing = false; if (CanUseGpu()) { // Use GPU processing only if at least one input tensor is already on GPU // (to avoid CPU->GPU overhead). - for (const auto& tensor : - cc->Inputs().Tag(kTensorsTag).Get>()) { + for (const auto& tensor : *kInTensors(cc)) { if (tensor.ready_on_gpu()) { gpu_processing = true; break; @@ -251,18 +234,13 @@ mediapipe::Status TensorsToDetectionsCalculator::Process( MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); } - // Output - cc->Outputs() - .Tag(kDetectionsTag) - .Add(output_detections.release(), cc->InputTimestamp()); - + kOutDetections(cc).Send(std::move(output_detections)); return mediapipe::OkStatus(); } mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( CalculatorContext* cc, std::vector* output_detections) { - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); + const auto& input_tensors = *kInTensors(cc); if (input_tensors.size() == 2 || input_tensors.size() == kNumInputTensorsWithAnchors) { @@ -294,10 +272,8 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( auto anchor_view = anchor_tensor->GetCpuReadView(); auto raw_anchors = anchor_view.buffer(); ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); - } else if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); - anchors_ = - cc->InputSidePackets().Tag("ANCHORS").Get>(); + } else if (!kInAnchors(cc).IsEmpty()) { + anchors_ = *kInAnchors(cc); } else { return mediapipe::UnavailableError("No anchor data available."); } @@ -391,8 +367,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); + const auto& input_tensors = *kInTensors(cc); RET_CHECK_GE(input_tensors.size(), 2); #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE @@ -400,21 +375,20 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( &output_detections]() -> mediapipe::Status { if (!anchors_init_) { - if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty()); - const auto& anchors = - cc->InputSidePackets().Tag(kAnchorsTag).Get>(); - auto anchors_view = raw_anchors_buffer_->GetCpuWriteView(); - auto raw_anchors = anchors_view.buffer(); - ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors); - } else { - CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); + if (input_tensors.size() == kNumInputTensorsWithAnchors) { auto read_view = input_tensors[2].GetOpenGlBufferReadView(); glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView(); glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, input_tensors[2].bytes()); + } else if (!kInAnchors(cc).IsEmpty()) { + const auto& anchors = *kInAnchors(cc); + auto anchors_view = raw_anchors_buffer_->GetCpuWriteView(); + auto raw_anchors = anchors_view.buffer(); + ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors); + } else { + return mediapipe::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -464,14 +438,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( #elif MEDIAPIPE_METAL_ENABLED id device = gpu_helper_.mtlDevice; if (!anchors_init_) { - if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag(kAnchorsTag).IsEmpty()); - const auto& anchors = - cc->InputSidePackets().Tag(kAnchorsTag).Get>(); - auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView(); - ConvertAnchorsToRawValues(anchors, num_boxes_, - raw_anchors_view.buffer()); - } else { + if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); auto command_buffer = [gpu_helper_ commandBuffer]; auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer); @@ -486,6 +453,13 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( size:input_tensors[2].bytes()]; [blit_command endEncoding]; [command_buffer commit]; + } else if (!kInAnchors(cc).IsEmpty()) { + const auto& anchors = *kInAnchors(cc); + auto raw_anchors_view = raw_anchors_buffer_->GetCpuWriteView(); + ConvertAnchorsToRawValues(anchors, num_boxes_, + raw_anchors_view.buffer()); + } else { + return mediapipe::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -1157,4 +1131,5 @@ kernel void scoreKernel( return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc index 8cb56f264..a95a9da8d 100644 --- a/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" @@ -43,47 +44,39 @@ inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); } // input_stream: "TENSORS:tensors" // output_stream: "FLOATS:floats" // } -class TensorsToFloatsCalculator : public CalculatorBase { +namespace api2 { +class TensorsToFloatsCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Output::Optional kOutFloat{"FLOAT"}; + static constexpr Output>::Optional kOutFloats{"FLOATS"}; + MEDIAPIPE_NODE_INTERFACE(TensorsToFloatsCalculator, kInTensors, kOutFloat, + kOutFloats); - mediapipe::Status Open(CalculatorContext* cc) override; - - mediapipe::Status Process(CalculatorContext* cc) override; + static mediapipe::Status UpdateContract(CalculatorContract* cc); + mediapipe::Status Open(CalculatorContext* cc) final; + mediapipe::Status Process(CalculatorContext* cc) final; private: ::mediapipe::TensorsToFloatsCalculatorOptions options_; }; -REGISTER_CALCULATOR(TensorsToFloatsCalculator); +MEDIAPIPE_REGISTER_NODE(TensorsToFloatsCalculator); -mediapipe::Status TensorsToFloatsCalculator::GetContract( +mediapipe::Status TensorsToFloatsCalculator::UpdateContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("TENSORS")); - RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); - - cc->Inputs().Tag("TENSORS").Set>(); - if (cc->Outputs().HasTag("FLOATS")) { - cc->Outputs().Tag("FLOATS").Set>(); - } - if (cc->Outputs().HasTag("FLOAT")) { - cc->Outputs().Tag("FLOAT").Set(); - } - + // Only exactly a single output allowed. + RET_CHECK(kOutFloat(cc).IsConnected() ^ kOutFloats(cc).IsConnected()); return mediapipe::OkStatus(); } mediapipe::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); options_ = cc->Options<::mediapipe::TensorsToFloatsCalculatorOptions>(); - return mediapipe::OkStatus(); } mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { - RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); - - const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); // TODO: Add option to specify which tensor to take from. auto view = input_tensors[0].GetCpuReadView(); auto raw_floats = view.buffer(); @@ -100,18 +93,15 @@ mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { break; } - if (cc->Outputs().HasTag("FLOAT")) { - // TODO: Could add an index in the option to specifiy returning one - // value of a float array. + if (kOutFloat(cc).IsConnected()) { + // TODO: Could add an index in the option to specifiy returning + // one value of a float array. RET_CHECK_EQ(num_values, 1); - cc->Outputs().Tag("FLOAT").AddPacket( - MakePacket(output_floats->at(0)).At(cc->InputTimestamp())); + kOutFloat(cc).Send(output_floats->at(0)); + } else { + kOutFloats(cc).Send(std::move(output_floats)); } - if (cc->Outputs().HasTag("FLOATS")) { - cc->Outputs().Tag("FLOATS").Add(output_floats.release(), - cc->InputTimestamp()); - } - return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc index dc4d26a36..ca69d1344 100644 --- a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" namespace mediapipe { +namespace api2 { namespace { @@ -85,9 +87,18 @@ float ApplyActivation( // } // } // } -class TensorsToLandmarksCalculator : public CalculatorBase { +class TensorsToLandmarksCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr Input::SideFallback::Optional kFlipHorizontally{ + "FLIP_HORIZONTALLY"}; + static constexpr Input::SideFallback::Optional kFlipVertically{ + "FLIP_VERTICALLY"}; + static constexpr Output::Optional kOutLandmarkList{"LANDMARKS"}; + static constexpr Output::Optional + kOutNormalizedLandmarkList{"NORM_LANDMARKS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kFlipHorizontally, kFlipVertically, + kOutLandmarkList, kOutNormalizedLandmarkList); mediapipe::Status Open(CalculatorContext* cc) override; mediapipe::Status Process(CalculatorContext* cc) override; @@ -95,100 +106,39 @@ class TensorsToLandmarksCalculator : public CalculatorBase { private: mediapipe::Status LoadOptions(CalculatorContext* cc); int num_landmarks_ = 0; - bool flip_vertically_ = false; - bool flip_horizontally_ = false; - ::mediapipe::TensorsToLandmarksCalculatorOptions options_; }; -REGISTER_CALCULATOR(TensorsToLandmarksCalculator); - -mediapipe::Status TensorsToLandmarksCalculator::GetContract( - CalculatorContract* cc) { - RET_CHECK(!cc->Inputs().GetTags().empty()); - RET_CHECK(!cc->Outputs().GetTags().empty()); - - if (cc->Inputs().HasTag("TENSORS")) { - cc->Inputs().Tag("TENSORS").Set>(); - } - - if (cc->Inputs().HasTag("FLIP_HORIZONTALLY")) { - cc->Inputs().Tag("FLIP_HORIZONTALLY").Set(); - } - - if (cc->Inputs().HasTag("FLIP_VERTICALLY")) { - cc->Inputs().Tag("FLIP_VERTICALLY").Set(); - } - - if (cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY")) { - cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Set(); - } - - if (cc->InputSidePackets().HasTag("FLIP_VERTICALLY")) { - cc->InputSidePackets().Tag("FLIP_VERTICALLY").Set(); - } - - if (cc->Outputs().HasTag("LANDMARKS")) { - cc->Outputs().Tag("LANDMARKS").Set(); - } - - if (cc->Outputs().HasTag("NORM_LANDMARKS")) { - cc->Outputs().Tag("NORM_LANDMARKS").Set(); - } - - return mediapipe::OkStatus(); -} +MEDIAPIPE_REGISTER_NODE(TensorsToLandmarksCalculator); mediapipe::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - MP_RETURN_IF_ERROR(LoadOptions(cc)); - if (cc->Outputs().HasTag("NORM_LANDMARKS")) { + if (kOutNormalizedLandmarkList(cc).IsConnected()) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) << "Must provide input with/height for getting normalized landmarks."; } - if (cc->Outputs().HasTag("LANDMARKS") && - (options_.flip_vertically() || options_.flip_horizontally() || - cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY") || - cc->InputSidePackets().HasTag("FLIP_VERTICALLY"))) { + if (kOutLandmarkList(cc).IsConnected() && + (options_.flip_horizontally() || options_.flip_vertically() || + kFlipHorizontally(cc).IsConnected() || + kFlipVertically(cc).IsConnected())) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for using flip_vertically option " - "when outputing landmarks in absolute coordinates."; + << "Must provide input with/height for using flipping when outputing " + "landmarks in absolute coordinates."; } - - flip_horizontally_ = - cc->InputSidePackets().HasTag("FLIP_HORIZONTALLY") - ? cc->InputSidePackets().Tag("FLIP_HORIZONTALLY").Get() - : options_.flip_horizontally(); - - flip_vertically_ = - cc->InputSidePackets().HasTag("FLIP_VERTICALLY") - ? cc->InputSidePackets().Tag("FLIP_VERTICALLY").Get() - : options_.flip_vertically(); - return mediapipe::OkStatus(); } mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { - // Override values if specified so. - if (cc->Inputs().HasTag("FLIP_HORIZONTALLY") && - !cc->Inputs().Tag("FLIP_HORIZONTALLY").IsEmpty()) { - flip_horizontally_ = cc->Inputs().Tag("FLIP_HORIZONTALLY").Get(); - } - if (cc->Inputs().HasTag("FLIP_VERTICALLY") && - !cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) { - flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get(); - } - - if (cc->Inputs().Tag("TENSORS").IsEmpty()) { + if (kInTensors(cc).IsEmpty()) { return mediapipe::OkStatus(); } + bool flip_horizontally = + kFlipHorizontally(cc).GetOr(options_.flip_horizontally()); + bool flip_vertically = kFlipVertically(cc).GetOr(options_.flip_vertically()); - const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); - + const auto& input_tensors = *kInTensors(cc); int num_values = input_tensors[0].shape().num_elements(); const int num_dimensions = num_values / num_landmarks_; CHECK_GT(num_dimensions, 0); @@ -202,13 +152,13 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { const int offset = ld * num_dimensions; Landmark* landmark = output_landmarks.add_landmark(); - if (flip_horizontally_) { + if (flip_horizontally) { landmark->set_x(options_.input_image_width() - raw_landmarks[offset]); } else { landmark->set_x(raw_landmarks[offset]); } if (num_dimensions > 1) { - if (flip_vertically_) { + if (flip_vertically) { landmark->set_y(options_.input_image_height() - raw_landmarks[offset + 1]); } else { @@ -229,7 +179,7 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { } // Output normalized landmarks if required. - if (cc->Outputs().HasTag("NORM_LANDMARKS")) { + if (kOutNormalizedLandmarkList(cc).IsConnected()) { NormalizedLandmarkList output_norm_landmarks; for (int i = 0; i < output_landmarks.landmark_size(); ++i) { const Landmark& landmark = output_landmarks.landmark(i); @@ -246,18 +196,12 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { norm_landmark->set_presence(landmark.presence()); } } - cc->Outputs() - .Tag("NORM_LANDMARKS") - .AddPacket(MakePacket(output_norm_landmarks) - .At(cc->InputTimestamp())); + kOutNormalizedLandmarkList(cc).Send(std::move(output_norm_landmarks)); } // Output absolute landmarks. - if (cc->Outputs().HasTag("LANDMARKS")) { - cc->Outputs() - .Tag("LANDMARKS") - .AddPacket(MakePacket(output_landmarks) - .At(cc->InputTimestamp())); + if (kOutLandmarkList(cc).IsConnected()) { + kOutLandmarkList(cc).Send(std::move(output_landmarks)); } return mediapipe::OkStatus(); @@ -272,4 +216,5 @@ mediapipe::Status TensorsToLandmarksCalculator::LoadOptions( return mediapipe::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index fd07bbe34..4b6d244e3 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -396,6 +396,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -841,6 +842,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index d4c054681..aece59e5a 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" @@ -57,7 +58,7 @@ namespace mpms = mediapipe::mediasequence; // bounding boxes from vector, and streams with the // "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's // associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints -// from unordered_map>>. "IMAGE_${NAME}", +// from flat_hash_map>>. "IMAGE_${NAME}", // "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of // each stream, which allows for multiple image streams to be included. However, // the default names are suppored by more tools. @@ -131,8 +132,8 @@ class PackMediaSequenceCalculator : public CalculatorBase { } cc->Inputs() .Tag(tag) - .Set>>>(); + .Set>>>(); } if (absl::StartsWith(tag, kBBoxTag)) { std::string key = ""; @@ -348,7 +349,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { const auto& keypoints = cc->Inputs() .Tag(tag) - .Get>>>(); for (const auto& pair : keypoints) { std::string prefix = mpms::merge_prefix(key, pair.first); diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index c71c3173a..19d03ecde 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -14,6 +14,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" @@ -537,8 +538,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { std::string test_video_id = "test_video_id"; mpms::SetClipMediaId(test_video_id, input_sequence.get()); - std::unordered_map>> points = - {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; + absl::flat_hash_map>> + points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; runner_->MutableInputs() ->Tag("KEYPOINTS_TEST") .packets.push_back(PointToForeign(&points).At(Timestamp(0))); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 8661e0744..2c956d63a 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -450,7 +450,9 @@ mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { } #else #if MEDIAPIPE_TFLITE_METAL_INFERENCE - if (gpu_inference_) { + // Metal delegate supports external command encoder only if all input and + // output buffers are on GPU. + if (gpu_inference_ && gpu_input_ && gpu_output_) { RET_CHECK( TFLGpuDelegateSetCommandEncoder(delegate_.get(), compute_encoder)); } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index fd81d4de9..e4bbc9145 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -934,11 +934,22 @@ cc_test( ], ) +mediapipe_proto_library( + name = "local_file_contents_calculator_proto", + srcs = ["local_file_contents_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/util/local_file_contents_calculator.cc b/mediapipe/calculators/util/local_file_contents_calculator.cc index 2883a961a..b9ec9e496 100644 --- a/mediapipe/calculators/util/local_file_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_contents_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "mediapipe/calculators/util/local_file_contents_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/util/resource_util.h" @@ -78,6 +79,8 @@ class LocalFileContentsCalculator : public CalculatorBase { mediapipe::Status Open(CalculatorContext* cc) override { CollectionItemId input_id = cc->InputSidePackets().BeginId(kFilePathTag); CollectionItemId output_id = cc->OutputSidePackets().BeginId(kContentsTag); + auto options = cc->Options(); + // Number of inputs and outpus is the same according to the contract. for (; input_id != cc->InputSidePackets().EndId(kFilePathTag); ++input_id, ++output_id) { @@ -86,7 +89,8 @@ class LocalFileContentsCalculator : public CalculatorBase { ASSIGN_OR_RETURN(file_path, PathToResourceAsFile(file_path)); std::string contents; - MP_RETURN_IF_ERROR(GetResourceContents(file_path, &contents)); + MP_RETURN_IF_ERROR( + GetResourceContents(file_path, &contents, options.read_as_binary())); cc->OutputSidePackets().Get(output_id).Set( MakePacket(std::move(contents))); } diff --git a/mediapipe/calculators/util/local_file_contents_calculator.proto b/mediapipe/calculators/util/local_file_contents_calculator.proto new file mode 100644 index 000000000..ca700fc58 --- /dev/null +++ b/mediapipe/calculators/util/local_file_contents_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LocalFileContentsCalculatorOptions { + extend CalculatorOptions { + optional LocalFileContentsCalculatorOptions ext = 346849340; + } + + // If true, set the file open mode to 'rb'. Otherwise, set the mode to 'r'. + optional bool read_as_binary = 1 [default = true]; +} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java index cda1819f5..b3a6dfeea 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/MainActivity.java @@ -56,8 +56,10 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { } catch (NameNotFoundException e) { Log.e(TAG, "Cannot find application info: " + e); } - + // Get allowed object category. String categoryName = applicationInfo.metaData.getString("categoryName"); + // Get maximum allowed number of objects. + int maxNumObjects = applicationInfo.metaData.getInt("maxNumObjects"); float[] modelScale = parseFloatArrayFromString( applicationInfo.metaData.getString("modelScale")); float[] modelTransform = parseFloatArrayFromString( @@ -70,6 +72,7 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { inputSidePackets.put("obj_texture", packetCreator.createRgbaImageFrame(objTexture)); inputSidePackets.put("box_texture", packetCreator.createRgbaImageFrame(boxTexture)); inputSidePackets.put("allowed_labels", packetCreator.createString(categoryName)); + inputSidePackets.put("max_num_objects", packetCreator.createInt32(maxNumObjects)); inputSidePackets.put("model_scale", packetCreator.createFloat32Array(modelScale)); inputSidePackets.put("model_transformation", packetCreator.createFloat32Array(modelTransform)); processor.setInputSidePackets(inputSidePackets); @@ -118,8 +121,8 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { } catch (RuntimeException e) { Log.e( TAG, - "MediaPipeException encountered adding packets to width and height" - + " input streams."); + "MediaPipeException encountered adding packets to input_width and input_height" + + " input streams.", e); } widthPacket.release(); heightPacket.release(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml index 4c4a5b930..10f8492ef 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/manifests/AndroidManifestCamera.xml @@ -8,6 +8,7 @@ + + + + Inputs().HasTag(kSelectTag)) { + cc->Inputs().Tag(kSelectTag).Set(); +} +``` + +you can write + +``` +static constexpr Input::Optional kSelect{"SELECT"}; +``` + +Instead of setting up the contract procedurally in `GetContract`, add ports to +the contract declaratively, as follows: + +``` +MEDIAPIPE_NODE_CONTRACT(kInput, kOutput); +``` + +To access an input in Process, instead of + +``` +int select = cc->Inputs().Tag(kSelectTag).Get(); +``` + +write + +``` +int select = kSelectTag(cc).Get(); // alternative: *kSelectTag(cc) +``` + +Sets of multiple ports can be declared with `::Multiple`. Note, also, that a tag +string must always be provided when declaring a port; use `""` for untagged +ports. For example: + + +``` +for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); +} +``` + +becomes + +``` +static constexpr Input::Multiple kIn{""}; +``` + +For output ports, the payload can be passed directly to the `Send` method. For +example, instead of + +``` +cc->Outputs().Index(0).Add( + new std::pair(cc->Inputs().Index(0).Value(), + cc->Inputs().Index(1).Value()), + cc->InputTimestamp()); +``` + +you can write + +``` +kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()}); +``` + +The input timestamp is propagated to the outputs by default. If your calculator +wants to alter timestamps, it must add a `TimestampChange` entry to its contract +declaration. For example: + +``` +MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop, + StreamHandler("ImmediateInputStreamHandler"), + TimestampChange::Arbitrary()); +``` + +Several calculators in +[`calculators/core`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core) and +[`calculators/tensor`](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/tensor) +have been updated to use this API. Reference them for more examples. + +More complete documentation will be provided in the future. + +## Builder API + +Documentation will be provided in the future. diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h new file mode 100644 index 000000000..10ad555a3 --- /dev/null +++ b/mediapipe/framework/api2/builder.h @@ -0,0 +1,576 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ +#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mediapipe/framework/api2/const_str.h" +#include "mediapipe/framework/api2/contract.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_contract.h" + +namespace mediapipe { +namespace api2 { +namespace builder { + +template +T& GetWithAutoGrow(std::vector>* vecp, int index) { + auto& vec = *vecp; + if (vec.size() <= index) { + vec.resize(index + 1); + } + if (vec[index] == nullptr) { + vec[index] = absl::make_unique(); + } + return *vec[index]; +} + +struct TagIndexLocation { + const std::string& tag; + std::size_t index; + std::size_t count; +}; + +template +class TagIndexMap { + public: + std::vector>& operator[](const std::string& tag) { + return map_[tag]; + } + + void Visit(std::function fun) const { + for (const auto& tagged : map_) { + TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; + for (const auto& item : tagged.second) { + fun(loc, *item); + ++loc.index; + } + } + } + + void Visit(std::function fun) { + for (auto& tagged : map_) { + TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; + for (auto& item : tagged.second) { + fun(loc, item.get()); + ++loc.index; + } + } + } + + // Note: entries are held by a unique_ptr to ensure pointers remain valid. + // Should use absl::flat_hash_map but ordering keys for now. + std::map>> map_; +}; + +// These structs are used internally to store information about the endpoints +// of a connection. +struct SourceBase; +struct DestinationBase { + SourceBase* source = nullptr; +}; +struct SourceBase { + std::vector dests_; + std::string name_; +}; + +// Following existing GraphConfig usage, we allow using a multiport as a single +// port as well. This is necessary for generic nodes, since we have no +// information about which ports are meant to be multiports or not, but it is +// also convenient with typed nodes. +template +class MultiPort : public Single { + public: + using Base = typename Single::Base; + + explicit MultiPort(std::vector>* vec) + : Single(vec), vec_(*vec) {} + + Single operator[](int index) { + CHECK_GE(index, 0); + return Single{&GetWithAutoGrow(&vec_, index)}; + } + + private: + std::vector>& vec_; +}; + +// These classes wrap references to the underlying source/destination +// endpoints, adding type information and the user-visible API. +template +class DestinationImpl { + public: + using Base = DestinationBase; + + explicit DestinationImpl(std::vector>* vec) + : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} + explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + DestinationBase& base_; +}; + +template +class DestinationImpl + : public MultiPort> { + public: + using MultiPort>::MultiPort; +}; + +template +class SourceImpl { + public: + using Base = SourceBase; + + // Src is used as the return type of fluent methods below. Since these are + // single-port methods, it is desirable to always decay to a reference to the + // single-port superclass, even if they are called on a multiport. + using Src = SourceImpl; + template + using Dst = DestinationImpl; + + // clang-format off + template + struct AllowConnection : public std::integral_constant{} || std::is_same{} || + std::is_same{}> {}; + // clang-format on + + explicit SourceImpl(std::vector>* vec) + : SourceImpl(&GetWithAutoGrow(vec, 0)) {} + explicit SourceImpl(SourceBase* base) : base_(*base) {} + + template {}, int>::type = 0> + Src& AddTarget(const Dst& dest) { + CHECK(dest.base_.source == nullptr); + dest.base_.source = &base_; + base_.dests_.emplace_back(&dest.base_); + return *this; + } + Src& SetName(std::string name) { + base_.name_ = std::move(name); + return *this; + } + template + Src& operator>>(const Dst& dest) { + return AddTarget(dest); + } + + private: + SourceBase& base_; +}; + +template +class SourceImpl + : public MultiPort> { + public: + using MultiPort>::MultiPort; +}; + +// A source and a destination correspond to an output/input stream on a node, +// and a side source and side destination correspond to an output/input side +// packet. +// For graph inputs/outputs, however, the inputs are sources, and the outputs +// are destinations. This is because graph ports are connected "from inside" +// when building the graph. +template +using Source = SourceImpl; +template +using SideSource = SourceImpl; +template +using Destination = DestinationImpl; +template +using SideDestination = DestinationImpl; + +class NodeBase { + public: + // TODO: right now access to an indexed port is made directly by + // specifying both a tag and an index. It would be better to represent this + // as a two-step lookup, first getting a multi-port, and then accessing one + // of its entries by index. However, for nodes without visible contracts we + // can't know whether a tag is indexable or not, so we would need the + // multi-port to also be usable as a port directly (representing index 0). + Source Out(const std::string& tag) { + return Source(&out_streams_[tag]); + } + + Destination In(const std::string& tag) { + return Destination(&in_streams_[tag]); + } + + SideSource SideOut(const std::string& tag) { + return SideSource(&out_sides_[tag]); + } + + SideDestination SideIn(const std::string& tag) { + return SideDestination(&in_sides_[tag]); + } + + // Convenience methods for accessing purely index-based ports. + Source Out(int index) { return Out("")[index]; } + + Destination In(int index) { return In("")[index]; } + + SideSource SideOut(int index) { return SideOut("")[index]; } + + SideDestination SideIn(int index) { return SideIn("")[index]; } + + template + T& GetOptions() { + options_used_ = true; + return *options_.MutableExtension(T::ext); + } + + protected: + NodeBase(std::string type) : type_(std::move(type)) {} + + std::string type_; + TagIndexMap in_streams_; + TagIndexMap out_streams_; + TagIndexMap in_sides_; + TagIndexMap out_sides_; + CalculatorOptions options_; + // ideally we'd just check if any extensions are set on options_ + bool options_used_ = false; + friend class Graph; +}; + +template +class Node; +#if __cplusplus >= 201703L +// Deduction guide to silence -Wctad-maybe-unsupported. +explicit Node()->Node; +#endif // C++17 + +template <> +class Node : public NodeBase { + public: + Node(std::string type) : NodeBase(std::move(type)) {} +}; + +using GenericNode = Node; + +template