From 97af47ebf55e910b5c2125cba2f878e396be1b14 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 18 Jan 2023 18:51:17 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 503035081 --- mediapipe/framework/tool/switch/BUILD | 26 +++++ .../framework/tool/switch/graph_processor.cc | 110 ++++++++++++++++++ .../framework/tool/switch/graph_processor.h | 59 ++++++++++ .../framework/tool/switch/packet_processor.h | 2 +- 4 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 mediapipe/framework/tool/switch/graph_processor.cc create mode 100644 mediapipe/framework/tool/switch/graph_processor.h diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD index 62f9095ef..e7a3ba741 100644 --- a/mediapipe/framework/tool/switch/BUILD +++ b/mediapipe/framework/tool/switch/BUILD @@ -32,3 +32,29 @@ cc_library( "//mediapipe/framework/port:status", ], ) + +cc_library( + name = "graph_processor", + srcs = ["graph_processor.cc"], + hdrs = ["graph_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":packet_processor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:input_stream_shard", + "//mediapipe/framework:output_stream_shard", + "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) diff --git a/mediapipe/framework/tool/switch/graph_processor.cc b/mediapipe/framework/tool/switch/graph_processor.cc new file mode 100644 index 000000000..f35730761 --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.cc @@ -0,0 +1,110 @@ +#include "mediapipe/framework/tool/switch/graph_processor.h" + +#include "absl/synchronization/mutex.h" + +namespace mediapipe { + +// TODO: add support for input and output side packets. +absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) { + graph_config_ = graph_config; + + ASSIGN_OR_RETURN(graph_input_map_, + tool::TagMap::Create(graph_config_.input_stream())); + ASSIGN_OR_RETURN(graph_output_map_, + tool::TagMap::Create(graph_config_.output_stream())); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) { + absl::MutexLock lock(&graph_mutex_); + const std::string& stream_name = graph_input_map_->Names().at(id.value()); + return graph_->AddPacketToInputStream(stream_name, packet); +} + +std::shared_ptr GraphProcessor::InputTags() { + return graph_input_map_; +} + +absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) { + MP_RETURN_IF_ERROR(WaitUntilInitialized()); + auto it = consumer_ids_.find(id); + if (it == consumer_ids_.end()) { + return absl::NotFoundError( + absl::StrCat("Consumer stream not found: ", id.value())); + } + return consumer_->AddPacket(it->second, packet); +} + +void GraphProcessor::SetConsumer(PacketConsumer* consumer) { + absl::MutexLock lock(&graph_mutex_); + consumer_ = consumer; + auto input_map = consumer_->InputTags(); + for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) { + auto tag_index = input_map->TagAndIndexFromId(id); + auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second); + consumer_ids_[stream_id] = id; + } +} + +absl::Status GraphProcessor::ObserveGraph() { + for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId(); + ++id) { + std::string stream_name = graph_output_map_->Names().at(id.value()); + MP_RETURN_IF_ERROR(graph_->ObserveOutputStream( + stream_name, + [this, id](const Packet& packet) { return SendPacket(id, packet); }, + true)); + } + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilInitialized() { + absl::MutexLock lock(&graph_mutex_); + auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) { + return graph_ != nullptr && consumer_ != nullptr; + }; + graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized), + absl::Seconds(4)); + RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out."; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Start() { + absl::MutexLock lock(&graph_mutex_); + graph_ = std::make_unique(); + + // The graph is validated here with its specified inputs and output. + MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_)); + MP_RETURN_IF_ERROR(ObserveGraph()); + MP_RETURN_IF_ERROR(graph_->StartRun({})); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Shutdown() { + absl::MutexLock lock(&graph_mutex_); + if (!graph_) { + return absl::OkStatus(); + } + MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph_->WaitUntilDone()); + graph_ = nullptr; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilIdle() { + absl::MutexLock lock(&graph_mutex_); + return graph_->WaitUntilIdle(); +} + +// TODO +absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) { + return absl::OkStatus(); +} +// TODO +std::shared_ptr GraphProcessor::SideInputTags() { + return nullptr; +} +// TODO +void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {} + +} // namespace mediapipe diff --git a/mediapipe/framework/tool/switch/graph_processor.h b/mediapipe/framework/tool/switch/graph_processor.h new file mode 100644 index 000000000..e2220b5dc --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.h @@ -0,0 +1,59 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/switch/packet_processor.h" + +namespace mediapipe { + +// Processes MediaPipe Packets using a MediaPipe CalculatorGraph. +class GraphProcessor : public PacketProcessor { + public: + GraphProcessor() = default; + + // Configures this GraphProcessor to create a run a CalculatorGraph. + absl::Status Initialize(CalculatorGraphConfig graph_config); + + public: + // The PacketProcessor interface. + absl::Status AddPacket(CollectionItemId id, Packet packet) override; + std::shared_ptr InputTags() override; + absl::Status SetSidePacket(CollectionItemId id, Packet packet) override; + std::shared_ptr SideInputTags() override; + void SetConsumer(PacketConsumer* consumer) override; + void SetSideConsumer(SidePacketConsumer* consumer) override; + absl::Status Start() override; + absl::Status Shutdown() override; + absl::Status WaitUntilIdle() override; + + private: + // Sends a tagged output packet. + absl::Status SendPacket(CollectionItemId id, Packet packet); + + // Observes output packets from the calculator graph. + absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_); + + // Blocks until this GraphProcessor is initialized. + absl::Status WaitUntilInitialized(); + + private: + CalculatorGraphConfig graph_config_; + std::shared_ptr graph_input_map_; + std::shared_ptr graph_output_map_; + std::map consumer_ids_; + + PacketConsumer* consumer_ = nullptr; + std::map side_packets_; + std::unique_ptr graph_ ABSL_GUARDED_BY(graph_mutex_) = + nullptr; + absl::Mutex graph_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h index 1789a46c5..d97883c53 100644 --- a/mediapipe/framework/tool/switch/packet_processor.h +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -56,7 +56,7 @@ class SidePacketConsumer { virtual std::shared_ptr SideInputTags() = 0; }; -// SidePacketProducer deleivers several tagged constant packets. +// SidePacketProducer delivers several tagged constant packets. class SidePacketProducer { public: virtual ~SidePacketProducer() = default;