Add functions for checking the existence of options in subgraphs and calculators.
PiperOrigin-RevId: 513689742
This commit is contained in:
parent
91d53cd181
commit
3837c92fd5
|
@ -1677,6 +1677,7 @@ cc_test(
|
||||||
":subgraph",
|
":subgraph",
|
||||||
":test_calculators",
|
":test_calculators",
|
||||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
|
|
@ -69,6 +69,11 @@ class CalculatorContext {
|
||||||
return calculator_state_->Options<T>();
|
return calculator_state_->Options<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool HasOptions() const {
|
||||||
|
return calculator_state_->HasOptions<T>();
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a counter using the graph's counter factory. The counter's name is
|
// Returns a counter using the graph's counter factory. The counter's name is
|
||||||
// the passed-in name, prefixed by the calculator node's name (if present) or
|
// the passed-in name, prefixed by the calculator node's name (if present) or
|
||||||
// the calculator's type (if not).
|
// the calculator's type (if not).
|
||||||
|
|
|
@ -131,6 +131,13 @@ TEST(CalculatorTest, GetOptions) {
|
||||||
auto calculator_state_3 = MakeCalculatorState(config.node(3), 3);
|
auto calculator_state_3 = MakeCalculatorState(config.node(3), 3);
|
||||||
auto cc_3 = MakeCalculatorContext(&*calculator_state_3);
|
auto cc_3 = MakeCalculatorContext(&*calculator_state_3);
|
||||||
|
|
||||||
|
EXPECT_TRUE(cc_0->HasOptions<NightLightCalculatorOptions>());
|
||||||
|
EXPECT_FALSE(cc_0->HasOptions<SkyLightCalculatorOptions>());
|
||||||
|
EXPECT_TRUE(cc_1->HasOptions<NightLightCalculatorOptions>());
|
||||||
|
EXPECT_FALSE(cc_1->HasOptions<SkyLightCalculatorOptions>());
|
||||||
|
EXPECT_FALSE(cc_3->HasOptions<NightLightCalculatorOptions>());
|
||||||
|
EXPECT_TRUE(cc_3->HasOptions<SkyLightCalculatorOptions>());
|
||||||
|
|
||||||
// Get a google::protobuf options extension from Node::options.
|
// Get a google::protobuf options extension from Node::options.
|
||||||
EXPECT_EQ(cc_0->Options<NightLightCalculatorOptions>().jitter(), 0.123);
|
EXPECT_EQ(cc_0->Options<NightLightCalculatorOptions>().jitter(), 0.123);
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,11 @@ class CalculatorContract {
|
||||||
return options_.Get<T>();
|
return options_.Get<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool HasOptions() const {
|
||||||
|
return options_.Has<T>();
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the PacketTypeSet for the input streams.
|
// Returns the PacketTypeSet for the input streams.
|
||||||
PacketTypeSet& Inputs() { return *inputs_; }
|
PacketTypeSet& Inputs() { return *inputs_; }
|
||||||
const PacketTypeSet& Inputs() const { return *inputs_; }
|
const PacketTypeSet& Inputs() const { return *inputs_; }
|
||||||
|
|
|
@ -41,6 +41,7 @@ TEST(CalculatorContractTest, Calculator) {
|
||||||
)pb");
|
)pb");
|
||||||
CalculatorContract contract;
|
CalculatorContract contract;
|
||||||
MP_EXPECT_OK(contract.Initialize(node));
|
MP_EXPECT_OK(contract.Initialize(node));
|
||||||
|
EXPECT_FALSE(contract.HasOptions<CalculatorContractTestOptions>());
|
||||||
EXPECT_EQ(contract.Inputs().NumEntries(), 4);
|
EXPECT_EQ(contract.Inputs().NumEntries(), 4);
|
||||||
EXPECT_EQ(contract.Outputs().NumEntries(), 1);
|
EXPECT_EQ(contract.Outputs().NumEntries(), 1);
|
||||||
EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1);
|
EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1);
|
||||||
|
@ -60,6 +61,7 @@ TEST(CalculatorContractTest, CalculatorOptions) {
|
||||||
})pb");
|
})pb");
|
||||||
CalculatorContract contract;
|
CalculatorContract contract;
|
||||||
MP_EXPECT_OK(contract.Initialize(node));
|
MP_EXPECT_OK(contract.Initialize(node));
|
||||||
|
ASSERT_TRUE(contract.HasOptions<CalculatorContractTestOptions>());
|
||||||
const auto& test_options =
|
const auto& test_options =
|
||||||
contract.Options().GetExtension(CalculatorContractTestOptions::ext);
|
contract.Options().GetExtension(CalculatorContractTestOptions::ext);
|
||||||
EXPECT_EQ(test_options.test_field(), 1.0);
|
EXPECT_EQ(test_options.test_field(), 1.0);
|
||||||
|
|
|
@ -66,6 +66,10 @@ class CalculatorState {
|
||||||
const T& Options() const {
|
const T& Options() const {
|
||||||
return options_.Get<T>();
|
return options_.Get<T>();
|
||||||
}
|
}
|
||||||
|
template <class T>
|
||||||
|
bool HasOptions() const {
|
||||||
|
return options_.Has<T>();
|
||||||
|
}
|
||||||
const std::string& NodeName() const { return node_name_; }
|
const std::string& NodeName() const { return node_name_; }
|
||||||
const int& NodeId() const { return node_id_; }
|
const int& NodeId() const { return node_id_; }
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,11 @@ class SubgraphContext {
|
||||||
return options_map_.GetMutable<T>();
|
return options_map_.GetMutable<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool HasOptions() {
|
||||||
|
return options_map_.Has<T>();
|
||||||
|
}
|
||||||
|
|
||||||
const CalculatorGraphConfig::Node& OriginalNode() const {
|
const CalculatorGraphConfig::Node& OriginalNode() const {
|
||||||
return original_node_;
|
return original_node_;
|
||||||
}
|
}
|
||||||
|
@ -119,6 +124,11 @@ class Subgraph {
|
||||||
return tool::OptionsMap().Initialize(supgraph_options).Get<T>();
|
return tool::OptionsMap().Initialize(supgraph_options).Get<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static bool HasOptions(const Subgraph::SubgraphOptions& supgraph_options) {
|
||||||
|
return tool::OptionsMap().Initialize(supgraph_options).Has<T>();
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the CalculatorGraphConfig::Node specifying the subgraph.
|
// Returns the CalculatorGraphConfig::Node specifying the subgraph.
|
||||||
// This provides to Subgraphs the same graph information that GetContract
|
// This provides to Subgraphs the same graph information that GetContract
|
||||||
// provides to Calculators.
|
// provides to Calculators.
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/graph_service_manager.h"
|
#include "mediapipe/framework/graph_service_manager.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
@ -133,5 +134,59 @@ TEST(SubgraphServicesTest, EmitStringFromTestService) {
|
||||||
EXPECT_EQ(side_string.Get<std::string>(), "Expected STRING");
|
EXPECT_EQ(side_string.Get<std::string>(), "Expected STRING");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class OptionsCheckingSubgraph : public Subgraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
mediapipe::SubgraphContext* sc) override {
|
||||||
|
std::string subgraph_side_packet_val;
|
||||||
|
if (sc->HasOptions<ConstantSidePacketCalculatorOptions>()) {
|
||||||
|
subgraph_side_packet_val =
|
||||||
|
sc->Options<ConstantSidePacketCalculatorOptions>()
|
||||||
|
.packet(0)
|
||||||
|
.string_value();
|
||||||
|
}
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
absl::StrFormat(R"(
|
||||||
|
output_side_packet: "string"
|
||||||
|
node {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
output_side_packet: "PACKET:string"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet { string_value: "%s" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)",
|
||||||
|
subgraph_side_packet_val));
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(OptionsCheckingSubgraph);
|
||||||
|
|
||||||
|
TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
output_side_packet: "str"
|
||||||
|
node {
|
||||||
|
calculator: "OptionsCheckingSubgraph"
|
||||||
|
output_side_packet: "str"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet { string_value: "test" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
auto packet = graph.GetOutputSidePacket("str");
|
||||||
|
MP_ASSERT_OK(packet);
|
||||||
|
EXPECT_EQ(packet.value().Get<std::string>(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -181,6 +181,20 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "options_map_test",
|
||||||
|
srcs = ["options_map_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":options_map",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/testdata:night_light_calculator_options_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "field_data_proto",
|
name = "field_data_proto",
|
||||||
srcs = ["field_data.proto"],
|
srcs = ["field_data.proto"],
|
||||||
|
|
|
@ -28,6 +28,18 @@ struct IsExtension {
|
||||||
static constexpr bool value = (sizeof(test<T>(0)) == sizeof(char));
|
static constexpr bool value = (sizeof(test<T>(0)) == sizeof(char));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class T,
|
||||||
|
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
|
||||||
|
bool HasExtension(const CalculatorOptions& options) {
|
||||||
|
return options.HasExtension(T::ext);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T,
|
||||||
|
typename std::enable_if<!IsExtension<T>::value, int>::type = 0>
|
||||||
|
bool HasExtension(const CalculatorOptions& options) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T,
|
template <class T,
|
||||||
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
|
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
|
||||||
T* GetExtension(CalculatorOptions& options) {
|
T* GetExtension(CalculatorOptions& options) {
|
||||||
|
@ -124,6 +136,27 @@ class OptionsMap {
|
||||||
return *result;
|
return *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool Has() const {
|
||||||
|
if (options_.Has<T>()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (node_config_->has_options()) {
|
||||||
|
return HasExtension<T>(node_config_->options());
|
||||||
|
}
|
||||||
|
#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY)
|
||||||
|
// protobuf::Any is unavailable with third_party/protobuf:protobuf-lite.
|
||||||
|
#else
|
||||||
|
for (const mediapipe::protobuf::Any& options :
|
||||||
|
node_config_->node_options()) {
|
||||||
|
if (options.Is<T>()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
CalculatorGraphConfig::Node* node_config_;
|
CalculatorGraphConfig::Node* node_config_;
|
||||||
TypeMap options_;
|
TypeMap options_;
|
||||||
};
|
};
|
||||||
|
|
88
mediapipe/framework/tool/options_map_test.cc
Normal file
88
mediapipe/framework/tool/options_map_test.cc
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/framework/tool/options_map.h"
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tool {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(OptionsMapTest, QueryNotFound) {
|
||||||
|
CalculatorGraphConfig::Node node =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "NightLightCalculator"
|
||||||
|
input_side_packet: "input_value"
|
||||||
|
output_stream: "values"
|
||||||
|
)pb");
|
||||||
|
OptionsMap options;
|
||||||
|
options.Initialize(node);
|
||||||
|
EXPECT_FALSE(options.Has<mediapipe::NightLightCalculatorOptions>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptionsMapTest, QueryFound) {
|
||||||
|
CalculatorGraphConfig::Node node =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "NightLightCalculator"
|
||||||
|
input_side_packet: "input_value"
|
||||||
|
output_stream: "values"
|
||||||
|
options {
|
||||||
|
[mediapipe.NightLightCalculatorOptions.ext] {
|
||||||
|
base_timestamp: 123
|
||||||
|
output_header: PASS_HEADER
|
||||||
|
jitter: 0.123
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
OptionsMap options;
|
||||||
|
options.Initialize(node);
|
||||||
|
EXPECT_TRUE(options.Has<mediapipe::NightLightCalculatorOptions>());
|
||||||
|
EXPECT_EQ(
|
||||||
|
options.Get<mediapipe::NightLightCalculatorOptions>().base_timestamp()[0],
|
||||||
|
123);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MutableOptionsMapTest, InsertAndQueryFound) {
|
||||||
|
CalculatorGraphConfig::Node node =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "NightLightCalculator"
|
||||||
|
input_side_packet: "input_value"
|
||||||
|
output_stream: "values"
|
||||||
|
)pb");
|
||||||
|
MutableOptionsMap options;
|
||||||
|
options.Initialize(node);
|
||||||
|
EXPECT_FALSE(options.Has<mediapipe::NightLightCalculatorOptions>());
|
||||||
|
mediapipe::NightLightCalculatorOptions night_light_options;
|
||||||
|
night_light_options.add_base_timestamp(123);
|
||||||
|
options.Set(night_light_options);
|
||||||
|
EXPECT_TRUE(options.Has<mediapipe::NightLightCalculatorOptions>());
|
||||||
|
EXPECT_EQ(
|
||||||
|
options.Get<mediapipe::NightLightCalculatorOptions>().base_timestamp()[0],
|
||||||
|
123);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tool
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user