mediapipe/mediapipe/framework/output_stream_handler.cc

170 lines
5.5 KiB
C++
Raw Normal View History

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/output_stream_handler.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/output_stream_shard.h"
namespace mediapipe {
absl::Status OutputStreamHandler::InitializeOutputStreamManagers(
OutputStreamManager* flat_output_stream_managers) {
for (CollectionItemId id = output_stream_managers_.BeginId();
id < output_stream_managers_.EndId(); ++id) {
output_stream_managers_.Get(id) = &flat_output_stream_managers[id.value()];
}
return absl::OkStatus();
}
absl::Status OutputStreamHandler::SetupOutputShards(
OutputStreamShardSet* output_shards) {
CHECK(output_shards);
for (CollectionItemId id = output_stream_managers_.BeginId();
id < output_stream_managers_.EndId(); ++id) {
OutputStreamManager* manager = output_stream_managers_.Get(id);
output_shards->Get(id).SetSpec(manager->Spec());
}
return absl::OkStatus();
}
void OutputStreamHandler::PrepareForRun(
const std::function<void(absl::Status)>& error_callback) {
for (auto& manager : output_stream_managers_) {
manager->PrepareForRun(error_callback);
}
absl::MutexLock lock(&timestamp_mutex_);
completed_input_timestamps_.clear();
task_timestamp_bound_ = Timestamp::Unset();
propagation_state_ = kIdle;
}
void OutputStreamHandler::Open(OutputStreamShardSet* output_shards) {
CHECK(output_shards);
PropagateOutputPackets(Timestamp::Unstarted(), output_shards);
for (auto& manager : output_stream_managers_) {
manager->PropagateHeader();
manager->LockIntroData();
}
}
void OutputStreamHandler::PrepareOutputs(Timestamp input_timestamp,
OutputStreamShardSet* output_shards) {
CHECK(output_shards);
for (CollectionItemId id = output_stream_managers_.BeginId();
id < output_stream_managers_.EndId(); ++id) {
output_stream_managers_.Get(id)->ResetShard(&output_shards->Get(id));
}
}
void OutputStreamHandler::UpdateTaskTimestampBound(Timestamp timestamp) {
if (!calculator_run_in_parallel_) {
TryPropagateTimestampBound(timestamp);
return;
}
{
absl::MutexLock lock(&timestamp_mutex_);
if (task_timestamp_bound_ == timestamp) {
return;
}
CHECK_GT(timestamp, task_timestamp_bound_);
task_timestamp_bound_ = timestamp;
if (propagation_state_ == kPropagatingBound) {
propagation_state_ = kPropagationPending;
return;
}
if (propagation_state_ != kIdle) {
return;
}
PropagationLoop();
}
}
void OutputStreamHandler::PostProcess(Timestamp input_timestamp) {
if (!calculator_run_in_parallel_) {
CalculatorContext* default_context =
calculator_context_manager_->GetDefaultCalculatorContext();
PropagateOutputPackets(input_timestamp, &default_context->Outputs());
return;
}
{
absl::MutexLock lock(&timestamp_mutex_);
completed_input_timestamps_.insert(input_timestamp);
if (propagation_state_ == kPropagatingBound) {
propagation_state_ = kPropagationPending;
return;
}
if (propagation_state_ != kIdle) {
return;
}
PropagationLoop();
}
}
std::string OutputStreamHandler::FirstStreamName() const {
if (output_stream_managers_.NumEntries() == 0) {
return std::string();
}
return (*output_stream_managers_.begin())->Name();
}
void OutputStreamHandler::TryPropagateTimestampBound(Timestamp input_bound) {
// TODO Some non-range values, such as PostStream(), should also be
// propagated.
if (!input_bound.IsRangeValue()) {
return;
}
OutputStreamShard empty_output;
for (OutputStreamManager* manager : output_stream_managers_) {
if (manager->OffsetEnabled() && !manager->IsClosed() &&
input_bound + manager->Offset() > manager->NextTimestampBound()) {
manager->PropagateUpdatesToMirrors(input_bound + manager->Offset(),
&empty_output);
}
}
}
void OutputStreamHandler::Close(OutputStreamShardSet* output_shards) {
for (CollectionItemId id = output_stream_managers_.BeginId();
id < output_stream_managers_.EndId(); ++id) {
if (output_shards) {
output_stream_managers_.Get(id)->PropagateUpdatesToMirrors(
Timestamp::Done(), &output_shards->Get(id));
}
output_stream_managers_.Get(id)->Close();
}
}
void OutputStreamHandler::PropagateOutputPackets(
Timestamp input_timestamp, OutputStreamShardSet* output_shards) {
CHECK(output_shards);
for (CollectionItemId id = output_stream_managers_.BeginId();
id < output_stream_managers_.EndId(); ++id) {
OutputStreamManager* manager = output_stream_managers_.Get(id);
if (manager->IsClosed()) {
continue;
}
OutputStreamShard* output = &output_shards->Get(id);
const Timestamp output_bound =
manager->ComputeOutputTimestampBound(*output, input_timestamp);
manager->PropagateUpdatesToMirrors(output_bound, output);
if (output->IsClosed()) {
manager->Close();
}
}
}
} // namespace mediapipe