137 lines
5.2 KiB
C++
137 lines
5.2 KiB
C++
// Copyright 2019 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "absl/base/macros.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/port/canonical_errors.h"
|
|
#include "mediapipe/framework/port/status.h"
|
|
#include "mediapipe/framework/tool/source.pb.h"
|
|
|
|
namespace mediapipe {
|
|
|
|
namespace tool {
|
|
|
|
// A calculator which takes N input side packets and passes them as
|
|
// N outputs. Each input side packet contains a vector of Packets, or a single
|
|
// Packet, as given in the options. The elements of the vector contained in
|
|
// the i-th input side packet are output as individual packets to the i-th
|
|
// output stream. Optionally, the packets can be timestamped, with either their
|
|
// index within the vector, or with Timestamp::PostStream(). No type
|
|
// checking is performed. It is only checked that the calculator receives 0
|
|
// inputs and the number of outputs equals the number of input side packets.
|
|
class SidePacketsToStreamsCalculator : public CalculatorBase {
|
|
public:
|
|
SidePacketsToStreamsCalculator() {}
|
|
SidePacketsToStreamsCalculator(const SidePacketsToStreamsCalculator&) =
|
|
delete;
|
|
SidePacketsToStreamsCalculator& operator=(
|
|
const SidePacketsToStreamsCalculator&) = delete;
|
|
~SidePacketsToStreamsCalculator() override {}
|
|
|
|
static absl::Status GetContract(CalculatorContract* cc) {
|
|
auto& options = cc->Options<SidePacketsToStreamsCalculatorOptions>();
|
|
if (options.has_num_inputs() &&
|
|
(options.num_inputs() != cc->InputSidePackets().NumEntries() ||
|
|
options.num_inputs() != cc->Outputs().NumEntries())) {
|
|
return absl::InvalidArgumentError(
|
|
"If num_inputs is specified it must be equal to the number of "
|
|
"input side packets and output streams.");
|
|
}
|
|
if (!options.vectors_of_packets() &&
|
|
options.set_timestamp() ==
|
|
SidePacketsToStreamsCalculatorOptions::NONE) {
|
|
return absl::InvalidArgumentError(
|
|
"If set_timestamp is NONE, vectors_of_packets must not be false.");
|
|
}
|
|
for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) {
|
|
if (options.vectors_of_packets()) {
|
|
cc->InputSidePackets().Index(i).Set<std::vector<Packet>>();
|
|
} else {
|
|
cc->InputSidePackets().Index(i).SetAny();
|
|
}
|
|
}
|
|
for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) {
|
|
if (options.vectors_of_packets()) {
|
|
cc->Outputs().Index(i).SetAny();
|
|
} else {
|
|
cc->Outputs().Index(i).SetSameAs(&cc->InputSidePackets().Index(i));
|
|
}
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) final {
|
|
const auto& options = cc->Options<SidePacketsToStreamsCalculatorOptions>();
|
|
// The i-th input side packet contains a vector of packets corresponding
|
|
// to the values of this input for all batch elements.
|
|
|
|
int batch_size = -1;
|
|
if (options.vectors_of_packets()) {
|
|
// Verify the batch size is consistent.
|
|
for (const Packet& input_side_packet : cc->InputSidePackets()) {
|
|
const auto& packets = input_side_packet.Get<std::vector<Packet>>();
|
|
if (batch_size >= 0) {
|
|
if (packets.size() != batch_size) {
|
|
return absl::InvalidArgumentError(
|
|
"The specified input side packets contain vectors of different "
|
|
"sizes.");
|
|
}
|
|
} else {
|
|
batch_size = packets.size();
|
|
}
|
|
}
|
|
} else {
|
|
batch_size = 1;
|
|
}
|
|
|
|
for (int b = 0; b < batch_size; ++b) {
|
|
for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) {
|
|
Packet packet;
|
|
if (options.vectors_of_packets()) {
|
|
const auto& packets =
|
|
cc->InputSidePackets().Index(i).Get<std::vector<Packet>>();
|
|
packet = packets[b];
|
|
} else {
|
|
packet = cc->InputSidePackets().Index(i);
|
|
}
|
|
switch (options.set_timestamp()) {
|
|
case SidePacketsToStreamsCalculatorOptions::VECTOR_INDEX:
|
|
cc->Outputs().Index(i).AddPacket(packet.At(Timestamp(b)));
|
|
break;
|
|
case SidePacketsToStreamsCalculatorOptions::WHOLE_STREAM:
|
|
cc->Outputs().Index(i).AddPacket(
|
|
packet.At(Timestamp::PostStream()));
|
|
break;
|
|
case SidePacketsToStreamsCalculatorOptions::PRE_STREAM:
|
|
cc->Outputs().Index(i).AddPacket(packet.At(Timestamp::PreStream()));
|
|
break;
|
|
default:
|
|
// SidePacketsToStreamsCalculatorOptions::NONE
|
|
cc->Outputs().Index(i).AddPacket(packet);
|
|
}
|
|
}
|
|
}
|
|
return tool::StatusStop();
|
|
}
|
|
};
|
|
|
|
REGISTER_CALCULATOR(SidePacketsToStreamsCalculator);
|
|
|
|
} // namespace tool
|
|
} // namespace mediapipe
|