Support proto3 node_option in api2 graph_builder
PiperOrigin-RevId: 523587324
This commit is contained in:
parent
b6a19ea9e8
commit
c9aa24b0e7
|
@ -17,8 +17,11 @@ cc_library(
|
|||
":port",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework/port:any_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +38,8 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
|
||||
"//mediapipe/framework/testdata:sky_light_calculator_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/protobuf/message_lite.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/port/any_proto.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
@ -380,20 +387,47 @@ class NodeBase {
|
|||
|
||||
SideDestination<> SideIn(int index) { return SideIn("")[index]; }
|
||||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
// Get mutable node options of type Options.
|
||||
template <
|
||||
typename OptionsT,
|
||||
typename std::enable_if<std::is_base_of<
|
||||
google::protobuf::MessageLite, OptionsT>::value>::type* = nullptr>
|
||||
OptionsT& GetOptions() {
|
||||
return GetOptionsInternal<OptionsT>(nullptr);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(extension);
|
||||
template <typename ExtensionT>
|
||||
auto& GetOptions(const ExtensionT& ext) {
|
||||
if (!calculator_option_.has_value()) {
|
||||
calculator_option_ = CalculatorOptions();
|
||||
}
|
||||
return *calculator_option_->MutableExtension(ext);
|
||||
}
|
||||
|
||||
protected:
|
||||
// GetOptionsInternal resolutes the overload greedily, which finds the first
|
||||
// match then succeed (template specialization tries all matches, thus could
|
||||
// be ambiguous)
|
||||
template <typename OptionsT>
|
||||
OptionsT& GetOptionsInternal(decltype(&OptionsT::ext) /*unused*/) {
|
||||
return GetOptions(OptionsT::ext);
|
||||
}
|
||||
template <typename OptionsT>
|
||||
OptionsT& GetOptionsInternal(...) {
|
||||
if (node_options_.count(kTypeId<OptionsT>)) {
|
||||
return *static_cast<OptionsT*>(
|
||||
node_options_[kTypeId<OptionsT>].message.get());
|
||||
}
|
||||
auto option = std::make_unique<OptionsT>();
|
||||
OptionsT* option_ptr = option.get();
|
||||
node_options_[kTypeId<OptionsT>] = {
|
||||
std::move(option),
|
||||
[option_ptr](protobuf::Any& any) { return any.PackFrom(*option_ptr); }};
|
||||
return *option_ptr;
|
||||
}
|
||||
|
||||
NodeBase(std::string type) : type_(std::move(type)) {}
|
||||
|
||||
std::string type_;
|
||||
|
@ -401,9 +435,14 @@ class NodeBase {
|
|||
TagIndexMap<SourceBase> out_streams_;
|
||||
TagIndexMap<DestinationBase> in_sides_;
|
||||
TagIndexMap<SourceBase> out_sides_;
|
||||
CalculatorOptions options_;
|
||||
// ideally we'd just check if any extensions are set on options_
|
||||
bool options_used_ = false;
|
||||
std::optional<CalculatorOptions> calculator_option_;
|
||||
// Stores real proto config, and lambda for packing config into Any.
|
||||
// We need the lambda because PackFrom() does not work with MessageLite.
|
||||
struct MessageAndPacker {
|
||||
std::unique_ptr<google::protobuf::MessageLite> message;
|
||||
std::function<bool(protobuf::Any&)> packer;
|
||||
};
|
||||
std::map<TypeId, MessageAndPacker> node_options_;
|
||||
friend class Graph;
|
||||
};
|
||||
|
||||
|
@ -749,8 +788,11 @@ class Graph {
|
|||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
|
||||
});
|
||||
if (node.options_used_) {
|
||||
*config->mutable_options() = node.options_;
|
||||
if (node.calculator_option_.has_value()) {
|
||||
*config->mutable_options() = *node.calculator_option_;
|
||||
}
|
||||
for (auto& [type_id, message_and_packer] : node.node_options_) {
|
||||
RET_CHECK(message_and_packer.packer(*config->add_node_options()));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
|
||||
#include "mediapipe/framework/testdata/sky_light_calculator.pb.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
@ -612,5 +614,134 @@ TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) {
|
|||
EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>());
|
||||
}
|
||||
|
||||
TEST(GetOptionsTest, AddProto3Options) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
|
||||
base >> foo.In("BASE");
|
||||
side >> foo.SideIn("SIDE");
|
||||
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
foo_out >> bar.In("IN");
|
||||
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "IN:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:__stream_0"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(GetOptionsTest, AddProto2Options) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
|
||||
base >> foo.In("BASE");
|
||||
side >> foo.SideIn("SIDE");
|
||||
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
foo_out >> bar.In("IN");
|
||||
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "IN:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:__stream_0"
|
||||
options {
|
||||
[mediapipe.NightLightCalculatorOptions.ext] {}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(GetOptionsTest, AddBothProto23Options) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
|
||||
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
|
||||
base >> foo.In("BASE");
|
||||
side >> foo.SideIn("SIDE");
|
||||
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
foo_out >> bar.In("IN");
|
||||
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "IN:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:out"
|
||||
node {
|
||||
calculator: "Foo"
|
||||
input_stream: "BASE:base"
|
||||
input_side_packet: "SIDE:side"
|
||||
output_stream: "OUT:__stream_0"
|
||||
options {
|
||||
[mediapipe.NightLightCalculatorOptions.ext] {}
|
||||
}
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "Bar"
|
||||
input_stream: "IN:__stream_0"
|
||||
output_stream: "OUT:out"
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
|
Loading…
Reference in New Issue
Block a user