Internal change
PiperOrigin-RevId: 503035081
This commit is contained in:
parent
66634bbef8
commit
97af47ebf5
|
@ -32,3 +32,29 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_processor",
|
||||
srcs = ["graph_processor.cc"],
|
||||
hdrs = ["graph_processor.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":packet_processor",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:input_stream_shard",
|
||||
"//mediapipe/framework:output_stream_shard",
|
||||
"//mediapipe/framework:validated_graph_config",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
110
mediapipe/framework/tool/switch/graph_processor.cc
Normal file
110
mediapipe/framework/tool/switch/graph_processor.cc
Normal file
|
@ -0,0 +1,110 @@
|
|||
#include "mediapipe/framework/tool/switch/graph_processor.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// TODO: add support for input and output side packets.
|
||||
absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) {
|
||||
graph_config_ = graph_config;
|
||||
|
||||
ASSIGN_OR_RETURN(graph_input_map_,
|
||||
tool::TagMap::Create(graph_config_.input_stream()));
|
||||
ASSIGN_OR_RETURN(graph_output_map_,
|
||||
tool::TagMap::Create(graph_config_.output_stream()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
const std::string& stream_name = graph_input_map_->Names().at(id.value());
|
||||
return graph_->AddPacketToInputStream(stream_name, packet);
|
||||
}
|
||||
|
||||
std::shared_ptr<tool::TagMap> GraphProcessor::InputTags() {
|
||||
return graph_input_map_;
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) {
|
||||
MP_RETURN_IF_ERROR(WaitUntilInitialized());
|
||||
auto it = consumer_ids_.find(id);
|
||||
if (it == consumer_ids_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Consumer stream not found: ", id.value()));
|
||||
}
|
||||
return consumer_->AddPacket(it->second, packet);
|
||||
}
|
||||
|
||||
void GraphProcessor::SetConsumer(PacketConsumer* consumer) {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
consumer_ = consumer;
|
||||
auto input_map = consumer_->InputTags();
|
||||
for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) {
|
||||
auto tag_index = input_map->TagAndIndexFromId(id);
|
||||
auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second);
|
||||
consumer_ids_[stream_id] = id;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::ObserveGraph() {
|
||||
for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId();
|
||||
++id) {
|
||||
std::string stream_name = graph_output_map_->Names().at(id.value());
|
||||
MP_RETURN_IF_ERROR(graph_->ObserveOutputStream(
|
||||
stream_name,
|
||||
[this, id](const Packet& packet) { return SendPacket(id, packet); },
|
||||
true));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::WaitUntilInitialized() {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) {
|
||||
return graph_ != nullptr && consumer_ != nullptr;
|
||||
};
|
||||
graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized),
|
||||
absl::Seconds(4));
|
||||
RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::Start() {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
graph_ = std::make_unique<CalculatorGraph>();
|
||||
|
||||
// The graph is validated here with its specified inputs and output.
|
||||
MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_));
|
||||
MP_RETURN_IF_ERROR(ObserveGraph());
|
||||
MP_RETURN_IF_ERROR(graph_->StartRun({}));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::Shutdown() {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
if (!graph_) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources());
|
||||
MP_RETURN_IF_ERROR(graph_->WaitUntilDone());
|
||||
graph_ = nullptr;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphProcessor::WaitUntilIdle() {
|
||||
absl::MutexLock lock(&graph_mutex_);
|
||||
return graph_->WaitUntilIdle();
|
||||
}
|
||||
|
||||
// TODO
|
||||
absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// TODO
|
||||
std::shared_ptr<tool::TagMap> GraphProcessor::SideInputTags() {
|
||||
return nullptr;
|
||||
}
|
||||
// TODO
|
||||
void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {}
|
||||
|
||||
} // namespace mediapipe
|
59
mediapipe/framework/tool/switch/graph_processor.h
Normal file
59
mediapipe/framework/tool/switch/graph_processor.h
Normal file
|
@ -0,0 +1,59 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/tool/switch/packet_processor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Processes MediaPipe Packets using a MediaPipe CalculatorGraph.
|
||||
class GraphProcessor : public PacketProcessor {
|
||||
public:
|
||||
GraphProcessor() = default;
|
||||
|
||||
// Configures this GraphProcessor to create a run a CalculatorGraph.
|
||||
absl::Status Initialize(CalculatorGraphConfig graph_config);
|
||||
|
||||
public:
|
||||
// The PacketProcessor interface.
|
||||
absl::Status AddPacket(CollectionItemId id, Packet packet) override;
|
||||
std::shared_ptr<tool::TagMap> InputTags() override;
|
||||
absl::Status SetSidePacket(CollectionItemId id, Packet packet) override;
|
||||
std::shared_ptr<tool::TagMap> SideInputTags() override;
|
||||
void SetConsumer(PacketConsumer* consumer) override;
|
||||
void SetSideConsumer(SidePacketConsumer* consumer) override;
|
||||
absl::Status Start() override;
|
||||
absl::Status Shutdown() override;
|
||||
absl::Status WaitUntilIdle() override;
|
||||
|
||||
private:
|
||||
// Sends a tagged output packet.
|
||||
absl::Status SendPacket(CollectionItemId id, Packet packet);
|
||||
|
||||
// Observes output packets from the calculator graph.
|
||||
absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_);
|
||||
|
||||
// Blocks until this GraphProcessor is initialized.
|
||||
absl::Status WaitUntilInitialized();
|
||||
|
||||
private:
|
||||
CalculatorGraphConfig graph_config_;
|
||||
std::shared_ptr<tool::TagMap> graph_input_map_;
|
||||
std::shared_ptr<tool::TagMap> graph_output_map_;
|
||||
std::map<CollectionItemId, CollectionItemId> consumer_ids_;
|
||||
|
||||
PacketConsumer* consumer_ = nullptr;
|
||||
std::map<std::string, Packet> side_packets_;
|
||||
std::unique_ptr<CalculatorGraph> graph_ ABSL_GUARDED_BY(graph_mutex_) =
|
||||
nullptr;
|
||||
absl::Mutex graph_mutex_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
|
@ -56,7 +56,7 @@ class SidePacketConsumer {
|
|||
virtual std::shared_ptr<tool::TagMap> SideInputTags() = 0;
|
||||
};
|
||||
|
||||
// SidePacketProducer deleivers several tagged constant packets.
|
||||
// SidePacketProducer delivers several tagged constant packets.
|
||||
class SidePacketProducer {
|
||||
public:
|
||||
virtual ~SidePacketProducer() = default;
|
||||
|
|
Loading…
Reference in New Issue
Block a user