Internal change

PiperOrigin-RevId: 494420725
This commit is contained in:
Hadon Nash 2022-12-10 12:32:04 -08:00 committed by Copybara-Service
parent e9bb51a524
commit 421f789ede
8 changed files with 348 additions and 82 deletions

View File

@ -346,6 +346,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -506,6 +507,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":proto_util_lite", ":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto",
"//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",

View File

@ -27,6 +27,9 @@ message TemplateExpression {
// The FieldDescriptor::Type of the modified field. // The FieldDescriptor::Type of the modified field.
optional mediapipe.FieldDescriptorProto.Type field_type = 5; optional mediapipe.FieldDescriptorProto.Type field_type = 5;
// The FieldDescriptor::Type of each map key in the path.
repeated mediapipe.FieldDescriptorProto.Type key_type = 6;
// Alternative value for the modified field, in protobuf binary format. // Alternative value for the modified field, in protobuf binary format.
optional string field_value = 7; optional string field_value = 7;
} }

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/type_map.h" #include "mediapipe/framework/type_map.h"
@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
// Extracts the data value(s) for one field from a serialized message. // Extracts the data value(s) for one field from a serialized message.
// The message with these field values removed is written to |out|. // The message with these field values removed is written to |out|.
absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in,
CodedInputStream* in, CodedOutputStream* out, CodedOutputStream* out,
std::vector<std::string>* field_values) { std::vector<std::string>* field_values) {
uint32 tag; uint32 tag;
while ((tag = in->ReadTag()) != 0) { while ((tag = in->ReadTag()) != 0) {
int field_number = WireFormatLite::GetTagFieldNumber(tag); int field_number = WireFormatLite::GetTagFieldNumber(tag);
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (field_number == field_id) { if (field_number == field_id) {
if (!IsLengthDelimited(wire_type) && if (!IsLengthDelimited(wire_type) &&
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) {
CodedInputStream in(&ais); CodedInputStream in(&ais);
StringOutputStream sos(&message_); StringOutputStream sos(&message_);
CodedOutputStream out(&sos); CodedOutputStream out(&sos);
WireFormatLite::WireType wire_type = return GetFieldValues(field_id_, &in, &out, &field_values_);
WireFormatLite::WireTypeForFieldType(field_type_);
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
} }
void FieldAccess::GetMessage(std::string* result) { void FieldAccess::GetMessage(std::string* result) {
@ -149,18 +149,56 @@ std::vector<FieldValue>* FieldAccess::mutable_field_values() {
return &field_values_; return &field_values_;
} }
namespace {
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
// Returns the FieldAccess and index for a field-id or a map-id.
// Returns access to the field-id if the field index is found,
// to the map-id if the map entry is found, and to the field-id otherwise.
absl::StatusOr<std::pair<FieldAccess, int>> AccessField(
const ProtoPathEntry& entry, FieldType field_type,
const FieldValue& message) {
FieldAccess result(entry.field_id, field_type);
if (entry.field_id >= 0) {
MP_RETURN_IF_ERROR(result.SetMessage(message));
if (entry.index < result.mutable_field_values()->size()) {
return std::pair(result, entry.index);
}
}
if (entry.map_id >= 0) {
FieldAccess access(entry.map_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(message));
auto& field_values = *access.mutable_field_values();
for (int index = 0; index < field_values.size(); ++index) {
FieldAccess key(entry.key_id, entry.key_type);
MP_RETURN_IF_ERROR(key.SetMessage(field_values[index]));
if (key.mutable_field_values()->at(0) == entry.key_value) {
return std::pair(std::move(access), index);
}
}
}
if (entry.field_id >= 0) {
return std::pair(result, entry.index);
}
return absl::InvalidArgumentError(absl::StrCat(
"ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ",
entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type));
}
} // namespace
// Replaces a range of field values for one field nested within a protobuf. // Replaces a range of field values for one field nested within a protobuf.
absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::ReplaceFieldRange(
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
const std::vector<FieldValue>& field_values) { const std::vector<FieldValue>& field_values) {
int field_id, index; ProtoPathEntry entry = proto_path.front();
std::tie(field_id, index) = proto_path.front();
proto_path.erase(proto_path.begin()); proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty() FieldType type =
? WireFormatLite::TYPE_MESSAGE !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
: field_type); ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message));
MP_RETURN_IF_ERROR(access.SetMessage(*message)); FieldAccess& access = r.first;
std::vector<std::string>& v = *access.mutable_field_values(); int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) { if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size()); RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
absl::Status ProtoUtilLite::GetFieldRange( absl::Status ProtoUtilLite::GetFieldRange(
const FieldValue& message, ProtoPath proto_path, int length, const FieldValue& message, ProtoPath proto_path, int length,
FieldType field_type, std::vector<FieldValue>* field_values) { FieldType field_type, std::vector<FieldValue>* field_values) {
int field_id, index; ProtoPathEntry entry = proto_path.front();
std::tie(field_id, index) = proto_path.front();
proto_path.erase(proto_path.begin()); proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty() FieldType type =
? WireFormatLite::TYPE_MESSAGE !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
: field_type); ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
MP_RETURN_IF_ERROR(access.SetMessage(message)); FieldAccess& access = r.first;
std::vector<std::string>& v = *access.mutable_field_values(); int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) { if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size()); RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values)); GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else { } else {
if (length == -1) {
length = v.size() - index;
}
RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index, field_values->insert(field_values->begin(), v.begin() + index,
@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
ProtoPath proto_path, ProtoPath proto_path,
FieldType field_type, FieldType field_type,
int* field_count) { int* field_count) {
int field_id, index; ProtoPathEntry entry = proto_path.front();
std::tie(field_id, index) = proto_path.back(); proto_path.erase(proto_path.begin());
proto_path.pop_back(); FieldType type =
std::vector<std::string> parent; !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
if (proto_path.empty()) { ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
parent.push_back(std::string(message)); FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldCount(v[index], proto_path, field_type, field_count));
} else { } else {
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( *field_count = v.size();
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
} }
FieldAccess access(field_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
*field_count = access.mutable_field_values()->size();
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -34,15 +34,31 @@ class ProtoUtilLite {
// Defines field types and tag formats. // Defines field types and tag formats.
using WireFormatLite = proto_ns::internal::WireFormatLite; using WireFormatLite = proto_ns::internal::WireFormatLite;
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<std::pair<int, int>>;
// The serialized value for a protobuf field. // The serialized value for a protobuf field.
using FieldValue = std::string; using FieldValue = std::string;
// The serialized data type for a protobuf field. // The serialized data type for a protobuf field.
using FieldType = WireFormatLite::FieldType; using FieldType = WireFormatLite::FieldType;
// A field-id and index, or a map-id and key, or both.
struct ProtoPathEntry {
ProtoPathEntry(int id, int index) : field_id(id), index(index) {}
ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value)
: map_id(id),
key_id(key_id),
key_type(key_type),
key_value(std::move(key_value)) {}
int field_id = -1;
int index = -1;
int map_id = -1;
int key_id = -1;
FieldType key_type;
FieldValue key_value;
};
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<ProtoPathEntry>;
class FieldAccess { class FieldAccess {
public: public:
// Provides access to a certain protobuf field. // Provides access to a certain protobuf field.
@ -57,9 +73,11 @@ class ProtoUtilLite {
// Returns the serialized values of the protobuf field. // Returns the serialized values of the protobuf field.
std::vector<FieldValue>* mutable_field_values(); std::vector<FieldValue>* mutable_field_values();
uint32 field_id() const { return field_id_; }
private: private:
const uint32 field_id_; uint32 field_id_;
const FieldType field_type_; FieldType field_type_;
std::string message_; std::string message_;
std::vector<FieldValue> field_values_; std::vector<FieldValue> field_values_;
}; };

View File

@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "absl/strings/ascii.h" #include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldValue = ProtoUtilLite::FieldValue; using FieldValue = ProtoUtilLite::FieldValue;
using FieldType = ProtoUtilLite::FieldType; using FieldType = ProtoUtilLite::FieldType;
using ProtoPath = ProtoUtilLite::ProtoPath; using ProtoPath = ProtoUtilLite::ProtoPath;
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
namespace { namespace {
@ -84,26 +86,87 @@ std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
return result; return result;
} }
// Returns the (tag, index) pairs in a field path. // Parses one ProtoPathEntry.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". // The parsed entry is appended to `result` and removed from `path`.
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { // ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes
absl::Status status; // to serialize the key text to protobuf wire format.
std::vector<std::string> ids = absl::StrSplit(path, '/'); absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
for (const std::string& id : ids) { bool ok = true;
if (id.length() > 0) { int sb = path.find('[');
std::pair<std::string, std::string> id_pair = int eb = path.find(']');
absl::StrSplit(id, absl::ByAnyChar("[]")); int field_id = -1;
int tag = 0; ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
int index = 0; auto selector = path.substr(sb + 1, eb - 1 - sb);
bool ok = absl::SimpleAtoi(id_pair.first, &tag) && if (absl::StartsWith(selector, "@")) {
absl::SimpleAtoi(id_pair.second, &index); int eq = selector.find('=');
if (!ok) { int key_id = -1;
status.Update(absl::InvalidArgumentError(path)); ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
} auto key_text = selector.substr(eq + 1);
result->push_back(std::make_pair(tag, index)); FieldType key_type = FieldType::TYPE_STRING;
result->push_back({field_id, key_id, key_type, std::string(key_text)});
} else {
int index = 0;
ok &= absl::SimpleAtoi(selector, &index);
result->push_back({field_id, index});
}
int end = path.find('/', eb);
if (end == std::string::npos) {
path = "";
} else {
path = path.substr(end + 1);
}
return ok ? absl::OkStatus()
: absl::InvalidArgumentError(
absl::StrCat("Failed to parse ProtoPath entry: ", path));
}
// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
// Each ProtoPathEntry::key_value is converted from text to the protobuf
// wire format for its key type.
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
ProtoPath* result) {
int i = 0;
for (ProtoPathEntry& entry : *result) {
if (entry.map_id >= 0) {
FieldType key_type = key_types[i++];
std::vector<FieldValue> key_value;
MP_RETURN_IF_ERROR(
ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
entry.key_type = key_type;
entry.key_value = key_value.front();
} }
} }
return status; return absl::OkStatus();
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
result->clear();
absl::string_view rest = path;
if (absl::StartsWith(rest, "/")) {
rest = rest.substr(1);
}
while (!rest.empty()) {
MP_RETURN_IF_ERROR(ParseEntry(rest, result));
}
return absl::OkStatus();
}
// Parse the TemplateExpression.path field into a ProtoPath struct.
absl::Status ParseProtoPath(const TemplateExpression& rule,
std::string base_path, ProtoPath* result) {
ProtoPath base_entries;
MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
std::vector<FieldType> key_types;
for (int type : rule.key_type()) {
key_types.push_back(static_cast<FieldType>(type));
}
MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
result->erase(result->begin(), result->begin() + base_entries.size());
return absl::OkStatus();
} }
// Returns true if one proto path is prefix by another. // Returns true if one proto path is prefix by another.
@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
return absl::StartsWith(path, prefix); return absl::StartsWith(path, prefix);
} }
// Returns the part of one proto path after a prefix proto path.
std::string ProtoPathRelative(const std::string& field_path,
const std::string& base_path) {
CHECK(ProtoPathStartsWith(field_path, base_path));
return field_path.substr(base_path.length());
}
// Returns the target ProtoUtilLite::FieldType of a rule. // Returns the target ProtoUtilLite::FieldType of a rule.
FieldType GetFieldType(const TemplateExpression& rule) { FieldType GetFieldType(const TemplateExpression& rule) {
return static_cast<FieldType>(rule.field_type()); return static_cast<FieldType>(rule.field_type());
@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) {
// Returns the count of field values at a ProtoPath. // Returns the count of field values at a ProtoPath.
int FieldCount(const FieldValue& base, ProtoPath field_path, int FieldCount(const FieldValue& base, ProtoPath field_path,
FieldType field_type) { FieldType field_type) {
int field_id, index; int result = 0;
std::tie(field_id, index) = field_path.back(); CHECK(
field_path.pop_back(); ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok());
std::vector<FieldValue> parent; return result;
if (field_path.empty()) {
parent.push_back(base);
} else {
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
ProtoUtilLite::FieldAccess access(field_id, field_type);
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
return access.mutable_field_values()->size();
} }
} // namespace } // namespace
@ -229,9 +276,7 @@ class TemplateExpanderImpl {
return absl::OkStatus(); return absl::OkStatus();
} }
ProtoPath field_path; ProtoPath field_path;
absl::Status status = MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
if (!status.ok()) return status;
return ProtoUtilLite::GetFieldRange(output, field_path, 1, return ProtoUtilLite::GetFieldRange(output, field_path, 1,
GetFieldType(rule), base); GetFieldType(rule), base);
} }
@ -242,12 +287,13 @@ class TemplateExpanderImpl {
const std::vector<FieldValue>& field_values, const std::vector<FieldValue>& field_values,
FieldValue* output) { FieldValue* output) {
if (!rule.has_path()) { if (!rule.has_path()) {
*output = field_values[0]; if (!field_values.empty()) {
*output = field_values[0];
}
return absl::OkStatus(); return absl::OkStatus();
} }
ProtoPath field_path; ProtoPath field_path;
RET_CHECK_OK( MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
int field_count = 1; int field_count = 1;
if (rule.has_field_value()) { if (rule.has_field_value()) {
// For a non-repeated field, only one value can be specified. // For a non-repeated field, only one value can be specified.
@ -257,7 +303,7 @@ class TemplateExpanderImpl {
"Multiple values specified for non-repeated field: ", rule.path())); "Multiple values specified for non-repeated field: ", rule.path()));
} }
// For a non-repeated field, the field value is stored only in the rule. // For a non-repeated field, the field value is stored only in the rule.
field_path[field_path.size() - 1].second = 0; field_path[field_path.size() - 1].index = 0;
field_count = 0; field_count = 0;
} }
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,

