Internal change
PiperOrigin-RevId: 495058817
This commit is contained in:
parent
78597c5b37
commit
db404b1a85
|
@ -48,11 +48,16 @@ class ProtoUtilLite {
|
|||
key_id(key_id),
|
||||
key_type(key_type),
|
||||
key_value(std::move(key_value)) {}
|
||||
bool operator==(const ProtoPathEntry& o) const {
|
||||
return field_id == o.field_id && index == o.index && map_id == o.map_id &&
|
||||
key_id == o.key_id && key_type == o.key_type &&
|
||||
key_value == o.key_value;
|
||||
}
|
||||
int field_id = -1;
|
||||
int index = -1;
|
||||
int map_id = -1;
|
||||
int key_id = -1;
|
||||
FieldType key_type;
|
||||
FieldType key_type = FieldType::MAX_FIELD_TYPE;
|
||||
FieldValue key_value;
|
||||
};
|
||||
|
||||
|
|
|
@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path,
|
|||
return status;
|
||||
}
|
||||
|
||||
// Returns a message serialized deterministically.
|
||||
bool DeterministicallySerialize(const Message& proto, std::string* result) {
|
||||
proto_ns::io::StringOutputStream stream(result);
|
||||
proto_ns::io::CodedOutputStream output(&stream);
|
||||
output.SetSerializationDeterministic(true);
|
||||
return proto.SerializeToCodedStream(&output);
|
||||
}
|
||||
|
||||
// Serialize one field of a message.
|
||||
void SerializeField(const Message* message, const FieldDescriptor* field,
|
||||
std::vector<ProtoUtilLite::FieldValue>* result) {
|
||||
ProtoUtilLite::FieldValue message_bytes;
|
||||
CHECK(message->SerializePartialToString(&message_bytes));
|
||||
CHECK(DeterministicallySerialize(*message, &message_bytes));
|
||||
ProtoUtilLite::FieldAccess access(
|
||||
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
|
||||
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
|
||||
*result = *access.mutable_field_values();
|
||||
}
|
||||
|
||||
// Serialize a ProtoPath as a readable string.
|
||||
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
||||
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
std::string ProtoPathJoin(ProtoPath path) {
|
||||
std::string result;
|
||||
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
||||
if (e.field_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
||||
} else if (e.map_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
||||
"]");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the message value from a field at an index.
|
||||
const Message* GetFieldMessage(const Message& message,
|
||||
const FieldDescriptor* field, int index) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!field->is_repeated()) {
|
||||
return &message.GetReflection()->GetMessage(message, field);
|
||||
}
|
||||
if (index < message.GetReflection()->FieldSize(message, field)) {
|
||||
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns all FieldDescriptors including extensions.
|
||||
std::vector<const FieldDescriptor*> GetFields(const Message* src) {
|
||||
std::vector<const FieldDescriptor*> result;
|
||||
src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(),
|
||||
&result);
|
||||
for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) {
|
||||
result.push_back(src->GetDescriptor()->field(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Orders map entries in dst to match src.
|
||||
void OrderMapEntries(const Message* src, Message* dst,
|
||||
std::set<const Message*>* seen = nullptr) {
|
||||
std::unique_ptr<std::set<const Message*>> seen_owner;
|
||||
if (!seen) {
|
||||
seen_owner = std::make_unique<std::set<const Message*>>();
|
||||
seen = seen_owner.get();
|
||||
}
|
||||
if (seen->count(src) > 0) {
|
||||
return;
|
||||
} else {
|
||||
seen->insert(src);
|
||||
}
|
||||
for (auto field : GetFields(src)) {
|
||||
if (field->is_map()) {
|
||||
dst->GetReflection()->ClearField(dst, field);
|
||||
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||
const Message& entry =
|
||||
src->GetReflection()->GetRepeatedMessage(*src, field, j);
|
||||
dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry);
|
||||
}
|
||||
}
|
||||
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||
if (field->is_repeated()) {
|
||||
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||
OrderMapEntries(
|
||||
&src->GetReflection()->GetRepeatedMessage(*src, field, j),
|
||||
dst->GetReflection()->MutableRepeatedMessage(dst, field, j),
|
||||
seen);
|
||||
}
|
||||
} else {
|
||||
OrderMapEntries(&src->GetReflection()->GetMessage(*src, field),
|
||||
dst->GetReflection()->MutableMessage(dst, field), seen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copies a Message, keeping map entries in order.
|
||||
std::unique_ptr<Message> CloneMessage(const Message* message) {
|
||||
std::unique_ptr<Message> result(message->New());
|
||||
result->CopyFrom(*message);
|
||||
OrderMapEntries(message, result.get());
|
||||
return result;
|
||||
}
|
||||
|
||||
using MessageMap = std::map<std::string, std::unique_ptr<Message>>;
|
||||
|
||||
// For a non-repeated field, move the most recently parsed field value
|
||||
// into the most recently parsed template expression.
|
||||
void StowFieldValue(Message* message, TemplateExpression* expression) {
|
||||
void StowFieldValue(Message* message, TemplateExpression* expression,
|
||||
MessageMap* stowed_messages) {
|
||||
const Reflection* reflection = message->GetReflection();
|
||||
const Descriptor* descriptor = message->GetDescriptor();
|
||||
ProtoUtilLite::ProtoPath path;
|
||||
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
|
||||
int field_number = path[path.size() - 1].field_id;
|
||||
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
|
||||
|
||||
// Save each stowed message unserialized preserving map entry order.
|
||||
if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||
(*stowed_messages)[ProtoPathJoin(path)] =
|
||||
CloneMessage(GetFieldMessage(*message, field, 0));
|
||||
}
|
||||
|
||||
if (!field->is_repeated()) {
|
||||
std::vector<ProtoUtilLite::FieldValue> field_values;
|
||||
SerializeField(message, field, &field_values);
|
||||
|
@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message,
|
|||
return result;
|
||||
}
|
||||
|
||||
// Returns the message value from a field at an index.
|
||||
const Message* GetFieldMessage(const Message& message,
|
||||
const FieldDescriptor* field, int index) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!field->is_repeated()) {
|
||||
return &message.GetReflection()->GetMessage(message, field);
|
||||
}
|
||||
if (index < message.GetReflection()->FieldSize(message, field)) {
|
||||
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Serialize a ProtoPath as a readable string.
|
||||
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
||||
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
std::string ProtoPathJoin(ProtoPath path) {
|
||||
std::string result;
|
||||
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
||||
if (e.field_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
||||
} else if (e.map_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
||||
"]");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the protobuf map key types from a ProtoPath.
|
||||
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
|
||||
std::vector<FieldType> result;
|
||||
|
@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) {
|
|||
return "";
|
||||
}
|
||||
|
||||
// Returns a Message store in CalculatorGraphTemplate::field_value.
|
||||
Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) {
|
||||
auto it = stowed_messages->find(ProtoPathJoin(proto_path));
|
||||
return (it != stowed_messages->end()) ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
const Message* GetNestedMessage(const Message& message,
|
||||
const FieldDescriptor* field,
|
||||
ProtoPath proto_path,
|
||||
MessageMap* stowed_messages) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
const Message* result = FindStowedMessage(stowed_messages, proto_path);
|
||||
if (!result) {
|
||||
result = GetFieldMessage(message, field, proto_path.back().index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Adjusts map-entries from indexes to keys.
|
||||
// Protobuf map-entry order is intentionally not preserved.
|
||||
mediapipe::Status KeyProtoMapEntries(Message* source) {
|
||||
absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) {
|
||||
// Copy the rules from the source CalculatorGraphTemplate.
|
||||
mediapipe::CalculatorGraphTemplate rules;
|
||||
rules.ParsePartialFromString(source->SerializePartialAsString());
|
||||
|
@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
|
|||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
|
||||
for (int j = 0; j < path.size(); ++j) {
|
||||
int field_id = path[j].field_id;
|
||||
int field_index = path[j].index;
|
||||
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
|
||||
ProtoPath prefix = {path.begin(), path.begin() + j + 1};
|
||||
message = GetNestedMessage(*message, field, prefix, stowed_messages);
|
||||
if (!message) {
|
||||
break;
|
||||
}
|
||||
if (field->is_map()) {
|
||||
const Message* map_entry =
|
||||
GetFieldMessage(*message, field, path[j].index);
|
||||
const Message* map_entry = message;
|
||||
int key_id =
|
||||
map_entry->GetDescriptor()->FindFieldByName("key")->number();
|
||||
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
|
||||
|
@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
|
|||
std::string key_value = GetMapKey(*map_entry);
|
||||
path[j] = {field_id, key_id, key_type, key_value};
|
||||
}
|
||||
message = GetFieldMessage(*message, field, field_index);
|
||||
if (!message) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!rule->path().empty()) {
|
||||
*rule->mutable_path() = ProtoPathJoin(path);
|
||||
|
@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
// Copy the template rules into the output template "rule" field.
|
||||
success &= MergeFields(template_rules_, output).ok();
|
||||
// Replace map-entry indexes with map keys.
|
||||
success &= KeyProtoMapEntries(output).ok();
|
||||
success &= KeyProtoMapEntries(output, &stowed_messages_).ok();
|
||||
return success;
|
||||
}
|
||||
|
||||
|
@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
DO(ConsumeFieldTemplate(message));
|
||||
} else {
|
||||
DO(ConsumeField(message));
|
||||
StowFieldValue(message, expression);
|
||||
StowFieldValue(message, expression, &stowed_messages_);
|
||||
}
|
||||
DO(ConsumeEndTemplate());
|
||||
return true;
|
||||
|
@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
}
|
||||
|
||||
mediapipe::CalculatorGraphTemplate template_rules_;
|
||||
std::map<std::string, std::unique_ptr<Message>> stowed_messages_;
|
||||
};
|
||||
|
||||
#undef DO
|
||||
|
|
Loading…
Reference in New Issue
Block a user