Internal change
PiperOrigin-RevId: 494420725
This commit is contained in:
parent
e9bb51a524
commit
421f789ede
|
@ -346,6 +346,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -506,6 +507,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":proto_util_lite",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
|
||||
"//mediapipe/framework/port:advanced_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
|
|
@ -27,6 +27,9 @@ message TemplateExpression {
|
|||
// The FieldDescriptor::Type of the modified field.
|
||||
optional mediapipe.FieldDescriptorProto.Type field_type = 5;
|
||||
|
||||
// The FieldDescriptor::Type of each map key in the path.
|
||||
repeated mediapipe.FieldDescriptorProto.Type key_type = 6;
|
||||
|
||||
// Alternative value for the modified field, in protobuf binary format.
|
||||
optional string field_value = 7;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/framework/tool/field_data.pb.h"
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
|
||||
|
@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
|
|||
|
||||
// Extracts the data value(s) for one field from a serialized message.
|
||||
// The message with these field values removed is written to |out|.
|
||||
absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
|
||||
CodedInputStream* in, CodedOutputStream* out,
|
||||
absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in,
|
||||
CodedOutputStream* out,
|
||||
std::vector<std::string>* field_values) {
|
||||
uint32 tag;
|
||||
while ((tag = in->ReadTag()) != 0) {
|
||||
int field_number = WireFormatLite::GetTagFieldNumber(tag);
|
||||
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
|
||||
if (field_number == field_id) {
|
||||
if (!IsLengthDelimited(wire_type) &&
|
||||
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
|
||||
|
@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) {
|
|||
CodedInputStream in(&ais);
|
||||
StringOutputStream sos(&message_);
|
||||
CodedOutputStream out(&sos);
|
||||
WireFormatLite::WireType wire_type =
|
||||
WireFormatLite::WireTypeForFieldType(field_type_);
|
||||
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
|
||||
return GetFieldValues(field_id_, &in, &out, &field_values_);
|
||||
}
|
||||
|
||||
void FieldAccess::GetMessage(std::string* result) {
|
||||
|
@ -149,18 +149,56 @@ std::vector<FieldValue>* FieldAccess::mutable_field_values() {
|
|||
return &field_values_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
|
||||
|
||||
// Returns the FieldAccess and index for a field-id or a map-id.
|
||||
// Returns access to the field-id if the field index is found,
|
||||
// to the map-id if the map entry is found, and to the field-id otherwise.
|
||||
absl::StatusOr<std::pair<FieldAccess, int>> AccessField(
|
||||
const ProtoPathEntry& entry, FieldType field_type,
|
||||
const FieldValue& message) {
|
||||
FieldAccess result(entry.field_id, field_type);
|
||||
if (entry.field_id >= 0) {
|
||||
MP_RETURN_IF_ERROR(result.SetMessage(message));
|
||||
if (entry.index < result.mutable_field_values()->size()) {
|
||||
return std::pair(result, entry.index);
|
||||
}
|
||||
}
|
||||
if (entry.map_id >= 0) {
|
||||
FieldAccess access(entry.map_id, field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||
auto& field_values = *access.mutable_field_values();
|
||||
for (int index = 0; index < field_values.size(); ++index) {
|
||||
FieldAccess key(entry.key_id, entry.key_type);
|
||||
MP_RETURN_IF_ERROR(key.SetMessage(field_values[index]));
|
||||
if (key.mutable_field_values()->at(0) == entry.key_value) {
|
||||
return std::pair(std::move(access), index);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (entry.field_id >= 0) {
|
||||
return std::pair(result, entry.index);
|
||||
}
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ",
|
||||
entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Replaces a range of field values for one field nested within a protobuf.
|
||||
absl::Status ProtoUtilLite::ReplaceFieldRange(
|
||||
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
|
||||
const std::vector<FieldValue>& field_values) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.front();
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldAccess access(field_id, !proto_path.empty()
|
||||
? WireFormatLite::TYPE_MESSAGE
|
||||
: field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(*message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
|
||||
|
@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
|
|||
absl::Status ProtoUtilLite::GetFieldRange(
|
||||
const FieldValue& message, ProtoPath proto_path, int length,
|
||||
FieldType field_type, std::vector<FieldValue>* field_values) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.front();
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldAccess access(field_id, !proto_path.empty()
|
||||
? WireFormatLite::TYPE_MESSAGE
|
||||
: field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldRange(v[index], proto_path, length, field_type, field_values));
|
||||
} else {
|
||||
if (length == -1) {
|
||||
length = v.size() - index;
|
||||
}
|
||||
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||
field_values->insert(field_values->begin(), v.begin() + index,
|
||||
|
@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
|
|||
ProtoPath proto_path,
|
||||
FieldType field_type,
|
||||
int* field_count) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.back();
|
||||
proto_path.pop_back();
|
||||
std::vector<std::string> parent;
|
||||
if (proto_path.empty()) {
|
||||
parent.push_back(std::string(message));
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldCount(v[index], proto_path, field_type, field_count));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
|
||||
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
|
||||
*field_count = v.size();
|
||||
}
|
||||
FieldAccess access(field_id, field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
|
||||
*field_count = access.mutable_field_values()->size();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -34,15 +34,31 @@ class ProtoUtilLite {
|
|||
// Defines field types and tag formats.
|
||||
using WireFormatLite = proto_ns::internal::WireFormatLite;
|
||||
|
||||
// Defines a sequence of nested field-number field-index pairs.
|
||||
using ProtoPath = std::vector<std::pair<int, int>>;
|
||||
|
||||
// The serialized value for a protobuf field.
|
||||
using FieldValue = std::string;
|
||||
|
||||
// The serialized data type for a protobuf field.
|
||||
using FieldType = WireFormatLite::FieldType;
|
||||
|
||||
// A field-id and index, or a map-id and key, or both.
|
||||
struct ProtoPathEntry {
|
||||
ProtoPathEntry(int id, int index) : field_id(id), index(index) {}
|
||||
ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value)
|
||||
: map_id(id),
|
||||
key_id(key_id),
|
||||
key_type(key_type),
|
||||
key_value(std::move(key_value)) {}
|
||||
int field_id = -1;
|
||||
int index = -1;
|
||||
int map_id = -1;
|
||||
int key_id = -1;
|
||||
FieldType key_type;
|
||||
FieldValue key_value;
|
||||
};
|
||||
|
||||
// Defines a sequence of nested field-number field-index pairs.
|
||||
using ProtoPath = std::vector<ProtoPathEntry>;
|
||||
|
||||
class FieldAccess {
|
||||
public:
|
||||
// Provides access to a certain protobuf field.
|
||||
|
@ -57,9 +73,11 @@ class ProtoUtilLite {
|
|||
// Returns the serialized values of the protobuf field.
|
||||
std::vector<FieldValue>* mutable_field_values();
|
||||
|
||||
uint32 field_id() const { return field_id_; }
|
||||
|
||||
private:
|
||||
const uint32 field_id_;
|
||||
const FieldType field_type_;
|
||||
uint32 field_id_;
|
||||
FieldType field_type_;
|
||||
std::string message_;
|
||||
std::vector<FieldValue> field_values_;
|
||||
};
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
|
@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite;
|
|||
using FieldValue = ProtoUtilLite::FieldValue;
|
||||
using FieldType = ProtoUtilLite::FieldType;
|
||||
using ProtoPath = ProtoUtilLite::ProtoPath;
|
||||
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -84,26 +86,87 @@ std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Returns the (tag, index) pairs in a field path.
|
||||
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]".
|
||||
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
|
||||
absl::Status status;
|
||||
std::vector<std::string> ids = absl::StrSplit(path, '/');
|
||||
for (const std::string& id : ids) {
|
||||
if (id.length() > 0) {
|
||||
std::pair<std::string, std::string> id_pair =
|
||||
absl::StrSplit(id, absl::ByAnyChar("[]"));
|
||||
int tag = 0;
|
||||
int index = 0;
|
||||
bool ok = absl::SimpleAtoi(id_pair.first, &tag) &&
|
||||
absl::SimpleAtoi(id_pair.second, &index);
|
||||
if (!ok) {
|
||||
status.Update(absl::InvalidArgumentError(path));
|
||||
}
|
||||
result->push_back(std::make_pair(tag, index));
|
||||
// Parses one ProtoPathEntry.
|
||||
// The parsed entry is appended to `result` and removed from `path`.
|
||||
// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes
|
||||
// to serialize the key text to protobuf wire format.
|
||||
absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
|
||||
bool ok = true;
|
||||
int sb = path.find('[');
|
||||
int eb = path.find(']');
|
||||
int field_id = -1;
|
||||
ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
|
||||
auto selector = path.substr(sb + 1, eb - 1 - sb);
|
||||
if (absl::StartsWith(selector, "@")) {
|
||||
int eq = selector.find('=');
|
||||
int key_id = -1;
|
||||
ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
|
||||
auto key_text = selector.substr(eq + 1);
|
||||
FieldType key_type = FieldType::TYPE_STRING;
|
||||
result->push_back({field_id, key_id, key_type, std::string(key_text)});
|
||||
} else {
|
||||
int index = 0;
|
||||
ok &= absl::SimpleAtoi(selector, &index);
|
||||
result->push_back({field_id, index});
|
||||
}
|
||||
int end = path.find('/', eb);
|
||||
if (end == std::string::npos) {
|
||||
path = "";
|
||||
} else {
|
||||
path = path.substr(end + 1);
|
||||
}
|
||||
return ok ? absl::OkStatus()
|
||||
: absl::InvalidArgumentError(
|
||||
absl::StrCat("Failed to parse ProtoPath entry: ", path));
|
||||
}
|
||||
|
||||
// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
|
||||
// Each ProtoPathEntry::key_value is converted from text to the protobuf
|
||||
// wire format for its key type.
|
||||
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
|
||||
ProtoPath* result) {
|
||||
int i = 0;
|
||||
for (ProtoPathEntry& entry : *result) {
|
||||
if (entry.map_id >= 0) {
|
||||
FieldType key_type = key_types[i++];
|
||||
std::vector<FieldValue> key_value;
|
||||
MP_RETURN_IF_ERROR(
|
||||
ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
|
||||
entry.key_type = key_type;
|
||||
entry.key_value = key_value.front();
|
||||
}
|
||||
}
|
||||
return status;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns the (tag, index) pairs in a field path.
|
||||
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
|
||||
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
|
||||
result->clear();
|
||||
absl::string_view rest = path;
|
||||
if (absl::StartsWith(rest, "/")) {
|
||||
rest = rest.substr(1);
|
||||
}
|
||||
while (!rest.empty()) {
|
||||
MP_RETURN_IF_ERROR(ParseEntry(rest, result));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Parse the TemplateExpression.path field into a ProtoPath struct.
|
||||
absl::Status ParseProtoPath(const TemplateExpression& rule,
|
||||
std::string base_path, ProtoPath* result) {
|
||||
ProtoPath base_entries;
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
|
||||
std::vector<FieldType> key_types;
|
||||
for (int type : rule.key_type()) {
|
||||
key_types.push_back(static_cast<FieldType>(type));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
|
||||
result->erase(result->begin(), result->begin() + base_entries.size());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns true if one proto path is prefix by another.
|
||||
|
@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
|
|||
return absl::StartsWith(path, prefix);
|
||||
}
|
||||
|
||||
// Returns the part of one proto path after a prefix proto path.
|
||||
std::string ProtoPathRelative(const std::string& field_path,
|
||||
const std::string& base_path) {
|
||||
CHECK(ProtoPathStartsWith(field_path, base_path));
|
||||
return field_path.substr(base_path.length());
|
||||
}
|
||||
|
||||
// Returns the target ProtoUtilLite::FieldType of a rule.
|
||||
FieldType GetFieldType(const TemplateExpression& rule) {
|
||||
return static_cast<FieldType>(rule.field_type());
|
||||
|
@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) {
|
|||
// Returns the count of field values at a ProtoPath.
|
||||
int FieldCount(const FieldValue& base, ProtoPath field_path,
|
||||
FieldType field_type) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = field_path.back();
|
||||
field_path.pop_back();
|
||||
std::vector<FieldValue> parent;
|
||||
if (field_path.empty()) {
|
||||
parent.push_back(base);
|
||||
} else {
|
||||
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
|
||||
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
|
||||
}
|
||||
ProtoUtilLite::FieldAccess access(field_id, field_type);
|
||||
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
|
||||
return access.mutable_field_values()->size();
|
||||
int result = 0;
|
||||
CHECK(
|
||||
ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -229,9 +276,7 @@ class TemplateExpanderImpl {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
ProtoPath field_path;
|
||||
absl::Status status =
|
||||
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
|
||||
if (!status.ok()) return status;
|
||||
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
|
||||
return ProtoUtilLite::GetFieldRange(output, field_path, 1,
|
||||
GetFieldType(rule), base);
|
||||
}
|
||||
|
@ -242,12 +287,13 @@ class TemplateExpanderImpl {
|
|||
const std::vector<FieldValue>& field_values,
|
||||
FieldValue* output) {
|
||||
if (!rule.has_path()) {
|
||||
*output = field_values[0];
|
||||
if (!field_values.empty()) {
|
||||
*output = field_values[0];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
ProtoPath field_path;
|
||||
RET_CHECK_OK(
|
||||
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
|
||||
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
|
||||
int field_count = 1;
|
||||
if (rule.has_field_value()) {
|
||||
// For a non-repeated field, only one value can be specified.
|
||||
|
@ -257,7 +303,7 @@ class TemplateExpanderImpl {
|
|||
"Multiple values specified for non-repeated field: ", rule.path()));
|
||||
}
|
||||
// For a non-repeated field, the field value is stored only in the rule.
|
||||
field_path[field_path.size() - 1].second = 0;
|
||||
field_path[field_path.size() - 1].index = 0;
|
||||
field_count = 0;
|
||||
}
|
||||
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/proto_descriptor.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
|
@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message;
|
|||
using mediapipe::proto_ns::OneofDescriptor;
|
||||
using mediapipe::proto_ns::Reflection;
|
||||
using mediapipe::proto_ns::TextFormat;
|
||||
using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath;
|
||||
using FieldType = mediapipe::tool::ProtoUtilLite::FieldType;
|
||||
using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue;
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path,
|
|||
if (!ok) {
|
||||
status.Update(absl::InvalidArgumentError(path));
|
||||
}
|
||||
result->push_back(std::make_pair(tag, index));
|
||||
result->push_back({tag, index});
|
||||
}
|
||||
}
|
||||
return status;
|
||||
|
@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) {
|
|||
const Descriptor* descriptor = message->GetDescriptor();
|
||||
ProtoUtilLite::ProtoPath path;
|
||||
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
|
||||
int field_number = path[path.size() - 1].first;
|
||||
int field_number = path[path.size() - 1].field_id;
|
||||
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
|
||||
if (!field->is_repeated()) {
|
||||
std::vector<ProtoUtilLite::FieldValue> field_values;
|
||||
|
@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) {
|
|||
}
|
||||
}
|
||||
|
||||
// Returns the field or extension for field number.
|
||||
const FieldDescriptor* FindFieldByNumber(const Message* message,
|
||||
int field_num) {
|
||||
const FieldDescriptor* result =
|
||||
message->GetDescriptor()->FindFieldByNumber(field_num);
|
||||
if (result == nullptr) {
|
||||
result = message->GetReflection()->FindKnownExtensionByNumber(field_num);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the message value from a field at an index.
|
||||
const Message* GetFieldMessage(const Message& message,
|
||||
const FieldDescriptor* field, int index) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!field->is_repeated()) {
|
||||
return &message.GetReflection()->GetMessage(message, field);
|
||||
}
|
||||
if (index < message.GetReflection()->FieldSize(message, field)) {
|
||||
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Serialize a ProtoPath as a readable string.
|
||||
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
||||
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
std::string ProtoPathJoin(ProtoPath path) {
|
||||
std::string result;
|
||||
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
||||
if (e.field_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
||||
} else if (e.map_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
||||
"]");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the protobuf map key types from a ProtoPath.
|
||||
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
|
||||
std::vector<FieldType> result;
|
||||
for (auto& entry : path) {
|
||||
if (entry.map_id >= 0) {
|
||||
result.push_back(entry.key_type);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the text value for a string or numeric protobuf map key.
|
||||
std::string GetMapKey(const Message& map_entry) {
|
||||
auto key_field = map_entry.GetDescriptor()->FindFieldByName("key");
|
||||
auto reflection = map_entry.GetReflection();
|
||||
if (key_field->type() == FieldDescriptor::TYPE_STRING) {
|
||||
return reflection->GetString(map_entry, key_field);
|
||||
} else if (key_field->type() == FieldDescriptor::TYPE_INT32) {
|
||||
return absl::StrCat(reflection->GetInt32(map_entry, key_field));
|
||||
} else if (key_field->type() == FieldDescriptor::TYPE_INT64) {
|
||||
return absl::StrCat(reflection->GetInt64(map_entry, key_field));
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// Adjusts map-entries from indexes to keys.
|
||||
// Protobuf map-entry order is intentionally not preserved.
|
||||
mediapipe::Status KeyProtoMapEntries(Message* source) {
|
||||
// Copy the rules from the source CalculatorGraphTemplate.
|
||||
mediapipe::CalculatorGraphTemplate rules;
|
||||
rules.ParsePartialFromString(source->SerializePartialAsString());
|
||||
// Only the "source" Message knows all extension types.
|
||||
Message* config_0 = source->GetReflection()->MutableMessage(
|
||||
source, source->GetDescriptor()->FindFieldByName("config"), nullptr);
|
||||
for (int i = 0; i < rules.rule().size(); ++i) {
|
||||
TemplateExpression* rule = rules.mutable_rule()->Mutable(i);
|
||||
const Message* message = config_0;
|
||||
ProtoPath path;
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
|
||||
for (int j = 0; j < path.size(); ++j) {
|
||||
int field_id = path[j].field_id;
|
||||
int field_index = path[j].index;
|
||||
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
|
||||
if (field->is_map()) {
|
||||
const Message* map_entry =
|
||||
GetFieldMessage(*message, field, path[j].index);
|
||||
int key_id =
|
||||
map_entry->GetDescriptor()->FindFieldByName("key")->number();
|
||||
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
|
||||
map_entry->GetDescriptor()->FindFieldByName("key")->type());
|
||||
std::string key_value = GetMapKey(*map_entry);
|
||||
path[j] = {field_id, key_id, key_type, key_value};
|
||||
}
|
||||
message = GetFieldMessage(*message, field, field_index);
|
||||
if (!message) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!rule->path().empty()) {
|
||||
*rule->mutable_path() = ProtoPathJoin(path);
|
||||
for (FieldType key_type : ProtoPathKeyTypes(path)) {
|
||||
*rule->mutable_key_type()->Add() = key_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Copy the rules back into the source CalculatorGraphTemplate.
|
||||
auto source_rules =
|
||||
source->GetReflection()->GetMutableRepeatedFieldRef<Message>(
|
||||
source, source->GetDescriptor()->FindFieldByName("rule"));
|
||||
source_rules.Clear();
|
||||
for (auto& rule : rules.rule()) {
|
||||
source_rules.Add(rule);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class TemplateParser::Parser::MediaPipeParserImpl
|
||||
|
@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
|
||||
// Copy the template rules into the output template "rule" field.
|
||||
success &= MergeFields(template_rules_, output).ok();
|
||||
// Replace map-entry indexes with map keys.
|
||||
success &= KeyProtoMapEntries(output).ok();
|
||||
return success;
|
||||
}
|
||||
|
||||
|
|
10
mediapipe/framework/tool/testdata/BUILD
vendored
10
mediapipe/framework/tool/testdata/BUILD
vendored
|
@ -17,6 +17,7 @@ load(
|
|||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||
"mediapipe_simple_subgraph",
|
||||
)
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -58,3 +59,12 @@ mediapipe_simple_subgraph(
|
|||
"//mediapipe/framework:test_calculators",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "frozen_generator_proto",
|
||||
srcs = ["frozen_generator.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet_generator_proto",
|
||||
],
|
||||
)
|
||||
|
|
20
mediapipe/framework/tool/testdata/frozen_generator.proto
vendored
Normal file
20
mediapipe/framework/tool/testdata/frozen_generator.proto
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/packet_generator.proto";
|
||||
|
||||
message FrozenGeneratorOptions {
|
||||
extend mediapipe.PacketGeneratorOptions {
|
||||
optional FrozenGeneratorOptions ext = 225748738;
|
||||
}
|
||||
|
||||
// Path to file containing serialized proto of type tensorflow::GraphDef.
|
||||
optional string graph_proto_path = 1;
|
||||
|
||||
// This map defines the which streams are fed to which tensors in the model.
|
||||
map<string, string> tag_to_tensor_names = 2;
|
||||
|
||||
// Graph nodes to run to initialize the model.
|
||||
repeated string initialization_op_names = 4;
|
||||
}
|
Loading…
Reference in New Issue
Block a user