diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19fa6f469..6353a113e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1670,6 +1670,7 @@ cc_test( name = "subgraph_test", srcs = ["subgraph_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service_manager", ":subgraph", @@ -1681,6 +1682,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], ) diff --git a/mediapipe/framework/subgraph_test.cc b/mediapipe/framework/subgraph_test.cc index 2789aa683..df5554850 100644 --- a/mediapipe/framework/subgraph_test.cc +++ b/mediapipe/framework/subgraph_test.cc @@ -16,8 +16,10 @@ #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/port/gmock.h" @@ -188,5 +190,52 @@ TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) { EXPECT_EQ(packet.value().Get(), "test"); } +class SubgraphUsingInternalExecutor : public Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + return mediapipe::ParseTextProtoOrDie( + R"pb( + output_side_packet: "string" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "ConstantSidePacketCalculator" + executor: "xyz" + output_side_packet: "PACKET:string" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "passed" } + } + } + } + )pb"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(SubgraphUsingInternalExecutor); + +TEST(SubgraphExecutorTest, SubgraphCanDefineAndUseInternalExecutor) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + output_side_packet: "str" + node { + calculator: "SubgraphUsingInternalExecutor" + output_side_packet: "str" + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilDone()); + auto packet = graph.GetOutputSidePacket("str"); + MP_ASSERT_OK(packet); + EXPECT_EQ(packet->Get(), "passed"); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index dcd055f59..1f9e4a535 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -319,6 +319,9 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, subgraph.status_handler().end(), proto_ns::RepeatedPtrFieldBackInserter( config->mutable_status_handler())); + std::copy( + subgraph.executor().begin(), subgraph.executor().end(), + proto_ns::RepeatedPtrFieldBackInserter(config->mutable_executor())); } } return absl::OkStatus(); diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index b4f58a42e..3c9ff3574 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -527,6 +527,64 @@ TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) { EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); } +// A subgraph that defines and uses an internal executor. +class NodeWithInternalExecutorSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + const SubgraphOptions& options) override { + return mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "IN:foo" + output_stream: "OUT:bar" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "PassThroughCalculator" + executor: "xyz" + input_stream: "foo" + output_stream: "bar" + } + )pb"); + } +}; +REGISTER_MEDIAPIPE_GRAPH(NodeWithInternalExecutorSubgraph); + +TEST(SubgraphExpansionTest, ExecutorCanDefinedAndUsedWithinSubgraph) { + CalculatorGraphConfig supergraph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "NodeWithInternalExecutorSubgraph" + input_stream: "IN:input" + output_stream: "OUT:output" + } + )pb"); + CalculatorGraphConfig expected_graph = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input" + executor { + name: "xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + node { + calculator: "PassThroughCalculator" + name: "nodewithinternalexecutorsubgraph__PassThroughCalculator" + input_stream: "input" + output_stream: "output" + executor: "xyz" + } + )pb"); + MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); +} + const mediapipe::GraphService kStringTestService{ "mediapipe::StringTestService"}; class GraphServicesClientTestSubgraph : public Subgraph {