mediapipe/mediapipe/framework/calculator_node.cc
MediaPipe Team 7fb37c80e8 Project import generated by Copybara.
GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108
2022-05-05 19:57:20 +00:00

919 lines
34 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_node.h"
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/input_stream_manager.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/output_stream_manager.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/tool/tag_map.h"
#include "mediapipe/framework/tool/validate_name.h"
namespace mediapipe {
namespace {
const PacketType* GetPacketType(const PacketTypeSet& packet_type_set,
const std::string& tag, const int index) {
CollectionItemId id;
if (tag.empty()) {
id = packet_type_set.GetId("", index);
} else {
id = packet_type_set.GetId(tag, 0);
}
CHECK(id.IsValid()) << "Internal mediapipe error.";
return &packet_type_set.Get(id);
}
// Copies a TagMap omitting entries with certain names.
std::shared_ptr<tool::TagMap> RemoveNames(const tool::TagMap& tag_map,
std::set<std::string> names) {
auto tag_index_names = tag_map.CanonicalEntries();
for (auto id = tag_map.EndId() - 1; id >= tag_map.BeginId(); --id) {
std::string name = tag_map.Names()[id.value()];
if (names.count(name) > 0) {
tag_index_names.erase(tag_index_names.begin() + id.value());
}
}
return tool::TagMap::Create(tag_index_names).value();
}
// Copies matching entries from another Collection.
template <class CollectionType>
void CopyCollection(const CollectionType& other, CollectionType* result) {
auto tag_map = result->TagMap();
for (auto id = tag_map->BeginId(); id != tag_map->EndId(); ++id) {
auto tag_index = tag_map->TagAndIndexFromId(id);
auto other_id = other.GetId(tag_index.first, tag_index.second);
if (other_id.IsValid()) {
result->Get(id) = other.Get(other_id);
}
}
}
// Copies packet types omitting entries that are optional and not provided.
std::unique_ptr<PacketTypeSet> RemoveOmittedPacketTypes(
const PacketTypeSet& packet_types,
const std::map<std::string, Packet>& all_side_packets,
const ValidatedGraphConfig* validated_graph) {
std::set<std::string> omitted_names;
for (auto id = packet_types.BeginId(); id != packet_types.EndId(); ++id) {
std::string name = packet_types.TagMap()->Names()[id.value()];
if (packet_types.Get(id).IsOptional() &&
validated_graph->IsExternalSidePacket(name) &&
all_side_packets.count(name) == 0) {
omitted_names.insert(name);
}
}
auto tag_map = RemoveNames(*packet_types.TagMap(), omitted_names);
auto result = std::make_unique<PacketTypeSet>(tag_map);
CopyCollection(packet_types, result.get());
return result;
}
} // namespace
CalculatorNode::CalculatorNode() {}
Timestamp CalculatorNode::SourceProcessOrder(
const CalculatorContext* cc) const {
return calculator_->SourceProcessOrder(cc);
}
absl::Status CalculatorNode::Initialize(
const ValidatedGraphConfig* validated_graph, NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context) {
RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL";
validated_graph_ = validated_graph;
profiling_context_ = profiling_context;
const CalculatorGraphConfig::Node* node_config;
if (node_ref.type == NodeTypeInfo::NodeType::CALCULATOR) {
node_config = &validated_graph_->Config().node(node_ref.index);
name_ = tool::CanonicalNodeName(validated_graph_->Config(), node_ref.index);
node_type_info_ = &validated_graph_->CalculatorInfos()[node_ref.index];
} else if (node_ref.type == NodeTypeInfo::NodeType::PACKET_GENERATOR) {
const PacketGeneratorConfig& pg_config =
validated_graph_->Config().packet_generator(node_ref.index);
name_ = absl::StrCat("__pg_", node_ref.index, "_",
pg_config.packet_generator());
node_type_info_ = &validated_graph_->GeneratorInfos()[node_ref.index];
node_config = &node_type_info_->Contract().GetWrapperConfig();
} else {
return absl::InvalidArgumentError(
"node_ref is not a calculator or packet generator");
}
max_in_flight_ = node_config->max_in_flight();
max_in_flight_ = max_in_flight_ ? max_in_flight_ : 1;
if (!node_config->executor().empty()) {
executor_ = node_config->executor();
}
source_layer_ = node_config->source_layer();
const CalculatorContract& contract = node_type_info_->Contract();
// TODO Propagate types between calculators when SetAny is used.
MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
node_type_info_->OutputSidePacketTypes(), output_side_packets));
MP_RETURN_IF_ERROR(InitializeInputSidePackets(output_side_packets));
MP_RETURN_IF_ERROR(
InitializeOutputStreamHandler(node_config->output_stream_handler(),
node_type_info_->OutputStreamTypes()));
MP_RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers));
calculator_state_ = absl::make_unique<CalculatorState>(
name_, node_ref.index, node_config->calculator(), *node_config,
profiling_context_);
// Inform the scheduler that this node has buffering behavior and that the
// maximum input queue size should be adjusted accordingly.
*buffer_size_hint = node_config->buffer_size_hint();
calculator_context_manager_.Initialize(
calculator_state_.get(), node_type_info_->InputStreamTypes().TagMap(),
node_type_info_->OutputStreamTypes().TagMap(),
/*calculator_run_in_parallel=*/max_in_flight_ > 1);
// The graph specified InputStreamHandler takes priority.
const bool graph_specified =
node_config->input_stream_handler().has_input_stream_handler();
const bool calc_specified =
!(node_type_info_->GetInputStreamHandler().empty());
// Only use calculator ISH if available, and if the graph ISH is not set.
InputStreamHandlerConfig handler_config;
const bool use_calc_specified = calc_specified && !graph_specified;
if (use_calc_specified) {
*(handler_config.mutable_input_stream_handler()) =
node_type_info_->GetInputStreamHandler();
*(handler_config.mutable_options()) =
node_type_info_->GetInputStreamHandlerOptions();
}
// Use calculator or graph specified InputStreamHandler, or the default ISH
// already set from graph.
MP_RETURN_IF_ERROR(InitializeInputStreamHandler(
use_calc_specified ? handler_config : node_config->input_stream_handler(),
node_type_info_->InputStreamTypes()));
for (auto& stream : output_stream_handler_->OutputStreams()) {
stream->Spec()->offset_enabled =
(contract.GetTimestampOffset() != TimestampDiff::Unset());
stream->Spec()->offset = contract.GetTimestampOffset();
}
input_stream_handler_->SetProcessTimestampBounds(
contract.GetProcessTimestampBounds());
return InitializeInputStreams(input_stream_managers, output_stream_managers);
}
absl::Status CalculatorNode::InitializeOutputSidePackets(
const PacketTypeSet& output_side_packet_types,
OutputSidePacketImpl* output_side_packets) {
output_side_packets_ =
absl::make_unique<OutputSidePacketSet>(output_side_packet_types.TagMap());
int base_index = node_type_info_->OutputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index);
for (CollectionItemId id = output_side_packets_->BeginId();
id < output_side_packets_->EndId(); ++id) {
output_side_packets_->GetPtr(id) =
&output_side_packets[base_index + id.value()];
}
return absl::OkStatus();
}
absl::Status CalculatorNode::InitializeInputSidePackets(
OutputSidePacketImpl* output_side_packets) {
int base_index = node_type_info_->InputSidePacketBaseIndex();
RET_CHECK_LE(0, base_index);
// Set all the mirrors.
for (CollectionItemId id = node_type_info_->InputSidePacketTypes().BeginId();
id < node_type_info_->InputSidePacketTypes().EndId(); ++id) {
int output_side_packet_index =
validated_graph_->InputSidePacketInfos()[base_index + id.value()]
.upstream;
if (output_side_packet_index < 0) {
// Not generated by a graph node. Comes from an extra side packet
// provided to the graph.
continue;
}
OutputSidePacketImpl* origin_output_side_packet =
&output_side_packets[output_side_packet_index];
VLOG(2) << "Adding mirror for input side packet with id " << id.value()
<< " and flat index " << base_index + id.value()
<< " which will be connected to output side packet with flat index "
<< output_side_packet_index;
origin_output_side_packet->AddMirror(&input_side_packet_handler_, id);
}
return absl::OkStatus();
}
absl::Status CalculatorNode::InitializeOutputStreams(
OutputStreamManager* output_stream_managers) {
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
RET_CHECK_LE(0, node_type_info_->OutputStreamBaseIndex());
OutputStreamManager* current_output_stream_managers =
&output_stream_managers[node_type_info_->OutputStreamBaseIndex()];
return output_stream_handler_->InitializeOutputStreamManagers(
current_output_stream_managers);
}
absl::Status CalculatorNode::InitializeInputStreams(
InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers) {
RET_CHECK(input_stream_managers) << "input_stream_managers is NULL";
RET_CHECK(output_stream_managers) << "output_stream_managers is NULL";
RET_CHECK_LE(0, node_type_info_->InputStreamBaseIndex());
InputStreamManager* current_input_stream_managers =
&input_stream_managers[node_type_info_->InputStreamBaseIndex()];
MP_RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers(
current_input_stream_managers));
// Set all the mirrors.
for (CollectionItemId id = node_type_info_->InputStreamTypes().BeginId();
id < node_type_info_->InputStreamTypes().EndId(); ++id) {
int output_stream_index =
validated_graph_
->InputStreamInfos()[node_type_info_->InputStreamBaseIndex() +
id.value()]
.upstream;
RET_CHECK_LE(0, output_stream_index);
OutputStreamManager* origin_output_stream_manager =
&output_stream_managers[output_stream_index];
VLOG(2) << "Adding mirror for input stream with id " << id.value()
<< " and flat index "
<< node_type_info_->InputStreamBaseIndex() + id.value()
<< " which will be connected to output stream with flat index "
<< output_stream_index;
origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id);
}
return absl::OkStatus();
}
absl::Status CalculatorNode::InitializeInputStreamHandler(
const InputStreamHandlerConfig& handler_config,
const PacketTypeSet& input_stream_types) {
const ProtoString& input_stream_handler_name =
handler_config.input_stream_handler();
RET_CHECK(!input_stream_handler_name.empty());
ASSIGN_OR_RETURN(input_stream_handler_,
InputStreamHandlerRegistry::CreateByNameInNamespace(
validated_graph_->Package(), input_stream_handler_name,
input_stream_types.TagMap(),
&calculator_context_manager_, handler_config.options(),
/*calculator_run_in_parallel=*/max_in_flight_ > 1),
_ << "\"" << input_stream_handler_name
<< "\" is not a registered input stream handler.");
return absl::OkStatus();
}
absl::Status CalculatorNode::InitializeOutputStreamHandler(
const OutputStreamHandlerConfig& handler_config,
const PacketTypeSet& output_stream_types) {
const ProtoString& output_stream_handler_name =
handler_config.output_stream_handler();
RET_CHECK(!output_stream_handler_name.empty());
ASSIGN_OR_RETURN(output_stream_handler_,
OutputStreamHandlerRegistry::CreateByNameInNamespace(
validated_graph_->Package(), output_stream_handler_name,
output_stream_types.TagMap(),
&calculator_context_manager_, handler_config.options(),
/*calculator_run_in_parallel=*/max_in_flight_ > 1),
_ << "\"" << output_stream_handler_name
<< "\" is not a registered output stream handler.");
return absl::OkStatus();
}
absl::Status CalculatorNode::ConnectShardsToStreams(
CalculatorContext* calculator_context) {
RET_CHECK(calculator_context);
MP_RETURN_IF_ERROR(
input_stream_handler_->SetupInputShards(&calculator_context->Inputs()));
return output_stream_handler_->SetupOutputShards(
&calculator_context->Outputs());
}
void CalculatorNode::SetExecutor(const std::string& executor) {
absl::MutexLock status_lock(&status_mutex_);
CHECK_LT(status_, kStateOpened);
executor_ = executor;
}
bool CalculatorNode::Prepared() const {
absl::MutexLock status_lock(&status_mutex_);
return status_ >= kStatePrepared;
}
bool CalculatorNode::Opened() const {
absl::MutexLock status_lock(&status_mutex_);
return status_ >= kStateOpened;
}
bool CalculatorNode::Active() const {
absl::MutexLock status_lock(&status_mutex_);
return status_ >= kStateActive;
}
bool CalculatorNode::Closed() const {
absl::MutexLock status_lock(&status_mutex_);
return status_ >= kStateClosed;
}
void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) {
CHECK(input_stream_handler_);
input_stream_handler_->SetMaxQueueSize(max_queue_size);
}
absl::Status CalculatorNode::PrepareForRun(
const std::map<std::string, Packet>& all_side_packets,
const std::map<std::string, Packet>& service_packets,
std::function<void()> ready_for_open_callback,
std::function<void()> source_node_opened_callback,
std::function<void(CalculatorContext*)> schedule_callback,
std::function<void(absl::Status)> error_callback,
CounterFactory* counter_factory) {
RET_CHECK(ready_for_open_callback) << "ready_for_open_callback is NULL";
RET_CHECK(schedule_callback) << "schedule_callback is NULL";
RET_CHECK(error_callback) << "error_callback is NULL";
calculator_state_->ResetBetweenRuns();
ready_for_open_callback_ = std::move(ready_for_open_callback);
source_node_opened_callback_ = std::move(source_node_opened_callback);
input_stream_handler_->PrepareForRun(
[this]() { CalculatorNode::InputStreamHeadersReady(); },
[this]() { CalculatorNode::CheckIfBecameReady(); },
std::move(schedule_callback), error_callback);
output_stream_handler_->PrepareForRun(error_callback);
const auto& contract = Contract();
input_side_packet_types_ = RemoveOmittedPacketTypes(
contract.InputSidePackets(), all_side_packets, validated_graph_);
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(
input_side_packet_types_.get(), all_side_packets,
[this]() { CalculatorNode::InputSidePacketsReady(); },
std::move(error_callback)));
calculator_state_->SetInputSidePackets(
&input_side_packet_handler_.InputSidePackets());
calculator_state_->SetOutputSidePackets(output_side_packets_.get());
calculator_state_->SetCounterFactory(counter_factory);
for (const auto& svc_req : contract.ServiceRequests()) {
const auto& req = svc_req.second;
auto it = service_packets.find(req.Service().key);
if (it == service_packets.end()) {
RET_CHECK(req.IsOptional())
<< "required service '" << req.Service().key << "' was not provided";
} else {
MP_RETURN_IF_ERROR(
calculator_state_->SetServicePacket(req.Service(), it->second));
}
}
MP_RETURN_IF_ERROR(calculator_context_manager_.PrepareForRun(std::bind(
&CalculatorNode::ConnectShardsToStreams, this, std::placeholders::_1)));
ASSIGN_OR_RETURN(
auto calculator_factory,
CalculatorBaseRegistry::CreateByNameInNamespace(
validated_graph_->Package(), calculator_state_->CalculatorType()));
calculator_ = calculator_factory->CreateCalculator(
calculator_context_manager_.GetDefaultCalculatorContext());
needs_to_close_ = false;
{
absl::MutexLock status_lock(&status_mutex_);
status_ = kStatePrepared;
scheduling_state_ = kIdle;
current_in_flight_ = 0;
input_stream_headers_ready_called_ = false;
input_side_packets_ready_called_ = false;
input_stream_headers_ready_ =
(input_stream_handler_->UnsetHeaderCount() == 0);
input_side_packets_ready_ =
(input_side_packet_handler_.MissingInputSidePacketCount() == 0);
}
return absl::OkStatus();
}
namespace {
// Returns the Packet sent to an OutputSidePacket, or an empty packet
// if none available.
const Packet GetPacket(const OutputSidePacket& out) {
auto impl = static_cast<const OutputSidePacketImpl*>(&out);
return (impl == nullptr) ? Packet() : impl->GetPacket();
}
// Resends the output-side-packets from the previous graph run.
absl::Status ResendSidePackets(CalculatorContext* cc) {
auto& outs = cc->OutputSidePackets();
for (CollectionItemId id = outs.BeginId(); id < outs.EndId(); ++id) {
Packet packet = GetPacket(outs.Get(id));
if (!packet.IsEmpty()) {
// OutputSidePacket::Set re-announces the side-packet to its mirrors.
outs.Get(id).Set(packet);
}
}
return absl::OkStatus();
}
} // namespace
bool CalculatorNode::OutputsAreConstant(CalculatorContext* cc) {
if (cc->Inputs().NumEntries() > 0 || cc->Outputs().NumEntries() > 0) {
return false;
}
if (input_side_packet_handler_.InputSidePacketsChanged()) {
return false;
}
return true;
}
absl::Status CalculatorNode::OpenNode() {
VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName();
CalculatorContext* default_context =
calculator_context_manager_.GetDefaultCalculatorContext();
InputStreamShardSet* inputs = &default_context->Inputs();
// The upstream calculators may set the headers in the output streams during
// Calculator::Open(), needs to update the header packets in input stream
// shards.
input_stream_handler_->UpdateInputShardHeaders(inputs);
OutputStreamShardSet* outputs = &default_context->Outputs();
output_stream_handler_->PrepareOutputs(Timestamp::Unstarted(), outputs);
calculator_context_manager_.PushInputTimestampToContext(
default_context, Timestamp::Unstarted());
absl::Status result;
if (OutputsAreConstant(default_context)) {
result = ResendSidePackets(default_context);
} else {
MEDIAPIPE_PROFILING(OPEN, default_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
result = calculator_->Open(default_context);
}
calculator_context_manager_.PopInputTimestampFromContext(default_context);
if (IsSource()) {
// A source node has a dummy input timestamp of 0 for Process(). This input
// timestamp is not popped until Close() is called.
calculator_context_manager_.PushInputTimestampToContext(default_context,
Timestamp(0));
}
LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute(
"Open() on node \"$0\" returned tool::StatusStop() which should only be "
"used to signal that a source node is done producing data.",
DebugName());
MP_RETURN_IF_ERROR(result).SetPrepend() << absl::Substitute(
"Calculator::Open() for node \"$0\" failed: ", DebugName());
needs_to_close_ = true;
bool offset_enabled = false;
for (auto& stream : output_stream_handler_->OutputStreams()) {
offset_enabled = offset_enabled || stream->Spec()->offset_enabled;
}
if (offset_enabled && input_stream_handler_->SyncSetCount() > 1) {
LOG(WARNING) << absl::Substitute(
"Calculator node \"$0\" is configured with multiple input sync-sets "
"and an output timestamp-offset, which will often conflict due to "
"the order of packet arrival. With multiple input sync-sets, use "
"SetProcessTimestampBounds in place of SetTimestampOffset.",
DebugName());
}
output_stream_handler_->Open(outputs);
{
absl::MutexLock status_lock(&status_mutex_);
status_ = kStateOpened;
}
return absl::OkStatus();
}
void CalculatorNode::ActivateNode() {
absl::MutexLock status_lock(&status_mutex_);
CHECK_EQ(status_, kStateOpened) << DebugName();
status_ = kStateActive;
}
void CalculatorNode::CloseInputStreams() {
{
absl::MutexLock status_lock(&status_mutex_);
if (status_ == kStateClosed) {
return;
}
}
VLOG(2) << "Closing node " << DebugName() << " input streams.";
// Clear the input queues and prevent the upstream nodes from filling them
// back in. We may still get ProcessNode called on us after this.
input_stream_handler_->Close();
}
void CalculatorNode::CloseOutputStreams(OutputStreamShardSet* outputs) {
{
absl::MutexLock status_lock(&status_mutex_);
if (status_ == kStateClosed) {
return;
}
}
VLOG(2) << "Closing node " << DebugName() << " output streams.";
output_stream_handler_->Close(outputs);
}
absl::Status CalculatorNode::CloseNode(const absl::Status& graph_status,
bool graph_run_ended) {
{
absl::MutexLock status_lock(&status_mutex_);
RET_CHECK_NE(status_, kStateClosed)
<< "CloseNode() must only be called once.";
}
CloseInputStreams();
CalculatorContext* default_context =
calculator_context_manager_.GetDefaultCalculatorContext();
OutputStreamShardSet* outputs = &default_context->Outputs();
output_stream_handler_->PrepareOutputs(Timestamp::Done(), outputs);
if (IsSource()) {
calculator_context_manager_.PopInputTimestampFromContext(default_context);
calculator_context_manager_.PushInputTimestampToContext(default_context,
Timestamp::Done());
}
calculator_context_manager_.SetGraphStatusInContext(default_context,
graph_status);
absl::Status result;
if (OutputsAreConstant(default_context)) {
// Do nothing.
result = absl::OkStatus();
} else {
MEDIAPIPE_PROFILING(CLOSE, default_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
result = calculator_->Close(default_context);
}
needs_to_close_ = false;
LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute(
"Close() on node \"$0\" returned tool::StatusStop() which should only be "
"used to signal that a source node is done producing data.",
DebugName());
// If the graph run has ended, we are cleaning up after the run and don't
// need to propagate updates to mirrors, so we can skip this
// CloseOutputStreams() call. CleanupAfterRun() will close the output
// streams.
if (!graph_run_ended) {
CloseOutputStreams(outputs);
}
{
absl::MutexLock status_lock(&status_mutex_);
status_ = kStateClosed;
}
MP_RETURN_IF_ERROR(result).SetPrepend() << absl::Substitute(
"Calculator::Close() for node \"$0\" failed: ", DebugName());
VLOG(2) << "Closed node " << DebugName();
return absl::OkStatus();
}
void CalculatorNode::CleanupAfterRun(const absl::Status& graph_status) {
if (needs_to_close_) {
calculator_context_manager_.PushInputTimestampToContext(
calculator_context_manager_.GetDefaultCalculatorContext(),
Timestamp::Done());
CloseNode(graph_status, /*graph_run_ended=*/true).IgnoreError();
}
calculator_ = nullptr;
// All pending output packets are automatically dropped when calculator
// context manager destroys all calculator context objects.
calculator_context_manager_.CleanupAfterRun();
CloseInputStreams();
// All output stream shards have been destroyed by calculator context manager.
CloseOutputStreams(/*outputs=*/nullptr);
{
absl::MutexLock lock(&status_mutex_);
status_ = kStateUninitialized;
scheduling_state_ = kIdle;
current_in_flight_ = 0;
}
}
void CalculatorNode::SchedulingLoop() {
int max_allowance = 0;
{
absl::MutexLock lock(&status_mutex_);
if (status_ == kStateClosed) {
scheduling_state_ = kIdle;
return;
}
max_allowance = max_in_flight_ - current_in_flight_;
}
while (true) {
Timestamp input_bound;
// input_bound is set to a meaningful value iff the latest readiness of the
// node is kNotReady when ScheduleInvocations() returns.
input_stream_handler_->ScheduleInvocations(max_allowance, &input_bound);
if (input_bound != Timestamp::Unset()) {
// Updates the minimum timestamp for which a new packet could possibly
// arrive.
output_stream_handler_->UpdateTaskTimestampBound(input_bound);
}
{
absl::MutexLock lock(&status_mutex_);
if (scheduling_state_ == kSchedulingPending &&
current_in_flight_ < max_in_flight_) {
max_allowance = max_in_flight_ - current_in_flight_;
scheduling_state_ = kScheduling;
} else {
scheduling_state_ = kIdle;
break;
}
}
}
}
bool CalculatorNode::ReadyForOpen() const {
absl::MutexLock lock(&status_mutex_);
return input_stream_headers_ready_ && input_side_packets_ready_;
}
void CalculatorNode::InputStreamHeadersReady() {
bool ready_for_open = false;
{
absl::MutexLock lock(&status_mutex_);
CHECK_EQ(status_, kStatePrepared) << DebugName();
CHECK(!input_stream_headers_ready_called_);
input_stream_headers_ready_called_ = true;
input_stream_headers_ready_ = true;
ready_for_open = input_side_packets_ready_;
}
if (ready_for_open) {
ready_for_open_callback_();
}
}
void CalculatorNode::InputSidePacketsReady() {
bool ready_for_open = false;
{
absl::MutexLock lock(&status_mutex_);
CHECK_EQ(status_, kStatePrepared) << DebugName();
CHECK(!input_side_packets_ready_called_);
input_side_packets_ready_called_ = true;
input_side_packets_ready_ = true;
ready_for_open = input_stream_headers_ready_;
}
if (ready_for_open) {
ready_for_open_callback_();
}
}
void CalculatorNode::CheckIfBecameReady() {
{
absl::MutexLock lock(&status_mutex_);
// Doesn't check if status_ is kStateActive since the function can only be
// invoked by non-source nodes.
if (status_ != kStateOpened) {
return;
}
if (scheduling_state_ == kIdle && current_in_flight_ < max_in_flight_) {
scheduling_state_ = kScheduling;
} else {
if (scheduling_state_ == kScheduling) {
// Changes the state to scheduling pending if another thread is doing
// the scheduling.
scheduling_state_ = kSchedulingPending;
}
return;
}
}
SchedulingLoop();
}
void CalculatorNode::NodeOpened() {
if (IsSource()) {
source_node_opened_callback_();
} else if (input_stream_handler_->NumInputStreams() != 0) {
// A node with input streams may have received input packets generated by
// the upstreams nodes' Open() or Process() methods. Check if the node is
// ready to run.
CheckIfBecameReady();
}
}
void CalculatorNode::EndScheduling() {
{
absl::MutexLock lock(&status_mutex_);
if (status_ != kStateOpened && status_ != kStateActive) {
return;
}
--current_in_flight_;
CHECK_GE(current_in_flight_, 0);
if (scheduling_state_ == kScheduling) {
// Changes the state to scheduling pending if another thread is doing the
// scheduling.
scheduling_state_ = kSchedulingPending;
return;
} else if (scheduling_state_ == kSchedulingPending) {
// Quits when another thread is already doing the scheduling.
return;
}
scheduling_state_ = kScheduling;
}
SchedulingLoop();
}
bool CalculatorNode::TryToBeginScheduling() {
absl::MutexLock lock(&status_mutex_);
if (current_in_flight_ < max_in_flight_) {
++current_in_flight_;
return true;
}
return false;
}
std::string CalculatorNode::DebugInputStreamNames() const {
return input_stream_handler_->DebugStreamNames();
}
std::string CalculatorNode::DebugName() const {
DCHECK(calculator_state_);
return calculator_state_->NodeName();
}
// TODO: Split this function.
absl::Status CalculatorNode::ProcessNode(
CalculatorContext* calculator_context) {
if (IsSource()) {
// This is a source Calculator.
if (Closed()) {
return absl::OkStatus();
}
const Timestamp input_timestamp = calculator_context->InputTimestamp();
OutputStreamShardSet* outputs = &calculator_context->Outputs();
output_stream_handler_->PrepareOutputs(input_timestamp, outputs);
VLOG(2) << "Calling Calculator::Process() for node: " << DebugName();
absl::Status result;
{
MEDIAPIPE_PROFILING(PROCESS, calculator_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(calculator_context);
result = calculator_->Process(calculator_context);
}
bool node_stopped = false;
if (!result.ok()) {
if (result == tool::StatusStop()) {
// Needs to call CloseNode().
node_stopped = true;
} else {
return mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend()
<< absl::Substitute(
"Calculator::Process() for node \"$0\" failed: ",
DebugName());
}
}
output_stream_handler_->PostProcess(input_timestamp);
if (node_stopped) {
MP_RETURN_IF_ERROR(
CloseNode(absl::OkStatus(), /*graph_run_ended=*/false));
}
return absl::OkStatus();
} else {
// This is not a source Calculator.
InputStreamShardSet* const inputs = &calculator_context->Inputs();
OutputStreamShardSet* const outputs = &calculator_context->Outputs();
absl::Status result =
absl::InternalError("Calculator context has no input packets.");
int num_invocations = calculator_context_manager_.NumberOfContextTimestamps(
*calculator_context);
RET_CHECK(num_invocations <= 1 || max_in_flight_ <= 1)
<< "num_invocations:" << num_invocations
<< ", max_in_flight_:" << max_in_flight_;
for (int i = 0; i < num_invocations; ++i) {
const Timestamp input_timestamp = calculator_context->InputTimestamp();
// The node is ready for Process().
if (input_timestamp.IsAllowedInStream()) {
input_stream_handler_->FinalizeInputSet(input_timestamp, inputs);
output_stream_handler_->PrepareOutputs(input_timestamp, outputs);
VLOG(2) << "Calling Calculator::Process() for node: " << DebugName()
<< " timestamp: " << input_timestamp;
if (OutputsAreConstant(calculator_context)) {
// Do nothing.
result = absl::OkStatus();
} else {
MEDIAPIPE_PROFILING(PROCESS, calculator_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(
calculator_context);
result = calculator_->Process(calculator_context);
}
VLOG(2) << "Called Calculator::Process() for node: " << DebugName()
<< " timestamp: " << input_timestamp;
// Removes one packet from each shard and progresses to the next input
// timestamp.
input_stream_handler_->ClearCurrentInputs(calculator_context);
// Nodes are allowed to return StatusStop() to cause the termination
// of the graph. This is different from an error in that it will
// ensure that all sources will be closed and that packets in input
// streams will be processed before the graph is terminated.
if (!result.ok() && result != tool::StatusStop()) {
return mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend()
<< absl::Substitute(
"Calculator::Process() for node \"$0\" failed: ",
DebugName());
}
output_stream_handler_->PostProcess(input_timestamp);
if (result == tool::StatusStop()) {
return result;
}
} else if (input_timestamp == Timestamp::Done()) {
// Some or all the input streams are closed and there are not enough
// open input streams for Process(). So this node needs to be closed
// too.
// If the streams are closed, there shouldn't be more input.
CHECK_EQ(calculator_context_manager_.NumberOfContextTimestamps(
*calculator_context),
1);
return CloseNode(absl::OkStatus(), /*graph_run_ended=*/false);
} else {
RET_CHECK_FAIL()
<< "Invalid input timestamp in ProcessNode(). timestamp: "
<< input_timestamp;
}
}
return result;
}
}
void CalculatorNode::SetQueueSizeCallbacks(
InputStreamManager::QueueSizeCallback becomes_full_callback,
InputStreamManager::QueueSizeCallback becomes_not_full_callback) {
CHECK(input_stream_handler_);
input_stream_handler_->SetQueueSizeCallbacks(
std::move(becomes_full_callback), std::move(becomes_not_full_callback));
}
} // namespace mediapipe