Enable defining and using internal executor within a subgraph
PiperOrigin-RevId: 522449982
This commit is contained in:
		
							parent
							
								
									81a405af1b
								
							
						
					
					
						commit
						38f838513a
					
				| 
						 | 
					@ -1670,6 +1670,7 @@ cc_test(
 | 
				
			||||||
    name = "subgraph_test",
 | 
					    name = "subgraph_test",
 | 
				
			||||||
    srcs = ["subgraph_test.cc"],
 | 
					    srcs = ["subgraph_test.cc"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":calculator_cc_proto",
 | 
				
			||||||
        ":calculator_framework",
 | 
					        ":calculator_framework",
 | 
				
			||||||
        ":graph_service_manager",
 | 
					        ":graph_service_manager",
 | 
				
			||||||
        ":subgraph",
 | 
					        ":subgraph",
 | 
				
			||||||
| 
						 | 
					@ -1681,6 +1682,7 @@ cc_test(
 | 
				
			||||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
					        "//mediapipe/framework/port:parse_text_proto",
 | 
				
			||||||
        "//mediapipe/framework/tool:sink",
 | 
					        "//mediapipe/framework/tool:sink",
 | 
				
			||||||
        "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
 | 
					        "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
        "@com_google_absl//absl/strings:str_format",
 | 
					        "@com_google_absl//absl/strings:str_format",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,8 +16,10 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
#include "absl/strings/str_format.h"
 | 
					#include "absl/strings/str_format.h"
 | 
				
			||||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
 | 
					#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/graph_service_manager.h"
 | 
					#include "mediapipe/framework/graph_service_manager.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gmock.h"
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
| 
						 | 
					@ -188,5 +190,52 @@ TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) {
 | 
				
			||||||
  EXPECT_EQ(packet.value().Get<std::string>(), "test");
 | 
					  EXPECT_EQ(packet.value().Get<std::string>(), "test");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SubgraphUsingInternalExecutor : public Subgraph {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
				
			||||||
 | 
					      mediapipe::SubgraphContext* sc) override {
 | 
				
			||||||
 | 
					    return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
				
			||||||
 | 
					        R"pb(
 | 
				
			||||||
 | 
					          output_side_packet: "string"
 | 
				
			||||||
 | 
					          executor {
 | 
				
			||||||
 | 
					            name: "xyz"
 | 
				
			||||||
 | 
					            type: "ThreadPoolExecutor"
 | 
				
			||||||
 | 
					            options {
 | 
				
			||||||
 | 
					              [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          node {
 | 
				
			||||||
 | 
					            calculator: "ConstantSidePacketCalculator"
 | 
				
			||||||
 | 
					            executor: "xyz"
 | 
				
			||||||
 | 
					            output_side_packet: "PACKET:string"
 | 
				
			||||||
 | 
					            options: {
 | 
				
			||||||
 | 
					              [mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
 | 
				
			||||||
 | 
					                packet { string_value: "passed" }
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        )pb");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					REGISTER_MEDIAPIPE_GRAPH(SubgraphUsingInternalExecutor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SubgraphExecutorTest, SubgraphCanDefineAndUseInternalExecutor) {
 | 
				
			||||||
 | 
					  CalculatorGraphConfig config =
 | 
				
			||||||
 | 
					      mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        output_side_packet: "str"
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "SubgraphUsingInternalExecutor"
 | 
				
			||||||
 | 
					          output_side_packet: "str"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(config));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  auto packet = graph.GetOutputSidePacket("str");
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(packet);
 | 
				
			||||||
 | 
					  EXPECT_EQ(packet->Get<std::string>(), "passed");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -319,6 +319,9 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
 | 
				
			||||||
                subgraph.status_handler().end(),
 | 
					                subgraph.status_handler().end(),
 | 
				
			||||||
                proto_ns::RepeatedPtrFieldBackInserter(
 | 
					                proto_ns::RepeatedPtrFieldBackInserter(
 | 
				
			||||||
                    config->mutable_status_handler()));
 | 
					                    config->mutable_status_handler()));
 | 
				
			||||||
 | 
					      std::copy(
 | 
				
			||||||
 | 
					          subgraph.executor().begin(), subgraph.executor().end(),
 | 
				
			||||||
 | 
					          proto_ns::RepeatedPtrFieldBackInserter(config->mutable_executor()));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -527,6 +527,64 @@ TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) {
 | 
				
			||||||
  EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
 | 
					  EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// A subgraph that defines and uses an internal executor.
 | 
				
			||||||
 | 
					class NodeWithInternalExecutorSubgraph : public Subgraph {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
				
			||||||
 | 
					      const SubgraphOptions& options) override {
 | 
				
			||||||
 | 
					    return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					      input_stream: "IN:foo"
 | 
				
			||||||
 | 
					      output_stream: "OUT:bar"
 | 
				
			||||||
 | 
					      executor {
 | 
				
			||||||
 | 
					        name: "xyz"
 | 
				
			||||||
 | 
					        type: "ThreadPoolExecutor"
 | 
				
			||||||
 | 
					        options {
 | 
				
			||||||
 | 
					          [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      node {
 | 
				
			||||||
 | 
					        calculator: "PassThroughCalculator"
 | 
				
			||||||
 | 
					        executor: "xyz"
 | 
				
			||||||
 | 
					        input_stream: "foo"
 | 
				
			||||||
 | 
					        output_stream: "bar"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    )pb");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					REGISTER_MEDIAPIPE_GRAPH(NodeWithInternalExecutorSubgraph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SubgraphExpansionTest, ExecutorCanDefinedAndUsedWithinSubgraph) {
 | 
				
			||||||
 | 
					  CalculatorGraphConfig supergraph =
 | 
				
			||||||
 | 
					      mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        input_stream: "input"
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "NodeWithInternalExecutorSubgraph"
 | 
				
			||||||
 | 
					          input_stream: "IN:input"
 | 
				
			||||||
 | 
					          output_stream: "OUT:output"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  CalculatorGraphConfig expected_graph =
 | 
				
			||||||
 | 
					      mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        input_stream: "input"
 | 
				
			||||||
 | 
					        executor {
 | 
				
			||||||
 | 
					          name: "xyz"
 | 
				
			||||||
 | 
					          type: "ThreadPoolExecutor"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "PassThroughCalculator"
 | 
				
			||||||
 | 
					          name: "nodewithinternalexecutorsubgraph__PassThroughCalculator"
 | 
				
			||||||
 | 
					          input_stream: "input"
 | 
				
			||||||
 | 
					          output_stream: "output"
 | 
				
			||||||
 | 
					          executor: "xyz"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
 | 
				
			||||||
 | 
					  EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const mediapipe::GraphService<std::string> kStringTestService{
 | 
					const mediapipe::GraphService<std::string> kStringTestService{
 | 
				
			||||||
    "mediapipe::StringTestService"};
 | 
					    "mediapipe::StringTestService"};
 | 
				
			||||||
class GraphServicesClientTestSubgraph : public Subgraph {
 | 
					class GraphServicesClientTestSubgraph : public Subgraph {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user