Internal change

PiperOrigin-RevId: 530562491
This commit is contained in:
MediaPipe Team 2023-05-09 03:46:38 -07:00 committed by Copybara-Service
parent 10776ef86f
commit 9de52a4a30
3 changed files with 49 additions and 15 deletions

View File

@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE";
constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW";
enum GateState {
GATE_UNINITIALIZED,
GATE_ALLOW,
GATE_DISALLOW,
};
std::string ToString(GateState state) {
std::string ToString(GateCalculatorOptions::GateState state) {
switch (state) {
case GATE_UNINITIALIZED:
case GateCalculatorOptions::UNSPECIFIED:
return "UNSPECIFIED";
case GateCalculatorOptions::GATE_UNINITIALIZED:
return "UNINITIALIZED";
case GATE_ALLOW:
case GateCalculatorOptions::GATE_ALLOW:
return "ALLOW";
case GATE_DISALLOW:
case GateCalculatorOptions::GATE_DISALLOW:
return "DISALLOW";
}
DLOG(FATAL) << "Unknown GateState";
@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase {
cc->SetOffset(TimestampDiff(0));
num_data_streams_ = cc->Inputs().NumEntries("");
last_gate_state_ = GATE_UNINITIALIZED;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
last_gate_state_ = options.initial_gate_state();
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
empty_packets_as_allow_ = options.empty_packets_as_allow();
if (!use_side_packet_for_allow_disallow_ &&
@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase {
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
}
}
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
const GateCalculatorOptions::GateState new_gate_state =
allow ? GateCalculatorOptions::GATE_ALLOW
: GateCalculatorOptions::GATE_DISALLOW;
if (cc->Outputs().HasTag(kStateChangeTag)) {
if (last_gate_state_ != GATE_UNINITIALIZED &&
if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
<< cc->InputTimestamp().Value() << " from "
@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase {
}
private:
GateState last_gate_state_ = GATE_UNINITIALIZED;
GateCalculatorOptions::GateState last_gate_state_ =
GateCalculatorOptions::GATE_UNINITIALIZED;
int num_data_streams_;
bool empty_packets_as_allow_;
bool use_side_packet_for_allow_disallow_ = false;

View File

@ -31,4 +31,13 @@ message GateCalculatorOptions {
// Whether to allow or disallow the input streams to pass when no
// ALLOW/DISALLOW input or side input is specified.
optional bool allow = 2 [default = false];
enum GateState {
UNSPECIFIED = 0;
GATE_UNINITIALIZED = 1;
GATE_ALLOW = 2;
GATE_DISALLOW = 3;
}
optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED];
}

View File

@ -458,5 +458,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
ASSERT_EQ(0, output.size());
}
// Must detect allow value for first timestamp as a state change when the
// initial state is set to GATE_DISALLOW.
TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:allow"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_change"
options: {
[mediapipe.GateCalculatorOptions.ext] {
initial_gate_state: GATE_DISALLOW
}
}
)");
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(1, output.size());
}
} // namespace
} // namespace mediapipe