diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index e082ef2e6..518eb6b0e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1677,6 +1677,7 @@ cc_test( ":subgraph", ":test_calculators", "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h index 34c2c2425..284226d92 100644 --- a/mediapipe/framework/calculator_context.h +++ b/mediapipe/framework/calculator_context.h @@ -69,6 +69,11 @@ class CalculatorContext { return calculator_state_->Options(); } + template + bool HasOptions() const { + return calculator_state_->HasOptions(); + } + // Returns a counter using the graph's counter factory. The counter's name is // the passed-in name, prefixed by the calculator node's name (if present) or // the calculator's type (if not). diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc index be9103b4d..382c4c09e 100644 --- a/mediapipe/framework/calculator_context_test.cc +++ b/mediapipe/framework/calculator_context_test.cc @@ -131,6 +131,13 @@ TEST(CalculatorTest, GetOptions) { auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); auto cc_3 = MakeCalculatorContext(&*calculator_state_3); + EXPECT_TRUE(cc_0->HasOptions()); + EXPECT_FALSE(cc_0->HasOptions()); + EXPECT_TRUE(cc_1->HasOptions()); + EXPECT_FALSE(cc_1->HasOptions()); + EXPECT_FALSE(cc_3->HasOptions()); + EXPECT_TRUE(cc_3->HasOptions()); + // Get a google::protobuf options extension from Node::options. EXPECT_EQ(cc_0->Options().jitter(), 0.123); diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index 2162f84e7..10fb2d049 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -69,6 +69,11 @@ class CalculatorContract { return options_.Get(); } + template + bool HasOptions() const { + return options_.Has(); + } + // Returns the PacketTypeSet for the input streams. PacketTypeSet& Inputs() { return *inputs_; } const PacketTypeSet& Inputs() const { return *inputs_; } diff --git a/mediapipe/framework/calculator_contract_test.cc b/mediapipe/framework/calculator_contract_test.cc index 694fc96fe..691a10dee 100644 --- a/mediapipe/framework/calculator_contract_test.cc +++ b/mediapipe/framework/calculator_contract_test.cc @@ -41,6 +41,7 @@ TEST(CalculatorContractTest, Calculator) { )pb"); CalculatorContract contract; MP_EXPECT_OK(contract.Initialize(node)); + EXPECT_FALSE(contract.HasOptions()); EXPECT_EQ(contract.Inputs().NumEntries(), 4); EXPECT_EQ(contract.Outputs().NumEntries(), 1); EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1); @@ -60,6 +61,7 @@ TEST(CalculatorContractTest, CalculatorOptions) { })pb"); CalculatorContract contract; MP_EXPECT_OK(contract.Initialize(node)); + ASSERT_TRUE(contract.HasOptions()); const auto& test_options = contract.Options().GetExtension(CalculatorContractTestOptions::ext); EXPECT_EQ(test_options.test_field(), 1.0); diff --git a/mediapipe/framework/calculator_state.h b/mediapipe/framework/calculator_state.h index f2af95725..33d8544eb 100644 --- a/mediapipe/framework/calculator_state.h +++ b/mediapipe/framework/calculator_state.h @@ -66,6 +66,10 @@ class CalculatorState { const T& Options() const { return options_.Get(); } + template + bool HasOptions() const { + return options_.Has(); + } const std::string& NodeName() const { return node_name_; } const int& NodeId() const { return node_id_; } diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index 5b1d9646a..cc5477a8f 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -62,6 +62,11 @@ class SubgraphContext { return options_map_.GetMutable(); } + template + bool HasOptions() { + return options_map_.Has(); + } + const CalculatorGraphConfig::Node& OriginalNode() const { return original_node_; } @@ -119,6 +124,11 @@ class Subgraph { return tool::OptionsMap().Initialize(supgraph_options).Get(); } + template + static bool HasOptions(const Subgraph::SubgraphOptions& supgraph_options) { + return tool::OptionsMap().Initialize(supgraph_options).Has(); + } + // Returns the CalculatorGraphConfig::Node specifying the subgraph. // This provides to Subgraphs the same graph information that GetContract // provides to Calculators. diff --git a/mediapipe/framework/subgraph_test.cc b/mediapipe/framework/subgraph_test.cc index 5a0dcce6f..2789aa683 100644 --- a/mediapipe/framework/subgraph_test.cc +++ b/mediapipe/framework/subgraph_test.cc @@ -17,6 +17,7 @@ #include #include "absl/strings/str_format.h" +#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/port/gmock.h" @@ -133,5 +134,59 @@ TEST(SubgraphServicesTest, EmitStringFromTestService) { EXPECT_EQ(side_string.Get(), "Expected STRING"); } +class OptionsCheckingSubgraph : public Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + std::string subgraph_side_packet_val; + if (sc->HasOptions()) { + subgraph_side_packet_val = + sc->Options() + .packet(0) + .string_value(); + } + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + absl::StrFormat(R"( + output_side_packet: "string" + node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:string" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "%s" } + } + } + } + )", + subgraph_side_packet_val)); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(OptionsCheckingSubgraph); + +TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + output_side_packet: "str" + node { + calculator: "OptionsCheckingSubgraph" + output_side_packet: "str" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "test" } + } + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilDone()); + auto packet = graph.GetOutputSidePacket("str"); + MP_ASSERT_OK(packet); + EXPECT_EQ(packet.value().Get(), "test"); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 193343a90..c34194fbf 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -181,6 +181,20 @@ cc_library( ], ) +cc_test( + name = "options_map_test", + srcs = ["options_map_test.cc"], + deps = [ + ":options_map", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", + "//mediapipe/framework/testdata:night_light_calculator_options_lib", + ], +) + mediapipe_proto_library( name = "field_data_proto", srcs = ["field_data.proto"], diff --git a/mediapipe/framework/tool/options_map.h b/mediapipe/framework/tool/options_map.h index 782d0f240..2b69f4fb6 100644 --- a/mediapipe/framework/tool/options_map.h +++ b/mediapipe/framework/tool/options_map.h @@ -28,6 +28,18 @@ struct IsExtension { static constexpr bool value = (sizeof(test(0)) == sizeof(char)); }; +template ::value, int>::type = 0> +bool HasExtension(const CalculatorOptions& options) { + return options.HasExtension(T::ext); +} + +template ::value, int>::type = 0> +bool HasExtension(const CalculatorOptions& options) { + return false; +} + template ::value, int>::type = 0> T* GetExtension(CalculatorOptions& options) { @@ -124,6 +136,27 @@ class OptionsMap { return *result; } + template + bool Has() const { + if (options_.Has()) { + return true; + } + if (node_config_->has_options()) { + return HasExtension(node_config_->options()); + } +#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY) + // protobuf::Any is unavailable with third_party/protobuf:protobuf-lite. +#else + for (const mediapipe::protobuf::Any& options : + node_config_->node_options()) { + if (options.Is()) { + return true; + } + } +#endif + return false; + } + CalculatorGraphConfig::Node* node_config_; TypeMap options_; }; diff --git a/mediapipe/framework/tool/options_map_test.cc b/mediapipe/framework/tool/options_map_test.cc new file mode 100644 index 000000000..529fd5770 --- /dev/null +++ b/mediapipe/framework/tool/options_map_test.cc @@ -0,0 +1,88 @@ + +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/tool/options_map.h" + +#include + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/testdata/night_light_calculator.pb.h" + +namespace mediapipe { +namespace tool { +namespace { + +TEST(OptionsMapTest, QueryNotFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + )pb"); + OptionsMap options; + options.Initialize(node); + EXPECT_FALSE(options.Has()); +} + +TEST(OptionsMapTest, QueryFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + options { + [mediapipe.NightLightCalculatorOptions.ext] { + base_timestamp: 123 + output_header: PASS_HEADER + jitter: 0.123 + } + } + )pb"); + OptionsMap options; + options.Initialize(node); + EXPECT_TRUE(options.Has()); + EXPECT_EQ( + options.Get().base_timestamp()[0], + 123); +} + +TEST(MutableOptionsMapTest, InsertAndQueryFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + )pb"); + MutableOptionsMap options; + options.Initialize(node); + EXPECT_FALSE(options.Has()); + mediapipe::NightLightCalculatorOptions night_light_options; + night_light_options.add_base_timestamp(123); + options.Set(night_light_options); + EXPECT_TRUE(options.Has()); + EXPECT_EQ( + options.Get().base_timestamp()[0], + 123); +} + +} // namespace +} // namespace tool +} // namespace mediapipe