From db404b1a8593a8b316cc4930dc1bcc845fc3df62 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 10:21:07 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 495058817 --- mediapipe/framework/tool/proto_util_lite.h | 7 +- mediapipe/framework/tool/template_parser.cc | 181 +++++++++++++++----- 2 files changed, 144 insertions(+), 44 deletions(-) diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index d71ceac83..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -48,11 +48,16 @@ class ProtoUtilLite { key_id(key_id), key_type(key_type), key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } int field_id = -1; int index = -1; int map_id = -1; int key_id = -1; - FieldType key_type; + FieldType key_type = FieldType::MAX_FIELD_TYPE; FieldValue key_value; }; diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 5a0ceccd3..cf23f3443 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path, return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message, return result; } -// Returns the message value from a field at an index. -const Message* GetFieldMessage(const Message& message, - const FieldDescriptor* field, int index) { - if (field->type() != FieldDescriptor::TYPE_MESSAGE) { - return nullptr; - } - if (!field->is_repeated()) { - return &message.GetReflection()->GetMessage(message, field); - } - if (index < message.GetReflection()->FieldSize(message, field)) { - return &message.GetReflection()->GetRepeatedMessage(message, field, index); - } - return nullptr; -} - -// Serialize a ProtoPath as a readable string. -// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", -// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". -std::string ProtoPathJoin(ProtoPath path) { - std::string result; - for (ProtoUtilLite::ProtoPathEntry& e : path) { - if (e.field_id >= 0) { - absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); - } else if (e.map_id >= 0) { - absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, - "]"); - } - } - return result; -} - // Returns the protobuf map key types from a ProtoPath. std::vector ProtoPathKeyTypes(ProtoPath path) { std::vector result; @@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) { return ""; } +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + // Adjusts map-entries from indexes to keys. // Protobuf map-entry order is intentionally not preserved. -mediapipe::Status KeyProtoMapEntries(Message* source) { +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { // Copy the rules from the source CalculatorGraphTemplate. mediapipe::CalculatorGraphTemplate rules; rules.ParsePartialFromString(source->SerializePartialAsString()); @@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); for (int j = 0; j < path.size(); ++j) { int field_id = path[j].field_id; - int field_index = path[j].index; const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } if (field->is_map()) { - const Message* map_entry = - GetFieldMessage(*message, field, path[j].index); + const Message* map_entry = message; int key_id = map_entry->GetDescriptor()->FindFieldByName("key")->number(); FieldType key_type = static_cast( @@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { std::string key_value = GetMapKey(*map_entry); path[j] = {field_id, key_id, key_type, key_value}; } - message = GetFieldMessage(*message, field, field_index); - if (!message) { - break; - } } if (!rule->path().empty()) { *rule->mutable_path() = ProtoPathJoin(path); @@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); // Replace map-entry indexes with map keys. - success &= KeyProtoMapEntries(output).ok(); + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO