Internal change

PiperOrigin-RevId: 494420725
This commit is contained in:
Hadon Nash 2022-12-10 12:32:04 -08:00 committed by Copybara-Service
parent e9bb51a524
commit 421f789ede
8 changed files with 348 additions and 82 deletions

View File

@ -346,6 +346,7 @@ cc_library(
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/strings",
],
)
@ -506,6 +507,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:integral_types",

View File

@ -27,6 +27,9 @@ message TemplateExpression {
// The FieldDescriptor::Type of the modified field.
optional mediapipe.FieldDescriptorProto.Type field_type = 5;
// The FieldDescriptor::Type of each map key in the path.
repeated mediapipe.FieldDescriptorProto.Type key_type = 6;
// Alternative value for the modified field, in protobuf binary format.
optional string field_value = 7;
}

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/type_map.h"
@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
// Extracts the data value(s) for one field from a serialized message.
// The message with these field values removed is written to |out|.
absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
CodedInputStream* in, CodedOutputStream* out,
absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in,
CodedOutputStream* out,
std::vector<std::string>* field_values) {
uint32 tag;
while ((tag = in->ReadTag()) != 0) {
int field_number = WireFormatLite::GetTagFieldNumber(tag);
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (field_number == field_id) {
if (!IsLengthDelimited(wire_type) &&
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) {
CodedInputStream in(&ais);
StringOutputStream sos(&message_);
CodedOutputStream out(&sos);
WireFormatLite::WireType wire_type =
WireFormatLite::WireTypeForFieldType(field_type_);
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
return GetFieldValues(field_id_, &in, &out, &field_values_);
}
void FieldAccess::GetMessage(std::string* result) {
@ -149,18 +149,56 @@ std::vector<FieldValue>* FieldAccess::mutable_field_values() {
return &field_values_;
}
namespace {
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
// Returns the FieldAccess and index for a field-id or a map-id.
// Returns access to the field-id if the field index is found,
// to the map-id if the map entry is found, and to the field-id otherwise.
absl::StatusOr<std::pair<FieldAccess, int>> AccessField(
const ProtoPathEntry& entry, FieldType field_type,
const FieldValue& message) {
FieldAccess result(entry.field_id, field_type);
if (entry.field_id >= 0) {
MP_RETURN_IF_ERROR(result.SetMessage(message));
if (entry.index < result.mutable_field_values()->size()) {
return std::pair(result, entry.index);
}
}
if (entry.map_id >= 0) {
FieldAccess access(entry.map_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(message));
auto& field_values = *access.mutable_field_values();
for (int index = 0; index < field_values.size(); ++index) {
FieldAccess key(entry.key_id, entry.key_type);
MP_RETURN_IF_ERROR(key.SetMessage(field_values[index]));
if (key.mutable_field_values()->at(0) == entry.key_value) {
return std::pair(std::move(access), index);
}
}
}
if (entry.field_id >= 0) {
return std::pair(result, entry.index);
}
return absl::InvalidArgumentError(absl::StrCat(
"ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ",
entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type));
}
} // namespace
// Replaces a range of field values for one field nested within a protobuf.
absl::Status ProtoUtilLite::ReplaceFieldRange(
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
const std::vector<FieldValue>& field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
MP_RETURN_IF_ERROR(access.SetMessage(*message));
std::vector<std::string>& v = *access.mutable_field_values();
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
absl::Status ProtoUtilLite::GetFieldRange(
const FieldValue& message, ProtoPath proto_path, int length,
FieldType field_type, std::vector<FieldValue>* field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
MP_RETURN_IF_ERROR(access.SetMessage(message));
std::vector<std::string>& v = *access.mutable_field_values();
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else {
if (length == -1) {
length = v.size() - index;
}
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index,
@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
ProtoPath proto_path,
FieldType field_type,
int* field_count) {
int field_id, index;
std::tie(field_id, index) = proto_path.back();
proto_path.pop_back();
std::vector<std::string> parent;
if (proto_path.empty()) {
parent.push_back(std::string(message));
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldCount(v[index], proto_path, field_type, field_count));
} else {
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
*field_count = v.size();
}
FieldAccess access(field_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
*field_count = access.mutable_field_values()->size();
return absl::OkStatus();
}

View File

@ -34,15 +34,31 @@ class ProtoUtilLite {
// Defines field types and tag formats.
using WireFormatLite = proto_ns::internal::WireFormatLite;
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<std::pair<int, int>>;
// The serialized value for a protobuf field.
using FieldValue = std::string;
// The serialized data type for a protobuf field.
using FieldType = WireFormatLite::FieldType;
// A field-id and index, or a map-id and key, or both.
struct ProtoPathEntry {
ProtoPathEntry(int id, int index) : field_id(id), index(index) {}
ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value)
: map_id(id),
key_id(key_id),
key_type(key_type),
key_value(std::move(key_value)) {}
int field_id = -1;
int index = -1;
int map_id = -1;
int key_id = -1;
FieldType key_type;
FieldValue key_value;
};
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<ProtoPathEntry>;
class FieldAccess {
public:
// Provides access to a certain protobuf field.
@ -57,9 +73,11 @@ class ProtoUtilLite {
// Returns the serialized values of the protobuf field.
std::vector<FieldValue>* mutable_field_values();
uint32 field_id() const { return field_id_; }
private:
const uint32 field_id_;
const FieldType field_type_;
uint32 field_id_;
FieldType field_type_;
std::string message_;
std::vector<FieldValue> field_values_;
};

View File

@ -22,6 +22,7 @@
#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldValue = ProtoUtilLite::FieldValue;
using FieldType = ProtoUtilLite::FieldType;
using ProtoPath = ProtoUtilLite::ProtoPath;
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
namespace {
@ -84,26 +86,87 @@ std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
return result;
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
absl::Status status;
std::vector<std::string> ids = absl::StrSplit(path, '/');
for (const std::string& id : ids) {
if (id.length() > 0) {
std::pair<std::string, std::string> id_pair =
absl::StrSplit(id, absl::ByAnyChar("[]"));
int tag = 0;
int index = 0;
bool ok = absl::SimpleAtoi(id_pair.first, &tag) &&
absl::SimpleAtoi(id_pair.second, &index);
if (!ok) {
status.Update(absl::InvalidArgumentError(path));
}
result->push_back(std::make_pair(tag, index));
// Parses one ProtoPathEntry.
// The parsed entry is appended to `result` and removed from `path`.
// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes
// to serialize the key text to protobuf wire format.
absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
bool ok = true;
int sb = path.find('[');
int eb = path.find(']');
int field_id = -1;
ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
auto selector = path.substr(sb + 1, eb - 1 - sb);
if (absl::StartsWith(selector, "@")) {
int eq = selector.find('=');
int key_id = -1;
ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
auto key_text = selector.substr(eq + 1);
FieldType key_type = FieldType::TYPE_STRING;
result->push_back({field_id, key_id, key_type, std::string(key_text)});
} else {
int index = 0;
ok &= absl::SimpleAtoi(selector, &index);
result->push_back({field_id, index});
}
int end = path.find('/', eb);
if (end == std::string::npos) {
path = "";
} else {
path = path.substr(end + 1);
}
return ok ? absl::OkStatus()
: absl::InvalidArgumentError(
absl::StrCat("Failed to parse ProtoPath entry: ", path));
}
// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
// Each ProtoPathEntry::key_value is converted from text to the protobuf
// wire format for its key type.
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
ProtoPath* result) {
int i = 0;
for (ProtoPathEntry& entry : *result) {
if (entry.map_id >= 0) {
FieldType key_type = key_types[i++];
std::vector<FieldValue> key_value;
MP_RETURN_IF_ERROR(
ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
entry.key_type = key_type;
entry.key_value = key_value.front();
}
}
return status;
return absl::OkStatus();
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
result->clear();
absl::string_view rest = path;
if (absl::StartsWith(rest, "/")) {
rest = rest.substr(1);
}
while (!rest.empty()) {
MP_RETURN_IF_ERROR(ParseEntry(rest, result));
}
return absl::OkStatus();
}
// Parse the TemplateExpression.path field into a ProtoPath struct.
absl::Status ParseProtoPath(const TemplateExpression& rule,
std::string base_path, ProtoPath* result) {
ProtoPath base_entries;
MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
std::vector<FieldType> key_types;
for (int type : rule.key_type()) {
key_types.push_back(static_cast<FieldType>(type));
}
MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
result->erase(result->begin(), result->begin() + base_entries.size());
return absl::OkStatus();
}
// Returns true if one proto path is prefix by another.
@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
return absl::StartsWith(path, prefix);
}
// Returns the part of one proto path after a prefix proto path.
std::string ProtoPathRelative(const std::string& field_path,
const std::string& base_path) {
CHECK(ProtoPathStartsWith(field_path, base_path));
return field_path.substr(base_path.length());
}
// Returns the target ProtoUtilLite::FieldType of a rule.
FieldType GetFieldType(const TemplateExpression& rule) {
return static_cast<FieldType>(rule.field_type());
@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) {
// Returns the count of field values at a ProtoPath.
int FieldCount(const FieldValue& base, ProtoPath field_path,
FieldType field_type) {
int field_id, index;
std::tie(field_id, index) = field_path.back();
field_path.pop_back();
std::vector<FieldValue> parent;
if (field_path.empty()) {
parent.push_back(base);
} else {
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
ProtoUtilLite::FieldAccess access(field_id, field_type);
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
return access.mutable_field_values()->size();
int result = 0;
CHECK(
ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok());
return result;
}
} // namespace
@ -229,9 +276,7 @@ class TemplateExpanderImpl {
return absl::OkStatus();
}
ProtoPath field_path;
absl::Status status =
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
if (!status.ok()) return status;
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
return ProtoUtilLite::GetFieldRange(output, field_path, 1,
GetFieldType(rule), base);
}
@ -242,12 +287,13 @@ class TemplateExpanderImpl {
const std::vector<FieldValue>& field_values,
FieldValue* output) {
if (!rule.has_path()) {
*output = field_values[0];
if (!field_values.empty()) {
*output = field_values[0];
}
return absl::OkStatus();
}
ProtoPath field_path;
RET_CHECK_OK(
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
int field_count = 1;
if (rule.has_field_value()) {
// For a non-repeated field, only one value can be specified.
@ -257,7 +303,7 @@ class TemplateExpanderImpl {
"Multiple values specified for non-repeated field: ", rule.path()));
}
// For a non-repeated field, the field value is stored only in the rule.
field_path[field_path.size() - 1].second = 0;
field_path[field_path.size() - 1].index = 0;
field_count = 0;
}
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,

View File

@ -26,6 +26,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/proto_descriptor.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message;
using mediapipe::proto_ns::OneofDescriptor;
using mediapipe::proto_ns::Reflection;
using mediapipe::proto_ns::TextFormat;
using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath;
using FieldType = mediapipe::tool::ProtoUtilLite::FieldType;
using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue;
namespace mediapipe {
@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path,
if (!ok) {
status.Update(absl::InvalidArgumentError(path));
}
result->push_back(std::make_pair(tag, index));
result->push_back({tag, index});
}
}
return status;
@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) {
const Descriptor* descriptor = message->GetDescriptor();
ProtoUtilLite::ProtoPath path;
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
int field_number = path[path.size() - 1].first;
int field_number = path[path.size() - 1].field_id;
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
if (!field->is_repeated()) {
std::vector<ProtoUtilLite::FieldValue> field_values;
@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) {
}
}
// Returns the field or extension for field number.
const FieldDescriptor* FindFieldByNumber(const Message* message,
int field_num) {
const FieldDescriptor* result =
message->GetDescriptor()->FindFieldByNumber(field_num);
if (result == nullptr) {
result = message->GetReflection()->FindKnownExtensionByNumber(field_num);
}
return result;
}
// Returns the message value from a field at an index.
const Message* GetFieldMessage(const Message& message,
const FieldDescriptor* field, int index) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
if (!field->is_repeated()) {
return &message.GetReflection()->GetMessage(message, field);
}
if (index < message.GetReflection()->FieldSize(message, field)) {
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
}
return nullptr;
}
// Serialize a ProtoPath as a readable string.
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
std::string ProtoPathJoin(ProtoPath path) {
std::string result;
for (ProtoUtilLite::ProtoPathEntry& e : path) {
if (e.field_id >= 0) {
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
} else if (e.map_id >= 0) {
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
"]");
}
}
return result;
}
// Returns the protobuf map key types from a ProtoPath.
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
std::vector<FieldType> result;
for (auto& entry : path) {
if (entry.map_id >= 0) {
result.push_back(entry.key_type);
}
}
return result;
}
// Returns the text value for a string or numeric protobuf map key.
std::string GetMapKey(const Message& map_entry) {
auto key_field = map_entry.GetDescriptor()->FindFieldByName("key");
auto reflection = map_entry.GetReflection();
if (key_field->type() == FieldDescriptor::TYPE_STRING) {
return reflection->GetString(map_entry, key_field);
} else if (key_field->type() == FieldDescriptor::TYPE_INT32) {
return absl::StrCat(reflection->GetInt32(map_entry, key_field));
} else if (key_field->type() == FieldDescriptor::TYPE_INT64) {
return absl::StrCat(reflection->GetInt64(map_entry, key_field));
}
return "";
}
// Adjusts map-entries from indexes to keys.
// Protobuf map-entry order is intentionally not preserved.
mediapipe::Status KeyProtoMapEntries(Message* source) {
// Copy the rules from the source CalculatorGraphTemplate.
mediapipe::CalculatorGraphTemplate rules;
rules.ParsePartialFromString(source->SerializePartialAsString());
// Only the "source" Message knows all extension types.
Message* config_0 = source->GetReflection()->MutableMessage(
source, source->GetDescriptor()->FindFieldByName("config"), nullptr);
for (int i = 0; i < rules.rule().size(); ++i) {
TemplateExpression* rule = rules.mutable_rule()->Mutable(i);
const Message* message = config_0;
ProtoPath path;
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
for (int j = 0; j < path.size(); ++j) {
int field_id = path[j].field_id;
int field_index = path[j].index;
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
if (field->is_map()) {
const Message* map_entry =
GetFieldMessage(*message, field, path[j].index);
int key_id =
map_entry->GetDescriptor()->FindFieldByName("key")->number();
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
map_entry->GetDescriptor()->FindFieldByName("key")->type());
std::string key_value = GetMapKey(*map_entry);
path[j] = {field_id, key_id, key_type, key_value};
}
message = GetFieldMessage(*message, field, field_index);
if (!message) {
break;
}
}
if (!rule->path().empty()) {
*rule->mutable_path() = ProtoPathJoin(path);
for (FieldType key_type : ProtoPathKeyTypes(path)) {
*rule->mutable_key_type()->Add() = key_type;
}
}
}
// Copy the rules back into the source CalculatorGraphTemplate.
auto source_rules =
source->GetReflection()->GetMutableRepeatedFieldRef<Message>(
source, source->GetDescriptor()->FindFieldByName("rule"));
source_rules.Clear();
for (auto& rule : rules.rule()) {
source_rules.Add(rule);
}
return absl::OkStatus();
}
} // namespace
class TemplateParser::Parser::MediaPipeParserImpl
@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl
// Copy the template rules into the output template "rule" field.
success &= MergeFields(template_rules_, output).ok();
// Replace map-entry indexes with map keys.
success &= KeyProtoMapEntries(output).ok();
return success;
}

View File

@ -17,6 +17,7 @@ load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph",
)
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"])
@ -58,3 +59,12 @@ mediapipe_simple_subgraph(
"//mediapipe/framework:test_calculators",
],
)
mediapipe_proto_library(
name = "frozen_generator_proto",
srcs = ["frozen_generator.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
"//mediapipe/framework:packet_generator_proto",
],
)

View File

@ -0,0 +1,20 @@
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
message FrozenGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional FrozenGeneratorOptions ext = 225748738;
}
// Path to file containing serialized proto of type tensorflow::GraphDef.
optional string graph_proto_path = 1;
// This map defines the which streams are fed to which tensors in the model.
map<string, string> tag_to_tensor_names = 2;
// Graph nodes to run to initialize the model.
repeated string initialization_op_names = 4;
}