392 lines
15 KiB
C++
392 lines
15 KiB
C++
// Copyright 2019 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// Forked from mediapipe/framework/tool/source.proto.
|
|
// The forked proto must remain identical to the original proto and should be
|
|
// ONLY used by mediapipe open source project.
|
|
|
|
#include "mediapipe/framework/tool/sink.h"
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_format.h"
|
|
#include "mediapipe/calculators/internal/callback_packet_calculator.pb.h"
|
|
#include "mediapipe/framework/calculator.pb.h"
|
|
#include "mediapipe/framework/calculator_base.h"
|
|
#include "mediapipe/framework/calculator_graph.h"
|
|
#include "mediapipe/framework/calculator_registry.h"
|
|
#include "mediapipe/framework/input_stream.h"
|
|
#include "mediapipe/framework/packet.h"
|
|
#include "mediapipe/framework/packet_type.h"
|
|
#include "mediapipe/framework/port/logging.h"
|
|
#include "mediapipe/framework/port/source_location.h"
|
|
#include "mediapipe/framework/port/status_builder.h"
|
|
#include "mediapipe/framework/tool/name_util.h"
|
|
|
|
namespace mediapipe {
|
|
|
|
namespace tool {
|
|
namespace {
|
|
// Produces an output packet with the PostStream timestamp containing the
|
|
// input side packet.
|
|
class MediaPipeInternalSidePacketToPacketStreamCalculator
|
|
: public CalculatorBase {
|
|
public:
|
|
static absl::Status GetContract(CalculatorContract* cc) {
|
|
cc->InputSidePackets().Index(0).SetAny();
|
|
cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0));
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Open(CalculatorContext* cc) final {
|
|
cc->Outputs().Index(0).AddPacket(
|
|
cc->InputSidePackets().Index(0).At(Timestamp::PostStream()));
|
|
cc->Outputs().Index(0).Close();
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) final {
|
|
// The framework treats this calculator as a source calculator.
|
|
return mediapipe::tool::StatusStop();
|
|
}
|
|
};
|
|
REGISTER_CALCULATOR(MediaPipeInternalSidePacketToPacketStreamCalculator);
|
|
} // namespace
|
|
|
|
void AddVectorSink(const std::string& stream_name, //
|
|
CalculatorGraphConfig* config, //
|
|
std::vector<Packet>* dumped_data) {
|
|
CHECK(config);
|
|
CHECK(dumped_data);
|
|
|
|
std::string input_side_packet_name;
|
|
tool::AddCallbackCalculator(stream_name, config, &input_side_packet_name,
|
|
/*use_std_function=*/true);
|
|
|
|
auto* node = config->add_node();
|
|
node->set_name(GetUnusedNodeName(
|
|
*config, absl::StrCat("callback_packet_calculator_that_generators_",
|
|
input_side_packet_name)));
|
|
node->set_calculator("CallbackPacketCalculator");
|
|
node->add_output_side_packet(input_side_packet_name);
|
|
CallbackPacketCalculatorOptions* options =
|
|
node->mutable_options()->MutableExtension(
|
|
CallbackPacketCalculatorOptions::ext);
|
|
options->set_type(CallbackPacketCalculatorOptions::VECTOR_PACKET);
|
|
// Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended.
|
|
char address[19];
|
|
int written = snprintf(address, sizeof(address), "%p", dumped_data);
|
|
CHECK(written > 0 && written < sizeof(address));
|
|
options->set_pointer(address);
|
|
}
|
|
|
|
void AddPostStreamPacketSink(const std::string& stream_name,
|
|
CalculatorGraphConfig* config,
|
|
Packet* post_stream_packet) {
|
|
CHECK(config);
|
|
CHECK(post_stream_packet);
|
|
|
|
std::string input_side_packet_name;
|
|
tool::AddCallbackCalculator(stream_name, config, &input_side_packet_name,
|
|
/*use_std_function=*/true);
|
|
auto* node = config->add_node();
|
|
node->set_name(GetUnusedNodeName(
|
|
*config, absl::StrCat("callback_packet_calculator_that_generators_",
|
|
input_side_packet_name)));
|
|
node->set_calculator("CallbackPacketCalculator");
|
|
node->add_output_side_packet(input_side_packet_name);
|
|
CallbackPacketCalculatorOptions* options =
|
|
node->mutable_options()->MutableExtension(
|
|
CallbackPacketCalculatorOptions::ext);
|
|
options->set_type(CallbackPacketCalculatorOptions::POST_STREAM_PACKET);
|
|
// Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended.
|
|
char address[19];
|
|
int written = snprintf(address, sizeof(address), "%p", post_stream_packet);
|
|
CHECK(written > 0 && written < sizeof(address));
|
|
options->set_pointer(address);
|
|
}
|
|
|
|
void AddSidePacketSink(const std::string& side_packet_name,
|
|
CalculatorGraphConfig* config, Packet* dumped_packet) {
|
|
CHECK(config);
|
|
CHECK(dumped_packet);
|
|
|
|
CalculatorGraphConfig::Node* conversion_node = config->add_node();
|
|
const std::string node_name = GetUnusedNodeName(
|
|
*config,
|
|
absl::StrCat("calculator_converts_side_packet_", side_packet_name));
|
|
conversion_node->set_name(node_name);
|
|
conversion_node->set_calculator(
|
|
"MediaPipeInternalSidePacketToPacketStreamCalculator");
|
|
conversion_node->add_input_side_packet(
|
|
GetUnusedSidePacketName(*config, side_packet_name));
|
|
|
|
const std::string output_stream_name =
|
|
absl::StrCat(node_name, "_output_stream");
|
|
conversion_node->add_output_stream(output_stream_name);
|
|
AddPostStreamPacketSink(output_stream_name, config, dumped_packet);
|
|
}
|
|
|
|
void AddCallbackCalculator(const std::string& stream_name,
|
|
CalculatorGraphConfig* config,
|
|
std::string* callback_side_packet_name,
|
|
bool use_std_function) {
|
|
CHECK(config);
|
|
CHECK(callback_side_packet_name);
|
|
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
|
sink_node->set_name(GetUnusedNodeName(
|
|
*config,
|
|
absl::StrCat("callback_calculator_that_collects_stream_", stream_name)));
|
|
sink_node->set_calculator("CallbackCalculator");
|
|
sink_node->add_input_stream(stream_name);
|
|
|
|
const std::string input_side_packet_name =
|
|
GetUnusedSidePacketName(*config, absl::StrCat(stream_name, "_callback"));
|
|
*callback_side_packet_name = input_side_packet_name;
|
|
if (use_std_function) {
|
|
// Uses tag "CALLBACK" if the input side packet contains a std::function.
|
|
sink_node->add_input_side_packet(
|
|
absl::StrCat("CALLBACK:", input_side_packet_name));
|
|
} else {
|
|
LOG(FATAL) << "AddCallbackCalculator must use std::function";
|
|
}
|
|
}
|
|
|
|
void AddMultiStreamCallback(
|
|
const std::vector<std::string>& streams,
|
|
std::function<void(const std::vector<Packet>&)> callback,
|
|
CalculatorGraphConfig* config,
|
|
std::pair<std::string, Packet>* side_packet) {
|
|
std::map<std::string, Packet> side_packets;
|
|
AddMultiStreamCallback(streams, callback, config, &side_packets,
|
|
/*observe_timestamp_bounds=*/false);
|
|
*side_packet = *side_packets.begin();
|
|
}
|
|
|
|
void AddMultiStreamCallback(
|
|
const std::vector<std::string>& streams,
|
|
std::function<void(const std::vector<Packet>&)> callback,
|
|
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
|
|
bool observe_timestamp_bounds) {
|
|
CHECK(config);
|
|
CHECK(side_packets);
|
|
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
|
const std::string name = GetUnusedNodeName(
|
|
*config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_")));
|
|
sink_node->set_name(name);
|
|
sink_node->set_calculator("CallbackCalculator");
|
|
for (const auto& stream_name : streams) {
|
|
sink_node->add_input_stream(stream_name);
|
|
}
|
|
|
|
if (observe_timestamp_bounds) {
|
|
const std::string observe_ts_bounds_packet_name = GetUnusedSidePacketName(
|
|
*config, absl::StrCat(name, "_observe_ts_bounds"));
|
|
sink_node->add_input_side_packet(absl::StrCat(
|
|
"OBSERVE_TIMESTAMP_BOUNDS:", observe_ts_bounds_packet_name));
|
|
InsertIfNotPresent(side_packets, observe_ts_bounds_packet_name,
|
|
MakePacket<bool>(true));
|
|
}
|
|
const std::string input_side_packet_name =
|
|
GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback"));
|
|
sink_node->add_input_side_packet(
|
|
absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name));
|
|
|
|
InsertIfNotPresent(
|
|
side_packets, input_side_packet_name,
|
|
MakePacket<std::function<void(const std::vector<Packet>&)>>(
|
|
std::move(callback)));
|
|
}
|
|
|
|
void AddCallbackWithHeaderCalculator(const std::string& stream_name,
|
|
const std::string& stream_header,
|
|
CalculatorGraphConfig* config,
|
|
std::string* callback_side_packet_name,
|
|
bool use_std_function) {
|
|
CHECK(config);
|
|
CHECK(callback_side_packet_name);
|
|
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
|
sink_node->set_name(GetUnusedNodeName(
|
|
*config,
|
|
absl::StrCat("callback_calculator_that_collects_stream_and_header_",
|
|
stream_name, "_", stream_header)));
|
|
sink_node->set_calculator("CallbackWithHeaderCalculator");
|
|
sink_node->add_input_stream(absl::StrCat("INPUT:", stream_name));
|
|
sink_node->add_input_stream(absl::StrCat("HEADER:", stream_header));
|
|
|
|
const std::string input_side_packet_name = GetUnusedSidePacketName(
|
|
*config, absl::StrCat(stream_name, "_", stream_header, "_callback"));
|
|
*callback_side_packet_name = input_side_packet_name;
|
|
|
|
if (use_std_function) {
|
|
// Uses tag "CALLBACK" if the input side packet contains a std::function.
|
|
sink_node->add_input_side_packet(
|
|
absl::StrCat("CALLBACK:", input_side_packet_name));
|
|
} else {
|
|
LOG(FATAL) << "AddCallbackWithHeaderCalculator must use std::function";
|
|
}
|
|
}
|
|
|
|
// CallbackCalculator
|
|
|
|
// static
|
|
absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) {
|
|
bool allow_multiple_streams = false;
|
|
// If the input side packet is specified using tag "CALLBACK" it must contain
|
|
// a std::function, which may be generated by CallbackPacketCalculator.
|
|
|
|
if (cc->InputSidePackets().HasTag("CALLBACK")) {
|
|
cc->InputSidePackets()
|
|
.Tag("CALLBACK")
|
|
.Set<std::function<void(const Packet&)>>();
|
|
} else if (cc->InputSidePackets().HasTag("VECTOR_CALLBACK")) {
|
|
cc->InputSidePackets()
|
|
.Tag("VECTOR_CALLBACK")
|
|
.Set<std::function<void(const std::vector<Packet>&)>>();
|
|
allow_multiple_streams = true;
|
|
} else {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "InputSidePackets must use tags.";
|
|
}
|
|
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS")) {
|
|
cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Set<bool>();
|
|
cc->SetProcessTimestampBounds(true);
|
|
}
|
|
|
|
int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1;
|
|
for (int i = 0; i < count; ++i) {
|
|
cc->Inputs().Index(i).SetAny();
|
|
}
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status CallbackCalculator::Open(CalculatorContext* cc) {
|
|
if (cc->InputSidePackets().HasTag("CALLBACK")) {
|
|
callback_ = cc->InputSidePackets()
|
|
.Tag("CALLBACK")
|
|
.Get<std::function<void(const Packet&)>>();
|
|
} else if (cc->InputSidePackets().HasTag("VECTOR_CALLBACK")) {
|
|
vector_callback_ =
|
|
cc->InputSidePackets()
|
|
.Tag("VECTOR_CALLBACK")
|
|
.Get<std::function<void(const std::vector<Packet>&)>>();
|
|
} else {
|
|
LOG(FATAL) << "InputSidePackets must use tags.";
|
|
}
|
|
if (callback_ == nullptr && vector_callback_ == nullptr) {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "missing callback.";
|
|
}
|
|
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS") &&
|
|
!cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Get<bool>()) {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "The value of the OBSERVE_TIMESTAMP_BOUNDS input side packet "
|
|
"must be set to true";
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status CallbackCalculator::Process(CalculatorContext* cc) {
|
|
if (callback_) {
|
|
callback_(cc->Inputs().Index(0).Value());
|
|
} else if (vector_callback_) {
|
|
int count = cc->Inputs().NumEntries("");
|
|
std::vector<Packet> packets;
|
|
packets.reserve(count);
|
|
for (int i = 0; i < count; ++i) {
|
|
packets.push_back(cc->Inputs().Index(i).Value());
|
|
}
|
|
vector_callback_(packets);
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
REGISTER_CALCULATOR(CallbackCalculator);
|
|
|
|
// CallbackWithHeaderCalculator
|
|
|
|
// static
|
|
absl::Status CallbackWithHeaderCalculator::GetContract(CalculatorContract* cc) {
|
|
cc->Inputs().Tag("INPUT").SetAny();
|
|
cc->Inputs().Tag("HEADER").SetAny();
|
|
|
|
if (cc->InputSidePackets().UsesTags()) {
|
|
CHECK(cc->InputSidePackets().HasTag("CALLBACK"));
|
|
cc->InputSidePackets()
|
|
.Tag("CALLBACK")
|
|
.Set<std::function<void(const Packet&, const Packet&)>>();
|
|
} else {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "InputSidePackets must use tags.";
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status CallbackWithHeaderCalculator::Open(CalculatorContext* cc) {
|
|
if (cc->InputSidePackets().UsesTags()) {
|
|
callback_ = cc->InputSidePackets()
|
|
.Tag("CALLBACK")
|
|
.Get<std::function<void(const Packet&, const Packet&)>>();
|
|
} else {
|
|
LOG(FATAL) << "InputSidePackets must use tags.";
|
|
}
|
|
if (callback_ == nullptr) {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "callback is nullptr.";
|
|
}
|
|
if (!cc->Inputs().HasTag("INPUT")) {
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "No input stream connected.";
|
|
}
|
|
if (!cc->Inputs().HasTag("HEADER")) {
|
|
// Note: for the current MediaPipe header implementation, we just need to
|
|
// connect the output stream to both of the two inputs: INPUT and HEADER.
|
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "No header stream connected.";
|
|
}
|
|
// If the input stream has the header, just use it as the header. Otherwise,
|
|
// assume the header is coming from HEADER stream.
|
|
if (!cc->Inputs().Tag("INPUT").Header().IsEmpty()) {
|
|
header_packet_ = cc->Inputs().Tag("INPUT").Header();
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status CallbackWithHeaderCalculator::Process(CalculatorContext* cc) {
|
|
if (!cc->Inputs().Tag("INPUT").Value().IsEmpty() &&
|
|
header_packet_.IsEmpty()) {
|
|
// Header packet should be available before we receive any normal input
|
|
// stream packet.
|
|
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
|
<< "Header not available!";
|
|
}
|
|
if (header_packet_.IsEmpty() &&
|
|
!cc->Inputs().Tag("HEADER").Value().IsEmpty()) {
|
|
header_packet_ = cc->Inputs().Tag("HEADER").Value();
|
|
}
|
|
if (!cc->Inputs().Tag("INPUT").Value().IsEmpty()) {
|
|
callback_(cc->Inputs().Tag("INPUT").Value(), header_packet_);
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
REGISTER_CALCULATOR(CallbackWithHeaderCalculator);
|
|
|
|
} // namespace tool
|
|
} // namespace mediapipe
|