Internal change

PiperOrigin-RevId: 503035081
This commit is contained in:
Hadon Nash 2023-01-18 18:51:17 -08:00 committed by Copybara-Service
parent 66634bbef8
commit 97af47ebf5
4 changed files with 196 additions and 1 deletions

View File

@ -32,3 +32,29 @@ cc_library(
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "graph_processor",
srcs = ["graph_processor.cc"],
hdrs = ["graph_processor.h"],
visibility = [
"//visibility:public",
],
deps = [
":packet_processor",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_shard",
"//mediapipe/framework:output_stream_shard",
"//mediapipe/framework:validated_graph_config",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
)

View File

@ -0,0 +1,110 @@
#include "mediapipe/framework/tool/switch/graph_processor.h"
#include "absl/synchronization/mutex.h"
namespace mediapipe {
// TODO: add support for input and output side packets.
absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) {
graph_config_ = graph_config;
ASSIGN_OR_RETURN(graph_input_map_,
tool::TagMap::Create(graph_config_.input_stream()));
ASSIGN_OR_RETURN(graph_output_map_,
tool::TagMap::Create(graph_config_.output_stream()));
return absl::OkStatus();
}
absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) {
absl::MutexLock lock(&graph_mutex_);
const std::string& stream_name = graph_input_map_->Names().at(id.value());
return graph_->AddPacketToInputStream(stream_name, packet);
}
std::shared_ptr<tool::TagMap> GraphProcessor::InputTags() {
return graph_input_map_;
}
absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) {
MP_RETURN_IF_ERROR(WaitUntilInitialized());
auto it = consumer_ids_.find(id);
if (it == consumer_ids_.end()) {
return absl::NotFoundError(
absl::StrCat("Consumer stream not found: ", id.value()));
}
return consumer_->AddPacket(it->second, packet);
}
void GraphProcessor::SetConsumer(PacketConsumer* consumer) {
absl::MutexLock lock(&graph_mutex_);
consumer_ = consumer;
auto input_map = consumer_->InputTags();
for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) {
auto tag_index = input_map->TagAndIndexFromId(id);
auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second);
consumer_ids_[stream_id] = id;
}
}
absl::Status GraphProcessor::ObserveGraph() {
for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId();
++id) {
std::string stream_name = graph_output_map_->Names().at(id.value());
MP_RETURN_IF_ERROR(graph_->ObserveOutputStream(
stream_name,
[this, id](const Packet& packet) { return SendPacket(id, packet); },
true));
}
return absl::OkStatus();
}
absl::Status GraphProcessor::WaitUntilInitialized() {
absl::MutexLock lock(&graph_mutex_);
auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) {
return graph_ != nullptr && consumer_ != nullptr;
};
graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized),
absl::Seconds(4));
RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out.";
return absl::OkStatus();
}
absl::Status GraphProcessor::Start() {
absl::MutexLock lock(&graph_mutex_);
graph_ = std::make_unique<CalculatorGraph>();
// The graph is validated here with its specified inputs and output.
MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_));
MP_RETURN_IF_ERROR(ObserveGraph());
MP_RETURN_IF_ERROR(graph_->StartRun({}));
return absl::OkStatus();
}
absl::Status GraphProcessor::Shutdown() {
absl::MutexLock lock(&graph_mutex_);
if (!graph_) {
return absl::OkStatus();
}
MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph_->WaitUntilDone());
graph_ = nullptr;
return absl::OkStatus();
}
absl::Status GraphProcessor::WaitUntilIdle() {
absl::MutexLock lock(&graph_mutex_);
return graph_->WaitUntilIdle();
}
// TODO
absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) {
return absl::OkStatus();
}
// TODO
std::shared_ptr<tool::TagMap> GraphProcessor::SideInputTags() {
return nullptr;
}
// TODO
void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {}
} // namespace mediapipe

View File

@ -0,0 +1,59 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
#include <memory>
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/switch/packet_processor.h"
namespace mediapipe {
// Processes MediaPipe Packets using a MediaPipe CalculatorGraph.
class GraphProcessor : public PacketProcessor {
public:
GraphProcessor() = default;
// Configures this GraphProcessor to create a run a CalculatorGraph.
absl::Status Initialize(CalculatorGraphConfig graph_config);
public:
// The PacketProcessor interface.
absl::Status AddPacket(CollectionItemId id, Packet packet) override;
std::shared_ptr<tool::TagMap> InputTags() override;
absl::Status SetSidePacket(CollectionItemId id, Packet packet) override;
std::shared_ptr<tool::TagMap> SideInputTags() override;
void SetConsumer(PacketConsumer* consumer) override;
void SetSideConsumer(SidePacketConsumer* consumer) override;
absl::Status Start() override;
absl::Status Shutdown() override;
absl::Status WaitUntilIdle() override;
private:
// Sends a tagged output packet.
absl::Status SendPacket(CollectionItemId id, Packet packet);
// Observes output packets from the calculator graph.
absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_);
// Blocks until this GraphProcessor is initialized.
absl::Status WaitUntilInitialized();
private:
CalculatorGraphConfig graph_config_;
std::shared_ptr<tool::TagMap> graph_input_map_;
std::shared_ptr<tool::TagMap> graph_output_map_;
std::map<CollectionItemId, CollectionItemId> consumer_ids_;
PacketConsumer* consumer_ = nullptr;
std::map<std::string, Packet> side_packets_;
std::unique_ptr<CalculatorGraph> graph_ ABSL_GUARDED_BY(graph_mutex_) =
nullptr;
absl::Mutex graph_mutex_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_

View File

@ -56,7 +56,7 @@ class SidePacketConsumer {
virtual std::shared_ptr<tool::TagMap> SideInputTags() = 0;
};
// SidePacketProducer deleivers several tagged constant packets.
// SidePacketProducer delivers several tagged constant packets.
class SidePacketProducer {
public:
virtual ~SidePacketProducer() = default;