mediapipe/mediapipe/framework/tool/subgraph_expansion_test.cc
MediaPipe Team 38f838513a Enable defining and using internal executor within a subgraph
PiperOrigin-RevId: 522449982
2023-04-06 15:18:56 -07:00

806 lines
28 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/tool/subgraph_expansion.h"
#include <memory>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/status_handler.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/framework/tool/node_chain_subgraph.pb.h"
namespace mediapipe {
namespace {
class SimpleTestCalculator : public CalculatorBase {
public:
absl::Status Process(CalculatorContext* cc) override {
return absl::OkStatus();
}
static absl::Status GetContract(CalculatorContract* cc) {
for (PacketType& type : cc->Inputs()) {
type.Set<int>();
}
for (PacketType& type : cc->Outputs()) {
type.Set<int>();
}
for (PacketType& type : cc->InputSidePackets()) {
type.Set<int>();
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(SimpleTestCalculator);
typedef SimpleTestCalculator SomeSourceCalculator;
typedef SimpleTestCalculator SomeSinkCalculator;
typedef SimpleTestCalculator SomeRegularCalculator;
typedef SimpleTestCalculator SomeAggregator;
REGISTER_CALCULATOR(SomeSourceCalculator);
REGISTER_CALCULATOR(SomeSinkCalculator);
REGISTER_CALCULATOR(SomeRegularCalculator);
REGISTER_CALCULATOR(SomeAggregator);
class TestSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& /*options*/) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "DATA:input_1"
node {
name: "regular_node"
calculator: "SomeRegularCalculator"
input_stream: "input_1"
output_stream: "stream_a"
input_side_packet: "side"
}
node {
name: "simple_sink"
calculator: "SomeSinkCalculator"
input_stream: "stream_a"
}
packet_generator {
packet_generator: "SomePacketGenerator"
output_side_packet: "side"
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(TestSubgraph);
class PacketFactoryTestSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& /*options*/) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "DATA:input_1"
node {
name: "regular_node"
calculator: "SomeRegularCalculator"
input_stream: "input_1"
output_stream: "stream_a"
input_side_packet: "side"
}
node {
name: "simple_sink"
calculator: "SomeSinkCalculator"
input_stream: "stream_a"
}
packet_factory {
packet_factory: "SomePacketFactory"
output_side_packet: "side"
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(PacketFactoryTestSubgraph);
// This subgraph chains copies of the specified node in series. The node type
// and the number of copies of the node are specified in subgraph options.
class NodeChainSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
auto opts =
Subgraph::GetOptions<mediapipe::NodeChainSubgraphOptions>(options);
const ProtoString& node_type = opts.node_type();
int chain_length = opts.chain_length();
RET_CHECK(!node_type.empty());
RET_CHECK_GT(chain_length, 0);
CalculatorGraphConfig config;
config.add_input_stream("INPUT:stream_0");
config.add_output_stream(absl::StrCat("OUTPUT:stream_", chain_length));
for (int i = 0; i < chain_length; ++i) {
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator(node_type);
node->add_input_stream(absl::StrCat("stream_", i));
node->add_output_stream(absl::StrCat("stream_", i + 1));
}
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(NodeChainSubgraph);
// A subgraph used in the ExecutorFieldOfNodeInSubgraphPreserved test. The
// subgraph contains a node with the executor field "custom_thread_pool".
class NodeWithExecutorSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "INPUT:foo"
output_stream: "OUTPUT:bar"
node {
calculator: "PassThroughCalculator"
input_stream: "foo"
output_stream: "bar"
executor: "custom_thread_pool"
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(NodeWithExecutorSubgraph);
// A subgraph used in the ExecutorFieldOfNodeInSubgraphPreserved test. The
// subgraph contains a NodeWithExecutorSubgraph.
class EnclosingSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "IN:in"
output_stream: "OUT:out"
node {
calculator: "NodeWithExecutorSubgraph"
input_stream: "INPUT:in"
output_stream: "OUTPUT:out"
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(EnclosingSubgraph);
TEST(SubgraphExpansionTest, TransformStreamNames) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSinkCalculator"
input_stream: "input_1"
input_stream: "VIDEO:input_2"
input_stream: "AUDIO:0:input_3"
input_stream: "AUDIO:1:input_4"
}
)pb");
CalculatorGraphConfig expected_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSinkCalculator"
input_stream: "input_1_foo"
input_stream: "VIDEO:input_2_foo"
input_stream: "AUDIO:0:input_3_foo"
input_stream: "AUDIO:1:input_4_foo"
}
)pb");
auto add_foo = [](absl::string_view s) { return absl::StrCat(s, "_foo"); };
MP_EXPECT_OK(tool::TransformStreamNames(
(*config.mutable_node())[0].mutable_input_stream(), add_foo));
EXPECT_THAT(config, mediapipe::EqualsProto(expected_config));
}
TEST(SubgraphExpansionTest, TransformNames) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_1"
node {
calculator: "SomeRegularCalculator"
name: "bob"
input_stream: "input_1"
input_stream: "VIDEO:input_2"
input_stream: "AUDIO:0:input_3"
input_stream: "AUDIO:1:input_4"
output_stream: "output_1"
}
node {
calculator: "SomeRegularCalculator"
input_stream: "output_1"
output_stream: "output_2"
}
)pb");
CalculatorGraphConfig expected_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "__sg0_input_1"
node {
calculator: "SomeRegularCalculator"
name: "__sg0_bob"
input_stream: "__sg0_input_1"
input_stream: "VIDEO:__sg0_input_2"
input_stream: "AUDIO:0:__sg0_input_3"
input_stream: "AUDIO:1:__sg0_input_4"
output_stream: "__sg0_output_1"
}
node {
name: "__sg0_SomeRegularCalculator"
calculator: "SomeRegularCalculator"
input_stream: "__sg0_output_1"
output_stream: "__sg0_output_2"
}
)pb");
auto add_prefix = [](absl::string_view s) {
return absl::StrCat("__sg0_", s);
};
MP_EXPECT_OK(tool::TransformNames(&config, add_prefix));
EXPECT_THAT(config, mediapipe::EqualsProto(expected_config));
}
TEST(SubgraphExpansionTest, FindCorrespondingStreams) {
CalculatorGraphConfig config1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_1"
input_stream: "VIDEO:input_2"
input_stream: "AUDIO:0:input_3"
input_stream: "AUDIO:1:input_4"
)pb");
CalculatorGraphConfig config2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSubgraph"
input_stream: "foo"
input_stream: "VIDEO:bar"
input_stream: "AUDIO:0:baz"
input_stream: "AUDIO:1:qux"
}
)pb");
std::map<std::string, std::string> stream_map;
MP_EXPECT_OK(tool::FindCorrespondingStreams(
&stream_map, config1.input_stream(), config2.node()[0].input_stream()));
EXPECT_THAT(stream_map,
testing::UnorderedElementsAre(testing::Pair("input_1", "foo"),
testing::Pair("input_2", "bar"),
testing::Pair("input_3", "baz"),
testing::Pair("input_4", "qux")));
}
TEST(SubgraphExpansionTest, FindCorrespondingStreamsNonexistentTag) {
// The VIDEO tag does not exist in the subgraph.
CalculatorGraphConfig config1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_1"
input_stream: "AUDIO:0:input_3"
input_stream: "AUDIO:1:input_4"
)pb");
CalculatorGraphConfig config2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSubgraph"
input_stream: "foo"
input_stream: "VIDEO:bar"
input_stream: "AUDIO:0:baz"
input_stream: "AUDIO:1:qux"
}
)pb");
std::map<std::string, std::string> stream_map;
auto status = tool::FindCorrespondingStreams(
&stream_map, config1.input_stream(), config2.node()[0].input_stream());
EXPECT_THAT(status.message(),
testing::AllOf(
// Problematic tag.
testing::HasSubstr("VIDEO"),
// Error.
testing::HasSubstr("does not exist")));
}
TEST(SubgraphExpansionTest, FindCorrespondingStreamsTooFewIndexes) {
// The AUDIO tag has too few indexes in the subgraph (1 vs. 2).
CalculatorGraphConfig config1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_1"
input_stream: "VIDEO:input_2"
input_stream: "AUDIO:0:input_3"
)pb");
CalculatorGraphConfig config2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSubgraph"
input_stream: "foo"
input_stream: "VIDEO:bar"
input_stream: "AUDIO:0:baz"
input_stream: "AUDIO:1:qux"
}
)pb");
std::map<std::string, std::string> stream_map;
auto status = tool::FindCorrespondingStreams(
&stream_map, config1.input_stream(), config2.node()[0].input_stream());
EXPECT_THAT(status.message(),
testing::AllOf(
// Problematic tag.
testing::HasSubstr("AUDIO"),
// Error.
testing::HasSubstr(" 2 "), testing::HasSubstr(" 1 ")));
}
TEST(SubgraphExpansionTest, ConnectSubgraphStreams) {
CalculatorGraphConfig subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "A:input_1"
input_stream: "B:input_2"
output_stream: "O:output_2"
input_side_packet: "SI:side_input"
output_side_packet: "SO:side_output"
node {
calculator: "SomeRegularCalculator"
input_stream: "input_1"
input_stream: "VIDEO:input_2"
input_side_packet: "side_input"
output_stream: "output_1"
}
node {
calculator: "SomeRegularCalculator"
input_stream: "input_1"
input_stream: "output_1"
output_stream: "output_2"
}
packet_generator {
packet_generator: "SomeGenerator"
input_side_packet: "side_input"
output_side_packet: "side_output"
}
)pb");
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SomeSubgraph"
input_stream: "A:foo"
input_stream: "B:bar"
output_stream: "O:foobar"
input_side_packet: "SI:flip"
output_side_packet: "SO:flop"
}
)pb");
// Note: graph input streams, output streams, and side packets on the
// subgraph are not changed because they are going to be discarded anyway.
CalculatorGraphConfig expected_subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "A:input_1"
input_stream: "B:input_2"
output_stream: "O:output_2"
input_side_packet: "SI:side_input"
output_side_packet: "SO:side_output"
node {
calculator: "SomeRegularCalculator"
input_stream: "foo"
input_stream: "VIDEO:bar"
input_side_packet: "flip"
output_stream: "output_1"
}
node {
calculator: "SomeRegularCalculator"
input_stream: "foo"
input_stream: "output_1"
output_stream: "foobar"
}
packet_generator {
packet_generator: "SomeGenerator"
input_side_packet: "flip"
output_side_packet: "flop"
}
)pb");
MP_EXPECT_OK(tool::ConnectSubgraphStreams(supergraph.node()[0], &subgraph));
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_subgraph));
}
TEST(SubgraphExpansionTest, ExpandSubgraphs) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
name: "simple_source"
calculator: "SomeSourceCalculator"
output_stream: "foo"
}
node { calculator: "TestSubgraph" input_stream: "DATA:foo" }
)pb");
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
name: "simple_source"
calculator: "SomeSourceCalculator"
output_stream: "foo"
}
node {
name: "testsubgraph__regular_node"
calculator: "SomeRegularCalculator"
input_stream: "foo"
output_stream: "testsubgraph__stream_a"
input_side_packet: "testsubgraph__side"
}
node {
name: "testsubgraph__simple_sink"
calculator: "SomeSinkCalculator"
input_stream: "testsubgraph__stream_a"
}
packet_generator {
packet_generator: "SomePacketGenerator"
output_side_packet: "testsubgraph__side"
}
)pb");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
}
TEST(SubgraphExpansionTest, ValidateSubgraphFields) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
name: "simple_source"
calculator: "SomeSourceCalculator"
output_stream: "foo"
}
node {
name: "foo_subgraph"
calculator: "TestSubgraph"
input_stream: "DATA:foo"
buffer_size_hint: -1 # This field is only applicable to calculators.
}
)pb");
absl::Status s1 = tool::ValidateSubgraphFields(supergraph.node(1));
EXPECT_EQ(s1.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(s1.message(), testing::HasSubstr("foo_subgraph"));
absl::Status s2 = tool::ExpandSubgraphs(&supergraph);
EXPECT_EQ(s2.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(s2.message(), testing::HasSubstr("foo_subgraph"));
}
// A test that captures the use case of CL 191001940. The "executor" field of
// a node inside a subgraph should be preserved, not mapped or mangled. This
// test will help us detect breakage of this use case when we implement
// subgraph executor support in the future.
TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
executor {
name: "custom_thread_pool"
type: "ThreadPoolExecutor"
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 4 }
}
}
node {
calculator: "EnclosingSubgraph"
input_stream: "IN:input"
output_stream: "OUT:output"
}
)pb");
CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie<
CalculatorGraphConfig>(R"pb(
input_stream: "input"
executor {
name: "custom_thread_pool"
type: "ThreadPoolExecutor"
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 4 }
}
}
node {
calculator: "PassThroughCalculator"
name: "enclosingsubgraph__nodewithexecutorsubgraph__PassThroughCalculator"
input_stream: "input"
output_stream: "output"
executor: "custom_thread_pool"
}
)pb");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
}
// A subgraph that defines and uses an internal executor.
class NodeWithInternalExecutorSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "IN:foo"
output_stream: "OUT:bar"
executor {
name: "xyz"
type: "ThreadPoolExecutor"
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: "PassThroughCalculator"
executor: "xyz"
input_stream: "foo"
output_stream: "bar"
}
)pb");
}
};
REGISTER_MEDIAPIPE_GRAPH(NodeWithInternalExecutorSubgraph);
TEST(SubgraphExpansionTest, ExecutorCanDefinedAndUsedWithinSubgraph) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "NodeWithInternalExecutorSubgraph"
input_stream: "IN:input"
output_stream: "OUT:output"
}
)pb");
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
executor {
name: "xyz"
type: "ThreadPoolExecutor"
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: "PassThroughCalculator"
name: "nodewithinternalexecutorsubgraph__PassThroughCalculator"
input_stream: "input"
output_stream: "output"
executor: "xyz"
}
)pb");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
}
const mediapipe::GraphService<std::string> kStringTestService{
"mediapipe::StringTestService"};
class GraphServicesClientTestSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
auto string_service = sc->Service(kStringTestService);
RET_CHECK(string_service.IsAvailable()) << "Service not available";
CalculatorGraphConfig config;
config.add_node()->set_calculator(string_service.GetObject());
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(GraphServicesClientTestSubgraph);
TEST(SubgraphExpansionTest, GraphServicesUsage) {
CalculatorGraphConfig supergraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { calculator: "GraphServicesClientTestSubgraph" }
)pb");
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
name: "graphservicesclienttestsubgraph__ExpectedNode"
calculator: "ExpectedNode"
}
)pb");
GraphServiceManager service_manager;
MP_ASSERT_OK(service_manager.SetServiceObject(
kStringTestService, std::make_shared<std::string>("ExpectedNode")));
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
/*graph_options=*/nullptr,
&service_manager));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
}
// Shows SubgraphOptions consumed by GraphRegistry::CreateByName.
TEST(SubgraphExpansionTest, SubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
// CalculatorGraph::Initialize passes the SubgraphOptions into:
// (1) GraphRegistry::CreateByName("NodeChainSubgraph", options)
// (2) tool::ExpandSubgraphs(&config, options)
auto graph_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
SubgraphContext context(&graph_options, /*service_manager=*/nullptr);
// "NodeChainSubgraph" consumes graph_options only in CreateByName.
auto subgraph_status =
graph_registry.CreateByName("", "NodeChainSubgraph", &context);
MP_ASSERT_OK(subgraph_status);
auto subgraph = std::move(subgraph_status).value();
MP_ASSERT_OK(
tool::ExpandSubgraphs(&subgraph, &graph_registry, &graph_options));
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_0"
output_stream: "stream_1"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_1"
output_stream: "stream_2"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_2"
output_stream: "stream_3"
}
input_stream: "INPUT:stream_0"
output_stream: "OUTPUT:stream_3"
)pb");
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_graph));
}
// Shows SubgraphOptions consumed by tool::ExpandSubgraphs.
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
auto moon_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
auto moon_subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
type: "MoonSubgraph"
graph_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node: {
calculator: "MoonCalculator"
node_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
option_value: "chain_length:options/chain_length"
}
)pb");
// The moon_options are copied into the graph_options of moon_subgraph.
MP_ASSERT_OK(
tool::ExpandSubgraphs(&moon_subgraph, &graph_registry, &moon_options));
// The field chain_length is copied from moon_options into MoonCalculator.
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "MoonCalculator"
node_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
chain_length: 3
}
}
}
type: "MoonSubgraph"
graph_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
)pb");
EXPECT_THAT(moon_subgraph, mediapipe::EqualsProto(expected_graph));
}
// Shows ExpandSubgraphs applied twice. "option_value" fields are evaluated
// and removed on the first ExpandSubgraphs call. If "option_value" fields
// are not removed during ExpandSubgraphs, they evaluate incorrectly on the
// second ExpandSubgraphs call and this test fails on "expected_node_options".
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsTwice) {
GraphRegistry graph_registry;
// Register a simple-subgraph that accepts graph options.
auto moon_subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
type: "MoonSubgraph"
graph_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node: {
calculator: "MoonCalculator"
node_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
option_value: "chain_length:options/chain_length"
}
)pb");
graph_registry.Register("MoonSubgraph", moon_subgraph);
// Invoke the simple-subgraph with graph options.
// The empty NodeChainSubgraphOptions below allows "option_value" fields
// on "MoonCalculator" to evaluate incorrectly, if not removed.
auto sky_graph = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
graph_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node: {
calculator: "MoonSubgraph"
options: {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
}
}
)pb");
// The first ExpandSubgraphs call evaluates and removes "option_value" fields.
MP_ASSERT_OK(tool::ExpandSubgraphs(&sky_graph, &graph_registry));
auto expanded_1 = sky_graph;
// The second ExpandSubgraphs call has no effect on the expanded graph.
MP_ASSERT_OK(tool::ExpandSubgraphs(&sky_graph, &graph_registry));
// Validate the expected node_options for the "MoonSubgraph".
// If the "option_value" fields are not removed during ExpandSubgraphs,
// this test fails with an incorrect value for "chain_length".
auto expected_node_options =
mediapipe::ParseTextProtoOrDie<mediapipe::NodeChainSubgraphOptions>(
"chain_length: 3");
mediapipe::NodeChainSubgraphOptions node_options;
sky_graph.node(0).node_options(0).UnpackTo(&node_options);
ASSERT_THAT(node_options, mediapipe::EqualsProto(expected_node_options));
// Validate the results from both ExpandSubgraphs() calls.
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
graph_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node {
name: "moonsubgraph__MoonCalculator"
calculator: "MoonCalculator"
node_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
chain_length: 3
}
}
}
)pb");
EXPECT_THAT(expanded_1, mediapipe::EqualsProto(expected_graph));
EXPECT_THAT(sky_graph, mediapipe::EqualsProto(expected_graph));
}
} // namespace
} // namespace mediapipe