mediapipe/mediapipe/python/pybind/validated_graph_config.cc
MediaPipe Team 350fbb2100 Project import generated by Copybara.
GitOrigin-RevId: d073f8e21be2fcc0e503cb97c6695078b6b75310
2021-02-27 03:30:05 -05:00

162 lines
5.9 KiB
C++

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/python/pybind/validated_graph_config.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/validated_graph_config.h"
#include "mediapipe/python/pybind/util.h"
namespace mediapipe {
namespace python {
namespace py = pybind11;
void ValidatedGraphConfigSubmodule(pybind11::module* module) {
py::module m = module->def_submodule(
"validated_graph_config", "MediaPipe validated graph config module.");
// Validated Graph Config
py::class_<ValidatedGraphConfig> validated_graph_config(
m, "ValidatedGraphConfig",
R"doc(A class to validate and canonicalize a CalculatorGraphConfig.)doc");
validated_graph_config.def(py::init())
.def(
"initialize",
[](ValidatedGraphConfig* self, py::kwargs kwargs) {
bool init_with_binary_graph = false;
bool init_with_graph_proto = false;
CalculatorGraphConfig graph_config_proto;
for (const auto& kw : kwargs) {
const std::string& key = kw.first.cast<std::string>();
if (key == "binary_graph_path") {
init_with_binary_graph = true;
std::string file_name(kw.second.cast<py::object>().str());
graph_config_proto =
ReadCalculatorGraphConfigFromFile(file_name);
} else if (key == "graph_config") {
init_with_graph_proto = true;
if (!ParseTextProto<CalculatorGraphConfig>(
kw.second.cast<py::object>().str(),
&graph_config_proto)) {
throw RaisePyError(
PyExc_RuntimeError,
absl::StrCat(
"Failed to parse: ",
std::string(kw.second.cast<py::object>().str()))
.c_str());
}
} else {
throw RaisePyError(
PyExc_RuntimeError,
absl::StrCat("Unknown kwargs input argument: ", key)
.c_str());
}
}
if (!(init_with_binary_graph ^ init_with_graph_proto)) {
throw RaisePyError(
PyExc_ValueError,
"Please either provide \'binary_graph_path\' to initialize "
"a ValidatedGraphConfig object with a binary graph file or "
"\'graph_config\' to initialize a ValidatedGraphConfig "
"object with a graph config proto.");
}
RaisePyErrorIfNotOk(self->Initialize(graph_config_proto));
},
R"doc(Initialize ValidatedGraphConfig with a CalculatorGraphConfig.
Args:
binary_graph_path: The path to a binary mediapipe graph file (.binarypb).
graph_config: A single CalculatorGraphConfig proto message or its text proto
format.
Raises:
FileNotFoundError: If the binary graph file can't be found.
ValueError: If the input arguments prvoided are more than needed or the
graph validation process contains error.
Examples:
validated_graph_config = mp.ValidatedGraphConfig()
validated_graph_config.initialize(graph_config=text_config)
)doc");
validated_graph_config.def(
"registered_stream_type_name",
[](ValidatedGraphConfig& self, const std::string& stream_name) {
auto status_or_type_name = self.RegisteredStreamTypeName(stream_name);
RaisePyErrorIfNotOk(status_or_type_name.status());
return status_or_type_name.value();
},
R"doc(Return the registered type name of the specified stream if it can be determined.
Args:
stream_name: The input/output stream name.
Returns:
The registered packet type name of the input/output stream.
Raises:
ValueError: If the input/output stream cannot be found.
Examples:
validated_graph_config.registered_stream_type_name('stream_name')
)doc");
validated_graph_config.def(
"registered_side_packet_type_name",
[](ValidatedGraphConfig& self, const std::string& side_packet_name) {
auto status_or_type_name =
self.RegisteredSidePacketTypeName(side_packet_name);
RaisePyErrorIfNotOk(status_or_type_name.status());
return status_or_type_name.value();
},
R"doc(Return the registered type name of the specified side packet if it can be determined.
Args:
side_packet_name: The input/output side packet name.
Returns:
The registered packet type name of the input/output side packet.
Raises:
ValueError: If the input/output side packet cannot be found.
Examples:
validated_graph_config.registered_side_packet_type_name('side_packet')
)doc");
// TODO: Return a Python CalculatorGraphConfig instead.
validated_graph_config.def_property_readonly(
"text_config", [](const ValidatedGraphConfig& self) {
return self.Config().DebugString();
});
validated_graph_config.def_property_readonly(
"binary_config", [](const ValidatedGraphConfig& self) {
return py::bytes(self.Config().SerializeAsString());
});
validated_graph_config.def(
"initialized",
[](const ValidatedGraphConfig& self) { return self.Initialized(); },
R"doc(Indicate if ValidatedGraphConfig is initialized.)doc");
}
} // namespace python
} // namespace mediapipe