// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Defines CalculatorState. #ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_ #define MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_ #include #include #include // TODO: Move protos in another CL after the C++ code migration. #include "absl/base/macros.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/counter.h" #include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/tool/options_map.h" namespace mediapipe { class ProfilingContext; // Holds data that the Calculator needs access to. This data is not // stored in Calculator directly since Calculator will be destroyed after // every CalculatorGraph::Run() . It is not stored in CalculatorNode // because Calculator should not depend on CalculatorNode. All // information conveyed in this class is flowing from the CalculatorNode // to the Calculator. class CalculatorState { public: CalculatorState(const std::string& node_name, int node_id, const std::string& calculator_type, const CalculatorGraphConfig::Node& node_config, std::shared_ptr profiling_context); CalculatorState(const CalculatorState&) = delete; CalculatorState& operator=(const CalculatorState&) = delete; ~CalculatorState(); // Called before every call to Calculator::Open() (during the PrepareForRun // phase). void ResetBetweenRuns(); const std::string& CalculatorType() const { return calculator_type_; } const CalculatorOptions& Options() const { return node_config_.options(); } // Returns the options given to this calculator. Template argument T must // be the type of the protobuf extension message or the protobuf::Any // message containing the options. template const T& Options() const { return options_.Get(); } const std::string& NodeName() const { return node_name_; } const int& NodeId() const { return node_id_; } //////////////////////////////////////// // Interface for Calculator. //////////////////////////////////////// const PacketSet& InputSidePackets() const { return *input_side_packets_; } OutputSidePacketSet& OutputSidePackets() { return *output_side_packets_; } // Returns a counter using the graph's counter factory. The counter's // name is the passed-in name, prefixed by the calculator NodeName. Counter* GetCounter(const std::string& name); // Returns a counter set, which can be passed to other classes, to generate // counters. NOTE: This differs from GetCounter, in that the counters // created by this counter set do not have the NodeName prefix. CounterFactory* GetCounterFactory(); std::shared_ptr GetSharedProfilingContext() const { return profiling_context_; } //////////////////////////////////////// // Interface for CalculatorNode. //////////////////////////////////////// // Sets the input side packets. void SetInputSidePackets(const PacketSet* input_side_packets); // Sets the output side packets. void SetOutputSidePackets(OutputSidePacketSet* output_side_packets); // Sets the counter factory. void SetCounterFactory(CounterFactory* counter_factory) { counter_factory_ = counter_factory; } absl::Status SetServicePacket(const GraphServiceBase& service, Packet packet) { return graph_service_manager_.SetServicePacket(service, packet); } template std::shared_ptr GetServiceObject(const GraphService& service) { return graph_service_manager_.GetServiceObject(service); } private: //////////////////////////////////////// // Persistent variables that are not cleared by ResetBetweenRuns(). //////////////////////////////////////// // The name associated with this calculator's node. const std::string node_name_; // The ID associated with this calculator's node. const int node_id_; // The registered type name of the Calculator. const std::string calculator_type_; // The Node protobuf containing the options for the calculator. const CalculatorGraphConfig::Node node_config_; // The unpacked protobuf options for the calculator. tool::OptionsMap options_; // The graph tracing and profiling interface. std::shared_ptr profiling_context_; GraphServiceManager graph_service_manager_; //////////////////////////////////////// // Variables which ARE cleared by ResetBetweenRuns(). //////////////////////////////////////// // The set of input side packets set by CalculatorNode::PrepareForRun(). // ResetBetweenRuns() clears this PacketSet pointer. const PacketSet* input_side_packets_; // The OutputSidePacketSet object is owned by the CalculatorNode. // CalculatorState obtains its pointer in CalculatorNode::PrepareForRun. OutputSidePacketSet* output_side_packets_; CounterFactory* counter_factory_; }; } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_STATE_H_