Support proto3 node_option in api2 graph_builder
PiperOrigin-RevId: 523587324
This commit is contained in:
parent
b6a19ea9e8
commit
c9aa24b0e7
|
@ -17,8 +17,11 @@ cc_library(
|
||||||
":port",
|
":port",
|
||||||
"//mediapipe/framework:calculator_base",
|
"//mediapipe/framework:calculator_base",
|
||||||
"//mediapipe/framework:calculator_contract",
|
"//mediapipe/framework:calculator_contract",
|
||||||
|
"//mediapipe/framework/port:any_proto",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@com_google_absl//absl/container:btree",
|
"@com_google_absl//absl/container:btree",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_protobuf//:protobuf",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +38,8 @@ cc_test(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/testdata:sky_light_calculator_cc_proto",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,16 +1,23 @@
|
||||||
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||||
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/btree_map.h"
|
#include "absl/container/btree_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "google/protobuf/message_lite.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_base.h"
|
#include "mediapipe/framework/calculator_base.h"
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
|
#include "mediapipe/framework/port/any_proto.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
@ -380,20 +387,47 @@ class NodeBase {
|
||||||
|
|
||||||
SideDestination<> SideIn(int index) { return SideIn("")[index]; }
|
SideDestination<> SideIn(int index) { return SideIn("")[index]; }
|
||||||
|
|
||||||
template <typename T>
|
// Get mutable node options of type Options.
|
||||||
T& GetOptions() {
|
template <
|
||||||
return GetOptions(T::ext);
|
typename OptionsT,
|
||||||
|
typename std::enable_if<std::is_base_of<
|
||||||
|
google::protobuf::MessageLite, OptionsT>::value>::type* = nullptr>
|
||||||
|
OptionsT& GetOptions() {
|
||||||
|
return GetOptionsInternal<OptionsT>(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this API when the proto extension does not follow the "ext" naming
|
// Use this API when the proto extension does not follow the "ext" naming
|
||||||
// convention.
|
// convention.
|
||||||
template <typename E>
|
template <typename ExtensionT>
|
||||||
auto& GetOptions(const E& extension) {
|
auto& GetOptions(const ExtensionT& ext) {
|
||||||
options_used_ = true;
|
if (!calculator_option_.has_value()) {
|
||||||
return *options_.MutableExtension(extension);
|
calculator_option_ = CalculatorOptions();
|
||||||
|
}
|
||||||
|
return *calculator_option_->MutableExtension(ext);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// GetOptionsInternal resolutes the overload greedily, which finds the first
|
||||||
|
// match then succeed (template specialization tries all matches, thus could
|
||||||
|
// be ambiguous)
|
||||||
|
template <typename OptionsT>
|
||||||
|
OptionsT& GetOptionsInternal(decltype(&OptionsT::ext) /*unused*/) {
|
||||||
|
return GetOptions(OptionsT::ext);
|
||||||
|
}
|
||||||
|
template <typename OptionsT>
|
||||||
|
OptionsT& GetOptionsInternal(...) {
|
||||||
|
if (node_options_.count(kTypeId<OptionsT>)) {
|
||||||
|
return *static_cast<OptionsT*>(
|
||||||
|
node_options_[kTypeId<OptionsT>].message.get());
|
||||||
|
}
|
||||||
|
auto option = std::make_unique<OptionsT>();
|
||||||
|
OptionsT* option_ptr = option.get();
|
||||||
|
node_options_[kTypeId<OptionsT>] = {
|
||||||
|
std::move(option),
|
||||||
|
[option_ptr](protobuf::Any& any) { return any.PackFrom(*option_ptr); }};
|
||||||
|
return *option_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
NodeBase(std::string type) : type_(std::move(type)) {}
|
NodeBase(std::string type) : type_(std::move(type)) {}
|
||||||
|
|
||||||
std::string type_;
|
std::string type_;
|
||||||
|
@ -401,9 +435,14 @@ class NodeBase {
|
||||||
TagIndexMap<SourceBase> out_streams_;
|
TagIndexMap<SourceBase> out_streams_;
|
||||||
TagIndexMap<DestinationBase> in_sides_;
|
TagIndexMap<DestinationBase> in_sides_;
|
||||||
TagIndexMap<SourceBase> out_sides_;
|
TagIndexMap<SourceBase> out_sides_;
|
||||||
CalculatorOptions options_;
|
std::optional<CalculatorOptions> calculator_option_;
|
||||||
// ideally we'd just check if any extensions are set on options_
|
// Stores real proto config, and lambda for packing config into Any.
|
||||||
bool options_used_ = false;
|
// We need the lambda because PackFrom() does not work with MessageLite.
|
||||||
|
struct MessageAndPacker {
|
||||||
|
std::unique_ptr<google::protobuf::MessageLite> message;
|
||||||
|
std::function<bool(protobuf::Any&)> packer;
|
||||||
|
};
|
||||||
|
std::map<TypeId, MessageAndPacker> node_options_;
|
||||||
friend class Graph;
|
friend class Graph;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -749,8 +788,11 @@ class Graph {
|
||||||
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
[&](const TagIndexLocation& loc, const SourceBase& endpoint) {
|
||||||
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
|
config->add_output_side_packet(TaggedName(loc, endpoint.name_));
|
||||||
});
|
});
|
||||||
if (node.options_used_) {
|
if (node.calculator_option_.has_value()) {
|
||||||
*config->mutable_options() = node.options_;
|
*config->mutable_options() = *node.calculator_option_;
|
||||||
|
}
|
||||||
|
for (auto& [type_id, message_and_packer] : node.node_options_) {
|
||||||
|
RET_CHECK(message_and_packer.packer(*config->add_node_options()));
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/testdata/sky_light_calculator.pb.h"
|
||||||
|
|
||||||
namespace mediapipe::api2::builder {
|
namespace mediapipe::api2::builder {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -612,5 +614,134 @@ TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) {
|
||||||
EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>());
|
EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GetOptionsTest, AddProto3Options) {
|
||||||
|
Graph graph;
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||||
|
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||||
|
|
||||||
|
auto& foo = graph.AddNode("Foo");
|
||||||
|
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
|
||||||
|
base >> foo.In("BASE");
|
||||||
|
side >> foo.SideIn("SIDE");
|
||||||
|
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||||
|
|
||||||
|
auto& bar = graph.AddNode("Bar");
|
||||||
|
foo_out >> bar.In("IN");
|
||||||
|
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
bar_out.SetName("out") >> graph.Out("OUT");
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "IN:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
node {
|
||||||
|
calculator: "Foo"
|
||||||
|
input_stream: "BASE:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:__stream_0"
|
||||||
|
node_options {
|
||||||
|
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "Bar"
|
||||||
|
input_stream: "IN:__stream_0"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetOptionsTest, AddProto2Options) {
|
||||||
|
Graph graph;
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||||
|
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||||
|
|
||||||
|
auto& foo = graph.AddNode("Foo");
|
||||||
|
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
|
||||||
|
base >> foo.In("BASE");
|
||||||
|
side >> foo.SideIn("SIDE");
|
||||||
|
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||||
|
|
||||||
|
auto& bar = graph.AddNode("Bar");
|
||||||
|
foo_out >> bar.In("IN");
|
||||||
|
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
bar_out.SetName("out") >> graph.Out("OUT");
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "IN:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
node {
|
||||||
|
calculator: "Foo"
|
||||||
|
input_stream: "BASE:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:__stream_0"
|
||||||
|
options {
|
||||||
|
[mediapipe.NightLightCalculatorOptions.ext] {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "Bar"
|
||||||
|
input_stream: "IN:__stream_0"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetOptionsTest, AddBothProto23Options) {
|
||||||
|
Graph graph;
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||||
|
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||||
|
|
||||||
|
auto& foo = graph.AddNode("Foo");
|
||||||
|
foo.GetOptions<mediapipe::SkyLightCalculatorOptions>();
|
||||||
|
foo.GetOptions<mediapipe::NightLightCalculatorOptions>();
|
||||||
|
base >> foo.In("BASE");
|
||||||
|
side >> foo.SideIn("SIDE");
|
||||||
|
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||||
|
|
||||||
|
auto& bar = graph.AddNode("Bar");
|
||||||
|
foo_out >> bar.In("IN");
|
||||||
|
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
bar_out.SetName("out") >> graph.Out("OUT");
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "IN:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
node {
|
||||||
|
calculator: "Foo"
|
||||||
|
input_stream: "BASE:base"
|
||||||
|
input_side_packet: "SIDE:side"
|
||||||
|
output_stream: "OUT:__stream_0"
|
||||||
|
options {
|
||||||
|
[mediapipe.NightLightCalculatorOptions.ext] {}
|
||||||
|
}
|
||||||
|
node_options {
|
||||||
|
[type.googleapis.com/mediapipe.SkyLightCalculatorOptions] {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "Bar"
|
||||||
|
input_stream: "IN:__stream_0"
|
||||||
|
output_stream: "OUT:out"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe::api2::builder
|
} // namespace mediapipe::api2::builder
|
||||||
|
|
Loading…
Reference in New Issue
Block a user