Internal change
PiperOrigin-RevId: 530562491
This commit is contained in:
parent
10776ef86f
commit
9de52a4a30
|
@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE";
|
||||||
constexpr char kDisallowTag[] = "DISALLOW";
|
constexpr char kDisallowTag[] = "DISALLOW";
|
||||||
constexpr char kAllowTag[] = "ALLOW";
|
constexpr char kAllowTag[] = "ALLOW";
|
||||||
|
|
||||||
enum GateState {
|
std::string ToString(GateCalculatorOptions::GateState state) {
|
||||||
GATE_UNINITIALIZED,
|
|
||||||
GATE_ALLOW,
|
|
||||||
GATE_DISALLOW,
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string ToString(GateState state) {
|
|
||||||
switch (state) {
|
switch (state) {
|
||||||
case GATE_UNINITIALIZED:
|
case GateCalculatorOptions::UNSPECIFIED:
|
||||||
|
return "UNSPECIFIED";
|
||||||
|
case GateCalculatorOptions::GATE_UNINITIALIZED:
|
||||||
return "UNINITIALIZED";
|
return "UNINITIALIZED";
|
||||||
case GATE_ALLOW:
|
case GateCalculatorOptions::GATE_ALLOW:
|
||||||
return "ALLOW";
|
return "ALLOW";
|
||||||
case GATE_DISALLOW:
|
case GateCalculatorOptions::GATE_DISALLOW:
|
||||||
return "DISALLOW";
|
return "DISALLOW";
|
||||||
}
|
}
|
||||||
DLOG(FATAL) << "Unknown GateState";
|
DLOG(FATAL) << "Unknown GateState";
|
||||||
|
@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase {
|
||||||
|
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
num_data_streams_ = cc->Inputs().NumEntries("");
|
num_data_streams_ = cc->Inputs().NumEntries("");
|
||||||
last_gate_state_ = GATE_UNINITIALIZED;
|
|
||||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
|
||||||
|
|
||||||
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
|
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
|
||||||
|
last_gate_state_ = options.initial_gate_state();
|
||||||
|
|
||||||
|
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
||||||
|
|
||||||
empty_packets_as_allow_ = options.empty_packets_as_allow();
|
empty_packets_as_allow_ = options.empty_packets_as_allow();
|
||||||
|
|
||||||
if (!use_side_packet_for_allow_disallow_ &&
|
if (!use_side_packet_for_allow_disallow_ &&
|
||||||
|
@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase {
|
||||||
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
|
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
const GateCalculatorOptions::GateState new_gate_state =
|
||||||
|
allow ? GateCalculatorOptions::GATE_ALLOW
|
||||||
|
: GateCalculatorOptions::GATE_DISALLOW;
|
||||||
|
|
||||||
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED &&
|
||||||
last_gate_state_ != new_gate_state) {
|
last_gate_state_ != new_gate_state) {
|
||||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||||
<< cc->InputTimestamp().Value() << " from "
|
<< cc->InputTimestamp().Value() << " from "
|
||||||
|
@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GateState last_gate_state_ = GATE_UNINITIALIZED;
|
GateCalculatorOptions::GateState last_gate_state_ =
|
||||||
|
GateCalculatorOptions::GATE_UNINITIALIZED;
|
||||||
int num_data_streams_;
|
int num_data_streams_;
|
||||||
bool empty_packets_as_allow_;
|
bool empty_packets_as_allow_;
|
||||||
bool use_side_packet_for_allow_disallow_ = false;
|
bool use_side_packet_for_allow_disallow_ = false;
|
||||||
|
|
|
@ -31,4 +31,13 @@ message GateCalculatorOptions {
|
||||||
// Whether to allow or disallow the input streams to pass when no
|
// Whether to allow or disallow the input streams to pass when no
|
||||||
// ALLOW/DISALLOW input or side input is specified.
|
// ALLOW/DISALLOW input or side input is specified.
|
||||||
optional bool allow = 2 [default = false];
|
optional bool allow = 2 [default = false];
|
||||||
|
|
||||||
|
enum GateState {
|
||||||
|
UNSPECIFIED = 0;
|
||||||
|
GATE_UNINITIALIZED = 1;
|
||||||
|
GATE_ALLOW = 2;
|
||||||
|
GATE_DISALLOW = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED];
|
||||||
}
|
}
|
||||||
|
|
|
@ -458,5 +458,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
||||||
ASSERT_EQ(0, output.size());
|
ASSERT_EQ(0, output.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must detect allow value for first timestamp as a state change when the
|
||||||
|
// initial state is set to GATE_DISALLOW.
|
||||||
|
TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) {
|
||||||
|
SetRunner(R"(
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "test_input"
|
||||||
|
input_stream: "ALLOW:allow"
|
||||||
|
output_stream: "test_output"
|
||||||
|
output_stream: "STATE_CHANGE:state_change"
|
||||||
|
options: {
|
||||||
|
[mediapipe.GateCalculatorOptions.ext] {
|
||||||
|
initial_gate_state: GATE_DISALLOW
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
|
||||||
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
|
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||||
|
|
||||||
|
const std::vector<Packet>& output =
|
||||||
|
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||||
|
ASSERT_EQ(1, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user