mediapipe/mediapipe/framework/packet_generator_graph.cc
MediaPipe Team 50c92c6623 Project import generated by Copybara.
GitOrigin-RevId: 27c70b5fe62ab71189d358ca122ee4b19c817a8f
2021-07-27 19:36:32 -04:00

422 lines
16 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/packet_generator_graph.h"
#include <deque>
#include <functional>
#include <memory>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/delegating_executor.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/thread_pool_executor.h"
#include "mediapipe/framework/tool/status_util.h"
namespace mediapipe {
namespace {
// Create the input side packet set for a generator (provided by
// index in the canonical config). unrunnable is set to true if the
// generator cannot be run given the currently available side packets
// (and false otherwise). If an error occurs then unrunnable and
// input_side_packet_set are undefined.
absl::Status CreateInputsForGenerator(
const ValidatedGraphConfig& validated_graph, int generator_index,
const std::map<std::string, Packet>& side_packets,
PacketSet* input_side_packet_set, bool* unrunnable) {
const NodeTypeInfo& node_type_info =
validated_graph.GeneratorInfos()[generator_index];
const auto& generator_name = validated_graph.Config()
.packet_generator(generator_index)
.packet_generator();
// Fill the PacketSet (if possible).
*unrunnable = false;
std::vector<absl::Status> statuses;
for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId();
id < node_type_info.InputSidePacketTypes().EndId(); ++id) {
const std::string& name =
node_type_info.InputSidePacketTypes().TagMap()->Names()[id.value()];
std::map<std::string, Packet>::const_iterator it = side_packets.find(name);
if (it == side_packets.end()) {
*unrunnable = true;
continue;
}
input_side_packet_set->Get(id) = it->second;
absl::Status status =
node_type_info.InputSidePacketTypes().Get(id).Validate(
input_side_packet_set->Get(id));
if (!status.ok()) {
statuses.push_back(tool::AddStatusPrefix(
absl::StrCat("Input side packet \"", name,
"\" for PacketGenerator \"", generator_name,
"\" is not of the correct type: "),
status));
}
}
if (!statuses.empty()) {
return tool::CombinedStatus(
absl::StrCat(generator_name, " had invalid configuration."), statuses);
}
return absl::OkStatus();
}
// Generate the packets from a PacketGenerator, place them in
// output_side_packet_set, and validate their types.
absl::Status Generate(const ValidatedGraphConfig& validated_graph,
int generator_index,
const PacketSet& input_side_packet_set,
PacketSet* output_side_packet_set) {
const NodeTypeInfo& node_type_info =
validated_graph.GeneratorInfos()[generator_index];
const PacketGeneratorConfig& generator_config =
validated_graph.Config().packet_generator(generator_index);
const auto& generator_name = generator_config.packet_generator();
ASSIGN_OR_RETURN(
auto static_access,
internal::StaticAccessToGeneratorRegistry::CreateByNameInNamespace(
validated_graph.Package(), generator_name),
_ << generator_name << " is not a valid PacketGenerator.");
MP_RETURN_IF_ERROR(static_access->Generate(generator_config.options(),
input_side_packet_set,
output_side_packet_set))
.SetPrepend()
<< generator_name << "::Generate() failed. ";
MP_RETURN_IF_ERROR(ValidatePacketSet(node_type_info.OutputSidePacketTypes(),
*output_side_packet_set))
.SetPrepend()
<< generator_name
<< "::Generate() output packets were of incorrect type: ";
return absl::OkStatus();
}
// GeneratorScheduler schedules the packet generators in a validated graph for
// execution on an executor.
class GeneratorScheduler {
public:
// If "executor" is null, a DelegatingExecutor will be created internally.
// "initial" must be set to true for the first pass and false for subsequent
// passes. If "initial" is false, non_base_generators contains the non-base
// PacketGenerators (those not run at initialize time due to missing
// dependencies).
GeneratorScheduler(const ValidatedGraphConfig* validated_graph,
mediapipe::Executor* executor,
const std::vector<int>& non_base_generators, bool initial);
// Run a PacketGenerator on a given executor on the provided input
// side packets. After running the generator, schedule any generators
// which became runnable.
void GenerateAndScheduleNext(int generator_index,
std::map<std::string, Packet>* side_packets,
std::unique_ptr<PacketSet> input_side_packet_set)
ABSL_LOCKS_EXCLUDED(mutex_);
// Iterate through all generators in the config, scheduling any that
// are runnable (and haven't been scheduled yet).
void ScheduleAllRunnableGenerators(
std::map<std::string, Packet>* side_packets) ABSL_LOCKS_EXCLUDED(mutex_);
// Waits until there are no pending tasks.
void WaitUntilIdle() ABSL_LOCKS_EXCLUDED(mutex_);
// Stores the indexes of the packet generators that were not scheduled (or
// rather, not executed) in non_scheduled_generators. Returns the combined
// error status if there were errors while running the packet generators.
// NOTE: This method should only be called when there are no pending tasks.
absl::Status GetNonScheduledGenerators(
std::vector<int>* non_scheduled_generators) const;
private:
// Called by delegating_executor_ to add a task.
void AddApplicationThreadTask(std::function<void()> task);
// Run all the application thread tasks (which are kept track of in
// app_thread_tasks_).
void RunApplicationThreadTasks() ABSL_LOCKS_EXCLUDED(app_thread_mutex_);
const ValidatedGraphConfig* const validated_graph_;
mediapipe::Executor* executor_;
mutable absl::Mutex mutex_;
// The number of pending tasks.
int num_tasks_ ABSL_GUARDED_BY(mutex_) = 0;
// This condition variable is signaled when num_tasks_ becomes 0.
absl::CondVar idle_condvar_;
// Accumulates the error statuses while running the packet generators.
std::vector<absl::Status> statuses_ ABSL_GUARDED_BY(mutex_);
// scheduled_generators_[i] is true if the packet generator with index i was
// scheduled (or rather, executed).
std::vector<bool> scheduled_generators_ ABSL_GUARDED_BY(mutex_);
absl::Mutex app_thread_mutex_;
// Tasks to be executed on the application thread.
std::deque<std::function<void()>> app_thread_tasks_
ABSL_GUARDED_BY(app_thread_mutex_);
std::unique_ptr<internal::DelegatingExecutor> delegating_executor_;
};
GeneratorScheduler::GeneratorScheduler(
const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
const std::vector<int>& non_base_generators, bool initial)
: validated_graph_(validated_graph),
executor_(executor),
scheduled_generators_(validated_graph_->Config().packet_generator_size(),
!initial) {
if (!executor_) {
// Run on the application thread.
delegating_executor_ = absl::make_unique<internal::DelegatingExecutor>(
std::bind(&GeneratorScheduler::AddApplicationThreadTask, this,
std::placeholders::_1));
executor_ = delegating_executor_.get();
}
if (!initial) {
// Only schedule the non-base generators.
for (int generator_index : non_base_generators) {
scheduled_generators_[generator_index] = false;
}
}
}
void GeneratorScheduler::GenerateAndScheduleNext(
int generator_index, std::map<std::string, Packet>* side_packets,
std::unique_ptr<PacketSet> input_side_packet_set) {
{
absl::MutexLock lock(&mutex_);
if (!statuses_.empty()) {
// Return early, don't run the generator if we already have errors.
return;
}
}
PacketSet output_side_packet_set(
validated_graph_->GeneratorInfos()[generator_index]
.OutputSidePacketTypes()
.TagMap());
VLOG(1) << "Running generator " << generator_index;
absl::Status status =
Generate(*validated_graph_, generator_index, *input_side_packet_set,
&output_side_packet_set);
{
absl::MutexLock lock(&mutex_);
if (!status.ok()) {
statuses_.push_back(std::move(status));
return;
}
// Add packets to side_packets .
for (CollectionItemId id = output_side_packet_set.BeginId();
id < output_side_packet_set.EndId(); ++id) {
const auto& name = output_side_packet_set.TagMap()->Names()[id.value()];
auto item = side_packets->emplace(name, output_side_packet_set.Get(id));
if (!item.second) {
statuses_.push_back(absl::AlreadyExistsError(
absl::StrCat("Side packet \"", name, "\" was defined twice.")));
}
}
if (!statuses_.empty()) {
return;
}
}
// Check all generators and schedule any that have become runnable.
// TODO Instead of checking all of them, only check ones
// that have input side packets which we have just produced.
ScheduleAllRunnableGenerators(side_packets);
}
void GeneratorScheduler::ScheduleAllRunnableGenerators(
std::map<std::string, Packet>* side_packets) {
absl::MutexLock lock(&mutex_);
const auto& generators = validated_graph_->Config().packet_generator();
for (int index = 0; index < generators.size(); ++index) {
if (scheduled_generators_[index]) {
continue;
}
bool is_unrunnable = false;
// TODO Input side packet set should only be created once.
auto input_side_packet_set =
absl::make_unique<PacketSet>(validated_graph_->GeneratorInfos()[index]
.InputSidePacketTypes()
.TagMap());
absl::Status status =
CreateInputsForGenerator(*validated_graph_, index, *side_packets,
input_side_packet_set.get(), &is_unrunnable);
if (!status.ok()) {
statuses_.push_back(std::move(status));
continue;
}
if (is_unrunnable) {
continue;
}
// The Generator is runnable, schedule a callback to run it.
scheduled_generators_[index] = true;
VLOG(1) << "Scheduling generator " << index;
// Get around the fact that we can't capture a unique_ptr (this
// means a memory leak will result if the lambda is not run).
PacketSet* input_side_packet_set_ptr = input_side_packet_set.release();
++num_tasks_;
mutex_.Unlock();
executor_->Schedule(
[this, index, side_packets, input_side_packet_set_ptr]() {
GenerateAndScheduleNext(
index, side_packets,
std::unique_ptr<PacketSet>(input_side_packet_set_ptr));
{
absl::MutexLock lock(&mutex_);
--num_tasks_;
if (num_tasks_ == 0) {
idle_condvar_.Signal();
}
}
});
mutex_.Lock();
}
}
void GeneratorScheduler::WaitUntilIdle() {
if (executor_ == delegating_executor_.get()) {
// Run the tasks on the application thread.
RunApplicationThreadTasks();
} else {
absl::MutexLock lock(&mutex_);
while (num_tasks_ != 0) {
idle_condvar_.Wait(&mutex_);
}
}
}
absl::Status GeneratorScheduler::GetNonScheduledGenerators(
std::vector<int>* non_scheduled_generators) const {
non_scheduled_generators->clear();
absl::MutexLock lock(&mutex_);
if (!statuses_.empty()) {
return tool::CombinedStatus("PacketGeneratorGraph failed.", statuses_);
}
for (int i = 0; i < scheduled_generators_.size(); ++i) {
if (!scheduled_generators_[i]) {
non_scheduled_generators->push_back(i);
}
}
return absl::OkStatus();
}
void GeneratorScheduler::AddApplicationThreadTask(std::function<void()> task) {
absl::MutexLock lock(&app_thread_mutex_);
app_thread_tasks_.push_back(std::move(task));
}
void GeneratorScheduler::RunApplicationThreadTasks() {
while (true) {
std::function<void()> task_callback;
{
// Get the next task.
absl::MutexLock lock(&app_thread_mutex_);
if (app_thread_tasks_.empty()) {
break;
}
task_callback = std::move(app_thread_tasks_.front());
app_thread_tasks_.pop_front();
}
// Run the next task. Don't hold any lock, since this task could
// schedule further tasks to be run on the application thread.
task_callback();
}
}
} // namespace
PacketGeneratorGraph::~PacketGeneratorGraph() {}
absl::Status PacketGeneratorGraph::Initialize(
const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
const std::map<std::string, Packet>& input_side_packets) {
validated_graph_ = validated_graph;
executor_ = executor;
base_packets_ = input_side_packets;
MP_RETURN_IF_ERROR(
validated_graph_->CanAcceptSidePackets(input_side_packets));
return ExecuteGenerators(&base_packets_, &non_base_generators_,
/*initial=*/true);
}
absl::Status PacketGeneratorGraph::RunGraphSetup(
const std::map<std::string, Packet>& input_side_packets,
std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators) const {
*output_side_packets = base_packets_;
for (const std::pair<const std::string, Packet>& item : input_side_packets) {
auto iter = output_side_packets->find(item.first);
if (iter != output_side_packets->end()) {
return absl::AlreadyExistsError(
absl::StrCat("Side packet \"", iter->first, "\" was defined twice."));
}
output_side_packets->insert(iter, item);
}
std::vector<int> non_scheduled_generators_local;
if (!non_scheduled_generators)
non_scheduled_generators = &non_scheduled_generators_local;
MP_RETURN_IF_ERROR(
validated_graph_->CanAcceptSidePackets(input_side_packets));
// This type check on the required side packets is redundant with
// error checking in ExecuteGenerators, but we do it now to fail early.
MP_RETURN_IF_ERROR(
validated_graph_->ValidateRequiredSidePackets(*output_side_packets));
MP_RETURN_IF_ERROR(ExecuteGenerators(
output_side_packets, non_scheduled_generators, /*initial=*/false));
return absl::OkStatus();
}
absl::Status PacketGeneratorGraph::ExecuteGenerators(
std::map<std::string, Packet>* output_side_packets,
std::vector<int>* non_scheduled_generators, bool initial) const {
VLOG(1) << "ExecuteGenerators initial == " << initial;
// Iterate through the generators and produce as many output
// side packets as we can. The generators that don't have all the
// required input side packets are put into non_scheduled_generators.
// The ValidatedGraphConfig object is expected to already have sorted
// generators in topological order.
GeneratorScheduler scheduler(validated_graph_, executor_,
non_base_generators_, initial);
scheduler.ScheduleAllRunnableGenerators(output_side_packets);
// Do not return early if scheduler encountered an error. The lambdas
// in the executor must run in order to free resources.
scheduler.WaitUntilIdle();
// It is safe to return now, since all the tasks have run.
return scheduler.GetNonScheduledGenerators(non_scheduled_generators);
}
} // namespace mediapipe