// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for CalculatorRunner. #include "mediapipe/framework/calculator_runner.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { const char CalculatorRunner::kSourcePrefix[] = "source_for_"; const char CalculatorRunner::kSinkPrefix[] = "sink_for_"; namespace { // Calculator generating a stream with the given contents. // Inputs: none // Outputs: 1, with the contents provided via the input side packet. // Input side packets: 1, pointing to CalculatorRunner::StreamContents. class CalculatorRunnerSourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets() .Index(0) .Set(); cc->Outputs().Index(0).SetAny(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { const auto* contents = cc->InputSidePackets() .Index(0) .Get(); // Set the header and packets of the output stream. cc->Outputs().Index(0).SetHeader(contents->header); for (const Packet& packet : contents->packets) { cc->Outputs().Index(0).AddPacket(packet); } return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { return tool::StatusStop(); } }; REGISTER_CALCULATOR(CalculatorRunnerSourceCalculator); // Calculator recording the contents of a stream. // Inputs: 1, with the contents written to the input side packet. // Outputs: none // Input side packets: 1, pointing to CalculatorRunner::StreamContents. class CalculatorRunnerSinkCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->InputSidePackets().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { contents_ = cc->InputSidePackets() .Index(0) .Get(); contents_->header = cc->Inputs().Index(0).Header(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { contents_->packets.push_back(cc->Inputs().Index(0).Value()); return absl::OkStatus(); } private: CalculatorRunner::StreamContents* contents_ = nullptr; }; REGISTER_CALCULATOR(CalculatorRunnerSinkCalculator); } // namespace CalculatorRunner::CalculatorRunner( const CalculatorGraphConfig::Node& node_config) { MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); } absl::Status CalculatorRunner::InitializeFromNodeConfig( const CalculatorGraphConfig::Node& node_config) { node_config_ = node_config; if (node_config_.external_input_size() > 0) { RET_CHECK_EQ(0, node_config_.input_side_packet_size()) << "Only one of input_side_packet or (deprecated) external_input can " "be set."; node_config_.mutable_external_input()->Swap( node_config_.mutable_input_side_packet()); } ASSIGN_OR_RETURN(auto input_map, tool::TagMap::Create(node_config_.input_stream())); inputs_ = absl::make_unique(input_map); ASSIGN_OR_RETURN(auto output_map, tool::TagMap::Create(node_config_.output_stream())); outputs_ = absl::make_unique(output_map); ASSIGN_OR_RETURN(auto input_side_map, tool::TagMap::Create(node_config_.input_side_packet())); input_side_packets_ = absl::make_unique(input_side_map); ASSIGN_OR_RETURN(auto output_side_map, tool::TagMap::Create(node_config_.output_side_packet())); output_side_packets_ = absl::make_unique(output_side_map); return absl::OkStatus(); } CalculatorRunner::CalculatorRunner(const std::string& calculator_type, const CalculatorOptions& options) { node_config_.set_calculator(calculator_type); *node_config_.mutable_options() = options; log_calculator_proto_ = true; } #if !defined(MEDIAPIPE_PROTO_LITE) CalculatorRunner::CalculatorRunner(const std::string& node_config_string) { CalculatorGraphConfig::Node node_config; CHECK( proto_ns::TextFormat::ParseFromString(node_config_string, &node_config)); MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); } CalculatorRunner::CalculatorRunner(const std::string& calculator_type, const std::string& options_string, int num_inputs, int num_outputs, int num_side_packets) { node_config_.set_calculator(calculator_type); CHECK(proto_ns::TextFormat::ParseFromString(options_string, node_config_.mutable_options())); SetNumInputs(num_inputs); SetNumOutputs(num_outputs); SetNumInputSidePackets(num_side_packets); // Reset log_calculator_proto to false, since it was set to true by // SetNum*() calls above. This constructor is not deprecated but is // currently implemented in terms of deprecated functions. log_calculator_proto_ = false; } #endif CalculatorRunner::~CalculatorRunner() {} void CalculatorRunner::SetNumInputs(int n) { tool::TagAndNameInfo info; for (int i = 0; i < n; ++i) { info.names.push_back(absl::StrCat("input_", i)); } InitializeInputs(info); } void CalculatorRunner::SetNumOutputs(int n) { tool::TagAndNameInfo info; for (int i = 0; i < n; ++i) { info.names.push_back(absl::StrCat("output_", i)); } InitializeOutputs(info); } void CalculatorRunner::SetNumInputSidePackets(int n) { tool::TagAndNameInfo info; for (int i = 0; i < n; ++i) { info.names.push_back(absl::StrCat("side_packet_", i)); } InitializeInputSidePackets(info); } void CalculatorRunner::InitializeInputs(const tool::TagAndNameInfo& info) { CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK( tool::SetFromTagAndNameInfo(info, node_config_.mutable_input_stream())); inputs_.reset(new StreamContentsSet(info)); log_calculator_proto_ = true; } void CalculatorRunner::InitializeOutputs(const tool::TagAndNameInfo& info) { CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK( tool::SetFromTagAndNameInfo(info, node_config_.mutable_output_stream())); outputs_.reset(new StreamContentsSet(info)); log_calculator_proto_ = true; } void CalculatorRunner::InitializeInputSidePackets( const tool::TagAndNameInfo& info) { CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK(tool::SetFromTagAndNameInfo( info, node_config_.mutable_input_side_packet())); input_side_packets_.reset(new PacketSet(info)); log_calculator_proto_ = true; } mediapipe::Counter* CalculatorRunner::GetCounter(const std::string& name) { return graph_->GetCounterFactory()->GetCounter(name); } std::map CalculatorRunner::GetCountersValues() { return graph_->GetCounterFactory()->GetCounterSet()->GetCountersValues(); } absl::Status CalculatorRunner::BuildGraph() { if (graph_ != nullptr) { // The graph was already built. return absl::OkStatus(); } RET_CHECK(inputs_) << "The inputs were not initialized."; RET_CHECK(outputs_) << "The outputs were not initialized."; RET_CHECK(input_side_packets_) << "The input side packets were not initialized."; CalculatorGraphConfig config; // Add the calculator node. *(config.add_node()) = node_config_; for (int i = 0; i < node_config_.input_stream_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i), &tag, &index, &name)); // Add a source for each input stream. auto* node = config.add_node(); node->set_calculator("CalculatorRunnerSourceCalculator"); node->add_output_stream(name); node->add_input_side_packet(absl::StrCat(kSourcePrefix, name)); } for (int i = 0; i < node_config_.output_stream_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i), &tag, &index, &name)); // Add a sink for each output stream. auto* node = config.add_node(); node->set_calculator("CalculatorRunnerSinkCalculator"); node->add_input_stream(name); node->add_input_side_packet(absl::StrCat(kSinkPrefix, name)); } config.set_num_threads(1); if (log_calculator_proto_) { #if defined(MEDIAPIPE_PROTO_LITE) LOG(INFO) << "Please initialize CalculatorRunner using the recommended " "constructor:\n CalculatorRunner runner(node_config);"; #else std::string config_string; proto_ns::TextFormat::Printer printer; printer.SetInitialIndentLevel(4); printer.PrintToString(node_config_, &config_string); LOG(INFO) << "Please initialize CalculatorRunner using the recommended " "constructor:\n CalculatorRunner runner(R\"(\n" << config_string << "\n )\");"; #endif } graph_ = absl::make_unique(); MP_RETURN_IF_ERROR(graph_->Initialize(config)); return absl::OkStatus(); } absl::Status CalculatorRunner::Run() { MP_RETURN_IF_ERROR(BuildGraph()); // Set the input side packets for the sources. std::map input_side_packets; int positional_index = -1; for (int i = 0; i < node_config_.input_stream_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i), &tag, &index, &name)); const CalculatorRunner::StreamContents* contents; if (index == -1) { // positional_index considers the case when the tag is empty, which is // always the case when index == -1. If we ever support indices for // non-empty tags ("ABC:input1" and "ABC:input2" with automatic indices), // this should be changed to use a map insted. contents = &inputs_->Get(tag, ++positional_index); } else { contents = &inputs_->Get(tag, index); } input_side_packets.emplace(absl::StrCat(kSourcePrefix, name), Adopt(new auto(contents))); } // Set the input side packets for the calculator. positional_index = -1; for (int i = 0; i < node_config_.input_side_packet_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName( node_config_.input_side_packet(i), &tag, &index, &name)); const Packet* packet; if (index == -1) { packet = &input_side_packets_->Get(tag, ++positional_index); } else { packet = &input_side_packets_->Get(tag, index); } input_side_packets.emplace(name, *packet); } // Set the input side packets for the sinks. positional_index = -1; for (int i = 0; i < node_config_.output_stream_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i), &tag, &index, &name)); CalculatorRunner::StreamContents* contents; if (index == -1) { contents = &outputs_->Get(tag, ++positional_index); } else { contents = &outputs_->Get(tag, index); } // Clear |contents| because Run() may be called multiple times. *contents = CalculatorRunner::StreamContents(); input_side_packets.emplace(absl::StrCat(kSinkPrefix, name), Adopt(new auto(contents))); } MP_RETURN_IF_ERROR(graph_->Run(input_side_packets)); positional_index = -1; for (int i = 0; i < node_config_.output_side_packet_size(); ++i) { std::string name; std::string tag; int index; MP_RETURN_IF_ERROR(tool::ParseTagIndexName( node_config_.output_side_packet(i), &tag, &index, &name)); Packet& contents = output_side_packets_->Get( tag, (index == -1) ? ++positional_index : index); ASSIGN_OR_RETURN(contents, graph_->GetOutputSidePacket(name)); } return absl::OkStatus(); } } // namespace mediapipe