Support proto3 node_option in api2 graph_builder

PiperOrigin-RevId: 523587324
This commit is contained in:
MediaPipe Team 2023-04-11 20:47:03 -07:00 committed by Copybara-Service
parent b6a19ea9e8
commit c9aa24b0e7
3 changed files with 190 additions and 12 deletions

View File

@ -17,8 +17,11 @@ cc_library(
":port", ":port",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:btree",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
], ],
) )
@ -35,6 +38,8 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
"//mediapipe/framework/testdata:sky_light_calculator_cc_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -1,16 +1,23 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ #ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ #define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
#include <functional>
#include <map>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <vector> #include <vector>
#include "absl/container/btree_map.h" #include "absl/container/btree_map.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -380,20 +387,47 @@ class NodeBase {
SideDestination<> SideIn(int index) { return SideIn("")[index]; } SideDestination<> SideIn(int index) { return SideIn("")[index]; }
template <typename T> // Get mutable node options of type Options.
T& GetOptions() { template <
return GetOptions(T::ext); typename OptionsT,
typename std::enable_if<std::is_base_of<
google::protobuf::MessageLite, OptionsT>::value>::type* = nullptr>
OptionsT& GetOptions() {
return GetOptionsInternal<OptionsT>(nullptr);
} }
// Use this API when the proto extension does not follow the "ext" naming // Use this API when the proto extension does not follow the "ext" naming
// convention. // convention.
template <typename E> template <typename ExtensionT>
auto& GetOptions(const E& extension) { auto& GetOptions(const ExtensionT& ext) {
options_used_ = true; if (!calculator_option_.has_value()) {
return *options_.MutableExtension(extension); calculator_option_ = CalculatorOptions();
}
return *calculator_option_->MutableExtension(ext);
} }
protected: protected:
// GetOptionsInternal resolutes the overload greedily, which finds the first
// match then succeed (template specialization tries all matches, thus could
// be ambiguous)
template <typename OptionsT>
OptionsT& GetOptionsInternal(decltype(&OptionsT::ext) /*unused*/) {
return GetOptions(OptionsT::ext);
}
template <typename OptionsT>
OptionsT& GetOptionsInternal(...) {
if (node_options_.count(kTypeId<OptionsT>)) {
return *static_cast<OptionsT*>(
node_options_[kTypeId<OptionsT>].message.get());
}
auto option = std::make_unique<OptionsT>();
OptionsT* option_ptr = option.get();
node_options_[kTypeId<OptionsT>] = {
std::move(option),
[option_ptr](protobuf::Any& any) { return any.PackFrom(*option_ptr); }};
return *option_ptr;
}
NodeBase(std::string type) : type_(std::move(type)) {} NodeBase(std::string type) : type_(std::move(type)) {}
std::string type_; std::string type_;
@ -401,9 +435,14 @@ class NodeBase {
TagIndexMap<SourceBase> out_streams_; TagIndexMap<SourceBase> out_streams_;
TagIndexMap<DestinationBase> in_sides_; TagIndexMap<DestinationBase> in_sides_;
TagIndexMap<SourceBase> out_sides_; TagIndexMap<SourceBase> out_sides_;
CalculatorOptions options_; std::optional<CalculatorOptions> calculator_option_;
// ideally we'd just check if any extensions are set on options_ // Stores real proto config, and lambda for packing config into Any.
bool options_used_ = false; // We need the lambda because PackFrom() does not work with MessageLite.
struct MessageAndPacker {
std::unique_ptr<google::protobuf::MessageLite> message;
std::function<bool(protobuf::Any&)> packer;
};
std::map<TypeId, MessageAndPacker> node_options_;
friend class Graph; friend class Graph;
}; };
@ -749,8 +788,11 @@ class Graph {
[&](const TagIndexLocation& loc, const SourceBase& endpoint) { [&](const TagIndexLocation& loc, const SourceBase& endpoint) {
config->add_output_side_packet(TaggedName(loc, endpoint.name_)); config->add_output_side_packet(TaggedName(loc, endpoint.name_));
}); });
if (node.options_used_) { if (node.calculator_option_.has_value()) {
*config->mutable_options() = node.options_; *config->mutable_options() = *node.calculator_option_;
}
for (auto& [type_id, message_and_packer] : node.node_options_) {
RET_CHECK(message_and_packer.packer(*config->add_node_options()));
} }
return {}; return {};
} }

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
#include "mediapipe/framework/testdata/sky_light_calculator.pb.h"
namespace mediapipe::api2::builder { namespace mediapipe::api2::builder {
namespace { namespace {
@ -612,5 +614,134 @@ TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) {
EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>()); EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>());
} }
TEST(GetOptionsTest, AddProto3Options) {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
auto& foo = graph.AddNode("Foo");
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
base >> foo.In("BASE");
side >> foo.SideIn("SIDE");
Stream<AnyType> foo_out = foo.Out("OUT");
auto& bar = graph.AddNode("Bar");
foo_out >> bar.In("IN");
Stream<AnyType> bar_out = bar.Out("OUT");
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "IN:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:__stream_0"
node_options {
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
}
}
node {
calculator: "Bar"
input_stream: "IN:__stream_0"
output_stream: "OUT:out"
}
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(GetOptionsTest, AddProto2Options) {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
auto& foo = graph.AddNode("Foo");
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
base >> foo.In("BASE");
side >> foo.SideIn("SIDE");
Stream<AnyType> foo_out = foo.Out("OUT");
auto& bar = graph.AddNode("Bar");
foo_out >> bar.In("IN");
Stream<AnyType> bar_out = bar.Out("OUT");
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "IN:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:__stream_0"
options {
[mediapipe.NightLightCalculatorOptions.ext] {}
}
}
node {
calculator: "Bar"
input_stream: "IN:__stream_0"
output_stream: "OUT:out"
}
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(GetOptionsTest, AddBothProto23Options) {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
auto& foo = graph.AddNode("Foo");
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
base >> foo.In("BASE");
side >> foo.SideIn("SIDE");
Stream<AnyType> foo_out = foo.Out("OUT");
auto& bar = graph.AddNode("Bar");
foo_out >> bar.In("IN");
Stream<AnyType> bar_out = bar.Out("OUT");
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "IN:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:out"
node {
calculator: "Foo"
input_stream: "BASE:base"
input_side_packet: "SIDE:side"
output_stream: "OUT:__stream_0"
options {
[mediapipe.NightLightCalculatorOptions.ext] {}
}
node_options {
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
}
}
node {
calculator: "Bar"
input_stream: "IN:__stream_0"
output_stream: "OUT:out"
}
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace } // namespace
} // namespace mediapipe::api2::builder } // namespace mediapipe::api2::builder