View File

@ -26,6 +26,7 @@
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message;
using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::OneofDescriptor;
using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::Reflection;
using mediapipe::proto_ns::TextFormat; using mediapipe::proto_ns::TextFormat;
using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath;
using FieldType = mediapipe::tool::ProtoUtilLite::FieldType;
using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue;
namespace mediapipe { namespace mediapipe {
@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path,
if (!ok) { if (!ok) {
status.Update(absl::InvalidArgumentError(path)); status.Update(absl::InvalidArgumentError(path));
} }
result->push_back(std::make_pair(tag, index)); result->push_back({tag, index});
} }
} }
return status; return status;
@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) {
const Descriptor* descriptor = message->GetDescriptor(); const Descriptor* descriptor = message->GetDescriptor();
ProtoUtilLite::ProtoPath path; ProtoUtilLite::ProtoPath path;
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
int field_number = path[path.size() - 1].first; int field_number = path[path.size() - 1].field_id;
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
if (!field->is_repeated()) { if (!field->is_repeated()) {
std::vector<ProtoUtilLite::FieldValue> field_values; std::vector<ProtoUtilLite::FieldValue> field_values;
@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) {
} }
} }
// Returns the field or extension for field number.
const FieldDescriptor* FindFieldByNumber(const Message* message,
int field_num) {
const FieldDescriptor* result =
message->GetDescriptor()->FindFieldByNumber(field_num);
if (result == nullptr) {
result = message->GetReflection()->FindKnownExtensionByNumber(field_num);
}
return result;
}
// Returns the message value from a field at an index.
const Message* GetFieldMessage(const Message& message,
const FieldDescriptor* field, int index) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
if (!field->is_repeated()) {
return &message.GetReflection()->GetMessage(message, field);
}
if (index < message.GetReflection()->FieldSize(message, field)) {
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
}
return nullptr;
}
// Serialize a ProtoPath as a readable string.
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
std::string ProtoPathJoin(ProtoPath path) {
std::string result;
for (ProtoUtilLite::ProtoPathEntry& e : path) {
if (e.field_id >= 0) {
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
} else if (e.map_id >= 0) {
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
"]");
}
}
return result;
}
// Returns the protobuf map key types from a ProtoPath.
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
std::vector<FieldType> result;
for (auto& entry : path) {
if (entry.map_id >= 0) {
result.push_back(entry.key_type);
}
}
return result;
}
// Returns the text value for a string or numeric protobuf map key.
std::string GetMapKey(const Message& map_entry) {
auto key_field = map_entry.GetDescriptor()->FindFieldByName("key");
auto reflection = map_entry.GetReflection();
if (key_field->type() == FieldDescriptor::TYPE_STRING) {
return reflection->GetString(map_entry, key_field);
} else if (key_field->type() == FieldDescriptor::TYPE_INT32) {
return absl::StrCat(reflection->GetInt32(map_entry, key_field));
} else if (key_field->type() == FieldDescriptor::TYPE_INT64) {
return absl::StrCat(reflection->GetInt64(map_entry, key_field));
}
return "";
}
// Adjusts map-entries from indexes to keys.
// Protobuf map-entry order is intentionally not preserved.
mediapipe::Status KeyProtoMapEntries(Message* source) {
// Copy the rules from the source CalculatorGraphTemplate.
mediapipe::CalculatorGraphTemplate rules;
rules.ParsePartialFromString(source->SerializePartialAsString());
// Only the "source" Message knows all extension types.
Message* config_0 = source->GetReflection()->MutableMessage(
source, source->GetDescriptor()->FindFieldByName("config"), nullptr);
for (int i = 0; i < rules.rule().size(); ++i) {
TemplateExpression* rule = rules.mutable_rule()->Mutable(i);
const Message* message = config_0;
ProtoPath path;
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
for (int j = 0; j < path.size(); ++j) {
int field_id = path[j].field_id;
int field_index = path[j].index;
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
if (field->is_map()) {
const Message* map_entry =
GetFieldMessage(*message, field, path[j].index);
int key_id =
map_entry->GetDescriptor()->FindFieldByName("key")->number();
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
map_entry->GetDescriptor()->FindFieldByName("key")->type());
std::string key_value = GetMapKey(*map_entry);
path[j] = {field_id, key_id, key_type, key_value};
}
message = GetFieldMessage(*message, field, field_index);
if (!message) {
break;
}
}
if (!rule->path().empty()) {
*rule->mutable_path() = ProtoPathJoin(path);
for (FieldType key_type : ProtoPathKeyTypes(path)) {
*rule->mutable_key_type()->Add() = key_type;
}
}
}
// Copy the rules back into the source CalculatorGraphTemplate.
auto source_rules =
source->GetReflection()->GetMutableRepeatedFieldRef<Message>(
source, source->GetDescriptor()->FindFieldByName("rule"));
source_rules.Clear();
for (auto& rule : rules.rule()) {
source_rules.Add(rule);
}
return absl::OkStatus();
}
} // namespace } // namespace
class TemplateParser::Parser::MediaPipeParserImpl class TemplateParser::Parser::MediaPipeParserImpl
@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl
// Copy the template rules into the output template "rule" field. // Copy the template rules into the output template "rule" field.
success &= MergeFields(template_rules_, output).ok(); success &= MergeFields(template_rules_, output).ok();
// Replace map-entry indexes with map keys.
success &= KeyProtoMapEntries(output).ok();
return success; return success;
} }

View File

@ -17,6 +17,7 @@ load(
"//mediapipe/framework/tool:mediapipe_graph.bzl", "//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph", "mediapipe_simple_subgraph",
) )
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"]) licenses(["notice"])
@ -58,3 +59,12 @@ mediapipe_simple_subgraph(
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
], ],
) )
mediapipe_proto_library(
name = "frozen_generator_proto",
srcs = ["frozen_generator.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
"//mediapipe/framework:packet_generator_proto",
],
)

View File

@ -0,0 +1,20 @@
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
message FrozenGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional FrozenGeneratorOptions ext = 225748738;
}
// Path to file containing serialized proto of type tensorflow::GraphDef.
optional string graph_proto_path = 1;
// This map defines the which streams are fed to which tensors in the model.
map<string, string> tag_to_tensor_names = 2;
// Graph nodes to run to initialize the model.
repeated string initialization_op_names = 4;
}