Internal change
PiperOrigin-RevId: 530562491
This commit is contained in:
parent
10776ef86f
commit
9de52a4a30
|
@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE";
|
|||
constexpr char kDisallowTag[] = "DISALLOW";
|
||||
constexpr char kAllowTag[] = "ALLOW";
|
||||
|
||||
enum GateState {
|
||||
GATE_UNINITIALIZED,
|
||||
GATE_ALLOW,
|
||||
GATE_DISALLOW,
|
||||
};
|
||||
|
||||
std::string ToString(GateState state) {
|
||||
std::string ToString(GateCalculatorOptions::GateState state) {
|
||||
switch (state) {
|
||||
case GATE_UNINITIALIZED:
|
||||
case GateCalculatorOptions::UNSPECIFIED:
|
||||
return "UNSPECIFIED";
|
||||
case GateCalculatorOptions::GATE_UNINITIALIZED:
|
||||
return "UNINITIALIZED";
|
||||
case GATE_ALLOW:
|
||||
case GateCalculatorOptions::GATE_ALLOW:
|
||||
return "ALLOW";
|
||||
case GATE_DISALLOW:
|
||||
case GateCalculatorOptions::GATE_DISALLOW:
|
||||
return "DISALLOW";
|
||||
}
|
||||
DLOG(FATAL) << "Unknown GateState";
|
||||
|
@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase {
|
|||
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
num_data_streams_ = cc->Inputs().NumEntries("");
|
||||
last_gate_state_ = GATE_UNINITIALIZED;
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
||||
|
||||
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
|
||||
last_gate_state_ = options.initial_gate_state();
|
||||
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
||||
|
||||
empty_packets_as_allow_ = options.empty_packets_as_allow();
|
||||
|
||||
if (!use_side_packet_for_allow_disallow_ &&
|
||||
|
@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase {
|
|||
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
|
||||
}
|
||||
}
|
||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
||||
const GateCalculatorOptions::GateState new_gate_state =
|
||||
allow ? GateCalculatorOptions::GATE_ALLOW
|
||||
: GateCalculatorOptions::GATE_DISALLOW;
|
||||
|
||||
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
||||
if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED &&
|
||||
last_gate_state_ != new_gate_state) {
|
||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||
<< cc->InputTimestamp().Value() << " from "
|
||||
|
@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
private:
|
||||
GateState last_gate_state_ = GATE_UNINITIALIZED;
|
||||
GateCalculatorOptions::GateState last_gate_state_ =
|
||||
GateCalculatorOptions::GATE_UNINITIALIZED;
|
||||
int num_data_streams_;
|
||||
bool empty_packets_as_allow_;
|
||||
bool use_side_packet_for_allow_disallow_ = false;
|
||||
|
|
|
@ -31,4 +31,13 @@ message GateCalculatorOptions {
|
|||
// Whether to allow or disallow the input streams to pass when no
|
||||
// ALLOW/DISALLOW input or side input is specified.
|
||||
optional bool allow = 2 [default = false];
|
||||
|
||||
enum GateState {
|
||||
UNSPECIFIED = 0;
|
||||
GATE_UNINITIALIZED = 1;
|
||||
GATE_ALLOW = 2;
|
||||
GATE_DISALLOW = 3;
|
||||
}
|
||||
|
||||
optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED];
|
||||
}
|
||||
|
|
|
@ -458,5 +458,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
|||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
// Must detect allow value for first timestamp as a state change when the
|
||||
// initial state is set to GATE_DISALLOW.
|
||||
TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:allow"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_change"
|
||||
options: {
|
||||
[mediapipe.GateCalculatorOptions.ext] {
|
||||
initial_gate_state: GATE_DISALLOW
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user