Cover the existing graph expansion behavior in regard to executors with unit tests.
PiperOrigin-RevId: 523192292
This commit is contained in:
parent
e5f28bc136
commit
02fed0b7d1
|
@ -749,6 +749,7 @@ cc_test(
|
|||
":node_chain_subgraph_cc_proto",
|
||||
":node_chain_subgraph_options_lib",
|
||||
":subgraph_expansion",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:graph_service_manager",
|
||||
|
|
|
@ -37,6 +37,8 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class SimpleTestCalculator : public CalculatorBase {
|
||||
public:
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
|
@ -743,5 +745,222 @@ TEST(SubgraphExpansionTest, SimpleSubgraphOptionsTwice) {
|
|||
EXPECT_THAT(sky_graph, mediapipe::EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
// A subgraph that defines and uses an internal executor with name "xyz".
|
||||
class InternalExecutorSubgraph : public Subgraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
const SubgraphOptions& options) override {
|
||||
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "IN:foo"
|
||||
output_stream: "OUT:bar"
|
||||
executor {
|
||||
name: "xyz"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughCalculator"
|
||||
executor: "xyz"
|
||||
input_stream: "foo"
|
||||
output_stream: "bar"
|
||||
}
|
||||
)pb");
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(InternalExecutorSubgraph);
|
||||
|
||||
// This test confirms that none of existing subgraphs can actually create an
|
||||
// executor when used as subgraphs and not like a final graph.
|
||||
TEST(SubgraphExpansionTest, SubgraphExecutorIsIgnored) {
|
||||
CalculatorGraphConfig supergraph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input"
|
||||
node {
|
||||
calculator: "InternalExecutorSubgraph"
|
||||
input_stream: "IN:input"
|
||||
output_stream: "OUT:output"
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input"
|
||||
node {
|
||||
name: "internalexecutorsubgraph__PassThroughCalculator"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "output"
|
||||
executor: "xyz"
|
||||
}
|
||||
)pb");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
EXPECT_THAT(calculator_graph.Initialize(supergraph),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("The executor \"xyz\" is "
|
||||
"not declared in an ExecutorConfig.")));
|
||||
}
|
||||
|
||||
class NestedInternalExecutorsSubgraph : public Subgraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
const SubgraphOptions& options) override {
|
||||
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "IN:foo"
|
||||
output_stream: "OUT:bar"
|
||||
node {
|
||||
calculator: "InternalExecutorSubgraph"
|
||||
input_stream: "IN:foo"
|
||||
output_stream: "OUT:bar_0"
|
||||
}
|
||||
executor {
|
||||
name: "xyz"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughCalculator"
|
||||
executor: "xyz"
|
||||
input_stream: "bar_0"
|
||||
output_stream: "bar_1"
|
||||
}
|
||||
executor {
|
||||
name: "abc"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughCalculator"
|
||||
executor: "abc"
|
||||
input_stream: "bar_1"
|
||||
output_stream: "bar"
|
||||
}
|
||||
)pb");
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(NestedInternalExecutorsSubgraph);
|
||||
|
||||
TEST(SubgraphExpansionTest, NestedSubgraphExecutorsAreIgnored) {
|
||||
CalculatorGraphConfig supergraph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input"
|
||||
node {
|
||||
calculator: "NestedInternalExecutorsSubgraph"
|
||||
input_stream: "IN:input"
|
||||
output_stream: "OUT:output"
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie<
|
||||
CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__PassThroughCalculator_1"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "nestedinternalexecutorssubgraph__bar_0"
|
||||
output_stream: "nestedinternalexecutorssubgraph__bar_1"
|
||||
executor: "xyz"
|
||||
}
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__PassThroughCalculator_2"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "nestedinternalexecutorssubgraph__bar_1"
|
||||
output_stream: "output"
|
||||
executor: "abc"
|
||||
}
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__PassThroughCalculator"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "nestedinternalexecutorssubgraph__bar_0"
|
||||
executor: "xyz"
|
||||
}
|
||||
input_stream: "input"
|
||||
)pb");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
EXPECT_THAT(calculator_graph.Initialize(supergraph),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("The executor \"xyz\" is "
|
||||
"not declared in an ExecutorConfig.")));
|
||||
}
|
||||
|
||||
TEST(SubgraphExpansionTest, GraphExecutorsSubstituteSubgraphExecutors) {
|
||||
CalculatorGraphConfig supergraph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input"
|
||||
executor {
|
||||
name: "xyz"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
executor {
|
||||
name: "abc"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "NestedInternalExecutorsSubgraph"
|
||||
input_stream: "IN:input"
|
||||
output_stream: "OUT:output"
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie<
|
||||
CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__PassThroughCalculator_1"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "nestedinternalexecutorssubgraph__bar_0"
|
||||
output_stream: "nestedinternalexecutorssubgraph__bar_1"
|
||||
executor: "xyz"
|
||||
}
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__PassThroughCalculator_2"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "nestedinternalexecutorssubgraph__bar_1"
|
||||
output_stream: "output"
|
||||
executor: "abc"
|
||||
}
|
||||
node {
|
||||
name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__PassThroughCalculator"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "nestedinternalexecutorssubgraph__bar_0"
|
||||
executor: "xyz"
|
||||
}
|
||||
input_stream: "input"
|
||||
executor {
|
||||
name: "xyz"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
executor {
|
||||
name: "abc"
|
||||
type: "ThreadPoolExecutor"
|
||||
options {
|
||||
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(supergraph));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user