Internal change
PiperOrigin-RevId: 503035081
This commit is contained in:
parent
66634bbef8
commit
97af47ebf5
|
@ -32,3 +32,29 @@ cc_library(
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "graph_processor",
|
||||||
|
srcs = ["graph_processor.cc"],
|
||||||
|
hdrs = ["graph_processor.h"],
|
||||||
|
visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":packet_processor",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection_item_id",
|
||||||
|
"//mediapipe/framework:input_stream_shard",
|
||||||
|
"//mediapipe/framework:output_stream_shard",
|
||||||
|
"//mediapipe/framework:validated_graph_config",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
"//mediapipe/framework/port:logging",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
110
mediapipe/framework/tool/switch/graph_processor.cc
Normal file
110
mediapipe/framework/tool/switch/graph_processor.cc
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
#include "mediapipe/framework/tool/switch/graph_processor.h"
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// TODO: add support for input and output side packets.
|
||||||
|
absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) {
|
||||||
|
graph_config_ = graph_config;
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(graph_input_map_,
|
||||||
|
tool::TagMap::Create(graph_config_.input_stream()));
|
||||||
|
ASSIGN_OR_RETURN(graph_output_map_,
|
||||||
|
tool::TagMap::Create(graph_config_.output_stream()));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
const std::string& stream_name = graph_input_map_->Names().at(id.value());
|
||||||
|
return graph_->AddPacketToInputStream(stream_name, packet);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<tool::TagMap> GraphProcessor::InputTags() {
|
||||||
|
return graph_input_map_;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) {
|
||||||
|
MP_RETURN_IF_ERROR(WaitUntilInitialized());
|
||||||
|
auto it = consumer_ids_.find(id);
|
||||||
|
if (it == consumer_ids_.end()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("Consumer stream not found: ", id.value()));
|
||||||
|
}
|
||||||
|
return consumer_->AddPacket(it->second, packet);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphProcessor::SetConsumer(PacketConsumer* consumer) {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
consumer_ = consumer;
|
||||||
|
auto input_map = consumer_->InputTags();
|
||||||
|
for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) {
|
||||||
|
auto tag_index = input_map->TagAndIndexFromId(id);
|
||||||
|
auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second);
|
||||||
|
consumer_ids_[stream_id] = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::ObserveGraph() {
|
||||||
|
for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId();
|
||||||
|
++id) {
|
||||||
|
std::string stream_name = graph_output_map_->Names().at(id.value());
|
||||||
|
MP_RETURN_IF_ERROR(graph_->ObserveOutputStream(
|
||||||
|
stream_name,
|
||||||
|
[this, id](const Packet& packet) { return SendPacket(id, packet); },
|
||||||
|
true));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::WaitUntilInitialized() {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) {
|
||||||
|
return graph_ != nullptr && consumer_ != nullptr;
|
||||||
|
};
|
||||||
|
graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized),
|
||||||
|
absl::Seconds(4));
|
||||||
|
RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out.";
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::Start() {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
graph_ = std::make_unique<CalculatorGraph>();
|
||||||
|
|
||||||
|
// The graph is validated here with its specified inputs and output.
|
||||||
|
MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_));
|
||||||
|
MP_RETURN_IF_ERROR(ObserveGraph());
|
||||||
|
MP_RETURN_IF_ERROR(graph_->StartRun({}));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::Shutdown() {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
if (!graph_) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources());
|
||||||
|
MP_RETURN_IF_ERROR(graph_->WaitUntilDone());
|
||||||
|
graph_ = nullptr;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status GraphProcessor::WaitUntilIdle() {
|
||||||
|
absl::MutexLock lock(&graph_mutex_);
|
||||||
|
return graph_->WaitUntilIdle();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
// TODO
|
||||||
|
std::shared_ptr<tool::TagMap> GraphProcessor::SideInputTags() {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// TODO
|
||||||
|
void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
59
mediapipe/framework/tool/switch/graph_processor.h
Normal file
59
mediapipe/framework/tool/switch/graph_processor.h
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/collection_item_id.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
#include "mediapipe/framework/tool/switch/packet_processor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Processes MediaPipe Packets using a MediaPipe CalculatorGraph.
|
||||||
|
class GraphProcessor : public PacketProcessor {
|
||||||
|
public:
|
||||||
|
GraphProcessor() = default;
|
||||||
|
|
||||||
|
// Configures this GraphProcessor to create a run a CalculatorGraph.
|
||||||
|
absl::Status Initialize(CalculatorGraphConfig graph_config);
|
||||||
|
|
||||||
|
public:
|
||||||
|
// The PacketProcessor interface.
|
||||||
|
absl::Status AddPacket(CollectionItemId id, Packet packet) override;
|
||||||
|
std::shared_ptr<tool::TagMap> InputTags() override;
|
||||||
|
absl::Status SetSidePacket(CollectionItemId id, Packet packet) override;
|
||||||
|
std::shared_ptr<tool::TagMap> SideInputTags() override;
|
||||||
|
void SetConsumer(PacketConsumer* consumer) override;
|
||||||
|
void SetSideConsumer(SidePacketConsumer* consumer) override;
|
||||||
|
absl::Status Start() override;
|
||||||
|
absl::Status Shutdown() override;
|
||||||
|
absl::Status WaitUntilIdle() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Sends a tagged output packet.
|
||||||
|
absl::Status SendPacket(CollectionItemId id, Packet packet);
|
||||||
|
|
||||||
|
// Observes output packets from the calculator graph.
|
||||||
|
absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_);
|
||||||
|
|
||||||
|
// Blocks until this GraphProcessor is initialized.
|
||||||
|
absl::Status WaitUntilInitialized();
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraphConfig graph_config_;
|
||||||
|
std::shared_ptr<tool::TagMap> graph_input_map_;
|
||||||
|
std::shared_ptr<tool::TagMap> graph_output_map_;
|
||||||
|
std::map<CollectionItemId, CollectionItemId> consumer_ids_;
|
||||||
|
|
||||||
|
PacketConsumer* consumer_ = nullptr;
|
||||||
|
std::map<std::string, Packet> side_packets_;
|
||||||
|
std::unique_ptr<CalculatorGraph> graph_ ABSL_GUARDED_BY(graph_mutex_) =
|
||||||
|
nullptr;
|
||||||
|
absl::Mutex graph_mutex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_
|
|
@ -56,7 +56,7 @@ class SidePacketConsumer {
|
||||||
virtual std::shared_ptr<tool::TagMap> SideInputTags() = 0;
|
virtual std::shared_ptr<tool::TagMap> SideInputTags() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// SidePacketProducer deleivers several tagged constant packets.
|
// SidePacketProducer delivers several tagged constant packets.
|
||||||
class SidePacketProducer {
|
class SidePacketProducer {
|
||||||
public:
|
public:
|
||||||
virtual ~SidePacketProducer() = default;
|
virtual ~SidePacketProducer() = default;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user