Add functions for checking the existence of options in subgraphs and calculators.

PiperOrigin-RevId: 513689742
This commit is contained in:
MediaPipe Team 2023-03-02 17:59:15 -08:00 committed by Copybara-Service
parent 91d53cd181
commit 3837c92fd5
11 changed files with 224 additions and 0 deletions

View File

@ -1677,6 +1677,7 @@ cc_test(
":subgraph",
":test_calculators",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",

View File

@ -69,6 +69,11 @@ class CalculatorContext {
return calculator_state_->Options<T>();
}
template <class T>
bool HasOptions() const {
return calculator_state_->HasOptions<T>();
}
// Returns a counter using the graph's counter factory. The counter's name is
// the passed-in name, prefixed by the calculator node's name (if present) or
// the calculator's type (if not).

View File

@ -131,6 +131,13 @@ TEST(CalculatorTest, GetOptions) {
auto calculator_state_3 = MakeCalculatorState(config.node(3), 3);
auto cc_3 = MakeCalculatorContext(&*calculator_state_3);
EXPECT_TRUE(cc_0->HasOptions<NightLightCalculatorOptions>());
EXPECT_FALSE(cc_0->HasOptions<SkyLightCalculatorOptions>());
EXPECT_TRUE(cc_1->HasOptions<NightLightCalculatorOptions>());
EXPECT_FALSE(cc_1->HasOptions<SkyLightCalculatorOptions>());
EXPECT_FALSE(cc_3->HasOptions<NightLightCalculatorOptions>());
EXPECT_TRUE(cc_3->HasOptions<SkyLightCalculatorOptions>());
// Get a google::protobuf options extension from Node::options.
EXPECT_EQ(cc_0->Options<NightLightCalculatorOptions>().jitter(), 0.123);

View File

@ -69,6 +69,11 @@ class CalculatorContract {
return options_.Get<T>();
}
template <class T>
bool HasOptions() const {
return options_.Has<T>();
}
// Returns the PacketTypeSet for the input streams.
PacketTypeSet& Inputs() { return *inputs_; }
const PacketTypeSet& Inputs() const { return *inputs_; }

View File

@ -41,6 +41,7 @@ TEST(CalculatorContractTest, Calculator) {
)pb");
CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node));
EXPECT_FALSE(contract.HasOptions<CalculatorContractTestOptions>());
EXPECT_EQ(contract.Inputs().NumEntries(), 4);
EXPECT_EQ(contract.Outputs().NumEntries(), 1);
EXPECT_EQ(contract.InputSidePackets().NumEntries(), 1);
@ -60,6 +61,7 @@ TEST(CalculatorContractTest, CalculatorOptions) {
})pb");
CalculatorContract contract;
MP_EXPECT_OK(contract.Initialize(node));
ASSERT_TRUE(contract.HasOptions<CalculatorContractTestOptions>());
const auto& test_options =
contract.Options().GetExtension(CalculatorContractTestOptions::ext);
EXPECT_EQ(test_options.test_field(), 1.0);

View File

@ -66,6 +66,10 @@ class CalculatorState {
const T& Options() const {
return options_.Get<T>();
}
template <class T>
bool HasOptions() const {
return options_.Has<T>();
}
const std::string& NodeName() const { return node_name_; }
const int& NodeId() const { return node_id_; }

View File

@ -62,6 +62,11 @@ class SubgraphContext {
return options_map_.GetMutable<T>();
}
template <typename T>
bool HasOptions() {
return options_map_.Has<T>();
}
const CalculatorGraphConfig::Node& OriginalNode() const {
return original_node_;
}
@ -119,6 +124,11 @@ class Subgraph {
return tool::OptionsMap().Initialize(supgraph_options).Get<T>();
}
template <typename T>
static bool HasOptions(const Subgraph::SubgraphOptions& supgraph_options) {
return tool::OptionsMap().Initialize(supgraph_options).Has<T>();
}
// Returns the CalculatorGraphConfig::Node specifying the subgraph.
// This provides to Subgraphs the same graph information that GetContract
// provides to Calculators.

View File

@ -17,6 +17,7 @@
#include <string>
#include "absl/strings/str_format.h"
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/port/gmock.h"
@ -133,5 +134,59 @@ TEST(SubgraphServicesTest, EmitStringFromTestService) {
EXPECT_EQ(side_string.Get<std::string>(), "Expected STRING");
}
class OptionsCheckingSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
std::string subgraph_side_packet_val;
if (sc->HasOptions<ConstantSidePacketCalculatorOptions>()) {
subgraph_side_packet_val =
sc->Options<ConstantSidePacketCalculatorOptions>()
.packet(0)
.string_value();
}
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::StrFormat(R"(
output_side_packet: "string"
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:string"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "%s" }
}
}
}
)",
subgraph_side_packet_val));
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(OptionsCheckingSubgraph);
TEST_F(SubgraphTest, CheckSubgraphOptionsPassedIn) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
output_side_packet: "str"
node {
calculator: "OptionsCheckingSubgraph"
output_side_packet: "str"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "test" }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilDone());
auto packet = graph.GetOutputSidePacket("str");
MP_ASSERT_OK(packet);
EXPECT_EQ(packet.value().Get<std::string>(), "test");
}
} // namespace
} // namespace mediapipe

