422 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			422 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright 2019 The MediaPipe Authors.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "mediapipe/framework/packet_generator_graph.h"
 | |
| 
 | |
| #include <deque>
 | |
| #include <functional>
 | |
| #include <memory>
 | |
| #include <utility>
 | |
| 
 | |
| #include "absl/strings/str_cat.h"
 | |
| #include "absl/strings/str_join.h"
 | |
| #include "absl/synchronization/mutex.h"
 | |
| #include "mediapipe/framework/delegating_executor.h"
 | |
| #include "mediapipe/framework/executor.h"
 | |
| #include "mediapipe/framework/packet_generator.h"
 | |
| #include "mediapipe/framework/packet_type.h"
 | |
| #include "mediapipe/framework/port.h"
 | |
| #include "mediapipe/framework/port/canonical_errors.h"
 | |
| #include "mediapipe/framework/port/logging.h"
 | |
| #include "mediapipe/framework/port/proto_ns.h"
 | |
| #include "mediapipe/framework/port/ret_check.h"
 | |
| #include "mediapipe/framework/port/status_builder.h"
 | |
| #include "mediapipe/framework/thread_pool_executor.h"
 | |
| #include "mediapipe/framework/tool/status_util.h"
 | |
| 
 | |
| namespace mediapipe {
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| // Create the input side packet set for a generator (provided by
 | |
| // index in the canonical config).  unrunnable is set to true if the
 | |
| // generator cannot be run given the currently available side packets
 | |
| // (and false otherwise).  If an error occurs then unrunnable and
 | |
| // input_side_packet_set are undefined.
 | |
| absl::Status CreateInputsForGenerator(
 | |
|     const ValidatedGraphConfig& validated_graph, int generator_index,
 | |
|     const std::map<std::string, Packet>& side_packets,
 | |
|     PacketSet* input_side_packet_set, bool* unrunnable) {
 | |
|   const NodeTypeInfo& node_type_info =
 | |
|       validated_graph.GeneratorInfos()[generator_index];
 | |
|   const auto& generator_name = validated_graph.Config()
 | |
|                                    .packet_generator(generator_index)
 | |
|                                    .packet_generator();
 | |
|   // Fill the PacketSet (if possible).
 | |
|   *unrunnable = false;
 | |
|   std::vector<absl::Status> statuses;
 | |
|   for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId();
 | |
|        id < node_type_info.InputSidePacketTypes().EndId(); ++id) {
 | |
|     const std::string& name =
 | |
|         node_type_info.InputSidePacketTypes().TagMap()->Names()[id.value()];
 | |
| 
 | |
|     std::map<std::string, Packet>::const_iterator it = side_packets.find(name);
 | |
|     if (it == side_packets.end()) {
 | |
|       *unrunnable = true;
 | |
|       continue;
 | |
|     }
 | |
|     input_side_packet_set->Get(id) = it->second;
 | |
|     absl::Status status =
 | |
|         node_type_info.InputSidePacketTypes().Get(id).Validate(
 | |
|             input_side_packet_set->Get(id));
 | |
|     if (!status.ok()) {
 | |
|       statuses.push_back(tool::AddStatusPrefix(
 | |
|           absl::StrCat("Input side packet \"", name,
 | |
|                        "\" for PacketGenerator \"", generator_name,
 | |
|                        "\" is not of the correct type: "),
 | |
|           status));
 | |
|     }
 | |
|   }
 | |
|   if (!statuses.empty()) {
 | |
|     return tool::CombinedStatus(
 | |
|         absl::StrCat(generator_name, " had invalid configuration."), statuses);
 | |
|   }
 | |
|   return absl::OkStatus();
 | |
| }
 | |
| 
 | |
| // Generate the packets from a PacketGenerator, place them in
 | |
| // output_side_packet_set, and validate their types.
 | |
