Adding BypassCalculator for use with SwitchContainer.

PiperOrigin-RevId: 482030395
This commit is contained in:
Hadon Nash 2022-10-18 14:57:41 -07:00 committed by Copybara-Service
parent e86cd39521
commit 7785603fbe
4 changed files with 536 additions and 0 deletions

View File

@ -1410,3 +1410,45 @@ cc_library(
],
alwayslink = 1,
)
mediapipe_proto_library(
name = "bypass_calculator_proto",
srcs = ["bypass_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bypass_calculator",
srcs = ["bypass_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":bypass_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "bypass_calculator_test",
srcs = ["bypass_calculator_test.cc"],
deps = [
":bypass_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,161 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
using mediapipe::BypassCalculatorOptions;
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
// By default, all inputs are discarded and all outputs are ignored.
// Certain input streams can be passed to corresponding output streams
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
// All output streams are updated with timestamp bounds indicating completed
// output.
//
// Note that this calculator is designed for use as a contained_node in a
// SwitchContainer. For this reason, any input and output tags are accepted,
// and stream semantics are specified through BypassCalculatorOptions.
//
// Example config:
// node {
// calculator: "BypassCalculator"
// input_stream: "APPEARANCES:appearances_post_facenet"
// input_stream: "VIDEO:video_frame"
// input_stream: "FEATURE_CONFIG:feature_config"
// input_stream: "ENABLE:gaze_enabled"
// output_stream: "APPEARANCES:analyzed_appearances"
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
// node_options: {
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
// pass_input_stream: "APPEARANCES"
// pass_output_stream: "APPEARANCES"
// }
// }
// }
//
class BypassCalculator : public Node {
public:
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
using IdMap = std::map<CollectionItemId, CollectionItemId>;
// Returns the map of passthrough input and output stream ids.
static absl::StatusOr<IdMap> GetPassMap(
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
const tool::TagMap& output_map) {
IdMap result;
auto& input_streams = options.pass_input_stream();
auto& output_streams = options.pass_output_stream();
int size = std::min(input_streams.size(), output_streams.size());
for (int i = 0; i < size; ++i) {
std::pair<std::string, int> in_tag, out_tag;
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
&in_tag.first, &in_tag.second));
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
&out_tag.first, &out_tag.second));
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
result[input_id] = output_id;
}
return result;
}
// Identifies all specified streams as "Any" packet type.
// Identifies passthrough streams as "Same" packet type.
static absl::Status UpdateContract(CalculatorContract* cc) {
auto options = cc->Options<BypassCalculatorOptions>();
RET_CHECK_EQ(options.pass_input_stream().size(),
options.pass_output_stream().size());
ASSIGN_OR_RETURN(
auto pass_streams,
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams) {
pass_out.insert(entry.second);
cc->Inputs().Get(entry.first).SetAny();
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
}
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
if (pass_streams.count(id) == 0) {
cc->Inputs().Get(id).SetAny();
}
}
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetAny();
}
}
return absl::OkStatus();
}
// Saves the map of passthrough input and output stream ids.
absl::Status Open(CalculatorContext* cc) override {
auto options = cc->Options<BypassCalculatorOptions>();
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
*cc->Outputs().TagMap()));
return absl::OkStatus();
}
// Copies packets between passthrough input and output streams.
// Updates timestamp bounds on all output streams.
absl::Status Process(CalculatorContext* cc) override {
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams_) {
pass_out.insert(entry.second);
auto& packet = cc->Inputs().Get(entry.first).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
cc->Outputs().Get(entry.first).AddPacket(packet);
}
}
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetNextTimestampBound(
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
}
}
return absl::OkStatus();
}
// Close all output streams.
absl::Status Close(CalculatorContext* cc) override {
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).Close();
}
return absl::OkStatus();
}
private:
IdMap pass_streams_;
};
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BypassCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BypassCalculatorOptions ext = 481259677;
}
// Names an input stream or streams to pass through, by "TAG:index".
repeated string pass_input_stream = 1;
// Names an output stream or streams to pass through, by "TAG:index".
repeated string pass_output_stream = 2;
}

View File

@ -0,0 +1,302 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A graph with using a BypassCalculator to pass through and ignore
// most of its inputs and outputs.
constexpr char kTestGraphConfig1[] = R"pb(
type: "AppearancesPassThroughSubgraph"
input_stream: "APPEARANCES:appearances"
input_stream: "VIDEO:video_frame"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "APPEARANCES:passthrough_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
node {
calculator: "BypassCalculator"
input_stream: "PASS:appearances"
input_stream: "TRUNCATE:0:video_frame"
input_stream: "TRUNCATE:1:feature_config"
output_stream: "PASS:passthrough_appearances"
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "PASS"
pass_output_stream: "PASS"
}
}
}
)pb";
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig2[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
}
}
}
)pb";
// A graph with using BypassCalculator as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig3[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: {
calculator: "BypassCalculator"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "APPEARANCES"
pass_output_stream: "APPEARANCES"
}
}
}
}
}
}
)pb";
// A graph with using BypassCalculator as a disabled-gate
// for input frames and appearances.
constexpr char kTestGraphConfig4[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "ENABLE:gaze_enabled"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "VIDEO:video_frame_out"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEATURE_CONFIG:feature_config_out"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "BypassCalculator" }
contained_node: { calculator: "PassThroughCalculator" }
}
}
}
)pb";
// Reports packet timestamp and string contents, or "<empty>"".
std::string DebugString(Packet p) {
return absl::StrCat(p.Timestamp().DebugString(), ":",
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
}
// Shows a bypass subgraph that passes through one stream.
TEST(BypassCalculatorTest, SubgraphChannel) {
CalculatorGraphConfig config_1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
CalculatorGraphConfig config_2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that passes through one stream.
TEST(BypassCalculatorTest, CalculatorChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
TEST(BypassCalculatorTest, GatedChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> video_frame;
MP_ASSERT_OK(graph.ObserveOutputStream(
"video_frame_out",
[&](const Packet& p) {
video_frame.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
// Close the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 200.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Only timestamps arrive from the BypassCalculator.
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
// Open the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 300.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets arrive from the PassThroughCalculator.
EXPECT_THAT(analyzed_appearances,
testing::ElementsAre("200:<empty>", "300:a2"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe