diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index 448329b88..e5e87b69b 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE"; constexpr char kDisallowTag[] = "DISALLOW"; constexpr char kAllowTag[] = "ALLOW"; -enum GateState { - GATE_UNINITIALIZED, - GATE_ALLOW, - GATE_DISALLOW, -}; - -std::string ToString(GateState state) { +std::string ToString(GateCalculatorOptions::GateState state) { switch (state) { - case GATE_UNINITIALIZED: + case GateCalculatorOptions::UNSPECIFIED: + return "UNSPECIFIED"; + case GateCalculatorOptions::GATE_UNINITIALIZED: return "UNINITIALIZED"; - case GATE_ALLOW: + case GateCalculatorOptions::GATE_ALLOW: return "ALLOW"; - case GATE_DISALLOW: + case GateCalculatorOptions::GATE_DISALLOW: return "DISALLOW"; } DLOG(FATAL) << "Unknown GateState"; @@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase { cc->SetOffset(TimestampDiff(0)); num_data_streams_ = cc->Inputs().NumEntries(""); - last_gate_state_ = GATE_UNINITIALIZED; - RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs())); const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); + last_gate_state_ = options.initial_gate_state(); + + RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs())); + empty_packets_as_allow_ = options.empty_packets_as_allow(); if (!use_side_packet_for_allow_disallow_ && @@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase { allow = !cc->Inputs().Tag(kDisallowTag).Get(); } } - const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; + const GateCalculatorOptions::GateState new_gate_state = + allow ? GateCalculatorOptions::GATE_ALLOW + : GateCalculatorOptions::GATE_DISALLOW; if (cc->Outputs().HasTag(kStateChangeTag)) { - if (last_gate_state_ != GATE_UNINITIALIZED && + if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED && last_gate_state_ != new_gate_state) { VLOG(2) << "State transition in " << cc->NodeName() << " @ " << cc->InputTimestamp().Value() << " from " @@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase { } private: - GateState last_gate_state_ = GATE_UNINITIALIZED; + GateCalculatorOptions::GateState last_gate_state_ = + GateCalculatorOptions::GATE_UNINITIALIZED; int num_data_streams_; bool empty_packets_as_allow_; bool use_side_packet_for_allow_disallow_ = false; diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto index b7d597a63..4153d5f32 100644 --- a/mediapipe/calculators/core/gate_calculator.proto +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -31,4 +31,13 @@ message GateCalculatorOptions { // Whether to allow or disallow the input streams to pass when no // ALLOW/DISALLOW input or side input is specified. optional bool allow = 2 [default = false]; + + enum GateState { + UNSPECIFIED = 0; + GATE_UNINITIALIZED = 1; + GATE_ALLOW = 2; + GATE_DISALLOW = 3; + } + + optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED]; } diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index 192019820..8875bd7e3 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -458,5 +458,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { ASSERT_EQ(0, output.size()); } +// Must detect allow value for first timestamp as a state change when the +// initial state is set to GATE_DISALLOW. +TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:allow" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_change" + options: { + [mediapipe.GateCalculatorOptions.ext] { + initial_gate_state: GATE_DISALLOW + } + } + )"); + + constexpr int64_t kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "ALLOW", true); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(1, output.size()); +} + } // namespace } // namespace mediapipe