Internal change

PiperOrigin-RevId: 530562491
This commit is contained in:
MediaPipe Team 2023-05-09 03:46:38 -07:00 committed by Copybara-Service
parent 10776ef86f
commit 9de52a4a30
3 changed files with 49 additions and 15 deletions

View File

@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE";
constexpr char kDisallowTag[] = "DISALLOW"; constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW"; constexpr char kAllowTag[] = "ALLOW";
enum GateState { std::string ToString(GateCalculatorOptions::GateState state) {
GATE_UNINITIALIZED,
GATE_ALLOW,
GATE_DISALLOW,
};
std::string ToString(GateState state) {
switch (state) { switch (state) {
case GATE_UNINITIALIZED: case GateCalculatorOptions::UNSPECIFIED:
return "UNSPECIFIED";
case GateCalculatorOptions::GATE_UNINITIALIZED:
return "UNINITIALIZED"; return "UNINITIALIZED";
case GATE_ALLOW: case GateCalculatorOptions::GATE_ALLOW:
return "ALLOW"; return "ALLOW";
case GATE_DISALLOW: case GateCalculatorOptions::GATE_DISALLOW:
return "DISALLOW"; return "DISALLOW";
} }
DLOG(FATAL) << "Unknown GateState"; DLOG(FATAL) << "Unknown GateState";
@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
num_data_streams_ = cc->Inputs().NumEntries(""); num_data_streams_ = cc->Inputs().NumEntries("");
last_gate_state_ = GATE_UNINITIALIZED;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
last_gate_state_ = options.initial_gate_state();
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
empty_packets_as_allow_ = options.empty_packets_as_allow(); empty_packets_as_allow_ = options.empty_packets_as_allow();
if (!use_side_packet_for_allow_disallow_ && if (!use_side_packet_for_allow_disallow_ &&
@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase {
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>(); allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
} }
} }
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; const GateCalculatorOptions::GateState new_gate_state =
allow ? GateCalculatorOptions::GATE_ALLOW
: GateCalculatorOptions::GATE_DISALLOW;
if (cc->Outputs().HasTag(kStateChangeTag)) { if (cc->Outputs().HasTag(kStateChangeTag)) {
if (last_gate_state_ != GATE_UNINITIALIZED && if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) { last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ " VLOG(2) << "State transition in " << cc->NodeName() << " @ "
<< cc->InputTimestamp().Value() << " from " << cc->InputTimestamp().Value() << " from "
@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase {
} }
private: private:
GateState last_gate_state_ = GATE_UNINITIALIZED; GateCalculatorOptions::GateState last_gate_state_ =
GateCalculatorOptions::GATE_UNINITIALIZED;
int num_data_streams_; int num_data_streams_;
bool empty_packets_as_allow_; bool empty_packets_as_allow_;
bool use_side_packet_for_allow_disallow_ = false; bool use_side_packet_for_allow_disallow_ = false;

View File

@ -31,4 +31,13 @@ message GateCalculatorOptions {
// Whether to allow or disallow the input streams to pass when no // Whether to allow or disallow the input streams to pass when no
// ALLOW/DISALLOW input or side input is specified. // ALLOW/DISALLOW input or side input is specified.
optional bool allow = 2 [default = false]; optional bool allow = 2 [default = false];
enum GateState {
UNSPECIFIED = 0;
GATE_UNINITIALIZED = 1;
GATE_ALLOW = 2;
GATE_DISALLOW = 3;
}
optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED];
} }

View File

@ -458,5 +458,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
ASSERT_EQ(0, output.size()); ASSERT_EQ(0, output.size());
} }
// Must detect allow value for first timestamp as a state change when the
// initial state is set to GATE_DISALLOW.
TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:allow"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_change"
options: {
[mediapipe.GateCalculatorOptions.ext] {
initial_gate_state: GATE_DISALLOW
}
}
)");
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(1, output.size());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe