diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index b13dba9b9..8899c89fc 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -427,7 +427,6 @@ cc_library( ":tag_map", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:packet_generator", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:status_handler_cc_proto", @@ -437,8 +436,12 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -782,6 +785,7 @@ cc_test( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", "//mediapipe/framework/tool/testdata:nested_test_subgraph", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index dcd055f59..a05aef894 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -23,8 +23,13 @@ #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/port.h" @@ -123,6 +128,19 @@ absl::Status TransformNames( MP_RETURN_IF_ERROR(TransformStreamNames( status_handler.mutable_input_side_packet(), transform)); } + // Prefix executor names, but only those defined in the current graph. + absl::flat_hash_set local_executor_names; + for (auto& executor : *config->mutable_executor()) { + if (!executor.name().empty()) { + local_executor_names.insert(executor.name()); + *executor.mutable_name() = transform(executor.name()); + } + } + for (auto& node : *config->mutable_node()) { + if (local_executor_names.contains(node.executor())) { + *node.mutable_executor() = transform(node.executor()); + } + } return absl::OkStatus(); } @@ -273,6 +291,41 @@ absl::Status ConnectSubgraphStreams( return absl::OkStatus(); } +absl::Status RemoveDuplicateExecutors( + const absl::flat_hash_set& seen_executors, + CalculatorGraphConfig* config) { + auto* mutable_executors = config->mutable_executor(); + auto unique_executors_it = std::remove_if( + mutable_executors->begin(), mutable_executors->end(), + [&seen_executors](const mediapipe::ExecutorConfig& executor_config) { + bool is_duplicate = seen_executors.contains(executor_config.name()); + // This can happen in the following situation: you define an + // executor at the top-level-graph and one or more of your + // subgraphs declare executors with the same name as well. + // + // Historically, executors defined in subgraphs were ignored + // (unless you use your subgraph as a top-level-graph). + // + // Now executors can be defined in subgraphs (their names are + // automatically updated to be prefixed with subgraph name). To be + // backward compatible, MediaPipe will ignore (remove) executors + // defined in subgraphs if they have the same names as one of + // top-level-graph defined executors. + // + // NOTE: If you see this warning, you may want to verify if you + // actually use the same executors and consider removing one or + // another. + if (is_duplicate) { + ABSL_LOG(WARNING) << absl::StrFormat( + "Removing a duplicate of top-level-graph executor: %s", + executor_config.name()); + } + return is_duplicate; + }); + mutable_executors->erase(unique_executors_it, mutable_executors->end()); + return absl::OkStatus(); +} + absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, const GraphRegistry* graph_registry, const Subgraph::SubgraphOptions* graph_options, @@ -283,6 +336,12 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions( graph_options ? *graph_options : CalculatorGraphConfig::Node(), config)); + + absl::flat_hash_set seen_executors; + for (int i = 0; i < config->executor_size(); ++i) { + seen_executors.insert(config->executor(i).name()); + } + auto* nodes = config->mutable_node(); while (1) { auto subgraph_nodes_start = std::stable_partition( @@ -303,6 +362,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, config->package(), node.calculator(), &subgraph_context)); MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(node, &subgraph)); + MP_RETURN_IF_ERROR(RemoveDuplicateExecutors(seen_executors, &subgraph)); MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); subgraphs.push_back(subgraph); @@ -319,6 +379,9 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, subgraph.status_handler().end(), proto_ns::RepeatedPtrFieldBackInserter( config->mutable_status_handler())); + std::copy( + subgraph.executor().begin(), subgraph.executor().end(), + proto_ns::RepeatedPtrFieldBackInserter(config->mutable_executor())); } } return absl::OkStatus(); diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index b6d9950a1..f6988c56a 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -15,6 +15,7 @@ #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -771,9 +772,7 @@ class InternalExecutorSubgraph : public Subgraph { }; REGISTER_MEDIAPIPE_GRAPH(InternalExecutorSubgraph); -// This test confirms that none of existing subgraphs can actually create an -// executor when used as subgraphs and not like a final graph. -TEST(SubgraphExpansionTest, SubgraphExecutorIsIgnored) { +TEST(SubgraphExpansionTest, SubgraphExecutorWorks) { CalculatorGraphConfig supergraph = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "input" @@ -785,23 +784,27 @@ TEST(SubgraphExpansionTest, SubgraphExecutorIsIgnored) { )pb"); CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie(R"pb( - input_stream: "input" node { name: "internalexecutorsubgraph__PassThroughCalculator" calculator: "PassThroughCalculator" input_stream: "input" output_stream: "output" - executor: "xyz" + executor: "internalexecutorsubgraph__xyz" + } + input_stream: "input" + executor { + name: "internalexecutorsubgraph__xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } } )pb"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); CalculatorGraph calculator_graph; - EXPECT_THAT(calculator_graph.Initialize(supergraph), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("The executor \"xyz\" is " - "not declared in an ExecutorConfig."))); + MP_EXPECT_OK(calculator_graph.Initialize(supergraph)); } class NestedInternalExecutorsSubgraph : public Subgraph { @@ -847,7 +850,7 @@ class NestedInternalExecutorsSubgraph : public Subgraph { }; REGISTER_MEDIAPIPE_GRAPH(NestedInternalExecutorsSubgraph); -TEST(SubgraphExpansionTest, NestedSubgraphExecutorsAreIgnored) { +TEST(SubgraphExpansionTest, NestedSubgraphExecutorsWork) { CalculatorGraphConfig supergraph = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "input" @@ -864,35 +867,55 @@ TEST(SubgraphExpansionTest, NestedSubgraphExecutorsAreIgnored) { calculator: "PassThroughCalculator" input_stream: "nestedinternalexecutorssubgraph__bar_0" output_stream: "nestedinternalexecutorssubgraph__bar_1" - executor: "xyz" + executor: "nestedinternalexecutorssubgraph__xyz" } node { name: "nestedinternalexecutorssubgraph__PassThroughCalculator_2" calculator: "PassThroughCalculator" input_stream: "nestedinternalexecutorssubgraph__bar_1" output_stream: "output" - executor: "abc" + executor: "nestedinternalexecutorssubgraph__abc" } node { name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__PassThroughCalculator" calculator: "PassThroughCalculator" input_stream: "input" output_stream: "nestedinternalexecutorssubgraph__bar_0" - executor: "xyz" + executor: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__xyz" } input_stream: "input" + executor { + name: "nestedinternalexecutorssubgraph__xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + executor { + name: "nestedinternalexecutorssubgraph__abc" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } + executor { + name: "nestedinternalexecutorssubgraph__internalexecutorsubgraph__xyz" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } + } + } )pb"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); CalculatorGraph calculator_graph; - EXPECT_THAT(calculator_graph.Initialize(supergraph), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("The executor \"xyz\" is " - "not declared in an ExecutorConfig."))); + MP_EXPECT_OK(calculator_graph.Initialize(supergraph)); } -TEST(SubgraphExpansionTest, GraphExecutorsSubstituteSubgraphExecutors) { +// For backward compatibility. +TEST(SubgraphExpansionTest, + TopLevelGraphExecutorsCauseSameNamedSubgraphExecutorsToBeRemoved) { CalculatorGraphConfig supergraph = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "input"