No public description
PiperOrigin-RevId: 568953918
This commit is contained in:
parent
8f8c66430f
commit
983fda5d4e
|
@ -1390,3 +1390,26 @@ cc_test(
|
|||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "value_or_default_calculator",
|
||||
srcs = ["value_or_default_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "value_or_default_calculator_test",
|
||||
srcs = ["value_or_default_calculator_test.cc"],
|
||||
deps = [
|
||||
":value_or_default_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
90
mediapipe/calculators/core/value_or_default_calculator.cc
Normal file
90
mediapipe/calculators/core/value_or_default_calculator.cc
Normal file
|
@ -0,0 +1,90 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kInputValueTag[] = "IN";
|
||||
constexpr char kTickerTag[] = "TICK";
|
||||
constexpr char kOutputTag[] = "OUT";
|
||||
constexpr char kIndicationTag[] = "FLAG";
|
||||
|
||||
} // namespace
|
||||
// For every packet received on the TICK stream, if the IN stream is not
|
||||
// empty - emit its value as is as OUT. Otherwise output a default packet.
|
||||
// FLAG outputs true every time the default value has been used. It does not
|
||||
// output anything when IN has a value.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ValueOrDefaultCalculator"
|
||||
// input_stream: "IN:sometimes_missing_value"
|
||||
// input_stream: "TICK:clock"
|
||||
// output_stream: "OUT:value_or_default"
|
||||
// output_stream: "FLAG:used_default"
|
||||
// input_side_packet: "default"
|
||||
// }
|
||||
//
|
||||
// TODO: Consider adding an option for a default value as a input-stream
|
||||
// instead of a side-packet, so it will enable using standard calculators
|
||||
// instead of creating a new packet-generators. It will also allow a dynamic
|
||||
// default value.
|
||||
class ValueOrDefaultCalculator : public mediapipe::CalculatorBase {
|
||||
public:
|
||||
ValueOrDefaultCalculator() {}
|
||||
|
||||
ValueOrDefaultCalculator(const ValueOrDefaultCalculator&) = delete;
|
||||
ValueOrDefaultCalculator& operator=(const ValueOrDefaultCalculator&) = delete;
|
||||
|
||||
static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
|
||||
cc->Inputs().Tag(kInputValueTag).SetAny();
|
||||
cc->Inputs().Tag(kTickerTag).SetAny();
|
||||
cc->Outputs().Tag(kOutputTag).SetSameAs(&cc->Inputs().Tag(kInputValueTag));
|
||||
cc->Outputs().Tag(kIndicationTag).Set<bool>();
|
||||
cc->InputSidePackets().Index(0).SetSameAs(
|
||||
&cc->Inputs().Tag(kInputValueTag));
|
||||
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Open(mediapipe::CalculatorContext* cc) override {
|
||||
if (!cc->Inputs().Tag(kInputValueTag).Header().IsEmpty()) {
|
||||
cc->Outputs()
|
||||
.Tag(kOutputTag)
|
||||
.SetHeader(cc->Inputs().Tag(kInputValueTag).Header());
|
||||
}
|
||||
default_ = cc->InputSidePackets().Index(0);
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Process(mediapipe::CalculatorContext* cc) override {
|
||||
// Output according to the TICK signal.
|
||||
if (cc->Inputs().Tag(kTickerTag).IsEmpty()) {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
if (!cc->Inputs().Tag(kInputValueTag).IsEmpty()) {
|
||||
// Output the input as is:
|
||||
cc->Outputs()
|
||||
.Tag(kOutputTag)
|
||||
.AddPacket(cc->Inputs().Tag(kInputValueTag).Value());
|
||||
} else {
|
||||
// Output default:
|
||||
cc->Outputs()
|
||||
.Tag(kOutputTag)
|
||||
.AddPacket(default_.At(cc->InputTimestamp()));
|
||||
cc->Outputs()
|
||||
.Tag(kIndicationTag)
|
||||
.Add(new bool(true), cc->InputTimestamp());
|
||||
}
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
// The default value to replicate every time there is no new value.
|
||||
mediapipe::Packet default_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(ValueOrDefaultCalculator);
|
||||
|
||||
} // namespace mediapipe
|
240
mediapipe/calculators/core/value_or_default_calculator_test.cc
Normal file
240
mediapipe/calculators/core/value_or_default_calculator_test.cc
Normal file
|
@ -0,0 +1,240 @@
|
|||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::testing::AllOf;
|
||||
using ::testing::ContainerEq;
|
||||
using ::testing::Each;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::SizeIs;
|
||||
using ::testing::Test;
|
||||
|
||||
const int kDefaultValue = 0;
|
||||
|
||||
// Utility to a create a mediapipe graph runner with the tested calculator and a
|
||||
// default value, for all the tests.
|
||||
class ValueOrDefaultRunner : public mediapipe::CalculatorRunner {
|
||||
public:
|
||||
ValueOrDefaultRunner()
|
||||
: mediapipe::CalculatorRunner(R"pb(
|
||||
calculator: "ValueOrDefaultCalculator"
|
||||
input_stream: "IN:in"
|
||||
input_stream: "TICK:tick"
|
||||
input_side_packet: "default"
|
||||
output_stream: "OUT:out"
|
||||
output_stream: "FLAG:used_default"
|
||||
)pb") {
|
||||
MutableSidePackets()->Index(0) = mediapipe::MakePacket<int>(kDefaultValue);
|
||||
}
|
||||
|
||||
// Utility to push inputs to the runner to the TICK stream, so we could easily
|
||||
// tick.
|
||||
void TickAt(int64_t time) {
|
||||
// The type or value of the stream isn't relevant, we use just a bool.
|
||||
MutableInputs()->Tag("TICK").packets.push_back(
|
||||
mediapipe::Adopt(new bool(false)).At(mediapipe::Timestamp(time)));
|
||||
}
|
||||
|
||||
// Utility to push the real inputs to the runner (IN stream).
|
||||
void ProvideInput(int64_t time, int value) {
|
||||
MutableInputs()->Tag("IN").packets.push_back(
|
||||
mediapipe::Adopt(new int(value)).At(mediapipe::Timestamp(time)));
|
||||
}
|
||||
|
||||
// Extracts the timestamps (as int64) of the output stream of the calculator.
|
||||
std::vector<int64_t> GetOutputTimestamps() const {
|
||||
std::vector<int64_t> timestamps;
|
||||
for (const mediapipe::Packet& packet : Outputs().Tag("OUT").packets) {
|
||||
timestamps.emplace_back(packet.Timestamp().Value());
|
||||
}
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
// Extracts the values from the output stream of the calculator.
|
||||
std::vector<int> GetOutputValues() const {
|
||||
std::vector<int> values;
|
||||
for (const mediapipe::Packet& packet : Outputs().Tag("OUT").packets) {
|
||||
values.emplace_back(packet.Get<int>());
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
// Extracts the timestamps (as int64) of the flag stream, which indicates on
|
||||
// times without an input value (i.e. using the default value).
|
||||
std::vector<int64_t> GetFlagTimestamps() const {
|
||||
std::vector<int64_t> timestamps;
|
||||
for (const mediapipe::Packet& packet : Outputs().Tag("FLAG").packets) {
|
||||
timestamps.emplace_back(packet.Timestamp().Value());
|
||||
}
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
// Extracts the output from the flags stream (which should always be true).
|
||||
std::vector<bool> GetFlagValues() const {
|
||||
std::vector<bool> flags;
|
||||
for (const mediapipe::Packet& packet : Outputs().Tag("FLAG").packets) {
|
||||
flags.emplace_back(packet.Get<bool>());
|
||||
}
|
||||
return flags;
|
||||
}
|
||||
};
|
||||
|
||||
// To be used as input values:
|
||||
std::vector<int> GetIntegersRange(int size) {
|
||||
std::vector<int> result;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
// We start with default-value+1 so it won't contain the default value.
|
||||
result.push_back(kDefaultValue + 1 + i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, NoInputs) {
|
||||
// Check that when no real inputs are provided - we get the default value over
|
||||
// and over, with the correct timestamps.
|
||||
ValueOrDefaultRunner runner;
|
||||
const std::vector<int64_t> ticks = {0, 1, 2, 5, 8, 12, 33, 231};
|
||||
|
||||
for (int tick : ticks) {
|
||||
runner.TickAt(tick);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
// Make sure we get the right timestamps:
|
||||
EXPECT_THAT(runner.GetOutputTimestamps(), ContainerEq(ticks));
|
||||
// All should be default value:
|
||||
EXPECT_THAT(runner.GetOutputValues(),
|
||||
AllOf(Each(kDefaultValue), SizeIs(ticks.size())));
|
||||
// We should get the default indication all the time:
|
||||
EXPECT_THAT(runner.GetFlagTimestamps(), ContainerEq(ticks));
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, NeverDefault) {
|
||||
// Check that when we provide the inputs on time - we get them as outputs.
|
||||
ValueOrDefaultRunner runner;
|
||||
const std::vector<int64_t> ticks = {0, 1, 2, 5, 8, 12, 33, 231};
|
||||
const std::vector<int> values = GetIntegersRange(ticks.size());
|
||||
|
||||
for (int i = 0; i < ticks.size(); ++i) {
|
||||
runner.TickAt(ticks[i]);
|
||||
runner.ProvideInput(ticks[i], values[i]);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
// Make sure we get the right timestamps:
|
||||
EXPECT_THAT(runner.GetOutputTimestamps(), ContainerEq(ticks));
|
||||
// Should get the inputs values:
|
||||
EXPECT_THAT(runner.GetOutputValues(), ContainerEq(values));
|
||||
// We should never get the default indication:
|
||||
EXPECT_THAT(runner.GetFlagTimestamps(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, DefaultAndValues) {
|
||||
// Check that when we provide inputs only part of the time - we get them, but
|
||||
// defaults at the missing times.
|
||||
// That's the usual use case for this calculator.
|
||||
ValueOrDefaultRunner runner;
|
||||
const std::vector<int64_t> ticks = {0, 1, 5, 8, 12, 231};
|
||||
// Provide inputs only part of the ticks.
|
||||
// Chosen so there will be defaults before the first input, between the
|
||||
// inputs and after the last input.
|
||||
const std::vector<int64_t> in_ticks = {/*0,*/ 1, 5, /*8,*/ 12, /*, 231*/};
|
||||
const std::vector<int> in_values = GetIntegersRange(in_ticks.size());
|
||||
|
||||
for (int tick : ticks) {
|
||||
runner.TickAt(tick);
|
||||
}
|
||||
for (int i = 0; i < in_ticks.size(); ++i) {
|
||||
runner.ProvideInput(in_ticks[i], in_values[i]);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
// Make sure we get all the timestamps:
|
||||
EXPECT_THAT(runner.GetOutputTimestamps(), ContainerEq(ticks));
|
||||
// The timestamps of the flag should be exactly the ones not in in_ticks.
|
||||
EXPECT_THAT(runner.GetFlagTimestamps(), ElementsAre(0, 8, 231));
|
||||
// And the values are default in these times, and the input values for
|
||||
// in_ticks.
|
||||
EXPECT_THAT(
|
||||
runner.GetOutputValues(),
|
||||
ElementsAre(kDefaultValue, 1, 2, kDefaultValue, 3, kDefaultValue));
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, TimestampsMissmatch) {
|
||||
// Check that when we provide the inputs not on time - we don't get them.
|
||||
ValueOrDefaultRunner runner;
|
||||
const std::vector<int64_t> ticks = {1, 2, 5, 8, 12, 33, 231};
|
||||
// The timestamps chosen so it will be before the first tick, in between ticks
|
||||
// and after the last one. Also - more inputs than ticks.
|
||||
const std::vector<int64_t> in_ticks = {0, 3, 4, 6, 7, 9, 10,
|
||||
11, 13, 14, 15, 16, 232};
|
||||
const std::vector<int> in_values = GetIntegersRange(in_ticks.size());
|
||||
for (int tick : ticks) {
|
||||
runner.TickAt(tick);
|
||||
}
|
||||
for (int i = 0; i < in_ticks.size(); ++i) {
|
||||
runner.ProvideInput(in_ticks[i], in_values[i]);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
// Non of the in_ticks should be inserted:
|
||||
EXPECT_THAT(runner.GetOutputTimestamps(), ContainerEq(ticks));
|
||||
EXPECT_THAT(runner.GetOutputValues(),
|
||||
AllOf(Each(kDefaultValue), SizeIs(ticks.size())));
|
||||
// All (and only) ticks should get the default.
|
||||
EXPECT_THAT(runner.GetFlagTimestamps(), ContainerEq(ticks));
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, FlagValue) {
|
||||
// Since we anyway suppose that the Flag is a bool - there is nothing
|
||||
// interesting to check, but we should check once that the value is the right
|
||||
// (true) one.
|
||||
ValueOrDefaultRunner runner;
|
||||
runner.TickAt(0);
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
EXPECT_THAT(runner.GetFlagValues(), ElementsAre(true));
|
||||
}
|
||||
|
||||
TEST(ValueOrDefaultCalculatorTest, FullTest) {
|
||||
// Make sure that nothing gets wrong with an input that have both right and
|
||||
// wrong timestamps, some defaults etc.
|
||||
ValueOrDefaultRunner runner;
|
||||
const std::vector<int64_t> ticks = {1, 2, 5, 8, 12, 33, 231};
|
||||
const std::vector<int64_t> in_ticks = {0, 2, 4, 6, 8, 9, 12, 33, 54, 232};
|
||||
const std::vector<int> in_values = GetIntegersRange(in_ticks.size());
|
||||
|
||||
for (int tick : ticks) {
|
||||
runner.TickAt(tick);
|
||||
}
|
||||
for (int i = 0; i < in_ticks.size(); ++i) {
|
||||
runner.ProvideInput(in_ticks[i], in_values[i]);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_THAT(runner.GetOutputTimestamps(), ContainerEq(ticks));
|
||||
// Calculated by hand:
|
||||
EXPECT_THAT(
|
||||
runner.GetOutputValues(),
|
||||
ElementsAre(kDefaultValue, 2, kDefaultValue, 5, 7, 8, kDefaultValue));
|
||||
EXPECT_THAT(runner.GetFlagTimestamps(), ElementsAre(1, 5, 231));
|
||||
EXPECT_THAT(runner.GetFlagValues(), AllOf(Each(true), SizeIs(3)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user