diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 56ca0dc65..fbdcf8c9e 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -749,6 +749,7 @@ cc_test( ":node_chain_subgraph_cc_proto", ":node_chain_subgraph_options_lib", ":subgraph_expansion", + "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index b4f58a42e..b6d9950a1 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -37,6 +37,8 @@ namespace mediapipe { namespace { +using ::testing::HasSubstr; + class SimpleTestCalculator : public CalculatorBase { public: absl::Status Process(CalculatorContext* cc) override { @@ -743,5 +745,222 @@ TEST(SubgraphExpansionTest, SimpleSubgraphOptionsTwice) { EXPECT_THAT(sky_graph, mediapipe::EqualsProto(expected_graph)); } +// A subgraph that defines and uses an internal executor with name "xyz". +class InternalExecutorSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + const SubgraphOptions& options) override { + return mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "IN:foo" + output_stream: "OUT:bar" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "PassThroughCalculator" + executor: "xyz" + input_stream: "foo" + output_stream: "bar" + } + )pb"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(InternalExecutorSubgraph); + +// This test confirms that none of existing subgraphs can actually create an +// executor when used as subgraphs and not like a final graph. +TEST(SubgraphExpansionTest, SubgraphExecutorIsIgnored) { + CalculatorGraphConfig supergraph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "InternalExecutorSubgraph" + input_stream: "IN:input" + output_stream: "OUT:output" + } + )pb"); + CalculatorGraphConfig expected_graph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + name: "internalexecutorsubgraph__PassThroughCalculator" + calculator: "PassThroughCalculator" + input_stream: "input" + output_stream: "output" + executor: "xyz" + } + )pb"); + MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); + + CalculatorGraph calculator_graph; + EXPECT_THAT(calculator_graph.Initialize(supergraph), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("The executor \"xyz\" is " + "not declared in an ExecutorConfig."))); +} + +class NestedInternalExecutorsSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + const SubgraphOptions& options) override { + return mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "IN:foo" + output_stream: "OUT:bar" + node { + calculator: "InternalExecutorSubgraph" + input_stream: "IN:foo" + output_stream: "OUT:bar_0" + } + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "PassThroughCalculator" + executor: "xyz" + input_stream: "bar_0" + output_stream: "bar_1" + } + executor { + name: "abc" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "PassThroughCalculator" + executor: "abc" + input_stream: "bar_1" + output_stream: "bar" + } + )pb"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(NestedInternalExecutorsSubgraph); + +TEST(SubgraphExpansionTest, NestedSubgraphExecutorsAreIgnored) { + CalculatorGraphConfig supergraph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "NestedInternalExecutorsSubgraph" + input_stream: "IN:input" + output_stream: "OUT:output" + } + )pb"); + CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie< + CalculatorGraphConfig>(R"pb( + node { + name: "nestedinternalexecutorssubgraph__PassThroughCalculator_1" + calculator: "PassThroughCalculator" + input_stream: "nestedinternalexecutorssubgraph__bar_0" + output_stream: "nestedinternalexecutorssubgraph__bar_1" + executor: "xyz" + } + node { + name: "nestedinternalexecutorssubgraph__PassThroughCalculator_2" + calculator: "PassThroughCalculator" + input_stream: "nestedinternalexecutorssubgraph__bar_1" + output_stream: "output" + executor: "abc" + } + node { + name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__PassThroughCalculator" + calculator: "PassThroughCalculator" + input_stream: "input" + output_stream: "nestedinternalexecutorssubgraph__bar_0" + executor: "xyz" + } + input_stream: "input" + )pb"); + MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); + + CalculatorGraph calculator_graph; + EXPECT_THAT(calculator_graph.Initialize(supergraph), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("The executor \"xyz\" is " + "not declared in an ExecutorConfig."))); +} + +TEST(SubgraphExpansionTest, GraphExecutorsSubstituteSubgraphExecutors) { + CalculatorGraphConfig supergraph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + executor { + name: "abc" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "NestedInternalExecutorsSubgraph" + input_stream: "IN:input" + output_stream: "OUT:output" + } + )pb"); + CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie< + CalculatorGraphConfig>(R"pb( + node { + name: "nestedinternalexecutorssubgraph__PassThroughCalculator_1" + calculator: "PassThroughCalculator" + input_stream: "nestedinternalexecutorssubgraph__bar_0" + output_stream: "nestedinternalexecutorssubgraph__bar_1" + executor: "xyz" + } + node { + name: "nestedinternalexecutorssubgraph__PassThroughCalculator_2" + calculator: "PassThroughCalculator" + input_stream: "nestedinternalexecutorssubgraph__bar_1" + output_stream: "output" + executor: "abc" + } + node { + name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__PassThroughCalculator" + calculator: "PassThroughCalculator" + input_stream: "input" + output_stream: "nestedinternalexecutorssubgraph__bar_0" + executor: "xyz" + } + input_stream: "input" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + executor { + name: "abc" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + )pb"); + MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(supergraph)); +} + } // namespace } // namespace mediapipe