mediapipe/mediapipe/framework/tool/template_expander.cc
MediaPipe Team 259b48e082 Project import generated by Copybara.
GitOrigin-RevId: b137378673f7d66d41bcd46e4fc3a0d9ef254894
2019-10-25 14:29:15 -07:00

659 lines
24 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/tool/template_expander.h"
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/numbers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/calculator_graph_template.pb.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe {
namespace tool {
using mediapipe::proto_ns::MessageLite;
using mediapipe::tool::ProtoUtilLite;
using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldValue = ProtoUtilLite::FieldValue;
using FieldType = ProtoUtilLite::FieldType;
using ProtoPath = ProtoUtilLite::ProtoPath;
namespace {
// Returns a template argument by name.
TemplateArgument* GetItem(TemplateDict* args, const std::string& name) {
for (TemplateDict::Parameter& arg : *args->mutable_arg()) {
if (arg.key() == name) {
return arg.mutable_value();
}
}
return nullptr;
}
// Sets the template argument for a param name.
void PutItem(TemplateDict* args, const std::string& name,
const TemplateArgument* value) {
for (int i = args->arg_size() - 1; i >= 0; --i) {
if (args->arg()[i].key() == name) {
if (value != nullptr) {
*args->mutable_arg(i)->mutable_value() = *value;
} else {
args->mutable_arg()->erase(args->mutable_arg()->begin() + i);
}
return;
}
}
if (value != nullptr) {
TemplateDict::Parameter* arg = args->add_arg();
*arg->mutable_key() = name;
*arg->mutable_value() = *value;
}
}
// Creates a deep copy of a message.
std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
std::unique_ptr<MessageLite> result(message.New());
result->CheckTypeAndMergeFrom(message);
return result;
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]".
mediapipe::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
mediapipe::Status status;
std::vector<std::string> ids = absl::StrSplit(path, '/');
for (const std::string& id : ids) {
if (id.length() > 0) {
std::pair<std::string, std::string> id_pair =
absl::StrSplit(id, absl::ByAnyChar("[]"));
int tag = 0;
int index = 0;
bool ok = absl::SimpleAtoi(id_pair.first, &tag) &&
absl::SimpleAtoi(id_pair.second, &index);
if (!ok) {
status.Update(::mediapipe::InvalidArgumentError(path));
}
result->push_back(std::make_pair(tag, index));
}
}
return status;
}
// Returns true if one proto path is prefix by another.
bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
return absl::StartsWith(path, prefix);
}
// Returns the part of one proto path after a prefix proto path.
std::string ProtoPathRelative(const std::string& field_path,
const std::string& base_path) {
CHECK(ProtoPathStartsWith(field_path, base_path));
return field_path.substr(base_path.length());
}
// Returns the target ProtoUtilLite::FieldType of a rule.
FieldType GetFieldType(const TemplateExpression& rule) {
return static_cast<FieldType>(rule.field_type());
}
// Returns the count of field values at a ProtoPath.
int FieldCount(const FieldValue& base, ProtoPath field_path,
FieldType field_type) {
int field_id, index;
std::tie(field_id, index) = field_path.back();
field_path.pop_back();
std::vector<FieldValue> parent;
if (field_path.empty()) {
parent.push_back(base);
} else {
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
ProtoUtilLite::FieldAccess access(field_id, field_type);
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
return access.mutable_field_values()->size();
}
} // namespace
// The default implementation for the mediapipe template rule interpreter.
class TemplateExpanderImpl {
public:
explicit TemplateExpanderImpl(std::vector<::mediapipe::Status>* errors)
: errors_(errors) {}
// Applies the rules specified in a CalculatorGraphTemplate to a
// CalculatorGraphConfig. Each rule references a nested field-value or
// message and defines zero or more replacement values for it.
bool ExpandTemplates(const TemplateDict& args,
const CalculatorGraphTemplate& templ,
CalculatorGraphConfig* output) {
// Extract the serialized CalculatorGraphConfig.
FieldValue base_value;
if (!templ.config().SerializeToString(&base_value)) {
return false;
}
// Extract the CalculatorGraphTemplate rules.
template_rules_ = templ;
template_rules_.clear_config();
// Invoke recursive rule expansion.
environment_ = args;
std::vector<FieldValue> result;
if (!ExpandNestedRules(0, "", base_value, &result)) {
return false;
}
return output->ParseFromString(result[0]);
}
private:
// Expands a template rule of a specific type.
// Modifies a base message to produce one or more expanded messages.
// Ownership of the result messages is transferred to the caller.
bool ExpandTemplateRule(int base_index, const FieldValue& base_message,
std::vector<FieldValue>* result) {
// Exapand a template rule of a specific type.
const TemplateExpression& rule = template_rules_.rule().Get(base_index);
if (rule.op() == "for") {
ExpandIterationRule(base_index, base_message, result);
} else if (rule.op() == "if") {
ExpandConditionalRule(base_index, base_message, result);
} else if (rule.op() == "param") {
ExpandDeclaration(base_index, base_message, result);
} else {
ExpandExpressionRule(base_index, result);
}
return true;
}
// Apply any remaining rules on the current field.
// If the next rule also applies to the current field, apply it.
// Otherwise, apply rules for nested fields.
bool ExpandPeerRules(int base_index, const FieldValue& base_message,
std::vector<FieldValue>* result) {
// If the next rule applies to the same message, apply it now.
auto& base_rule = template_rules_.rule().Get(base_index);
int next_index = base_index + 1;
if (next_index < template_rules_.rule().size()) {
auto& next_rule = template_rules_.rule().Get(next_index);
if (next_rule.path() == base_rule.path()) {
return ExpandTemplateRule(next_index, base_message, result);
}
}
// Otheriwse, apply rules for nested fields.
return ExpandNestedRules(next_index, base_rule.path(), base_message,
result);
}
// Return the field values addressed by a template rule.
mediapipe::Status GetBaseValue(const std::string& base_path,
const TemplateExpression& rule,
const FieldValue& output,
std::vector<FieldValue>* base) {
if (!rule.has_path()) {
base->push_back(output);
return ::mediapipe::OkStatus();
}
if (rule.has_field_value()) {
// For a non-repeated field, the field value is stored only in the rule.
base->push_back(rule.field_value());
return ::mediapipe::OkStatus();
}
ProtoPath field_path;
mediapipe::Status status =
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
if (!status.ok()) return status;
return ProtoUtilLite::GetFieldRange(output, field_path, 1,
GetFieldType(rule), base);
}
// Replace the field values addressed by a template rule.
mediapipe::Status ReplaceBaseValue(
const std::string& base_path, const TemplateExpression& rule,
const std::vector<FieldValue>& field_values, FieldValue* output) {
if (!rule.has_path()) {
*output = field_values[0];
return ::mediapipe::OkStatus();
}
ProtoPath field_path;
RET_CHECK_OK(
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
int field_count = 1;
if (rule.has_field_value()) {
// For a non-repeated field, only one value can be specified.
if (!field_values.empty() &&
FieldCount(*output, field_path, GetFieldType(rule)) > 0) {
return ::mediapipe::InvalidArgumentError(absl::StrCat(
"Multiple values specified for non-repeated field: ", rule.path()));
}
// For a non-repeated field, the field value is stored only in the rule.
field_path[field_path.size() - 1].second = 0;
field_count = 0;
}
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,
GetFieldType(rule), field_values);
}
// Replaces nested fields by following nested template rules.
bool ExpandNestedRules(int base_index, const std::string& base_path,
const FieldValue& base_message,
std::vector<FieldValue>* result) {
mediapipe::Status status;
FieldValue output = base_message;
// Evaluate the rules nested below base_path in lexical order.
std::vector<int> rules = GetNestedRules(base_index, base_path);
std::vector<std::vector<FieldValue>> edits;
for (int i = 0; i < rules.size(); ++i) {
const auto& rule = template_rules_.rule().Get(rules[i]);
std::vector<FieldValue> base;
status = GetBaseValue(base_path, rule, output, &base);
if (!status.ok()) break;
std::vector<FieldValue> values;
if (!ExpandTemplateRule(rules[i], base[0], &values)) {
status = ::mediapipe::InternalError("ExpandTemplateRule failed");
break;
}
edits.push_back(values);
}
if (!status.ok()) {
RecordError(status);
return false;
}
// Replace base field values with the evaluated results.
// Edits are applied in reverse order since later indices are invalidated.
for (int i = edits.size() - 1; i >= 0; --i) {
const auto& rule = template_rules_.rule().Get(rules[i]);
status = ReplaceBaseValue(base_path, rule, edits[i], &output);
if (!status.ok()) break;
}
if (!status.ok()) {
RecordError(status);
return false;
}
result->push_back(output);
return true;
}
// Returns indexes of the rules directly nested within a certain rule.
std::vector<int> GetNestedRules(int rule_index,
const std::string& rule_path) {
std::vector<int> result;
std::string prev_path = "-1[-1]";
for (int i = rule_index; i < template_rules_.rule().size(); ++i) {
auto& rule = template_rules_.rule().Get(i);
if (!ProtoPathStartsWith(rule.path(), rule_path)) {
break;
}
if (!ProtoPathStartsWith(rule.path(), prev_path)) {
result.push_back(i);
prev_path = rule.path();
}
}
return result;
}
// Apply a "for" operation to a base message.
// Expands nested rules once for each iteration range value.
bool ExpandIterationRule(int base_index, const FieldValue& base_message,
std::vector<FieldValue>* result) {
// Retrieve the var param and the range expression.
const TemplateExpression& rule = template_rules_.rule().Get(base_index);
std::string var_param = rule.arg().Get(0).param();
const TemplateExpression& range_expr = rule.arg().Get(1);
TemplateArgument range = EvalExpression(range_expr);
// For each value of the range param, expand all nested rules.
TemplateArgument* shadow_item = GetItem(&environment_, var_param);
for (const TemplateArgument& item : range.element()) {
PutItem(&environment_, var_param, &item);
ExpandPeerRules(base_index, base_message, result);
}
PutItem(&environment_, var_param, shadow_item);
return true;
}
// Initializes a parameter in the parameter environment.
bool ExpandDeclaration(int base_index, const FieldValue& base_message,
std::vector<FieldValue>* result) {
// Retrieve the var param and the range expression.
const TemplateExpression& rule = template_rules_.rule().Get(base_index);
if (rule.arg().empty() || rule.arg().size() > 2) {
RecordError(::mediapipe::InvalidArgumentError(
"Param declaration must specify a parameter name and "
"may specify a single default value."));
}
// TODO: Validate that all params are declared or none.
// Delarations for required params will have no default value.
if (rule.arg().size() == 2) {
std::string var_param = rule.arg().Get(0).param();
const TemplateExpression& item_expr = rule.arg().Get(1);
TemplateArgument item = EvalExpression(item_expr);
// The parameter default value is used if no other value is specified.
if (GetItem(&environment_, var_param) == nullptr) {
PutItem(&environment_, var_param, &item);
}
}
ExpandPeerRules(base_index, base_message, result);
return true;
}
// Applies an "if" operation to a base message.
// Expands nested rules zero or more times.
bool ExpandConditionalRule(int base_index, const FieldValue& base_message,
std::vector<FieldValue>* result) {
// Retrieve the condition expression.
const TemplateExpression& rule = template_rules_.rule().Get(base_index);
// Expand this template zero or one times.
bool condition = AsBool(EvalExpression(rule.arg(0)));
if (condition) {
ExpandPeerRules(base_index, base_message, result);
}
return true;
}
// A self-contained expression just defines a single result value.
bool ExpandExpressionRule(int base_index, std::vector<FieldValue>* result) {
const TemplateExpression& rule = template_rules_.rule().Get(base_index);
TemplateArgument item = EvalExpression(rule);
std::vector<FieldValue> values;
mediapipe::Status status = AsFieldValues(
std::vector<TemplateArgument>{item}, GetFieldType(rule), &values);
if (!status.ok()) {
RecordError(status);
return false;
}
result->push_back(values[0]);
return true;
}
// The "param" operation does variable environment lookup.
TemplateArgument EvalParam(const TemplateExpression& expr) {
TemplateArgument* result = GetItem(&environment_, expr.param());
if (result == nullptr) {
RecordError(
::mediapipe::NotFoundError(absl::StrCat("param: ", expr.param())));
return AsArgument(0.0);
}
return *result;
}
// The "." operator does template dict lookup.
TemplateArgument EvalDot(const TemplateExpression& expr) {
TemplateArgument lhs = EvalExpression(expr.arg(0));
TemplateArgument* result = GetItem(lhs.mutable_dict(), expr.arg(1).param());
if (result == nullptr) {
RecordError(::mediapipe::NotFoundError(
absl::StrCat("param field: ", expr.arg(1).param())));
return AsArgument(0.0);
}
return *result;
}
// Converts a TemplateArgument to double.
double AsNum(const TemplateArgument& value) {
double result = 0;
if (value.has_num()) {
result = value.num();
}
if (value.has_str()) {
if (!absl::SimpleAtod(value.str(), &result)) {
RecordError(::mediapipe::InvalidArgumentError(value.str()));
}
}
return result;
}
// Converts a TemplateArgument to std::string.
std::string AsString(const TemplateArgument& value) {
std::string result;
if (value.has_num()) {
result = absl::StrCat(value.num());
}
if (value.has_str()) {
result = value.str();
}
return result;
}
// Converts a TemplateArgument to bool.
bool AsBool(const TemplateArgument& value) {
bool result = false;
if (value.has_num()) {
return value.num() != 0;
} else if (value.has_str()) {
if (!absl::SimpleAtob(value.str(), &result)) {
RecordError(::mediapipe::InvalidArgumentError(value.str()));
}
}
return result;
}
// Converts a vector of TemplateArguments to a dict TemplateArgument.
TemplateArgument AsDict(const std::vector<TemplateArgument>& args) {
TemplateArgument result;
if (args.size() % 2 != 0) {
RecordError(::mediapipe::InvalidArgumentError(absl::StrCat(
"Dict requires an even number of arguments, got: ", args.size())));
return result;
}
TemplateDict* dict = result.mutable_dict();
for (int i = 0; i < args.size(); i += 2) {
TemplateDict::Parameter* p = dict->add_arg();
*p->mutable_key() = AsString(args[i]);
*p->mutable_value() = args[i + 1];
}
return result;
}
// Converts a vector of TemplateArguments to a list TemplateArgument.
TemplateArgument AsList(const std::vector<TemplateArgument>& args) {
TemplateArgument result;
auto list = result.mutable_element();
for (int i = 0; i < args.size(); ++i) {
*list->Add() = args[i];
}
return result;
}
// Evaluate each of the sub-expressions of a TemplateExpression.
void EvalNestedExpressions(const TemplateExpression& expr,
std::vector<TemplateArgument>* result) {
for (const TemplateExpression& e : expr.arg()) {
result->push_back(EvalExpression(e));
}
}
// Returns true if a TemplateArgument represents a number.
bool IsNum(const TemplateArgument& value) {
double r = 0;
return value.has_num() || absl::SimpleAtod(value.str(), &r);
}
// Returns 0 if v1 == v1, positive if v1 > v2, negative if v1 < v2.
int CompareArgs(const TemplateArgument& v1, const TemplateArgument& v2) {
if (IsNum(v1) && IsNum(v2)) {
double d = AsNum(v1) - AsNum(v2);
return (d < 0) ? -1 : (d > 0) ? 1 : 0;
} else {
return AsString(v1).compare(AsString(v2));
}
}
// Evaluates a TemplateExpression to produce a template argument.
TemplateArgument EvalExpression(const TemplateExpression& expr) {
if (expr.op() == "literal") {
return AsArgument(expr.param());
} else if (expr.op() == ".") {
return EvalDot(expr);
} else if (expr.has_param()) {
return EvalParam(expr);
}
std::vector<TemplateArgument> args;
EvalNestedExpressions(expr, &args);
TemplateArgument result;
if (expr.op() == "paren") {
result = args[0];
} else if (expr.op() == "+") {
if (IsNum(args[0]) && IsNum(args[1])) {
result = AsArgument(AsNum(args[0]) + AsNum(args[1]));
} else {
result = AsArgument(AsString(args[0]) + AsString(args[1]));
}
} else if (expr.op() == "-") {
result = AsArgument(AsNum(args[0]) - AsNum(args[1]));
} else if (expr.op() == "*") {
result = AsArgument(AsNum(args[0]) * AsNum(args[1]));
} else if (expr.op() == "/") {
result = AsArgument(AsNum(args[0]) / AsNum(args[1]));
} else if (expr.op() == ">") {
result = AsArgument(CompareArgs(args[0], args[1]) > 0);
} else if (expr.op() == "<") {
result = AsArgument(CompareArgs(args[0], args[1]) < 0);
} else if (expr.op() == ">=") {
result = AsArgument(CompareArgs(args[0], args[1]) >= 0);
} else if (expr.op() == "<=") {
result = AsArgument(CompareArgs(args[0], args[1]) <= 0);
} else if (expr.op() == "==") {
result = AsArgument(CompareArgs(args[0], args[1]) == 0);
} else if (expr.op() == "!=") {
result = AsArgument(CompareArgs(args[0], args[1]) != 0);
} else if (expr.op() == "&&") {
result = AsArgument(AsBool(args[0]) && AsBool(args[1]));
} else if (expr.op() == "||") {
result = AsArgument(AsBool(args[0]) || AsBool(args[1]));
} else if (expr.op() == "!") {
result = AsArgument(!(AsBool(args[0])));
} else if (expr.op() == "min") {
result = AsArgument(std::min(AsNum(args[0]), AsNum(args[1])));
} else if (expr.op() == "max") {
result = AsArgument(std::max(AsNum(args[0]), AsNum(args[1])));
} else if (expr.op() == "concat") {
result = AsArgument(AsString(args[0]) + AsString(args[1]));
} else if (expr.op() == "lowercase") {
result = AsArgument(absl::AsciiStrToLower(AsString(args[0])));
} else if (expr.op() == "uppercase") {
result = AsArgument(absl::AsciiStrToUpper(AsString(args[0])));
} else if (expr.op() == "dict") {
result = AsDict(args);
} else if (expr.op() == "list") {
result = AsList(args);
} else if (expr.op() == "size") {
return AsArgument(static_cast<double>(
args[0].has_dict() ? args[0].mutable_dict()->arg_size()
: args[0].mutable_element()->size()));
}
return result;
}
// Converts a simple value to a template argument for further processing.
TemplateArgument AsArgument(const std::string& value) {
TemplateArgument result;
result.set_str(value);
return result;
}
// Converts a simple value to a template argument for further processing.
TemplateArgument AsArgument(double value) {
TemplateArgument result;
result.set_num(value);
return result;
}
// Converts a boolean result into a template argument for further processing.
TemplateArgument AsArgument(bool b) {
return AsArgument(static_cast<double>(b));
}
// Convert between a proto feild value and a template argument.
mediapipe::Status AsFieldValues(const std::vector<TemplateArgument>& args,
FieldType field_type,
std::vector<FieldValue>* result) {
for (int i = 0; i < args.size(); ++i) {
if (args[i].has_dict()) {
FieldValue dict_bytes;
CHECK(args[i].dict().SerializePartialToString(&dict_bytes));
result->push_back(dict_bytes);
} else if (args[i].has_num() || args[i].has_str()) {
std::string text_value = args[i].has_num()
? mediapipe::SimpleDtoa(args[i].num())
: args[i].str();
std::vector<FieldValue> r;
MP_RETURN_IF_ERROR(
ProtoUtilLite::Serialize({text_value}, field_type, &r));
result->push_back(r[0]);
}
}
return ::mediapipe::OkStatus();
}
// Record a Status if it indicates an error.
void RecordError(const ::mediapipe::Status& status) {
if (!status.ok()) {
errors_->push_back(status);
}
}
private:
// The list of template rules.
mediapipe::CalculatorGraphTemplate template_rules_;
// The template variable environment.
TemplateDict environment_;
// List of errors found in template parameters.
std::vector<::mediapipe::Status>* errors_;
};
TemplateExpander::TemplateExpander() {}
// Expands template rules within a proto message.
// Replaces template rules with expanded sub-messages.
::mediapipe::Status TemplateExpander::ExpandTemplates(
const TemplateDict& args, const CalculatorGraphTemplate& templ,
CalculatorGraphConfig* output) {
errors_.clear();
TemplateExpanderImpl expander(&errors_);
if (!expander.ExpandTemplates(args, templ, output)) {
errors_.push_back(::mediapipe::InternalError("ExpandTemplates failed"));
}
::mediapipe::Status status;
for (const ::mediapipe::Status& error : errors_) {
LOG(ERROR) << error;
status.Update(error);
}
return status;
}
} // namespace tool
} // namespace mediapipe