Internal change

PiperOrigin-RevId: 495058817
This commit is contained in:
Hadon Nash 2022-12-13 10:21:07 -08:00 committed by Copybara-Service
parent 78597c5b37
commit db404b1a85
2 changed files with 144 additions and 44 deletions

View File

@ -48,11 +48,16 @@ class ProtoUtilLite {
key_id(key_id),
key_type(key_type),
key_value(std::move(key_value)) {}
bool operator==(const ProtoPathEntry& o) const {
return field_id == o.field_id && index == o.index && map_id == o.map_id &&
key_id == o.key_id && key_type == o.key_type &&
key_value == o.key_value;
}
int field_id = -1;
int index = -1;
int map_id = -1;
int key_id = -1;
FieldType key_type;
FieldType key_type = FieldType::MAX_FIELD_TYPE;
FieldValue key_value;
};

View File

@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path,
return status;
}
// Returns a message serialized deterministically.
bool DeterministicallySerialize(const Message& proto, std::string* result) {
proto_ns::io::StringOutputStream stream(result);
proto_ns::io::CodedOutputStream output(&stream);
output.SetSerializationDeterministic(true);
return proto.SerializeToCodedStream(&output);
}
// Serialize one field of a message.
void SerializeField(const Message* message, const FieldDescriptor* field,
std::vector<ProtoUtilLite::FieldValue>* result) {
ProtoUtilLite::FieldValue message_bytes;
CHECK(message->SerializePartialToString(&message_bytes));
CHECK(DeterministicallySerialize(*message, &message_bytes));
ProtoUtilLite::FieldAccess access(
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
*result = *access.mutable_field_values();
}
// Serialize a ProtoPath as a readable string.
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
std::string ProtoPathJoin(ProtoPath path) {
std::string result;
for (ProtoUtilLite::ProtoPathEntry& e : path) {
if (e.field_id >= 0) {
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
} else if (e.map_id >= 0) {
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
"]");
}
}
return result;
}
// Returns the message value from a field at an index.
const Message* GetFieldMessage(const Message& message,
const FieldDescriptor* field, int index) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
if (!field->is_repeated()) {
return &message.GetReflection()->GetMessage(message, field);
}
if (index < message.GetReflection()->FieldSize(message, field)) {
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
}
return nullptr;
}
// Returns all FieldDescriptors including extensions.
std::vector<const FieldDescriptor*> GetFields(const Message* src) {
std::vector<const FieldDescriptor*> result;
src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(),
&result);
for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) {
result.push_back(src->GetDescriptor()->field(i));
}
return result;
}
// Orders map entries in dst to match src.
void OrderMapEntries(const Message* src, Message* dst,
std::set<const Message*>* seen = nullptr) {
std::unique_ptr<std::set<const Message*>> seen_owner;
if (!seen) {
seen_owner = std::make_unique<std::set<const Message*>>();
seen = seen_owner.get();
}
if (seen->count(src) > 0) {
return;
} else {
seen->insert(src);
}
for (auto field : GetFields(src)) {
if (field->is_map()) {
dst->GetReflection()->ClearField(dst, field);
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
const Message& entry =
src->GetReflection()->GetRepeatedMessage(*src, field, j);
dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry);
}
}
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (field->is_repeated()) {
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
OrderMapEntries(
&src->GetReflection()->GetRepeatedMessage(*src, field, j),
dst->GetReflection()->MutableRepeatedMessage(dst, field, j),
seen);
}
} else {
OrderMapEntries(&src->GetReflection()->GetMessage(*src, field),
dst->GetReflection()->MutableMessage(dst, field), seen);
}
}
}
}
// Copies a Message, keeping map entries in order.
std::unique_ptr<Message> CloneMessage(const Message* message) {
std::unique_ptr<Message> result(message->New());
result->CopyFrom(*message);
OrderMapEntries(message, result.get());
return result;
}
using MessageMap = std::map<std::string, std::unique_ptr<Message>>;
// For a non-repeated field, move the most recently parsed field value
// into the most recently parsed template expression.
void StowFieldValue(Message* message, TemplateExpression* expression) {
void StowFieldValue(Message* message, TemplateExpression* expression,
MessageMap* stowed_messages) {
const Reflection* reflection = message->GetReflection();
const Descriptor* descriptor = message->GetDescriptor();
ProtoUtilLite::ProtoPath path;
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
int field_number = path[path.size() - 1].field_id;
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
// Save each stowed message unserialized preserving map entry order.
if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) {
(*stowed_messages)[ProtoPathJoin(path)] =
CloneMessage(GetFieldMessage(*message, field, 0));
}
if (!field->is_repeated()) {
std::vector<ProtoUtilLite::FieldValue> field_values;
SerializeField(message, field, &field_values);
@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message,
return result;
}
// Returns the message value from a field at an index.
const Message* GetFieldMessage(const Message& message,
const FieldDescriptor* field, int index) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
if (!field->is_repeated()) {
return &message.GetReflection()->GetMessage(message, field);
}
if (index < message.GetReflection()->FieldSize(message, field)) {
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
}
return nullptr;
}
// Serialize a ProtoPath as a readable string.
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
std::string ProtoPathJoin(ProtoPath path) {
std::string result;
for (ProtoUtilLite::ProtoPathEntry& e : path) {
if (e.field_id >= 0) {
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
} else if (e.map_id >= 0) {
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
"]");
}
}
return result;
}
// Returns the protobuf map key types from a ProtoPath.
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
std::vector<FieldType> result;
@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) {
return "";
}
// Returns a Message store in CalculatorGraphTemplate::field_value.
Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) {
auto it = stowed_messages->find(ProtoPathJoin(proto_path));
return (it != stowed_messages->end()) ? it->second.get() : nullptr;
}
const Message* GetNestedMessage(const Message& message,
const FieldDescriptor* field,
ProtoPath proto_path,
MessageMap* stowed_messages) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
const Message* result = FindStowedMessage(stowed_messages, proto_path);
if (!result) {
result = GetFieldMessage(message, field, proto_path.back().index);
}
return result;
}
// Adjusts map-entries from indexes to keys.
// Protobuf map-entry order is intentionally not preserved.
mediapipe::Status KeyProtoMapEntries(Message* source) {
absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) {
// Copy the rules from the source CalculatorGraphTemplate.
mediapipe::CalculatorGraphTemplate rules;
rules.ParsePartialFromString(source->SerializePartialAsString());
@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
for (int j = 0; j < path.size(); ++j) {
int field_id = path[j].field_id;
int field_index = path[j].index;
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
ProtoPath prefix = {path.begin(), path.begin() + j + 1};
message = GetNestedMessage(*message, field, prefix, stowed_messages);
if (!message) {
break;
}
if (field->is_map()) {
const Message* map_entry =
GetFieldMessage(*message, field, path[j].index);
const Message* map_entry = message;
int key_id =
map_entry->GetDescriptor()->FindFieldByName("key")->number();
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
std::string key_value = GetMapKey(*map_entry);
path[j] = {field_id, key_id, key_type, key_value};
}
message = GetFieldMessage(*message, field, field_index);
if (!message) {
break;
}
}
if (!rule->path().empty()) {
*rule->mutable_path() = ProtoPathJoin(path);
@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
// Copy the template rules into the output template "rule" field.
success &= MergeFields(template_rules_, output).ok();
// Replace map-entry indexes with map keys.
success &= KeyProtoMapEntries(output).ok();
success &= KeyProtoMapEntries(output, &stowed_messages_).ok();
return success;
}
@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
DO(ConsumeFieldTemplate(message));
} else {
DO(ConsumeField(message));
StowFieldValue(message, expression);
StowFieldValue(message, expression, &stowed_messages_);
}
DO(ConsumeEndTemplate());
return true;
@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
}
mediapipe::CalculatorGraphTemplate template_rules_;
std::map<std::string, std::unique_ptr<Message>> stowed_messages_;
};
#undef DO