462 lines
17 KiB
C++
462 lines
17 KiB
C++
// Copyright 2019 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "mediapipe/framework/calculator.pb.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/deps/message_matchers.h"
|
|
#include "mediapipe/framework/port/gmock.h"
|
|
#include "mediapipe/framework/port/gtest.h"
|
|
#include "mediapipe/framework/port/logging.h"
|
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
|
#include "mediapipe/framework/port/proto_ns.h"
|
|
#include "mediapipe/framework/port/ret_check.h"
|
|
#include "mediapipe/framework/port/status.h"
|
|
#include "mediapipe/framework/port/status_matchers.h"
|
|
#include "mediapipe/framework/subgraph.h"
|
|
#include "mediapipe/framework/tool/node_chain_subgraph.pb.h"
|
|
#include "mediapipe/framework/tool/subgraph_expansion.h"
|
|
|
|
namespace mediapipe {
|
|
namespace {
|
|
|
|
// A Calculator that outputs thrice the value of its input packet (an int).
|
|
// It also accepts a side packet tagged "TIMEZONE", but doesn't use it.
|
|
class TripleIntCalculator : public CalculatorBase {
|
|
public:
|
|
static absl::Status GetContract(CalculatorContract* cc) {
|
|
cc->Inputs().Index(0).Set<int>().Optional();
|
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)).Optional();
|
|
cc->InputSidePackets().Index(0).Set<int>().Optional();
|
|
cc->OutputSidePackets()
|
|
.Index(0)
|
|
.SetSameAs(&cc->InputSidePackets().Index(0))
|
|
.Optional();
|
|
cc->InputSidePackets().Tag("TIMEZONE").Set<int>().Optional();
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Open(CalculatorContext* cc) final {
|
|
cc->SetOffset(TimestampDiff(0));
|
|
if (cc->OutputSidePackets().HasTag("")) {
|
|
cc->OutputSidePackets().Index(0).Set(
|
|
MakePacket<int>(cc->InputSidePackets().Index(0).Get<int>() * 3));
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) final {
|
|
int value = cc->Inputs().Index(0).Value().Get<int>();
|
|
cc->Outputs().Index(0).Add(new int(3 * value), cc->InputTimestamp());
|
|
return absl::OkStatus();
|
|
}
|
|
};
|
|
REGISTER_CALCULATOR(TripleIntCalculator);
|
|
|
|
// A testing example of a SwitchContainer containing two subnodes.
|
|
// Note that the input and output tags supplied to the container node,
|
|
// must match the input and output tags required by the subnodes.
|
|
CalculatorGraphConfig SubnodeContainerExample() {
|
|
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
|
input_stream: "foo"
|
|
input_stream: "enable"
|
|
input_side_packet: "timezone"
|
|
node {
|
|
calculator: "SwitchContainer"
|
|
input_stream: "ENABLE:enable"
|
|
input_stream: "foo"
|
|
output_stream: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {
|
|
contained_node: { calculator: "TripleIntCalculator" }
|
|
contained_node: { calculator: "PassThroughCalculator" }
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_stream: "foo"
|
|
input_stream: "bar"
|
|
output_stream: "output_foo"
|
|
output_stream: "output_bar"
|
|
}
|
|
)pb");
|
|
}
|
|
|
|
// A testing example of a SwitchContainer containing two subnodes.
|
|
// Note that the side-input and side-output tags supplied to the container node,
|
|
// must match the side-input and side-output tags required by the subnodes.
|
|
CalculatorGraphConfig SideSubnodeContainerExample() {
|
|
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
|
input_side_packet: "foo"
|
|
input_side_packet: "enable"
|
|
output_side_packet: "output_bar"
|
|
node {
|
|
calculator: "SwitchContainer"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "foo"
|
|
output_side_packet: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {
|
|
contained_node: { calculator: "TripleIntCalculator" }
|
|
contained_node: { calculator: "PassThroughCalculator" }
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "foo"
|
|
input_side_packet: "bar"
|
|
output_side_packet: "output_foo"
|
|
output_side_packet: "output_bar"
|
|
}
|
|
)pb");
|
|
}
|
|
|
|
// Runs the test container graph with a few input packets.
|
|
void RunTestContainer(CalculatorGraphConfig supergraph) {
|
|
CalculatorGraph graph;
|
|
std::vector<Packet> out_foo, out_bar;
|
|
tool::AddVectorSink("output_foo", &supergraph, &out_foo);
|
|
tool::AddVectorSink("output_bar", &supergraph, &out_bar);
|
|
MP_ASSERT_OK(graph.Initialize(supergraph, {}));
|
|
MP_ASSERT_OK(graph.StartRun({{"timezone", MakePacket<int>(3)}}));
|
|
|
|
// Send enable == true signal at 5000 us.
|
|
const int64 enable_ts = 5000;
|
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
"enable", MakePacket<bool>(true).At(Timestamp(enable_ts))));
|
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
|
|
const int packet_count = 10;
|
|
// Send int value packets at {10K, 20K, 30K, ..., 100K}.
|
|
for (uint64 t = 1; t <= packet_count; ++t) {
|
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
// The inputs are sent to the input stream "foo", they should pass through.
|
|
EXPECT_EQ(out_foo.size(), t);
|
|
// Since "enable == true" for ts 10K...100K us, the second contained graph
|
|
// i.e. the one containing the PassThroughCalculator should output the
|
|
// input values without changing them.
|
|
EXPECT_EQ(out_bar.size(), t);
|
|
EXPECT_EQ(out_bar.back().Get<int>(), t);
|
|
}
|
|
|
|
// Send enable == false signal at 105K us.
|
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
"enable", MakePacket<bool>(false).At(Timestamp(105000))));
|
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
|
|
// Send int value packets at {110K, 120K, ..., 200K}.
|
|
for (uint64 t = 11; t <= packet_count * 2; ++t) {
|
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
// The inputs are sent to the input stream "foo", they should pass through.
|
|
EXPECT_EQ(out_foo.size(), t);
|
|
// Since "enable == false" for ts 110K...200K us, the first contained graph
|
|
// i.e. the one containing the TripleIntCalculator should output the values
|
|
// after tripling them.
|
|
EXPECT_EQ(out_bar.size(), t);
|
|
EXPECT_EQ(out_bar.back().Get<int>(), t * 3);
|
|
}
|
|
|
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
|
EXPECT_EQ(out_foo.size(), packet_count * 2);
|
|
EXPECT_EQ(out_bar.size(), packet_count * 2);
|
|
}
|
|
|
|
// Runs the test side-packet container graph with input side-packets.
|
|
void RunTestSideContainer(CalculatorGraphConfig supergraph) {
|
|
CalculatorGraph graph;
|
|
MP_ASSERT_OK(graph.Initialize(supergraph, {}));
|
|
MP_ASSERT_OK(graph.StartRun({
|
|
{"enable", MakePacket<bool>(false)},
|
|
{"foo", MakePacket<int>(4)},
|
|
}));
|
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
|
Packet side_output = graph.GetOutputSidePacket("output_bar").value();
|
|
EXPECT_EQ(side_output.Get<int>(), 12);
|
|
|
|
MP_ASSERT_OK(graph.StartRun({
|
|
{"enable", MakePacket<bool>(true)},
|
|
{"foo", MakePacket<int>(4)},
|
|
}));
|
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
|
side_output = graph.GetOutputSidePacket("output_bar").value();
|
|
EXPECT_EQ(side_output.Get<int>(), 4);
|
|
}
|
|
|
|
// Rearrange the Node messages within a CalculatorGraphConfig message.
|
|
CalculatorGraphConfig OrderNodes(const CalculatorGraphConfig& config,
|
|
std::vector<int> order) {
|
|
auto result = config;
|
|
result.clear_node();
|
|
for (int i = 0; i < order.size(); ++i) {
|
|
*result.add_node() = config.node(order[i]);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Shows the SwitchContainer container applied to a pair of simple subnodes.
|
|
TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
|
CalculatorGraphConfig expected_graph =
|
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
|
node {
|
|
name: "switchcontainer__SwitchDemuxCalculator"
|
|
calculator: "SwitchDemuxCalculator"
|
|
input_stream: "ENABLE:enable"
|
|
input_stream: "foo"
|
|
output_stream: "C0__:switchcontainer__c0__foo"
|
|
output_stream: "C1__:switchcontainer__c1__foo"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
name: "switchcontainer__TripleIntCalculator"
|
|
calculator: "TripleIntCalculator"
|
|
input_stream: "switchcontainer__c0__foo"
|
|
output_stream: "switchcontainer__c0__bar"
|
|
}
|
|
node {
|
|
name: "switchcontainer__PassThroughCalculator"
|
|
calculator: "PassThroughCalculator"
|
|
input_stream: "switchcontainer__c1__foo"
|
|
output_stream: "switchcontainer__c1__bar"
|
|
}
|
|
node {
|
|
name: "switchcontainer__SwitchMuxCalculator"
|
|
calculator: "SwitchMuxCalculator"
|
|
input_stream: "ENABLE:enable"
|
|
input_stream: "C0__:switchcontainer__c0__bar"
|
|
input_stream: "C1__:switchcontainer__c1__bar"
|
|
output_stream: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_stream: "foo"
|
|
input_stream: "bar"
|
|
output_stream: "output_foo"
|
|
output_stream: "output_bar"
|
|
}
|
|
input_stream: "foo"
|
|
input_stream: "enable"
|
|
input_side_packet: "timezone"
|
|
)pb");
|
|
expected_graph = OrderNodes(expected_graph, {4, 0, 3, 1, 2});
|
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
|
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
|
}
|
|
|
|
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
|
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
|
RunTestContainer(supergraph);
|
|
}
|
|
|
|
// Shows the SwitchContainer does not allow input_stream_handler overwrite.
|
|
TEST(SwitchContainerTest, ValidateInputStreamHandler) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraph graph;
|
|
CalculatorGraphConfig supergraph = SideSubnodeContainerExample();
|
|
*supergraph.mutable_input_stream_handler()->mutable_input_stream_handler() =
|
|
"DefaultInputStreamHandler";
|
|
MP_ASSERT_OK(graph.Initialize(supergraph, {}));
|
|
CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie<
|
|
CalculatorGraphConfig>(R"pb(
|
|
node {
|
|
name: "switchcontainer__SwitchDemuxCalculator"
|
|
calculator: "SwitchDemuxCalculator"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "foo"
|
|
output_side_packet: "C0__:switchcontainer__c0__foo"
|
|
output_side_packet: "C1__:switchcontainer__c1__foo"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
name: "switchcontainer__TripleIntCalculator"
|
|
calculator: "TripleIntCalculator"
|
|
input_side_packet: "switchcontainer__c0__foo"
|
|
output_side_packet: "switchcontainer__c0__bar"
|
|
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
|
}
|
|
node {
|
|
name: "switchcontainer__PassThroughCalculator"
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "switchcontainer__c1__foo"
|
|
output_side_packet: "switchcontainer__c1__bar"
|
|
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
|
}
|
|
node {
|
|
name: "switchcontainer__SwitchMuxCalculator"
|
|
calculator: "SwitchMuxCalculator"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "C0__:switchcontainer__c0__bar"
|
|
input_side_packet: "C1__:switchcontainer__c1__bar"
|
|
output_side_packet: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "foo"
|
|
input_side_packet: "bar"
|
|
output_side_packet: "output_foo"
|
|
output_side_packet: "output_bar"
|
|
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
|
}
|
|
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
|
executor {}
|
|
input_side_packet: "foo"
|
|
input_side_packet: "enable"
|
|
output_side_packet: "output_bar"
|
|
)pb");
|
|
EXPECT_THAT(graph.Config(), mediapipe::EqualsProto(expected_graph));
|
|
}
|
|
|
|
// Shows the SwitchContainer container applied to a pair of simple subnodes.
|
|
TEST(SwitchContainerTest, ApplyToSideSubnodes) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraphConfig supergraph = SideSubnodeContainerExample();
|
|
CalculatorGraphConfig expected_graph =
|
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
|
input_side_packet: "foo"
|
|
input_side_packet: "enable"
|
|
output_side_packet: "output_bar"
|
|
node {
|
|
name: "switchcontainer__SwitchDemuxCalculator"
|
|
calculator: "SwitchDemuxCalculator"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "foo"
|
|
output_side_packet: "C0__:switchcontainer__c0__foo"
|
|
output_side_packet: "C1__:switchcontainer__c1__foo"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
name: "switchcontainer__TripleIntCalculator"
|
|
calculator: "TripleIntCalculator"
|
|
input_side_packet: "switchcontainer__c0__foo"
|
|
output_side_packet: "switchcontainer__c0__bar"
|
|
}
|
|
node {
|
|
name: "switchcontainer__PassThroughCalculator"
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "switchcontainer__c1__foo"
|
|
output_side_packet: "switchcontainer__c1__bar"
|
|
}
|
|
node {
|
|
name: "switchcontainer__SwitchMuxCalculator"
|
|
calculator: "SwitchMuxCalculator"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "C0__:switchcontainer__c0__bar"
|
|
input_side_packet: "C1__:switchcontainer__c1__bar"
|
|
output_side_packet: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {}
|
|
}
|
|
input_stream_handler {
|
|
input_stream_handler: "ImmediateInputStreamHandler"
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "foo"
|
|
input_side_packet: "bar"
|
|
output_side_packet: "output_foo"
|
|
output_side_packet: "output_bar"
|
|
}
|
|
)pb");
|
|
expected_graph = OrderNodes(expected_graph, {4, 0, 3, 1, 2});
|
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
|
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
|
}
|
|
|
|
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
|
TEST(SwitchContainerTest, RunWithSideSubnodes) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraphConfig supergraph = SideSubnodeContainerExample();
|
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
|
RunTestSideContainer(supergraph);
|
|
}
|
|
|
|
// Shows validation of SwitchContainer container side inputs.
|
|
TEST(SwitchContainerTest, ValidateSideInputs) {
|
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
|
CalculatorGraphConfig supergraph =
|
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
|
input_side_packet: "foo"
|
|
input_side_packet: "enable"
|
|
output_side_packet: "output_bar"
|
|
node {
|
|
calculator: "SwitchContainer"
|
|
input_side_packet: "ENABLE:enable"
|
|
input_side_packet: "SELECT:enable"
|
|
input_side_packet: "foo"
|
|
output_side_packet: "bar"
|
|
options {
|
|
[mediapipe.SwitchContainerOptions.ext] {
|
|
contained_node: { calculator: "TripleIntCalculator" }
|
|
contained_node: { calculator: "PassThroughCalculator" }
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
calculator: "PassThroughCalculator"
|
|
input_side_packet: "foo"
|
|
input_side_packet: "bar"
|
|
output_side_packet: "output_foo"
|
|
output_side_packet: "output_bar"
|
|
}
|
|
)pb");
|
|
auto status = tool::ExpandSubgraphs(&supergraph);
|
|
EXPECT_EQ(std::pair(status.code(), std::string(status.message())),
|
|
std::pair(absl::StatusCode::kInvalidArgument,
|
|
std::string("Only one of SwitchContainer inputs "
|
|
"'ENABLE' and 'SELECT' can be specified")));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace mediapipe
|