Enable defining and using internal executor within a subgraph
PiperOrigin-RevId: 522449982
This commit is contained in:
parent
81a405af1b
commit
38f838513a
|
@ -1670,6 +1670,7 @@ cc_test(
|
||||||
name = "subgraph_test",
|
name = "subgraph_test",
|
||||||
srcs = ["subgraph_test.cc"],
|
srcs = ["subgraph_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":calculator_cc_proto",
|
||||||
":calculator_framework",
|
":calculator_framework",
|
||||||
":graph_service_manager",
|
":graph_service_manager",
|
||||||
":subgraph",
|
":subgraph",
|
||||||
|
@ -1681,6 +1682,7 @@ cc_test(
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/tool:sink",
|
"//mediapipe/framework/tool:sink",
|
||||||
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
|
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,8 +16,10 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/graph_service_manager.h"
|
#include "mediapipe/framework/graph_service_manager.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
@ -188,5 +190,52 @@ TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) {
|
||||||
EXPECT_EQ(packet.value().Get<std::string>(), "test");
|
EXPECT_EQ(packet.value().Get<std::string>(), "test");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class SubgraphUsingInternalExecutor : public Subgraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
mediapipe::SubgraphContext* sc) override {
|
||||||
|
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
R"pb(
|
||||||
|
output_side_packet: "string"
|
||||||
|
executor {
|
||||||
|
name: "xyz"
|
||||||
|
type: "ThreadPoolExecutor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
executor: "xyz"
|
||||||
|
output_side_packet: "PACKET:string"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet { string_value: "passed" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(SubgraphUsingInternalExecutor);
|
||||||
|
|
||||||
|
TEST(SubgraphExecutorTest, SubgraphCanDefineAndUseInternalExecutor) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
output_side_packet: "str"
|
||||||
|
node {
|
||||||
|
calculator: "SubgraphUsingInternalExecutor"
|
||||||
|
output_side_packet: "str"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
auto packet = graph.GetOutputSidePacket("str");
|
||||||
|
MP_ASSERT_OK(packet);
|
||||||
|
EXPECT_EQ(packet->Get<std::string>(), "passed");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -319,6 +319,9 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
|
||||||
subgraph.status_handler().end(),
|
subgraph.status_handler().end(),
|
||||||
proto_ns::RepeatedPtrFieldBackInserter(
|
proto_ns::RepeatedPtrFieldBackInserter(
|
||||||
config->mutable_status_handler()));
|
config->mutable_status_handler()));
|
||||||
|
std::copy(
|
||||||
|
subgraph.executor().begin(), subgraph.executor().end(),
|
||||||
|
proto_ns::RepeatedPtrFieldBackInserter(config->mutable_executor()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -527,6 +527,64 @@ TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) {
|
||||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A subgraph that defines and uses an internal executor.
|
||||||
|
class NodeWithInternalExecutorSubgraph : public Subgraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
const SubgraphOptions& options) override {
|
||||||
|
return mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "IN:foo"
|
||||||
|
output_stream: "OUT:bar"
|
||||||
|
executor {
|
||||||
|
name: "xyz"
|
||||||
|
type: "ThreadPoolExecutor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughCalculator"
|
||||||
|
executor: "xyz"
|
||||||
|
input_stream: "foo"
|
||||||
|
output_stream: "bar"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(NodeWithInternalExecutorSubgraph);
|
||||||
|
|
||||||
|
TEST(SubgraphExpansionTest, ExecutorCanDefinedAndUsedWithinSubgraph) {
|
||||||
|
CalculatorGraphConfig supergraph =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "NodeWithInternalExecutorSubgraph"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "OUT:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraphConfig expected_graph =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
executor {
|
||||||
|
name: "xyz"
|
||||||
|
type: "ThreadPoolExecutor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughCalculator"
|
||||||
|
name: "nodewithinternalexecutorsubgraph__PassThroughCalculator"
|
||||||
|
input_stream: "input"
|
||||||
|
output_stream: "output"
|
||||||
|
executor: "xyz"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||||
|
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||||
|
}
|
||||||
|
|
||||||
const mediapipe::GraphService<std::string> kStringTestService{
|
const mediapipe::GraphService<std::string> kStringTestService{
|
||||||
"mediapipe::StringTestService"};
|
"mediapipe::StringTestService"};
|
||||||
class GraphServicesClientTestSubgraph : public Subgraph {
|
class GraphServicesClientTestSubgraph : public Subgraph {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user