Internal change
PiperOrigin-RevId: 495058817
This commit is contained in:
parent
78597c5b37
commit
db404b1a85
|
@ -48,11 +48,16 @@ class ProtoUtilLite {
|
||||||
key_id(key_id),
|
key_id(key_id),
|
||||||
key_type(key_type),
|
key_type(key_type),
|
||||||
key_value(std::move(key_value)) {}
|
key_value(std::move(key_value)) {}
|
||||||
|
bool operator==(const ProtoPathEntry& o) const {
|
||||||
|
return field_id == o.field_id && index == o.index && map_id == o.map_id &&
|
||||||
|
key_id == o.key_id && key_type == o.key_type &&
|
||||||
|
key_value == o.key_value;
|
||||||
|
}
|
||||||
int field_id = -1;
|
int field_id = -1;
|
||||||
int index = -1;
|
int index = -1;
|
||||||
int map_id = -1;
|
int map_id = -1;
|
||||||
int key_id = -1;
|
int key_id = -1;
|
||||||
FieldType key_type;
|
FieldType key_type = FieldType::MAX_FIELD_TYPE;
|
||||||
FieldValue key_value;
|
FieldValue key_value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path,
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a message serialized deterministically.
|
||||||
|
bool DeterministicallySerialize(const Message& proto, std::string* result) {
|
||||||
|
proto_ns::io::StringOutputStream stream(result);
|
||||||
|
proto_ns::io::CodedOutputStream output(&stream);
|
||||||
|
output.SetSerializationDeterministic(true);
|
||||||
|
return proto.SerializeToCodedStream(&output);
|
||||||
|
}
|
||||||
|
|
||||||
// Serialize one field of a message.
|
// Serialize one field of a message.
|
||||||
void SerializeField(const Message* message, const FieldDescriptor* field,
|
void SerializeField(const Message* message, const FieldDescriptor* field,
|
||||||
std::vector<ProtoUtilLite::FieldValue>* result) {
|
std::vector<ProtoUtilLite::FieldValue>* result) {
|
||||||
ProtoUtilLite::FieldValue message_bytes;
|
ProtoUtilLite::FieldValue message_bytes;
|
||||||
CHECK(message->SerializePartialToString(&message_bytes));
|
CHECK(DeterministicallySerialize(*message, &message_bytes));
|
||||||
ProtoUtilLite::FieldAccess access(
|
ProtoUtilLite::FieldAccess access(
|
||||||
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
|
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
|
||||||
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
|
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
|
||||||
*result = *access.mutable_field_values();
|
*result = *access.mutable_field_values();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serialize a ProtoPath as a readable string.
|
||||||
|
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
||||||
|
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
||||||
|
std::string ProtoPathJoin(ProtoPath path) {
|
||||||
|
std::string result;
|
||||||
|
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
||||||
|
if (e.field_id >= 0) {
|
||||||
|
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
||||||
|
} else if (e.map_id >= 0) {
|
||||||
|
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the message value from a field at an index.
|
||||||
|
const Message* GetFieldMessage(const Message& message,
|
||||||
|
const FieldDescriptor* field, int index) {
|
||||||
|
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!field->is_repeated()) {
|
||||||
|
return &message.GetReflection()->GetMessage(message, field);
|
||||||
|
}
|
||||||
|
if (index < message.GetReflection()->FieldSize(message, field)) {
|
||||||
|
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns all FieldDescriptors including extensions.
|
||||||
|
std::vector<const FieldDescriptor*> GetFields(const Message* src) {
|
||||||
|
std::vector<const FieldDescriptor*> result;
|
||||||
|
src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(),
|
||||||
|
&result);
|
||||||
|
for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) {
|
||||||
|
result.push_back(src->GetDescriptor()->field(i));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Orders map entries in dst to match src.
|
||||||
|
void OrderMapEntries(const Message* src, Message* dst,
|
||||||
|
std::set<const Message*>* seen = nullptr) {
|
||||||
|
std::unique_ptr<std::set<const Message*>> seen_owner;
|
||||||
|
if (!seen) {
|
||||||
|
seen_owner = std::make_unique<std::set<const Message*>>();
|
||||||
|
seen = seen_owner.get();
|
||||||
|
}
|
||||||
|
if (seen->count(src) > 0) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
seen->insert(src);
|
||||||
|
}
|
||||||
|
for (auto field : GetFields(src)) {
|
||||||
|
if (field->is_map()) {
|
||||||
|
dst->GetReflection()->ClearField(dst, field);
|
||||||
|
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||||
|
const Message& entry =
|
||||||
|
src->GetReflection()->GetRepeatedMessage(*src, field, j);
|
||||||
|
dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||||
|
if (field->is_repeated()) {
|
||||||
|
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||||
|
OrderMapEntries(
|
||||||
|
&src->GetReflection()->GetRepeatedMessage(*src, field, j),
|
||||||
|
dst->GetReflection()->MutableRepeatedMessage(dst, field, j),
|
||||||
|
seen);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
OrderMapEntries(&src->GetReflection()->GetMessage(*src, field),
|
||||||
|
dst->GetReflection()->MutableMessage(dst, field), seen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies a Message, keeping map entries in order.
|
||||||
|
std::unique_ptr<Message> CloneMessage(const Message* message) {
|
||||||
|
std::unique_ptr<Message> result(message->New());
|
||||||
|
result->CopyFrom(*message);
|
||||||
|
OrderMapEntries(message, result.get());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
using MessageMap = std::map<std::string, std::unique_ptr<Message>>;
|
||||||
|
|
||||||
// For a non-repeated field, move the most recently parsed field value
|
// For a non-repeated field, move the most recently parsed field value
|
||||||
// into the most recently parsed template expression.
|
// into the most recently parsed template expression.
|
||||||
void StowFieldValue(Message* message, TemplateExpression* expression) {
|
void StowFieldValue(Message* message, TemplateExpression* expression,
|
||||||
|
MessageMap* stowed_messages) {
|
||||||
const Reflection* reflection = message->GetReflection();
|
const Reflection* reflection = message->GetReflection();
|
||||||
const Descriptor* descriptor = message->GetDescriptor();
|
const Descriptor* descriptor = message->GetDescriptor();
|
||||||
ProtoUtilLite::ProtoPath path;
|
ProtoUtilLite::ProtoPath path;
|
||||||
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
|
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
|
||||||
int field_number = path[path.size() - 1].field_id;
|
int field_number = path[path.size() - 1].field_id;
|
||||||
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
|
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
|
||||||
|
|
||||||
|
// Save each stowed message unserialized preserving map entry order.
|
||||||
|
if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||||
|
(*stowed_messages)[ProtoPathJoin(path)] =
|
||||||
|
CloneMessage(GetFieldMessage(*message, field, 0));
|
||||||
|
}
|
||||||
|
|
||||||
if (!field->is_repeated()) {
|
if (!field->is_repeated()) {
|
||||||
std::vector<ProtoUtilLite::FieldValue> field_values;
|
std::vector<ProtoUtilLite::FieldValue> field_values;
|
||||||
SerializeField(message, field, &field_values);
|
SerializeField(message, field, &field_values);
|
||||||
|
@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the message value from a field at an index.
|
|
||||||
const Message* GetFieldMessage(const Message& message,
|
|
||||||
const FieldDescriptor* field, int index) {
|
|
||||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (!field->is_repeated()) {
|
|
||||||
return &message.GetReflection()->GetMessage(message, field);
|
|
||||||
}
|
|
||||||
if (index < message.GetReflection()->FieldSize(message, field)) {
|
|
||||||
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize a ProtoPath as a readable string.
|
|
||||||
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
|
||||||
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
|
||||||
std::string ProtoPathJoin(ProtoPath path) {
|
|
||||||
std::string result;
|
|
||||||
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
|
||||||
if (e.field_id >= 0) {
|
|
||||||
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
|
||||||
} else if (e.map_id >= 0) {
|
|
||||||
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
|
||||||
"]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the protobuf map key types from a ProtoPath.
|
// Returns the protobuf map key types from a ProtoPath.
|
||||||
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
|
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
|
||||||
std::vector<FieldType> result;
|
std::vector<FieldType> result;
|
||||||
|
@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a Message store in CalculatorGraphTemplate::field_value.
|
||||||
|
Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) {
|
||||||
|
auto it = stowed_messages->find(ProtoPathJoin(proto_path));
|
||||||
|
return (it != stowed_messages->end()) ? it->second.get() : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Message* GetNestedMessage(const Message& message,
|
||||||
|
const FieldDescriptor* field,
|
||||||
|
ProtoPath proto_path,
|
||||||
|
MessageMap* stowed_messages) {
|
||||||
|
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const Message* result = FindStowedMessage(stowed_messages, proto_path);
|
||||||
|
if (!result) {
|
||||||
|
result = GetFieldMessage(message, field, proto_path.back().index);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Adjusts map-entries from indexes to keys.
|
// Adjusts map-entries from indexes to keys.
|
||||||
// Protobuf map-entry order is intentionally not preserved.
|
// Protobuf map-entry order is intentionally not preserved.
|
||||||
mediapipe::Status KeyProtoMapEntries(Message* source) {
|
absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) {
|
||||||
// Copy the rules from the source CalculatorGraphTemplate.
|
// Copy the rules from the source CalculatorGraphTemplate.
|
||||||
mediapipe::CalculatorGraphTemplate rules;
|
mediapipe::CalculatorGraphTemplate rules;
|
||||||
rules.ParsePartialFromString(source->SerializePartialAsString());
|
rules.ParsePartialFromString(source->SerializePartialAsString());
|
||||||
|
@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
|
||||||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
|
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
|
||||||
for (int j = 0; j < path.size(); ++j) {
|
for (int j = 0; j < path.size(); ++j) {
|
||||||
int field_id = path[j].field_id;
|
int field_id = path[j].field_id;
|
||||||
int field_index = path[j].index;
|
|
||||||
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
|
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
|
||||||
|
ProtoPath prefix = {path.begin(), path.begin() + j + 1};
|
||||||
|
message = GetNestedMessage(*message, field, prefix, stowed_messages);
|
||||||
|
if (!message) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (field->is_map()) {
|
if (field->is_map()) {
|
||||||
const Message* map_entry =
|
const Message* map_entry = message;
|
||||||
GetFieldMessage(*message, field, path[j].index);
|
|
||||||
int key_id =
|
int key_id =
|
||||||
map_entry->GetDescriptor()->FindFieldByName("key")->number();
|
map_entry->GetDescriptor()->FindFieldByName("key")->number();
|
||||||
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
|
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
|
||||||
|
@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) {
|
||||||
std::string key_value = GetMapKey(*map_entry);
|
std::string key_value = GetMapKey(*map_entry);
|
||||||
path[j] = {field_id, key_id, key_type, key_value};
|
path[j] = {field_id, key_id, key_type, key_value};
|
||||||
}
|
}
|
||||||
message = GetFieldMessage(*message, field, field_index);
|
|
||||||
if (!message) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!rule->path().empty()) {
|
if (!rule->path().empty()) {
|
||||||
*rule->mutable_path() = ProtoPathJoin(path);
|
*rule->mutable_path() = ProtoPathJoin(path);
|
||||||
|
@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
||||||
// Copy the template rules into the output template "rule" field.
|
// Copy the template rules into the output template "rule" field.
|
||||||
success &= MergeFields(template_rules_, output).ok();
|
success &= MergeFields(template_rules_, output).ok();
|
||||||
// Replace map-entry indexes with map keys.
|
// Replace map-entry indexes with map keys.
|
||||||
success &= KeyProtoMapEntries(output).ok();
|
success &= KeyProtoMapEntries(output, &stowed_messages_).ok();
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
||||||
DO(ConsumeFieldTemplate(message));
|
DO(ConsumeFieldTemplate(message));
|
||||||
} else {
|
} else {
|
||||||
DO(ConsumeField(message));
|
DO(ConsumeField(message));
|
||||||
StowFieldValue(message, expression);
|
StowFieldValue(message, expression, &stowed_messages_);
|
||||||
}
|
}
|
||||||
DO(ConsumeEndTemplate());
|
DO(ConsumeEndTemplate());
|
||||||
return true;
|
return true;
|
||||||
|
@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
mediapipe::CalculatorGraphTemplate template_rules_;
|
mediapipe::CalculatorGraphTemplate template_rules_;
|
||||||
|
std::map<std::string, std::unique_ptr<Message>> stowed_messages_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#undef DO
|
#undef DO
|
||||||
|
|
Loading…
Reference in New Issue
Block a user