| absl::Status Generate(const ValidatedGraphConfig& validated_graph,
 | |
|                       int generator_index,
 | |
|                       const PacketSet& input_side_packet_set,
 | |
|                       PacketSet* output_side_packet_set) {
 | |
|   const NodeTypeInfo& node_type_info =
 | |
|       validated_graph.GeneratorInfos()[generator_index];
 | |
|   const PacketGeneratorConfig& generator_config =
 | |
|       validated_graph.Config().packet_generator(generator_index);
 | |
|   const auto& generator_name = generator_config.packet_generator();
 | |
| 
 | |
|   ASSIGN_OR_RETURN(
 | |
|       auto static_access,
 | |
|       internal::StaticAccessToGeneratorRegistry::CreateByNameInNamespace(
 | |
|           validated_graph.Package(), generator_name),
 | |
|       _ << generator_name << " is not a valid PacketGenerator.");
 | |
|   MP_RETURN_IF_ERROR(static_access->Generate(generator_config.options(),
 | |
|                                              input_side_packet_set,
 | |
|                                              output_side_packet_set))
 | |
|           .SetPrepend()
 | |
|       << generator_name << "::Generate() failed. ";
 | |
| 
 | |
|   MP_RETURN_IF_ERROR(ValidatePacketSet(node_type_info.OutputSidePacketTypes(),
 | |
|                                        *output_side_packet_set))
 | |
|           .SetPrepend()
 | |
|       << generator_name
 | |
|       << "::Generate() output packets were of incorrect type: ";
 | |
|   return absl::OkStatus();
 | |
| }
 | |
| 
 | |
| // GeneratorScheduler schedules the packet generators in a validated graph for
 | |
| // execution on an executor.
 | |
| class GeneratorScheduler {
 | |
|  public:
 | |
|   // If "executor" is null, a DelegatingExecutor will be created internally.
 | |
|   // "initial" must be set to true for the first pass and false for subsequent
 | |
|   // passes. If "initial" is false, non_base_generators contains the non-base
 | |
|   // PacketGenerators (those not run at initialize time due to missing
 | |
|   // dependencies).
 | |
|   GeneratorScheduler(const ValidatedGraphConfig* validated_graph,
 | |
|                      mediapipe::Executor* executor,
 | |
|                      const std::vector<int>& non_base_generators, bool initial);
 | |
| 
 | |
|   // Run a PacketGenerator on a given executor on the provided input
 | |
|   // side packets.  After running the generator, schedule any generators
 | |
|   // which became runnable.
 | |
|   void GenerateAndScheduleNext(int generator_index,
 | |
|                                std::map<std::string, Packet>* side_packets,
 | |
|                                std::unique_ptr<PacketSet> input_side_packet_set)
 | |
|       ABSL_LOCKS_EXCLUDED(mutex_);
 | |
| 
 | |
|   // Iterate through all generators in the config, scheduling any that
 | |
|   // are runnable (and haven't been scheduled yet).
 | |
|   void ScheduleAllRunnableGenerators(
 | |
|       std::map<std::string, Packet>* side_packets) ABSL_LOCKS_EXCLUDED(mutex_);
 | |
| 
 | |
|   // Waits until there are no pending tasks.
 | |
|   void WaitUntilIdle() ABSL_LOCKS_EXCLUDED(mutex_);
 | |
| 
 | |
|   // Stores the indexes of the packet generators that were not scheduled (or
 | |
|   // rather, not executed) in non_scheduled_generators. Returns the combined
 | |
|   // error status if there were errors while running the packet generators.
 | |
|   // NOTE: This method should only be called when there are no pending tasks.
 | |
|   absl::Status GetNonScheduledGenerators(
 | |
|       std::vector<int>* non_scheduled_generators) const;
 | |
| 
 | |
|  private:
 | |
|   // Called by delegating_executor_ to add a task.
 | |
|   void AddApplicationThreadTask(std::function<void()> task);
 | |
| 
 | |
|   // Run all the application thread tasks (which are kept track of in
 | |
|   // app_thread_tasks_).
 | |
|   void RunApplicationThreadTasks() ABSL_LOCKS_EXCLUDED(app_thread_mutex_);
 | |
| 
 | |
|   const ValidatedGraphConfig* const validated_graph_;
 | |
|   mediapipe::Executor* executor_;
 | |
| 
 | |
|   mutable absl::Mutex mutex_;
 | |
|   // The number of pending tasks.
 | |
|   int num_tasks_ ABSL_GUARDED_BY(mutex_) = 0;
 | |
|   // This condition variable is signaled when num_tasks_ becomes 0.
 | |
|   absl::CondVar idle_condvar_;
 | |
|   // Accumulates the error statuses while running the packet generators.
 | |
|   std::vector<absl::Status> statuses_ ABSL_GUARDED_BY(mutex_);
 | |
|   // scheduled_generators_[i] is true if the packet generator with index i was
 | |
|   // scheduled (or rather, executed).
 | |
|   std::vector<bool> scheduled_generators_ ABSL_GUARDED_BY(mutex_);
 | |
| 
 | |
|   absl::Mutex app_thread_mutex_;
 | |
|   // Tasks to be executed on the application thread.
 | |
|   std::deque<std::function<void()>> app_thread_tasks_
 | |
|       ABSL_GUARDED_BY(app_thread_mutex_);
 | |
|   std::unique_ptr<internal::DelegatingExecutor> delegating_executor_;
 | |
| };
 | |
| 
 | |
| GeneratorScheduler::GeneratorScheduler(
 | |
|     const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
 | |
|     const std::vector<int>& non_base_generators, bool initial)
 | |
|     : validated_graph_(validated_graph),
 | |
|       executor_(executor),
 | |
|       scheduled_generators_(validated_graph_->Config().packet_generator_size(),
 | |
|                             !initial) {
 | |
|   if (!executor_) {
 | |
|     // Run on the application thread.
 | |
|     delegating_executor_ = absl::make_unique<internal::DelegatingExecutor>(
 | |
|         std::bind(&GeneratorScheduler::AddApplicationThreadTask, this,
 | |
|                   std::placeholders::_1));
 | |
|     executor_ = delegating_executor_.get();
 | |
|   }
 | |
| 
 | |
|   if (!initial) {
 | |
|     // Only schedule the non-base generators.
 | |
|     for (int generator_index : non_base_generators) {
 | |
|       scheduled_generators_[generator_index] = false;
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| void GeneratorScheduler::GenerateAndScheduleNext(
 | |
|     int generator_index, std::map<std::string, Packet>* side_packets,
 | |
|     std::unique_ptr<PacketSet> input_side_packet_set) {
 | |
|   {
 | |
|     absl::MutexLock lock(&mutex_);
 | |
|     if (!statuses_.empty()) {
 | |
|       // Return early, don't run the generator if we already have errors.
 | |
|       return;
 | |
|     }
 | |
|   }
 | |
|   PacketSet output_side_packet_set(
 | |
|       validated_graph_->GeneratorInfos()[generator_index]
 | |
|           .OutputSidePacketTypes()
 | |
|           .TagMap());
 | |
|   VLOG(1) << "Running generator " << generator_index;
 | |
|   absl::Status status =
 | |
|       Generate(*validated_graph_, generator_index, *input_side_packet_set,
 | |
|                &output_side_packet_set);
 | |
| 
 | |
|   {
 | |
|     absl::MutexLock lock(&mutex_);
 | |
|     if (!status.ok()) {
 | |
|       statuses_.push_back(std::move(status));
 | |
|       return;
 | |
|     }
 | |
|     // Add packets to side_packets .
 | |
|     for (CollectionItemId id = output_side_packet_set.BeginId();
 | |
|          id < output_side_packet_set.EndId(); ++id) {
 | |
|       const auto& name = output_side_packet_set.TagMap()->Names()[id.value()];
 | |
|       auto item = side_packets->emplace(name, output_side_packet_set.Get(id));
 | |
|       if (!item.second) {
 | |
|         statuses_.push_back(absl::AlreadyExistsError(
 | |
|             absl::StrCat("Side packet \"", name, "\" was defined twice.")));
 | |
|       }
 | |
|     }
 | |
|     if (!statuses_.empty()) {
 | |
|       return;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // Check all generators and schedule any that have become runnable.
 | |
|   // TODO Instead of checking all of them, only check ones
 | |
|   // that have input side packets which we have just produced.
 | |
|   ScheduleAllRunnableGenerators(side_packets);
 | |
| }
 | |
| 
 | |
| void GeneratorScheduler::ScheduleAllRunnableGenerators(
 | |
|     std::map<std::string, Packet>* side_packets) {
 | |
|   absl::MutexLock lock(&mutex_);
 | |
|   const auto& generators = validated_graph_->Config().packet_generator();
 | |
| 
 | |
|   for (int index = 0; index < generators.size(); ++index) {
 | |
|     if (scheduled_generators_[index]) {
 | |
|       continue;
 | |
|     }
 | |
|     bool is_unrunnable = false;
 | |
|     // TODO Input side packet set should only be created once.
 | |
|     auto input_side_packet_set =
 | |
|         absl::make_unique<PacketSet>(validated_graph_->GeneratorInfos()[index]
 | |
|                                          .InputSidePacketTypes()
 | |
|                                          .TagMap());
 | |
| 
 | |
|     absl::Status status =
 | |
|         CreateInputsForGenerator(*validated_graph_, index, *side_packets,
 | |
|                                  input_side_packet_set.get(), &is_unrunnable);
 | |
|     if (!status.ok()) {
 | |
|       statuses_.push_back(std::move(status));
 | |
|       continue;
 | |
|     }
 | |
|     if (is_unrunnable) {
 | |
|       continue;
 | |
|     }
 | |
|     // The Generator is runnable, schedule a callback to run it.
 | |
|     scheduled_generators_[index] = true;
 | |
|     VLOG(1) << "Scheduling generator " << index;
 | |
|     // Get around the fact that we can't capture a unique_ptr (this
 | |
|     // means a memory leak will result if the lambda is not run).
 | |
|     PacketSet* input_side_packet_set_ptr = input_side_packet_set.release();
 | |
|     ++num_tasks_;
 | |
|     mutex_.Unlock();
 | |
|     executor_->Schedule(
 | |
|         [this, index, side_packets, input_side_packet_set_ptr]() {
 | |
|           GenerateAndScheduleNext(
 | |
|               index, side_packets,
 | |
|               std::unique_ptr<PacketSet>(input_side_packet_set_ptr));
 | |
|           {
 | |
|             absl::MutexLock lock(&mutex_);
 | |
|             --num_tasks_;
 | |
|             if (num_tasks_ == 0) {
 | |
|               idle_condvar_.Signal();
 | |
|             }
 | |
|           }
 | |
|         });
 | |
|     mutex_.Lock();
 | |
|   }
 | |
| }
 | |
| 
 | |
| void GeneratorScheduler::WaitUntilIdle() {
 | |
|   if (executor_ == delegating_executor_.get()) {
 | |
|     // Run the tasks on the application thread.
 | |
|     RunApplicationThreadTasks();
 | |
|   } else {
 | |
|     absl::MutexLock lock(&mutex_);
 | |
|     while (num_tasks_ != 0) {
 | |
|       idle_condvar_.Wait(&mutex_);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| absl::Status GeneratorScheduler::GetNonScheduledGenerators(
 | |
|     std::vector<int>* non_scheduled_generators) const {
 | |
|   non_scheduled_generators->clear();
 | |
| 
 | |
|   absl::MutexLock lock(&mutex_);
 | |
|   if (!statuses_.empty()) {
 | |
|     return tool::CombinedStatus("PacketGeneratorGraph failed.", statuses_);
 | |
|   }
 | |
|   for (int i = 0; i < scheduled_generators_.size(); ++i) {
 | |
|     if (!scheduled_generators_[i]) {
 | |
|       non_scheduled_generators->push_back(i);
 | |
|     }
 | |
|   }
 | |
|   return absl::OkStatus();
 | |
| }
 | |
| 
 | |
| void GeneratorScheduler::AddApplicationThreadTask(std::function<void()> task) {
 | |
|   absl::MutexLock lock(&app_thread_mutex_);
 | |
|   app_thread_tasks_.push_back(std::move(task));
 | |
| }
 | |
| 
 | |
| void GeneratorScheduler::RunApplicationThreadTasks() {
 | |
|   while (true) {
 | |
|     std::function<void()> task_callback;
 | |
|     {
 | |
|       // Get the next task.
 | |
|       absl::MutexLock lock(&app_thread_mutex_);
 | |
|       if (app_thread_tasks_.empty()) {
 | |
|         break;
 | |
|       }
 | |
|       task_callback = std::move(app_thread_tasks_.front());
 | |
|       app_thread_tasks_.pop_front();
 | |
|     }
 | |
|     // Run the next task.  Don't hold any lock, since this task could
 | |
|     // schedule further tasks to be run on the application thread.
 | |
|     task_callback();
 | |
|   }
 | |
| }
 | |
| 
 | |
| }  // namespace
 | |
| 
 | |
| PacketGeneratorGraph::~PacketGeneratorGraph() {}
 | |
| 
 | |
| absl::Status PacketGeneratorGraph::Initialize(
 | |
|     const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
 | |
|     const std::map<std::string, Packet>& input_side_packets) {
 | |
|   validated_graph_ = validated_graph;
 | |
|   executor_ = executor;
 | |
|   base_packets_ = input_side_packets;
 | |
|   MP_RETURN_IF_ERROR(
 | |
|       validated_graph_->CanAcceptSidePackets(input_side_packets));
 | |
|   return ExecuteGenerators(&base_packets_, &non_base_generators_,
 | |
|                            /*initial=*/true);
 | |
| }
 | |
| 
 | |
| absl::Status PacketGeneratorGraph::RunGraphSetup(
 | |
|     const std::map<std::string, Packet>& input_side_packets,
 | |
|     std::map<std::string, Packet>* output_side_packets,
 | |
|     std::vector<int>* non_scheduled_generators) const {
 | |
|   *output_side_packets = base_packets_;
 | |
|   for (const std::pair<const std::string, Packet>& item : input_side_packets) {
 | |
|     auto iter = output_side_packets->find(item.first);
 | |
|     if (iter != output_side_packets->end()) {
 | |
|       return absl::AlreadyExistsError(
 | |
|           absl::StrCat("Side packet \"", iter->first, "\" was defined twice."));
 | |
|     }
 | |
|     output_side_packets->insert(iter, item);
 | |
|   }
 | |
|   std::vector<int> non_scheduled_generators_local;
 | |
|   if (!non_scheduled_generators)
 | |
|     non_scheduled_generators = &non_scheduled_generators_local;
 | |
| 
 | |
|   MP_RETURN_IF_ERROR(
 | |
|       validated_graph_->CanAcceptSidePackets(input_side_packets));
 | |
|   // This type check on the required side packets is redundant with
 | |
|   // error checking in ExecuteGenerators, but we do it now to fail early.
 | |
|   MP_RETURN_IF_ERROR(
 | |
|       validated_graph_->ValidateRequiredSidePackets(*output_side_packets));
 | |
|   MP_RETURN_IF_ERROR(ExecuteGenerators(
 | |
|       output_side_packets, non_scheduled_generators, /*initial=*/false));
 | |
|   return absl::OkStatus();
 | |
| }
 | |
| 
 | |
| absl::Status PacketGeneratorGraph::ExecuteGenerators(
 | |
|     std::map<std::string, Packet>* output_side_packets,
 | |
|     std::vector<int>* non_scheduled_generators, bool initial) const {
 | |
|   VLOG(1) << "ExecuteGenerators initial == " << initial;
 | |
| 
 | |
|   // Iterate through the generators and produce as many output
 | |
|   // side packets as we can. The generators that don't have all the
 | |
|   // required input side packets are put into non_scheduled_generators.
 | |
|   // The ValidatedGraphConfig object is expected to already have sorted
 | |
|   // generators in topological order.
 | |
|   GeneratorScheduler scheduler(validated_graph_, executor_,
 | |
|                                non_base_generators_, initial);
 | |
|   scheduler.ScheduleAllRunnableGenerators(output_side_packets);
 | |
|   // Do not return early if scheduler encountered an error.  The lambdas
 | |
|   // in the executor must run in order to free resources.
 | |
| 
 | |
|   scheduler.WaitUntilIdle();
 | |
| 
 | |
|   // It is safe to return now, since all the tasks have run.
 | |
|   return scheduler.GetNonScheduledGenerators(non_scheduled_generators);
 | |
| }
 | |
| 
 | |
| }  // namespace mediapipe
 |