View File

@ -181,6 +181,20 @@ cc_library(
],
)
cc_test(
name = "options_map_test",
srcs = ["options_map_test.cc"],
deps = [
":options_map",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
"//mediapipe/framework/testdata:night_light_calculator_options_lib",
],
)
mediapipe_proto_library(
name = "field_data_proto",
srcs = ["field_data.proto"],

View File

@ -28,6 +28,18 @@ struct IsExtension {
static constexpr bool value = (sizeof(test<T>(0)) == sizeof(char));
};
template <class T,
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
bool HasExtension(const CalculatorOptions& options) {
return options.HasExtension(T::ext);
}
template <class T,
typename std::enable_if<!IsExtension<T>::value, int>::type = 0>
bool HasExtension(const CalculatorOptions& options) {
return false;
}
template <class T,
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
T* GetExtension(CalculatorOptions& options) {
@ -124,6 +136,27 @@ class OptionsMap {
return *result;
}
template <class T>
bool Has() const {
if (options_.Has<T>()) {
return true;
}
if (node_config_->has_options()) {
return HasExtension<T>(node_config_->options());
}
#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY)
// protobuf::Any is unavailable with third_party/protobuf:protobuf-lite.
#else
for (const mediapipe::protobuf::Any& options :
node_config_->node_options()) {
if (options.Is<T>()) {
return true;
}
}
#endif
return false;
}
CalculatorGraphConfig::Node* node_config_;
TypeMap options_;
};

View File

@ -0,0 +1,88 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/tool/options_map.h"
#include <unistd.h>
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
namespace mediapipe {
namespace tool {
namespace {
TEST(OptionsMapTest, QueryNotFound) {
CalculatorGraphConfig::Node node =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "NightLightCalculator"
input_side_packet: "input_value"
output_stream: "values"
)pb");
OptionsMap options;
options.Initialize(node);
EXPECT_FALSE(options.Has<mediapipe::NightLightCalculatorOptions>());
}
TEST(OptionsMapTest, QueryFound) {
CalculatorGraphConfig::Node node =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "NightLightCalculator"
input_side_packet: "input_value"
output_stream: "values"
options {
[mediapipe.NightLightCalculatorOptions.ext] {
base_timestamp: 123
output_header: PASS_HEADER
jitter: 0.123
}
}
)pb");
OptionsMap options;
options.Initialize(node);
EXPECT_TRUE(options.Has<mediapipe::NightLightCalculatorOptions>());
EXPECT_EQ(
options.Get<mediapipe::NightLightCalculatorOptions>().base_timestamp()[0],
123);
}
TEST(MutableOptionsMapTest, InsertAndQueryFound) {
CalculatorGraphConfig::Node node =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "NightLightCalculator"
input_side_packet: "input_value"
output_stream: "values"
)pb");
MutableOptionsMap options;
options.Initialize(node);
EXPECT_FALSE(options.Has<mediapipe::NightLightCalculatorOptions>());
mediapipe::NightLightCalculatorOptions night_light_options;
night_light_options.add_base_timestamp(123);
options.Set(night_light_options);
EXPECT_TRUE(options.Has<mediapipe::NightLightCalculatorOptions>());
EXPECT_EQ(
options.Get<mediapipe::NightLightCalculatorOptions>().base_timestamp()[0],
123);
}
} // namespace
} // namespace tool
} // namespace mediapipe